mini / app2.py
Asad235's picture
Update app2.py
43d498c verified
Raw
History Blame Contribute Delete
1.62 kB
from flask import Flask, request, jsonify
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
app = Flask(__name__)
MODEL_NAME = "openbmb/MiniCPM-2B-sft-bf16"
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
torch_dtype=torch.float32,
device_map="cpu"
)
model.eval()
# πŸ”₯ chat function (same logic)
def chat(message):
prompt = f"User: {message}\nAssistant:"
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=1024
)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=150,
do_sample=True,
temperature=0.7,
top_p=0.9
)
reply = tokenizer.decode(
output[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
return reply.strip()
# 🌐 Home route
@app.route("/", methods=["GET"])
def home():
return jsonify({
"status": "running",
"message": "MiniCPM Flask API is live"
})
# πŸ€– Chat endpoint
@app.route("/chat", methods=["POST"])
def chat_api():
data = request.get_json()
if not data or "message" not in data:
return jsonify({"error": "Send JSON with 'message'"}), 400
user_message = data["message"]
response = chat(user_message)
return jsonify({
"response": response
})
# πŸš€ run server
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)