from flask import Flask, request, jsonify from flask_cors import CORS # To handle CORS if needed from transformers import AutoModelForCausalLM, AutoTokenizer import torch app = Flask(__name__) CORS(app) # Enable CORS for frontend communication MODEL_NAME = "tanusrich/Mental_Health_Chatbot" device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) try: model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, device_map="auto", torch_dtype=torch.float16, # Uses half-precision for lower RAM usage low_cpu_mem_usage=True ).to(device) except Exception as e: print(f"Error loading model: {e}") exit(1) @app.route("/chat", methods=["POST"]) def chat(): try: data = request.json user_input = data.get("message", "").strip() if not user_input: return jsonify({"error": "Message is required"}), 400 inputs = tokenizer(user_input, return_tensors="pt").to(device) with torch.no_grad(): outputs = model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=150 ) response_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return jsonify({"response": response_text}) except Exception as e: return jsonify({"error": f"Internal Server Error: {str(e)}"}), 500 if __name__ == "__main__": app.run(host="0.0.0.0", port=5000, debug=True)