qwenapi / app.py
junaid17's picture
Create app.py
faeb2df verified
Raw
History Blame Contribute Delete
2.23 kB
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer
)
import torch
from threading import Thread
# ============================================
# MODEL
# ============================================
MODEL_NAME = "junaid17/qwen-0.5b-16bit_merged"
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="auto"
)
# ============================================
# FASTAPI
# ============================================
app = FastAPI()
# ============================================
# REQUEST SCHEMA
# ============================================
class ChatRequest(BaseModel):
query: str
max_new_tokens: int = 256
temperature: float = 0.7
# ============================================
# STREAM CHAT
# ============================================
@app.post("/chat")
async def chat(request: ChatRequest):
messages = [
{
"role": "system",
"content": "You are a helpful AI assistant."
},
{
"role": "user",
"content": request.query
}
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(
prompt,
return_tensors="pt"
).to(model.device)
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=request.max_new_tokens,
temperature=request.temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
thread = Thread(
target=model.generate,
kwargs=generation_kwargs
)
thread.start()
def generate_tokens():
for token in streamer:
yield token
return StreamingResponse(
generate_tokens(),
media_type="text/plain"
)