moeid526's picture
Update app.py
8acce75 verified
from flask import Flask, request, jsonify
from flask_cors import CORS # To handle CORS if needed
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
app = Flask(__name__)
CORS(app) # Enable CORS for frontend communication
MODEL_NAME = "tanusrich/Mental_Health_Chatbot"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
try:
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto",
torch_dtype=torch.float16, # Uses half-precision for lower RAM usage
low_cpu_mem_usage=True
).to(device)
except Exception as e:
print(f"Error loading model: {e}")
exit(1)
@app.route("/chat", methods=["POST"])
def chat():
try:
data = request.json
user_input = data.get("message", "").strip()
if not user_input:
return jsonify({"error": "Message is required"}), 400
inputs = tokenizer(user_input, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=150
)
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return jsonify({"response": response_text})
except Exception as e:
return jsonify({"error": f"Internal Server Error: {str(e)}"}), 500
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5000, debug=True)