# deploy_chatbot_api.py from transformers import AutoTokenizer, AutoModelForCausalLM from flask import Flask, request, jsonify import torch # Load tokenizer and model MODEL_ID = "bitext/Mistral-7B-Customer-Support" print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) print("Loading model...") model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ) model.eval() print("Model loaded successfully.") # Flask app app = Flask(__name__) # Inference function def generate_response(user_message): messages = [{"role": "user", "content": user_message}] inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(model.device) outputs = model.generate(**inputs, max_new_tokens=100) response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True) return response.strip() # API endpoint @app.route("/chat", methods=["POST"]) def chat(): data = request.get_json() if not data or "message" not in data: return jsonify({"error": "Missing 'message' field"}), 400 user_msg = data["message"] try: ai_response = generate_response(user_msg) return jsonify({"response": ai_response}) except Exception as e: return jsonify({"error": str(e)}), 500 # Run the app if __name__ == "__main__": app.run(host="0.0.0.0", port=5000)