moeid526 commited on
Commit
8acce75
·
verified ·
1 Parent(s): 778b9f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -16
app.py CHANGED
@@ -1,9 +1,11 @@
1
- import os
2
- os.system("pip install numpy<2 --quiet")
3
-
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  import torch
6
 
 
 
 
7
  MODEL_NAME = "tanusrich/Mental_Health_Chatbot"
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
@@ -22,22 +24,28 @@ except Exception as e:
22
 
23
  @app.route("/chat", methods=["POST"])
24
  def chat():
25
- data = request.json
26
- user_input = data.get("message", "")
 
 
 
 
27
 
28
- if not user_input:
29
- return jsonify({"error": "Message is required"}), 400
30
 
31
- inputs = tokenizer(user_input, return_tensors="pt").to(device)
 
 
 
 
 
32
 
33
- with torch.no_grad():
34
- outputs = model.generate(
35
- input_ids=inputs["input_ids"],
36
- attention_mask=inputs["attention_mask"],
37
- max_length=150
38
- )
39
 
40
- response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
41
 
42
- return jsonify({"response": response_text})
 
43
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from flask_cors import CORS # To handle CORS if needed
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
 
6
+ app = Flask(__name__)
7
+ CORS(app) # Enable CORS for frontend communication
8
+
9
  MODEL_NAME = "tanusrich/Mental_Health_Chatbot"
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
 
24
 
25
  @app.route("/chat", methods=["POST"])
26
  def chat():
27
+ try:
28
+ data = request.json
29
+ user_input = data.get("message", "").strip()
30
+
31
+ if not user_input:
32
+ return jsonify({"error": "Message is required"}), 400
33
 
34
+ inputs = tokenizer(user_input, return_tensors="pt").to(device)
 
35
 
36
+ with torch.no_grad():
37
+ outputs = model.generate(
38
+ input_ids=inputs["input_ids"],
39
+ attention_mask=inputs["attention_mask"],
40
+ max_length=150
41
+ )
42
 
43
+ response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
44
 
45
+ return jsonify({"response": response_text})
46
 
47
+ except Exception as e:
48
+ return jsonify({"error": f"Internal Server Error: {str(e)}"}), 500
49
 
50
+ if __name__ == "__main__":
51
+ app.run(host="0.0.0.0", port=5000, debug=True)