Spaces:
Runtime error
Runtime error
File size: 1,553 Bytes
8acce75 4acf6b0 f955d67 5befe93 8acce75 def5cc8 f955d67 def5cc8 4acf6b0 def5cc8 4acf6b0 def5cc8 f955d67 5befe93 8acce75 f955d67 8acce75 f955d67 8acce75 f955d67 8acce75 def5cc8 8acce75 def5cc8 8acce75 5befe93 8acce75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
from flask import Flask, request, jsonify
from flask_cors import CORS # To handle CORS if needed
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
app = Flask(__name__)
CORS(app) # Enable CORS for frontend communication
MODEL_NAME = "tanusrich/Mental_Health_Chatbot"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
try:
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto",
torch_dtype=torch.float16, # Uses half-precision for lower RAM usage
low_cpu_mem_usage=True
).to(device)
except Exception as e:
print(f"Error loading model: {e}")
exit(1)
@app.route("/chat", methods=["POST"])
def chat():
try:
data = request.json
user_input = data.get("message", "").strip()
if not user_input:
return jsonify({"error": "Message is required"}), 400
inputs = tokenizer(user_input, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=150
)
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return jsonify({"response": response_text})
except Exception as e:
return jsonify({"error": f"Internal Server Error: {str(e)}"}), 500
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5000, debug=True)
|