Spaces:
Sleeping
Sleeping
| import asyncio | |
| import threading | |
| from contextlib import asynccontextmanager | |
| from typing import AsyncGenerator | |
| import torch | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel, Field, field_validator | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct" | |
| # Loaded once at startup. | |
| tokenizer = None | |
| model = None | |
| class GenerateRequest(BaseModel): | |
| prompt: str = Field(..., min_length=1, description="Input prompt text") | |
| max_tokens: int = Field(default=512, ge=1, le=2048) | |
| temperature: float = Field(default=0.7, ge=0.0, le=2.0) | |
| top_p: float = Field(default=0.9, gt=0.0, le=1.0) | |
| def prompt_must_not_be_blank(cls, value: str) -> str: | |
| if not value.strip(): | |
| raise ValueError("Prompt cannot be empty or whitespace") | |
| return value | |
| async def lifespan(_: FastAPI): | |
| global tokenizer, model | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype="auto", | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| model.eval() | |
| yield | |
| app = FastAPI( | |
| title="Hugging Face Space Streaming LLM Inference API", | |
| description="Streaming token generation API using Qwen2.5-0.5B-Instruct", | |
| version="1.0.0", | |
| lifespan=lifespan, | |
| ) | |
| async def health() -> dict: | |
| return { | |
| "status": "ok", | |
| "model": MODEL_ID, | |
| "endpoints": ["POST /generate_stream"], | |
| } | |
| async def stream_generate(req: GenerateRequest) -> AsyncGenerator[str, None]: | |
| if model is None or tokenizer is None: | |
| raise HTTPException(status_code=503, detail="Model is still loading") | |
| inputs = tokenizer(req.prompt, return_tensors="pt") | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| streamer = TextIteratorStreamer( | |
| tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True, | |
| ) | |
| generation_kwargs = { | |
| **inputs, | |
| "streamer": streamer, | |
| "max_new_tokens": req.max_tokens, | |
| "do_sample": req.temperature > 0, | |
| "temperature": req.temperature if req.temperature > 0 else None, | |
| "top_p": req.top_p, | |
| "pad_token_id": tokenizer.eos_token_id, | |
| } | |
| def run_generation() -> None: | |
| with torch.no_grad(): | |
| model.generate(**generation_kwargs) | |
| thread = threading.Thread(target=run_generation, daemon=True) | |
| thread.start() | |
| for text in streamer: | |
| # SSE format: each event line starts with "data:" | |
| yield f"data: {text}\n\n" | |
| await asyncio.sleep(0) | |
| yield "data: [DONE]\n\n" | |
| async def generate_stream(req: GenerateRequest): | |
| try: | |
| return StreamingResponse(stream_generate(req), media_type="text/event-stream") | |
| except HTTPException: | |
| raise | |
| except Exception as exc: # pragma: no cover | |
| raise HTTPException(status_code=500, detail=f"Generation error: {str(exc)}") from exc | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860) | |