import asyncio import threading from contextlib import asynccontextmanager from typing import AsyncGenerator import torch from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field, field_validator from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct" # Loaded once at startup. tokenizer = None model = None class GenerateRequest(BaseModel): prompt: str = Field(..., min_length=1, description="Input prompt text") max_tokens: int = Field(default=512, ge=1, le=2048) temperature: float = Field(default=0.7, ge=0.0, le=2.0) top_p: float = Field(default=0.9, gt=0.0, le=1.0) @field_validator("prompt") @classmethod def prompt_must_not_be_blank(cls, value: str) -> str: if not value.strip(): raise ValueError("Prompt cannot be empty or whitespace") return value @asynccontextmanager async def lifespan(_: FastAPI): global tokenizer, model tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype="auto", device_map="auto", trust_remote_code=True, ) model.eval() yield app = FastAPI( title="Hugging Face Space Streaming LLM Inference API", description="Streaming token generation API using Qwen2.5-0.5B-Instruct", version="1.0.0", lifespan=lifespan, ) @app.get("/") async def health() -> dict: return { "status": "ok", "model": MODEL_ID, "endpoints": ["POST /generate_stream"], } async def stream_generate(req: GenerateRequest) -> AsyncGenerator[str, None]: if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model is still loading") inputs = tokenizer(req.prompt, return_tensors="pt") inputs = {k: v.to(model.device) for k, v in inputs.items()} streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True, ) generation_kwargs = { **inputs, "streamer": streamer, "max_new_tokens": req.max_tokens, "do_sample": req.temperature > 0, "temperature": req.temperature if req.temperature > 0 else None, "top_p": req.top_p, "pad_token_id": tokenizer.eos_token_id, } def run_generation() -> None: with torch.no_grad(): model.generate(**generation_kwargs) thread = threading.Thread(target=run_generation, daemon=True) thread.start() for text in streamer: # SSE format: each event line starts with "data:" yield f"data: {text}\n\n" await asyncio.sleep(0) yield "data: [DONE]\n\n" @app.post("/generate_stream") async def generate_stream(req: GenerateRequest): try: return StreamingResponse(stream_generate(req), media_type="text/event-stream") except HTTPException: raise except Exception as exc: # pragma: no cover raise HTTPException(status_code=500, detail=f"Generation error: {str(exc)}") from exc if __name__ == "__main__": import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=7860)