Valtry-Bot / app.py
Valtry's picture
Upload 4 files
cf97964 verified
import asyncio
import threading
from contextlib import asynccontextmanager
from typing import AsyncGenerator
import torch
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field, field_validator
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
# Loaded once at startup.
tokenizer = None
model = None
class GenerateRequest(BaseModel):
prompt: str = Field(..., min_length=1, description="Input prompt text")
max_tokens: int = Field(default=512, ge=1, le=2048)
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
top_p: float = Field(default=0.9, gt=0.0, le=1.0)
@field_validator("prompt")
@classmethod
def prompt_must_not_be_blank(cls, value: str) -> str:
if not value.strip():
raise ValueError("Prompt cannot be empty or whitespace")
return value
@asynccontextmanager
async def lifespan(_: FastAPI):
global tokenizer, model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True,
)
model.eval()
yield
app = FastAPI(
title="Hugging Face Space Streaming LLM Inference API",
description="Streaming token generation API using Qwen2.5-0.5B-Instruct",
version="1.0.0",
lifespan=lifespan,
)
@app.get("/")
async def health() -> dict:
return {
"status": "ok",
"model": MODEL_ID,
"endpoints": ["POST /generate_stream"],
}
async def stream_generate(req: GenerateRequest) -> AsyncGenerator[str, None]:
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Model is still loading")
inputs = tokenizer(req.prompt, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
generation_kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": req.max_tokens,
"do_sample": req.temperature > 0,
"temperature": req.temperature if req.temperature > 0 else None,
"top_p": req.top_p,
"pad_token_id": tokenizer.eos_token_id,
}
def run_generation() -> None:
with torch.no_grad():
model.generate(**generation_kwargs)
thread = threading.Thread(target=run_generation, daemon=True)
thread.start()
for text in streamer:
# SSE format: each event line starts with "data:"
yield f"data: {text}\n\n"
await asyncio.sleep(0)
yield "data: [DONE]\n\n"
@app.post("/generate_stream")
async def generate_stream(req: GenerateRequest):
try:
return StreamingResponse(stream_generate(req), media_type="text/event-stream")
except HTTPException:
raise
except Exception as exc: # pragma: no cover
raise HTTPException(status_code=500, detail=f"Generation error: {str(exc)}") from exc
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)