qwen_api / app.py
aryo100's picture
update app
66ca71d
raw
history blame
1.77 kB
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import uvicorn
app = FastAPI()
# Load model & tokenizer sekali saat startup
# MODEL_NAME = "Qwen/Qwen1.5-1.8B-Chat"
MODEL_NAME = "Qwen/Qwen1.5-4B-Chat"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
device_map="cpu",
trust_remote_code=True,
)
model.config.use_cache = True
# fallback kalau chat_template kosong
if not tokenizer.chat_template:
tokenizer.chat_template = """{% for message in messages %}
{{ message['role'] }}: {{ message['content'] }}
{% endfor %}
assistant:"""
# Request schema
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: list[Message]
max_new_tokens: int = 128
@app.post("/chat")
def chat(req: ChatRequest):
text = tokenizer.apply_chat_template(
[m.dict() for m in req.messages],
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
do_sample=True,
top_p=0.9,
temperature=0.7
)
response = tokenizer.decode(
# outputs[0][inputs["input_ids"]:].tolist(),
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True
)
return {"response": response}
@app.get("/")
def root():
return {"message": "Qwen FastAPI running πŸš€"}
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run("app:app", host="0.0.0.0", port=port)