File size: 3,389 Bytes
cf97964
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import asyncio
import threading
from contextlib import asynccontextmanager
from typing import AsyncGenerator

import torch
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field, field_validator
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"

# Loaded once at startup.
tokenizer = None
model = None


class GenerateRequest(BaseModel):
    prompt: str = Field(..., min_length=1, description="Input prompt text")
    max_tokens: int = Field(default=512, ge=1, le=2048)
    temperature: float = Field(default=0.7, ge=0.0, le=2.0)
    top_p: float = Field(default=0.9, gt=0.0, le=1.0)

    @field_validator("prompt")
    @classmethod
    def prompt_must_not_be_blank(cls, value: str) -> str:
        if not value.strip():
            raise ValueError("Prompt cannot be empty or whitespace")
        return value


@asynccontextmanager
async def lifespan(_: FastAPI):
    global tokenizer, model
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype="auto",
        device_map="auto",
        trust_remote_code=True,
    )
    model.eval()
    yield


app = FastAPI(
    title="Hugging Face Space Streaming LLM Inference API",
    description="Streaming token generation API using Qwen2.5-0.5B-Instruct",
    version="1.0.0",
    lifespan=lifespan,
)


@app.get("/")
async def health() -> dict:
    return {
        "status": "ok",
        "model": MODEL_ID,
        "endpoints": ["POST /generate_stream"],
    }


async def stream_generate(req: GenerateRequest) -> AsyncGenerator[str, None]:
    if model is None or tokenizer is None:
        raise HTTPException(status_code=503, detail="Model is still loading")

    inputs = tokenizer(req.prompt, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    streamer = TextIteratorStreamer(
        tokenizer,
        skip_prompt=True,
        skip_special_tokens=True,
    )

    generation_kwargs = {
        **inputs,
        "streamer": streamer,
        "max_new_tokens": req.max_tokens,
        "do_sample": req.temperature > 0,
        "temperature": req.temperature if req.temperature > 0 else None,
        "top_p": req.top_p,
        "pad_token_id": tokenizer.eos_token_id,
    }

    def run_generation() -> None:
        with torch.no_grad():
            model.generate(**generation_kwargs)

    thread = threading.Thread(target=run_generation, daemon=True)
    thread.start()

    for text in streamer:
        # SSE format: each event line starts with "data:"
        yield f"data: {text}\n\n"
        await asyncio.sleep(0)

    yield "data: [DONE]\n\n"


@app.post("/generate_stream")
async def generate_stream(req: GenerateRequest):
    try:
        return StreamingResponse(stream_generate(req), media_type="text/event-stream")
    except HTTPException:
        raise
    except Exception as exc:  # pragma: no cover
        raise HTTPException(status_code=500, detail=f"Generation error: {str(exc)}") from exc


if __name__ == "__main__":
    import uvicorn

    uvicorn.run("app:app", host="0.0.0.0", port=7860)