customerservice / deploy_chatbot_api.py
jpatel20's picture
Upload deploy_chatbot_api.py
f30f16f verified
# deploy_chatbot_api.py
from transformers import AutoTokenizer, AutoModelForCausalLM
from flask import Flask, request, jsonify
import torch
# Load tokenizer and model
MODEL_ID = "bitext/Mistral-7B-Customer-Support"
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
model.eval()
print("Model loaded successfully.")
# Flask app
app = Flask(__name__)
# Inference function
def generate_response(user_message):
messages = [{"role": "user", "content": user_message}]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device)
outputs = model.generate(**inputs, max_new_tokens=100)
response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
return response.strip()
# API endpoint
@app.route("/chat", methods=["POST"])
def chat():
data = request.get_json()
if not data or "message" not in data:
return jsonify({"error": "Missing 'message' field"}), 400
user_msg = data["message"]
try:
ai_response = generate_response(user_msg)
return jsonify({"response": ai_response})
except Exception as e:
return jsonify({"error": str(e)}), 500
# Run the app
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5000)