Spaces:
Sleeping
Sleeping
File size: 3,389 Bytes
cf97964 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 | import asyncio
import threading
from contextlib import asynccontextmanager
from typing import AsyncGenerator
import torch
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field, field_validator
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
# Loaded once at startup.
tokenizer = None
model = None
class GenerateRequest(BaseModel):
prompt: str = Field(..., min_length=1, description="Input prompt text")
max_tokens: int = Field(default=512, ge=1, le=2048)
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
top_p: float = Field(default=0.9, gt=0.0, le=1.0)
@field_validator("prompt")
@classmethod
def prompt_must_not_be_blank(cls, value: str) -> str:
if not value.strip():
raise ValueError("Prompt cannot be empty or whitespace")
return value
@asynccontextmanager
async def lifespan(_: FastAPI):
global tokenizer, model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True,
)
model.eval()
yield
app = FastAPI(
title="Hugging Face Space Streaming LLM Inference API",
description="Streaming token generation API using Qwen2.5-0.5B-Instruct",
version="1.0.0",
lifespan=lifespan,
)
@app.get("/")
async def health() -> dict:
return {
"status": "ok",
"model": MODEL_ID,
"endpoints": ["POST /generate_stream"],
}
async def stream_generate(req: GenerateRequest) -> AsyncGenerator[str, None]:
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Model is still loading")
inputs = tokenizer(req.prompt, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
generation_kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": req.max_tokens,
"do_sample": req.temperature > 0,
"temperature": req.temperature if req.temperature > 0 else None,
"top_p": req.top_p,
"pad_token_id": tokenizer.eos_token_id,
}
def run_generation() -> None:
with torch.no_grad():
model.generate(**generation_kwargs)
thread = threading.Thread(target=run_generation, daemon=True)
thread.start()
for text in streamer:
# SSE format: each event line starts with "data:"
yield f"data: {text}\n\n"
await asyncio.sleep(0)
yield "data: [DONE]\n\n"
@app.post("/generate_stream")
async def generate_stream(req: GenerateRequest):
try:
return StreamingResponse(stream_generate(req), media_type="text/event-stream")
except HTTPException:
raise
except Exception as exc: # pragma: no cover
raise HTTPException(status_code=500, detail=f"Generation error: {str(exc)}") from exc
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|