Spaces:
Runtime error
Runtime error
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS # To handle CORS if needed | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| app = Flask(__name__) | |
| CORS(app) # Enable CORS for frontend communication | |
| MODEL_NAME = "tanusrich/Mental_Health_Chatbot" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| try: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| device_map="auto", | |
| torch_dtype=torch.float16, # Uses half-precision for lower RAM usage | |
| low_cpu_mem_usage=True | |
| ).to(device) | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| exit(1) | |
| def chat(): | |
| try: | |
| data = request.json | |
| user_input = data.get("message", "").strip() | |
| if not user_input: | |
| return jsonify({"error": "Message is required"}), 400 | |
| inputs = tokenizer(user_input, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| max_length=150 | |
| ) | |
| response_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return jsonify({"response": response_text}) | |
| except Exception as e: | |
| return jsonify({"error": f"Internal Server Error: {str(e)}"}), 500 | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=5000, debug=True) | |