qwen_api / app.py
aryo100's picture
update app
6e0aaee
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
import os
import uvicorn
import threading
app = FastAPI()
# Load model & tokenizer sekali saat startup
# MODEL_NAME = "Qwen/Qwen1.5-1.8B-Chat"
MODEL_NAME = "Qwen/Qwen1.5-4B-Chat"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
device_map="cpu",
trust_remote_code=True,
)
model.config.use_cache = True
# fallback kalau chat_template kosong
if not tokenizer.chat_template:
tokenizer.chat_template = """{% for message in messages %}
{{ message['role'] }}: {{ message['content'] }}
{% endfor %}
assistant:"""
# Request schema
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: list[Message]
max_new_tokens: int = 128
# Generator untuk streaming token
def generate_stream(prompt, max_new_tokens=128):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# streamer = tokenizer.as_target_tokenizer()
# # pakai generate incremental
# with torch.no_grad():
# output_ids = model.generate(
# **inputs,
# max_new_tokens=max_new_tokens,
# do_sample=True,
# top_p=0.9,
# temperature=0.7
# )[0]
# # Ambil hasil tanpa input
# generated_tokens = output_ids[inputs["input_ids"].shape[1]:]
# for tok in generated_tokens:
# text = tokenizer.decode(tok, skip_special_tokens=True)
# if text.strip():
# yield text
# streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# generation_kwargs = dict(
# **inputs,
# max_new_tokens=max_new_tokens,
# eos_token_id=tokenizer.eos_token_id,
# do_sample=True,
# temperature=0.7,
# streamer=streamer
# )
# thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
# thread.start()
inputs = {k: v.to(model.device) for k, v in inputs.items()}
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
def run_generation():
try:
model.generate(
**inputs,
max_new_tokens=max_new_tokens,
# min_new_tokens=16,
do_sample=True,
temperature=0.7,
top_p=0.9,
streamer=streamer,
# early_stopping=True,
# eos_token_id=tokenizer.eos_token_id,
# pad_token_id=tokenizer.pad_token_id,
use_cache=True,
)
except Exception as e:
# simpan error agar bisa dikembalikan ke client setelah streamer selesai
# error_container.append(str(e))
pass
thread = threading.Thread(target=run_generation, daemon=True)
thread.start()
for token in streamer:
yield token
# streamer = tokenizer.as_target_tokenizer()
# with torch.no_grad():
# output_ids = model.generate(
# **inputs,
# max_new_tokens=128, # batasi jawaban
# min_new_tokens=16, # biar ga berhenti terlalu cepat
# temperature=0.7, # lebih to the point
# top_p=0.9,
# do_sample=True,
# early_stopping=True,
# eos_token_id=tokenizer.eos_token_id,
# pad_token_id=tokenizer.pad_token_id,
# )
# decoded = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# if "Assistant:" in decoded:
# answer = decoded.split("Assistant:")[-1].strip()
# else:
# answer = decoded
# # stream potongan kalimat (kata demi kata)
# for word in answer.split():
# yield word + " "
@app.post("/stream")
async def chat(req: ChatRequest):
# Format prompt sesuai chat template
text = tokenizer.apply_chat_template(
req.messages,
tokenize=False,
add_generation_prompt=True
)
generator = generate_stream(text, req.max_new_tokens)
return StreamingResponse(generator, media_type="text/plain")
@app.post("/chat")
def chat(req: ChatRequest):
text = tokenizer.apply_chat_template(
[m.model_dump() for m in req.messages],
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
do_sample=True,
top_p=0.9,
temperature=0.7
)
response = tokenizer.decode(
# outputs[0][inputs["input_ids"]:].tolist(),
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True
)
return {"response": response}
@app.get("/")
def root():
return {"message": "Qwen FastAPI running 🚀"}
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run("app:app", host="0.0.0.0", port=port)