|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from flask import Flask, request, jsonify |
|
|
import torch |
|
|
|
|
|
|
|
|
MODEL_ID = "bitext/Mistral-7B-Customer-Support" |
|
|
|
|
|
print("Loading tokenizer...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
|
|
|
print("Loading model...") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
) |
|
|
model.eval() |
|
|
|
|
|
print("Model loaded successfully.") |
|
|
|
|
|
|
|
|
app = Flask(__name__) |
|
|
|
|
|
|
|
|
def generate_response(user_message): |
|
|
messages = [{"role": "user", "content": user_message}] |
|
|
inputs = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True, |
|
|
tokenize=True, |
|
|
return_dict=True, |
|
|
return_tensors="pt", |
|
|
).to(model.device) |
|
|
|
|
|
outputs = model.generate(**inputs, max_new_tokens=100) |
|
|
response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True) |
|
|
return response.strip() |
|
|
|
|
|
|
|
|
@app.route("/chat", methods=["POST"]) |
|
|
def chat(): |
|
|
data = request.get_json() |
|
|
if not data or "message" not in data: |
|
|
return jsonify({"error": "Missing 'message' field"}), 400 |
|
|
|
|
|
user_msg = data["message"] |
|
|
try: |
|
|
ai_response = generate_response(user_msg) |
|
|
return jsonify({"response": ai_response}) |
|
|
except Exception as e: |
|
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
app.run(host="0.0.0.0", port=5000) |
|
|
|