| """ |
| Shared Hugging Face Space runtime for streaming chat inference. |
| |
| This module provides: |
| - one-time global model loading |
| - async request queue |
| - worker pool with semaphore-based concurrency limits |
| - per-request streamer/thread isolation |
| - SSE streaming responses |
| """ |
|
|
| from __future__ import annotations |
|
|
| import asyncio |
| import json |
| import logging |
| import os |
| import time |
| import uuid |
| from contextlib import asynccontextmanager |
| from dataclasses import dataclass, field |
| from queue import Empty as QueueEmpty |
| from threading import Event as ThreadEvent |
| from threading import Thread |
| from typing import Any, Dict, List, Optional |
|
|
| import torch |
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import FileResponse, StreamingResponse |
| from pydantic import BaseModel |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| StoppingCriteria, |
| StoppingCriteriaList, |
| TextIteratorStreamer, |
| ) |
|
|
|
|
| class Message(BaseModel): |
| role: str |
| content: str |
|
|
|
|
| class ChatRequest(BaseModel): |
| messages: List[Message] |
| stream: bool = True |
| max_tokens: int = 8192 |
| temperature: Optional[float] = None |
| tools: Optional[List[Dict[str, Any]]] = None |
|
|
|
|
| @dataclass(frozen=True) |
| class RuntimeConfig: |
| model_name: str |
| title: str |
| description: str |
| version: str = "1.0.0" |
| max_input_tokens: int = 32768 |
| max_new_tokens: int = 131072 |
| top_p: float = 0.95 |
| top_k: Optional[int] = None |
| repetition_penalty: float = 1.0 |
| eos_token_id: Optional[int] = None |
| default_temperature: float = 0.6 |
| tokenizer_use_fast: Optional[bool] = None |
| logger_name: str = "hf_space" |
|
|
|
|
| @dataclass |
| class GenerationTask: |
| request_id: str |
| prompt: str |
| max_tokens: int |
| temperature: float |
| output_queue: asyncio.Queue[Optional[Dict[str, Any]]] |
| created_at: float = field(default_factory=time.time) |
| cancel_event: ThreadEvent = field(default_factory=ThreadEvent) |
| prompt_tokens: int = 0 |
| generated_tokens: int = 0 |
| first_token_latency: Optional[float] = None |
| start_time: Optional[float] = None |
| end_time: Optional[float] = None |
|
|
|
|
| class CancelAwareStoppingCriteria(StoppingCriteria): |
| """Stops generation when the request is cancelled/disconnected.""" |
|
|
| def __init__(self, cancel_event: ThreadEvent): |
| self.cancel_event = cancel_event |
|
|
| def __call__(self, input_ids, scores, **kwargs) -> bool: |
| return self.cancel_event.is_set() |
|
|
|
|
| def _is_truthy(value: str) -> bool: |
| return value.strip().lower() in {"1", "true", "yes", "on"} |
|
|
|
|
| def _format_sse_event(payload: Dict[str, Any]) -> str: |
| event_type = str(payload.get("type", "token")) |
| return f"event: {event_type}\ndata: {json.dumps(payload)}\n\n" |
|
|
|
|
| def _read_stream_item(stream_iter) -> tuple[bool, Optional[str]]: |
| """Read one item from streamer iterator without leaking StopIteration across threads.""" |
| try: |
| return False, next(stream_iter) |
| except StopIteration: |
| return True, None |
|
|
|
|
| def _detect_concurrency(device: str) -> int: |
| |
| override = os.getenv("HF_MAX_WORKERS", "").strip() |
| if override: |
| try: |
| parsed = int(override) |
| if parsed > 0: |
| return parsed |
| except ValueError: |
| pass |
|
|
| if device == "cuda" and torch.cuda.is_available(): |
| total_vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) |
| if total_vram_gb >= 20: |
| return 5 |
| if total_vram_gb >= 10: |
| return 4 |
| return 3 |
|
|
| cpu_count = os.cpu_count() or 1 |
| |
| return max(1, min(4, max(1, cpu_count // 6))) |
|
|
|
|
| def create_hf_space_app(config: RuntimeConfig) -> FastAPI: |
| logger = logging.getLogger(config.logger_name) |
| logging.basicConfig(level=logging.INFO) |
|
|
| debug_token_logs = _is_truthy(os.getenv("HF_DEBUG_TOKEN_LOGS", "0")) |
| queue_max_size = int(os.getenv("HF_QUEUE_MAX_SIZE", "512")) |
| streamer_timeout = float(os.getenv("HF_STREAMER_TIMEOUT_SECONDS", "8")) |
| join_timeout = float(os.getenv("HF_GENERATION_JOIN_TIMEOUT_SECONDS", "180")) |
| max_input_tokens = int(os.getenv("HF_MAX_INPUT_TOKENS", str(config.max_input_tokens))) |
| max_new_tokens_limit = int(os.getenv("HF_MAX_NEW_TOKENS", str(config.max_new_tokens))) |
| model_load_retries = max(1, int(os.getenv("HF_MODEL_LOAD_RETRIES", "4"))) |
| model_load_retry_delay = max(1.0, float(os.getenv("HF_MODEL_LOAD_RETRY_DELAY_SECONDS", "8"))) |
| local_files_only = _is_truthy(os.getenv("HF_LOCAL_FILES_ONLY", "0")) |
|
|
| base_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
| model = None |
| tokenizer = None |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| max_workers = _detect_concurrency(device) |
|
|
| request_queue: asyncio.Queue[Optional[GenerationTask]] = asyncio.Queue(maxsize=queue_max_size) |
| worker_tasks: List[asyncio.Task] = [] |
| worker_semaphore = asyncio.Semaphore(max_workers) |
|
|
| active_workers = 0 |
| active_workers_lock = asyncio.Lock() |
|
|
| async def set_active_workers(delta: int) -> int: |
| nonlocal active_workers |
| async with active_workers_lock: |
| active_workers += delta |
| if active_workers < 0: |
| active_workers = 0 |
| return active_workers |
|
|
| def format_messages_proper(messages: List[Message], tools: Optional[List[Dict[str, Any]]] = None) -> str: |
| message_dicts = [{"role": msg.role, "content": msg.content} for msg in messages] |
| if tools: |
| return tokenizer.apply_chat_template( |
| message_dicts, |
| tools=tools, |
| add_generation_prompt=True, |
| tokenize=False, |
| ) |
| return tokenizer.apply_chat_template( |
| message_dicts, |
| add_generation_prompt=True, |
| tokenize=False, |
| ) |
|
|
| async def run_generation(task: GenerationTask, worker_id: int) -> None: |
| request_start = time.time() |
| task.start_time = request_start |
| await set_active_workers(+1) |
|
|
| try: |
| logger.info( |
| "[%s] worker=%d start queue_size=%d active_workers=%d", |
| task.request_id, |
| worker_id, |
| request_queue.qsize(), |
| active_workers, |
| ) |
|
|
| inputs = tokenizer( |
| task.prompt, |
| return_tensors="pt", |
| truncation=True, |
| max_length=max_input_tokens, |
| add_special_tokens=False, |
| ) |
|
|
| task.prompt_tokens = int(inputs.input_ids.shape[1]) |
|
|
| if device == "cuda": |
| inputs = inputs.to("cuda") |
|
|
| streamer = TextIteratorStreamer( |
| tokenizer, |
| skip_prompt=True, |
| skip_special_tokens=True, |
| timeout=streamer_timeout, |
| ) |
|
|
| stopping_criteria = StoppingCriteriaList( |
| [CancelAwareStoppingCriteria(task.cancel_event)] |
| ) |
|
|
| generation_kwargs: Dict[str, Any] = dict( |
| **inputs, |
| streamer=streamer, |
| max_new_tokens=min(task.max_tokens, max_new_tokens_limit), |
| temperature=task.temperature, |
| top_p=config.top_p, |
| repetition_penalty=config.repetition_penalty, |
| do_sample=task.temperature > 0, |
| eos_token_id=config.eos_token_id if config.eos_token_id is not None else tokenizer.eos_token_id, |
| pad_token_id=tokenizer.eos_token_id, |
| stopping_criteria=stopping_criteria, |
| ) |
| if config.top_k is not None: |
| generation_kwargs["top_k"] = config.top_k |
|
|
| generation_error: Dict[str, Exception] = {} |
| generation_done = ThreadEvent() |
|
|
| def generate_target() -> None: |
| try: |
| with torch.inference_mode(): |
| model.generate(**generation_kwargs) |
| except Exception as exc: |
| generation_error["error"] = exc |
| logger.error("[%s] generation thread error: %s", task.request_id, exc, exc_info=True) |
| finally: |
| generation_done.set() |
| try: |
| streamer.end() |
| except Exception: |
| |
| pass |
|
|
| generation_thread = Thread( |
| target=generate_target, |
| name=f"gen-{task.request_id[:8]}", |
| daemon=True, |
| ) |
| generation_thread.start() |
|
|
| stream_iter = iter(streamer) |
| while True: |
| if task.cancel_event.is_set(): |
| logger.info("[%s] cancellation requested", task.request_id) |
| break |
|
|
| try: |
| stream_finished, new_text = await asyncio.to_thread(_read_stream_item, stream_iter) |
| if stream_finished: |
| break |
| except QueueEmpty: |
| if generation_done.is_set(): |
| break |
| continue |
| except Exception as exc: |
| if generation_done.is_set(): |
| break |
| logger.error("[%s] streamer read error: %s", task.request_id, exc, exc_info=True) |
| generation_error["error"] = exc |
| break |
|
|
| if not new_text: |
| continue |
|
|
| task.generated_tokens += 1 |
| if task.first_token_latency is None: |
| task.first_token_latency = time.time() - request_start |
| logger.info( |
| "[%s] first_token=%.2fs worker=%d", |
| task.request_id, |
| task.first_token_latency, |
| worker_id, |
| ) |
|
|
| if debug_token_logs: |
| logger.info("[%s] token#%d: %r", task.request_id, task.generated_tokens, new_text) |
|
|
| await task.output_queue.put({"type": "token", "content": new_text}) |
| await asyncio.sleep(0) |
|
|
| |
| try: |
| await asyncio.wait_for(asyncio.to_thread(generation_thread.join), timeout=join_timeout) |
| except asyncio.TimeoutError: |
| logger.error( |
| "[%s] generation thread still alive after %.1fs join timeout", |
| task.request_id, |
| join_timeout, |
| ) |
|
|
| if task.cancel_event.is_set(): |
| await task.output_queue.put({"type": "error", "content": "Generation interrupted. You can continue."}) |
| elif "error" in generation_error: |
| await task.output_queue.put({"type": "error", "content": str(generation_error["error"])}) |
| else: |
| await task.output_queue.put({"type": "done", "content": ""}) |
|
|
| except Exception as exc: |
| logger.error("[%s] worker failure: %s", task.request_id, exc, exc_info=True) |
| await task.output_queue.put({"type": "error", "content": str(exc)}) |
| finally: |
| task.end_time = time.time() |
| duration = max(1e-6, task.end_time - request_start) |
| tps = task.generated_tokens / duration |
| logger.info( |
| "[%s] worker=%d end tokens=%d duration=%.2fs tok_s=%.2f active_workers=%d queue_size=%d", |
| task.request_id, |
| worker_id, |
| task.generated_tokens, |
| duration, |
| tps, |
| active_workers, |
| request_queue.qsize(), |
| ) |
|
|
| await task.output_queue.put(None) |
| await set_active_workers(-1) |
|
|
| async def worker_loop(worker_id: int) -> None: |
| logger.info("Worker-%d started", worker_id) |
| while True: |
| task = await request_queue.get() |
| if task is None: |
| request_queue.task_done() |
| logger.info("Worker-%d received shutdown signal", worker_id) |
| break |
|
|
| try: |
| if task.cancel_event.is_set(): |
| await task.output_queue.put({"type": "error", "content": "Request cancelled before execution."}) |
| await task.output_queue.put(None) |
| continue |
|
|
| async with worker_semaphore: |
| await run_generation(task, worker_id) |
| finally: |
| request_queue.task_done() |
|
|
| logger.info("Worker-%d stopped", worker_id) |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| nonlocal model, tokenizer, worker_tasks, max_workers, device |
|
|
| logger.info("Loading model %s on %s", config.model_name, device) |
| tokenizer_kwargs: Dict[str, Any] = { |
| "trust_remote_code": True, |
| "local_files_only": local_files_only, |
| } |
| if config.tokenizer_use_fast is not None: |
| tokenizer_kwargs["use_fast"] = config.tokenizer_use_fast |
| model_load_kwargs: Dict[str, Any] = { |
| "trust_remote_code": True, |
| "device_map": "auto" if device == "cuda" else None, |
| "local_files_only": local_files_only, |
| } |
| if device == "cuda": |
| model_load_kwargs["dtype"] = "auto" |
| else: |
| model_load_kwargs["torch_dtype"] = torch.float32 |
|
|
| last_load_error: Optional[Exception] = None |
| for attempt in range(1, model_load_retries + 1): |
| try: |
| tokenizer = AutoTokenizer.from_pretrained(config.model_name, **tokenizer_kwargs) |
| try: |
| model = AutoModelForCausalLM.from_pretrained( |
| config.model_name, |
| **model_load_kwargs, |
| ) |
| except TypeError: |
| |
| if "dtype" in model_load_kwargs: |
| model_load_kwargs["torch_dtype"] = model_load_kwargs.pop("dtype") |
| model = AutoModelForCausalLM.from_pretrained( |
| config.model_name, |
| **model_load_kwargs, |
| ) |
| break |
| except Exception as exc: |
| last_load_error = exc |
| logger.warning( |
| "Model load attempt %d/%d failed: %s", |
| attempt, |
| model_load_retries, |
| str(exc), |
| ) |
| if attempt < model_load_retries: |
| await asyncio.sleep(model_load_retry_delay) |
| else: |
| logger.error( |
| "Model loading failed after %d attempts (local_files_only=%s)", |
| model_load_retries, |
| str(local_files_only), |
| ) |
| raise last_load_error |
|
|
| if device != "cuda": |
| model = model.to("cpu") |
|
|
| logger.info( |
| "Model loaded: %s | device=%s | max_workers=%d | queue_max_size=%d", |
| config.model_name, |
| device, |
| max_workers, |
| queue_max_size, |
| ) |
| logger.info( |
| "Runtime config: max_input_tokens=%d max_new_tokens_limit=%d top_p=%.3f top_k=%s rep_penalty=%.3f", |
| max_input_tokens, |
| max_new_tokens_limit, |
| config.top_p, |
| str(config.top_k), |
| config.repetition_penalty, |
| ) |
|
|
| worker_tasks = [ |
| asyncio.create_task(worker_loop(i + 1), name=f"generation-worker-{i + 1}") |
| for i in range(max_workers) |
| ] |
|
|
| try: |
| yield |
| finally: |
| logger.info("Shutting down workers...") |
| for _ in worker_tasks: |
| await request_queue.put(None) |
| await asyncio.gather(*worker_tasks, return_exceptions=True) |
|
|
| logger.info("Releasing model resources...") |
| del model |
| del tokenizer |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| app = FastAPI( |
| title=config.title, |
| description=config.description, |
| version=config.version, |
| lifespan=lifespan, |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| @app.get("/") |
| async def root(): |
| return { |
| "name": config.title, |
| "version": config.version, |
| "model": config.model_name, |
| "status": "running", |
| "device": device, |
| "max_workers": max_workers, |
| } |
|
|
| @app.get("/index", response_class=FileResponse) |
| async def serve_chat(): |
| return FileResponse(os.path.join(base_dir, "index.html")) |
|
|
| @app.get("/health") |
| async def health(): |
| return { |
| "status": "healthy", |
| "model_loaded": model is not None and tokenizer is not None, |
| "device": device, |
| "active_workers": active_workers, |
| "queue_size": request_queue.qsize(), |
| "max_workers": max_workers, |
| } |
|
|
| @app.post("/chat") |
| async def chat(request: ChatRequest): |
| if model is None or tokenizer is None: |
| raise HTTPException(status_code=503, detail="Model not loaded yet") |
|
|
| prompt = format_messages_proper(request.messages, request.tools) |
| task = GenerationTask( |
| request_id=uuid.uuid4().hex, |
| prompt=prompt, |
| max_tokens=request.max_tokens, |
| temperature=request.temperature if request.temperature is not None else config.default_temperature, |
| output_queue=asyncio.Queue(maxsize=2048), |
| ) |
|
|
| logger.info( |
| "[%s] queued request prompt_len=%d queue_size=%d", |
| task.request_id, |
| len(prompt), |
| request_queue.qsize(), |
| ) |
| await request_queue.put(task) |
|
|
| if request.stream: |
| async def stream_events(): |
| try: |
| while True: |
| event = await task.output_queue.get() |
| if event is None: |
| break |
| yield _format_sse_event(event) |
| except asyncio.CancelledError: |
| task.cancel_event.set() |
| raise |
| finally: |
| task.cancel_event.set() |
|
|
| return StreamingResponse( |
| stream_events(), |
| media_type="text/event-stream", |
| headers={ |
| "Cache-Control": "no-cache, no-store, must-revalidate", |
| "Pragma": "no-cache", |
| "Expires": "0", |
| "Connection": "keep-alive", |
| "X-Accel-Buffering": "no", |
| "Transfer-Encoding": "chunked", |
| }, |
| ) |
|
|
| chunks: List[str] = [] |
| error_message: Optional[str] = None |
| while True: |
| event = await task.output_queue.get() |
| if event is None: |
| break |
| event_type = event.get("type") |
| if event_type == "token": |
| chunks.append(str(event.get("content", ""))) |
| elif event_type == "error": |
| error_message = str(event.get("content", "Generation failed")) |
|
|
| if error_message: |
| raise HTTPException(status_code=500, detail=error_message) |
|
|
| response_text = "".join(chunks).strip() |
| return { |
| "content": response_text, |
| "usage": { |
| "prompt_tokens": task.prompt_tokens, |
| "completion_tokens": task.generated_tokens, |
| }, |
| } |
|
|
| return app |
|
|