File size: 2,229 Bytes
faeb2df | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 | from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer
)
import torch
from threading import Thread
# ============================================
# MODEL
# ============================================
MODEL_NAME = "junaid17/qwen-0.5b-16bit_merged"
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="auto"
)
# ============================================
# FASTAPI
# ============================================
app = FastAPI()
# ============================================
# REQUEST SCHEMA
# ============================================
class ChatRequest(BaseModel):
query: str
max_new_tokens: int = 256
temperature: float = 0.7
# ============================================
# STREAM CHAT
# ============================================
@app.post("/chat")
async def chat(request: ChatRequest):
messages = [
{
"role": "system",
"content": "You are a helpful AI assistant."
},
{
"role": "user",
"content": request.query
}
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(
prompt,
return_tensors="pt"
).to(model.device)
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=request.max_new_tokens,
temperature=request.temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
thread = Thread(
target=model.generate,
kwargs=generation_kwargs
)
thread.start()
def generate_tokens():
for token in streamer:
yield token
return StreamingResponse(
generate_tokens(),
media_type="text/plain"
) |