jpatel20 commited on
Commit
f30f16f
·
verified ·
1 Parent(s): 04d745b

Upload deploy_chatbot_api.py

Browse files
Files changed (1) hide show
  1. deploy_chatbot_api.py +58 -0
deploy_chatbot_api.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # deploy_chatbot_api.py
3
+
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from flask import Flask, request, jsonify
6
+ import torch
7
+
8
+ # Load tokenizer and model
9
+ MODEL_ID = "bitext/Mistral-7B-Customer-Support"
10
+
11
+ print("Loading tokenizer...")
12
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
13
+
14
+ print("Loading model...")
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ MODEL_ID,
17
+ device_map="auto",
18
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
19
+ )
20
+ model.eval()
21
+
22
+ print("Model loaded successfully.")
23
+
24
+ # Flask app
25
+ app = Flask(__name__)
26
+
27
+ # Inference function
28
+ def generate_response(user_message):
29
+ messages = [{"role": "user", "content": user_message}]
30
+ inputs = tokenizer.apply_chat_template(
31
+ messages,
32
+ add_generation_prompt=True,
33
+ tokenize=True,
34
+ return_dict=True,
35
+ return_tensors="pt",
36
+ ).to(model.device)
37
+
38
+ outputs = model.generate(**inputs, max_new_tokens=100)
39
+ response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
40
+ return response.strip()
41
+
42
+ # API endpoint
43
+ @app.route("/chat", methods=["POST"])
44
+ def chat():
45
+ data = request.get_json()
46
+ if not data or "message" not in data:
47
+ return jsonify({"error": "Missing 'message' field"}), 400
48
+
49
+ user_msg = data["message"]
50
+ try:
51
+ ai_response = generate_response(user_msg)
52
+ return jsonify({"response": ai_response})
53
+ except Exception as e:
54
+ return jsonify({"error": str(e)}), 500
55
+
56
+ # Run the app
57
+ if __name__ == "__main__":
58
+ app.run(host="0.0.0.0", port=5000)