deploy_kcv / app.py
GitHub CI
sync from GitHub @ 857043581d72dbc895ed688f324d45783fd0534d
d7bd2b7
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
import threading
app = FastAPI()
MODEL_NAME = "Qwen/Qwen3.5-0.8B"
with open("prompts/system.txt", "r") as f:
SYSTEM_PROMPT = f.read().strip()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto"
)
app.mount("/static", StaticFiles(directory="static"), name="static")
class ChatRequest(BaseModel):
message: str
max_new_tokens: int = 512
temperature: float = 0.7
top_p: float = 0.9
repetition_penalty: float = 1.3
@app.get("/")
def root():
return FileResponse("static/index.html")
@app.get("/health")
def health():
return {"status": "healthy"}
@app.get("/info")
def info():
return {"status": "running", "model": MODEL_NAME}
@app.post("/chat")
def chat(request: ChatRequest):
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": request.message},
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = tokenizer([text], return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=request.max_new_tokens,
temperature=request.temperature,
repetition_penalty=request.repetition_penalty,
top_p=request.top_p,
do_sample=True,
)
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
def token_stream():
for token in streamer:
yield token
thread.join()
return StreamingResponse(token_stream(), media_type="text/plain")