Spaces:
Running
Running
| """ | |
| LFM2.5 FastAPI Backend - ONNX Runtime Edition | |
| ============================================== | |
| Lightweight, CPU-friendly FastAPI backend for LiquidAI LFM2.5-1.2B-Instruct. | |
| Uses official ONNX model for fast inference without heavy PyTorch dependencies. | |
| Features: | |
| - ONNX Runtime for fast CPU inference (no GPU required) | |
| - Q8 quantization for 95%+ accuracy retention | |
| - Streaming SSE responses | |
| - OpenAI-compatible API | |
| - Optimized for HuggingFace Spaces (2 vCPU, 16GB RAM) | |
| """ | |
| import asyncio | |
| import json | |
| import logging | |
| import time | |
| import uuid | |
| import threading | |
| import queue # Thread-safe queue for true streaming | |
| from contextlib import asynccontextmanager | |
| from typing import AsyncGenerator, Dict, List, Optional, Union | |
| from pathlib import Path | |
| import numpy as np | |
| import onnxruntime as ort | |
| from fastapi import FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from huggingface_hub import hf_hub_download, list_repo_files | |
| from pydantic import BaseModel, Field | |
| from sse_starlette.sse import EventSourceResponse | |
| from transformers import AutoTokenizer, PreTrainedTokenizerFast | |
| from config import settings | |
| # Configure logging | |
| logging.basicConfig( | |
| level=getattr(logging, settings.log_level.upper()), | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================== | |
| # Pydantic Models for OpenAI-compatible API | |
| # ============================================================================== | |
| class ChatMessage(BaseModel): | |
| role: str = Field(..., description="Role: 'system', 'user', or 'assistant'") | |
| content: str = Field(..., description="Message content") | |
| class ChatCompletionRequest(BaseModel): | |
| model: str = Field(default="lfm", description="Model identifier") | |
| messages: List[ChatMessage] = Field(..., description="Conversation messages") | |
| temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0) | |
| top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0) | |
| top_k: Optional[int] = Field(default=None, ge=0) | |
| max_tokens: Optional[int] = Field(default=None, ge=1) | |
| stream: bool = Field(default=False, description="Enable streaming response") | |
| stop: Optional[Union[str, List[str]]] = Field(default=None) | |
| class CompletionRequest(BaseModel): | |
| model: str = Field(default="lfm", description="Model identifier") | |
| prompt: str = Field(..., description="Text prompt") | |
| temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0) | |
| top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0) | |
| top_k: Optional[int] = Field(default=None, ge=0) | |
| max_tokens: Optional[int] = Field(default=None, ge=1) | |
| stream: bool = Field(default=False, description="Enable streaming response") | |
| class ChatCompletionChoice(BaseModel): | |
| index: int | |
| message: ChatMessage | |
| finish_reason: Optional[str] = None | |
| class ChatCompletionResponse(BaseModel): | |
| id: str | |
| object: str = "chat.completion" | |
| created: int | |
| model: str | |
| choices: List[ChatCompletionChoice] | |
| usage: Dict[str, int] | |
| class CompletionChoice(BaseModel): | |
| index: int | |
| text: str | |
| finish_reason: Optional[str] = None | |
| class CompletionResponse(BaseModel): | |
| id: str | |
| object: str = "text_completion" | |
| created: int | |
| model: str | |
| choices: List[CompletionChoice] | |
| usage: Dict[str, int] | |
| class ModelInfo(BaseModel): | |
| id: str | |
| object: str = "model" | |
| created: int | |
| owned_by: str = "liquid-ai" | |
| class ModelListResponse(BaseModel): | |
| object: str = "list" | |
| data: List[ModelInfo] | |
| # ============================================================================== | |
| # ONNX Model Manager | |
| # ============================================================================== | |
| # ONNX dtype mapping | |
| ONNX_DTYPE = { | |
| "tensor(float)": np.float32, | |
| "tensor(float16)": np.float16, | |
| "tensor(int64)": np.int64 | |
| } | |
| class ONNXModelManager: | |
| """Manages ONNX model with KV cache for efficient generation.""" | |
| def __init__(self): | |
| self._session = None | |
| self._tokenizer = None | |
| self._cache_template = None | |
| self._use_position_ids = False | |
| self._lock = threading.Lock() | |
| def is_loaded(self) -> bool: | |
| return self._session is not None | |
| def download_model(self) -> str: | |
| """Download ONNX model files from HuggingFace.""" | |
| model_id = settings.model_id | |
| variant = settings.model_variant | |
| logger.info(f"Downloading model: {model_id} (variant: {variant})") | |
| # Download main model file | |
| model_filename = f"onnx/model_{variant}.onnx" | |
| model_path = hf_hub_download(model_id, model_filename) | |
| # Download all data files for this variant | |
| for f in list_repo_files(model_id): | |
| if f.startswith(f"onnx/model_{variant}.onnx_data"): | |
| logger.info(f"Downloading: {f}") | |
| hf_hub_download(model_id, f) | |
| return model_path | |
| def load_model(self) -> None: | |
| """Load the ONNX model and tokenizer.""" | |
| with self._lock: | |
| if self._session is not None: | |
| return | |
| logger.info("=" * 60) | |
| logger.info("Loading LFM2.5-1.2B-Instruct ONNX model...") | |
| logger.info(f"Model: {settings.model_id}") | |
| logger.info(f"Variant: {settings.model_variant} (Q8 = ~95% accuracy)") | |
| logger.info("=" * 60) | |
| start_time = time.time() | |
| # Download model | |
| model_path = self.download_model() | |
| # Configure ONNX Runtime for CPU | |
| sess_options = ort.SessionOptions() | |
| sess_options.intra_op_num_threads = settings.num_threads | |
| sess_options.inter_op_num_threads = settings.num_threads | |
| sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| # Load ONNX session | |
| self._session = ort.InferenceSession( | |
| model_path, | |
| sess_options=sess_options, | |
| providers=['CPUExecutionProvider'] | |
| ) | |
| # Load tokenizer with fallback for models with invalid tokenizer_class | |
| try: | |
| self._tokenizer = AutoTokenizer.from_pretrained( | |
| settings.model_id, | |
| trust_remote_code=True | |
| ) | |
| except ValueError as e: | |
| if "TokenizersBackend" in str(e): | |
| # LFM models incorrectly specify TokenizersBackend as tokenizer_class | |
| # Fallback to PreTrainedTokenizerFast which works with tokenizers backend | |
| logger.warning( | |
| "AutoTokenizer failed with TokenizersBackend error. " | |
| "Falling back to PreTrainedTokenizerFast..." | |
| ) | |
| self._tokenizer = PreTrainedTokenizerFast.from_pretrained( | |
| settings.model_id, | |
| trust_remote_code=True | |
| ) | |
| else: | |
| raise | |
| # Initialize cache template | |
| self._init_cache_template() | |
| # Check if model uses position_ids | |
| input_names = {inp.name for inp in self._session.get_inputs()} | |
| self._use_position_ids = "position_ids" in input_names | |
| load_time = time.time() - start_time | |
| logger.info("=" * 60) | |
| logger.info(f"✓ Model loaded in {load_time:.2f}s") | |
| logger.info(f" Threads: {settings.num_threads}") | |
| logger.info(f" Provider: CPU") | |
| logger.info("=" * 60) | |
| def _init_cache_template(self) -> None: | |
| """Initialize KV cache template.""" | |
| self._cache_template = {} | |
| for inp in self._session.get_inputs(): | |
| if inp.name in {"input_ids", "attention_mask", "position_ids"}: | |
| continue | |
| shape = [d if isinstance(d, int) else 1 for d in inp.shape] | |
| for i, d in enumerate(inp.shape): | |
| if isinstance(d, str) and "sequence" in d.lower(): | |
| shape[i] = 0 | |
| dtype = ONNX_DTYPE.get(inp.type, np.float32) | |
| self._cache_template[inp.name] = (shape, dtype) | |
| def _create_empty_cache(self) -> Dict[str, np.ndarray]: | |
| """Create a new empty KV cache.""" | |
| return { | |
| name: np.zeros(shape, dtype=dtype) | |
| for name, (shape, dtype) in self._cache_template.items() | |
| } | |
| def session(self): | |
| if self._session is None: | |
| raise RuntimeError("Model not loaded") | |
| return self._session | |
| def tokenizer(self): | |
| if self._tokenizer is None: | |
| raise RuntimeError("Tokenizer not loaded") | |
| return self._tokenizer | |
| def generate( | |
| self, | |
| input_ids: np.ndarray, | |
| max_tokens: int = 512, | |
| temperature: float = 0.1, | |
| top_k: int = 50, | |
| top_p: float = 0.1, | |
| stop_tokens: Optional[List[int]] = None | |
| ) -> List[int]: | |
| """Generate tokens using ONNX model.""" | |
| if stop_tokens is None: | |
| stop_tokens = [self._tokenizer.eos_token_id] | |
| cache = self._create_empty_cache() | |
| seq_len = input_ids.shape[1] | |
| generated_tokens = [] | |
| for step in range(max_tokens): | |
| if step == 0: | |
| ids = input_ids | |
| pos = np.arange(seq_len, dtype=np.int64).reshape(1, -1) | |
| else: | |
| ids = np.array([[generated_tokens[-1]]], dtype=np.int64) | |
| pos = np.array([[seq_len + len(generated_tokens) - 1]], dtype=np.int64) | |
| attn_mask = np.ones((1, seq_len + len(generated_tokens)), dtype=np.int64) | |
| feed = {"input_ids": ids, "attention_mask": attn_mask, **cache} | |
| if self._use_position_ids: | |
| feed["position_ids"] = pos | |
| outputs = self._session.run(None, feed) | |
| # Get logits and apply temperature | |
| logits = outputs[0][0, -1] | |
| if temperature > 0: | |
| logits = logits / temperature | |
| # Apply top-k | |
| if top_k > 0: | |
| indices_to_remove = np.argsort(logits)[:-top_k] | |
| logits[indices_to_remove] = -np.inf | |
| # Apply top-p (nucleus sampling) | |
| if top_p < 1.0: | |
| sorted_indices = np.argsort(logits)[::-1] | |
| sorted_logits = logits[sorted_indices] | |
| probs = np.exp(sorted_logits - np.max(sorted_logits)) | |
| probs = probs / probs.sum() | |
| cumulative_probs = np.cumsum(probs) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy() | |
| sorted_indices_to_remove[0] = False | |
| indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
| logits[indices_to_remove] = -np.inf | |
| # Sample | |
| probs = np.exp(logits - np.max(logits)) | |
| probs = probs / probs.sum() | |
| next_token = int(np.random.choice(len(probs), p=probs)) | |
| else: | |
| next_token = int(np.argmax(logits)) | |
| generated_tokens.append(next_token) | |
| # Update cache | |
| for i, out in enumerate(self._session.get_outputs()[1:], 1): | |
| name = out.name.replace("present_conv", "past_conv").replace("present.", "past_key_values.") | |
| if name in cache: | |
| cache[name] = outputs[i] | |
| if next_token in stop_tokens: | |
| break | |
| return generated_tokens | |
| def generate_stream( | |
| self, | |
| input_ids: np.ndarray, | |
| max_tokens: int = 2000, | |
| temperature: float = 0.1, | |
| top_k: int = 50, | |
| top_p: float = 0.1, | |
| stop_tokens: Optional[List[int]] = None | |
| ): | |
| """Fixed and optimized streaming generation.""" | |
| if stop_tokens is None: | |
| stop_tokens = [self._tokenizer.eos_token_id] | |
| cache = self._create_empty_cache() | |
| seq_len = input_ids.shape[1] | |
| # Pre-allocate inputs | |
| max_possible_len = seq_len + max_tokens | |
| attn_mask = np.ones((1, max_possible_len), dtype=np.int64) | |
| # Pre-compute flags | |
| use_temp = temperature > 0 | |
| use_top_k = top_k > 0 | |
| use_top_p = top_p < 1.0 | |
| # Reuse this dict to avoid garbage collection overhead | |
| feed = {} | |
| # Initialize token storage | |
| generated_tokens = [] | |
| for step in range(max_tokens): | |
| current_len = seq_len + step | |
| # Input Preparation | |
| if step == 0: | |
| ids = input_ids | |
| if self._use_position_ids: | |
| pos = np.arange(seq_len, dtype=np.int64).reshape(1, -1) | |
| else: | |
| # FIX: Access list directly. O(1) speed, no UnboundLocalError. | |
| ids = np.array([[generated_tokens[-1]]], dtype=np.int64) | |
| if self._use_position_ids: | |
| pos = np.array([[current_len - 1]], dtype=np.int64) | |
| # Update Feed Dict (In-place update is faster than creating new dict) | |
| feed.clear() | |
| feed["input_ids"] = ids | |
| feed["attention_mask"] = attn_mask[:, :current_len] | |
| if self._use_position_ids: | |
| feed["position_ids"] = pos | |
| feed.update(cache) # Merging cache is unavoidable | |
| # Inference | |
| outputs = self._session.run(None, feed) | |
| logits = outputs[0][0, -1] | |
| # --- Ultra-Fast Sampling --- | |
| if use_temp: | |
| logits /= temperature | |
| # 1. Top-K Selection (Partitioning is O(N)) | |
| if use_top_k and top_k < len(logits): | |
| # Moves largest k elements to the right; unordered | |
| top_k_idx = np.argpartition(logits, -top_k)[-top_k:] | |
| # Mask everything else | |
| mask = np.ones(logits.shape, dtype=bool) | |
| mask[top_k_idx] = False | |
| logits[mask] = -np.inf | |
| # 2. Top-P (Nucleus) | |
| if use_top_p: | |
| valid_mask = logits > -np.inf | |
| if valid_mask.any(): | |
| valid_logits = logits[valid_mask] | |
| valid_indices = np.where(valid_mask)[0] | |
| # Sort only the valid candidates (small N) | |
| sorted_indices = np.argsort(valid_logits)[::-1] | |
| sorted_logits = valid_logits[sorted_indices] | |
| # Softmax on valid set | |
| exp_logits = np.exp(sorted_logits - np.max(sorted_logits)) | |
| probs = exp_logits / exp_logits.sum() | |
| cumulative = np.cumsum(probs) | |
| # Find cutoff | |
| cutoff = np.searchsorted(cumulative, top_p) | |
| # Ensure we keep at least one token | |
| cutoff = min(cutoff + 1, len(sorted_logits)) | |
| # Filter indices | |
| accepted_indices = sorted_indices[:cutoff] | |
| accepted_probs = probs[:cutoff] | |
| accepted_probs /= accepted_probs.sum() # Re-normalize | |
| # Fast Weighted Sample: Use searchsorted instead of np.random.choice | |
| # This avoids Python overhead in np.random.choice | |
| sample_idx = np.searchsorted(np.cumsum(accepted_probs), np.random.rand()) | |
| next_token = int(valid_indices[accepted_indices[sample_idx]]) | |
| else: | |
| next_token = int(np.argmax(logits)) | |
| else: | |
| # Fallback if only Top-K was used | |
| valid_mask = logits > -np.inf | |
| valid_logits = logits[valid_mask] | |
| valid_indices = np.where(valid_mask)[0] | |
| exp_logits = np.exp(valid_logits - np.max(valid_logits)) | |
| probs = exp_logits / exp_logits.sum() | |
| sample_idx = np.searchsorted(np.cumsum(probs), np.random.rand()) | |
| next_token = int(valid_indices[sample_idx]) | |
| else: | |
| next_token = int(np.argmax(logits)) | |
| # Storage | |
| generated_tokens.append(next_token) | |
| yield next_token | |
| if next_token in stop_tokens: | |
| break | |
| # Update Cache | |
| for i, out in enumerate(self._session.get_outputs()[1:], 1): | |
| name = out.name.replace("present_conv", "past_conv").replace("present.", "past_key_values.") | |
| if name in cache: | |
| cache[name] = outputs[i] | |
| def unload(self) -> None: | |
| """Unload model from memory.""" | |
| with self._lock: | |
| if self._session is not None: | |
| del self._session | |
| del self._tokenizer | |
| self._session = None | |
| self._tokenizer = None | |
| logger.info("Model unloaded") | |
| # Global model manager | |
| model_manager = ONNXModelManager() | |
| # ============================================================================== | |
| # Application Lifecycle | |
| # ============================================================================== | |
| async def lifespan(app: FastAPI): | |
| """Application lifespan handler.""" | |
| logger.info("Starting LFM2.5 API Server (ONNX Runtime)...") | |
| loop = asyncio.get_event_loop() | |
| await loop.run_in_executor(None, model_manager.load_model) | |
| yield | |
| logger.info("Shutting down...") | |
| model_manager.unload() | |
| # ============================================================================== | |
| # FastAPI Application | |
| # ============================================================================== | |
| app = FastAPI( | |
| title=settings.app_name, | |
| description="Fast CPU inference for LiquidAI LFM2.5-1.2B-Instruct using ONNX Runtime", | |
| version=settings.app_version, | |
| lifespan=lifespan, | |
| docs_url="/docs", | |
| redoc_url="/redoc", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allow all origins | |
| allow_credentials=False, # Must be False when using wildcard origins | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| expose_headers=["*"], # Expose all headers for SSE | |
| ) | |
| # Custom middleware to handle null origin (file:// protocol) | |
| async def add_cors_for_null_origin(request: Request, call_next): | |
| """Handle CORS for null origin (when HTML is opened from file://).""" | |
| origin = request.headers.get("origin", "") | |
| response = await call_next(request) | |
| # If origin is null (file:// protocol), add explicit CORS headers | |
| if origin == "null" or not origin: | |
| response.headers["Access-Control-Allow-Origin"] = "*" | |
| response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS" | |
| response.headers["Access-Control-Allow-Headers"] = "*" | |
| response.headers["Access-Control-Expose-Headers"] = "*" | |
| return response | |
| # ============================================================================== | |
| # Helper Functions | |
| # ============================================================================== | |
| def generate_id() -> str: | |
| return f"chatcmpl-{uuid.uuid4().hex[:12]}" | |
| async def stream_chat_completion(request: ChatCompletionRequest) -> AsyncGenerator[str, None]: | |
| """ | |
| Optimized 'Zero-Latency' Streaming. | |
| Uses asyncio.Queue + call_soon_threadsafe to eliminate polling and blocking. | |
| """ | |
| request_id = generate_id() | |
| created = int(time.time()) | |
| # Capture the running event loop to bridge the background thread safely | |
| loop = asyncio.get_running_loop() | |
| # Async Queue allows 'await get()' which is non-blocking and instant | |
| async_queue = asyncio.Queue() | |
| tokenizer = model_manager.tokenizer | |
| # Prepare inputs | |
| messages = [{"role": m.role, "content": m.content} for m in request.messages] | |
| prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| input_ids = np.array([tokenizer.encode(prompt, add_special_tokens=False)], dtype=np.int64) | |
| # Config | |
| max_tokens = request.max_tokens or settings.max_tokens | |
| temperature = request.temperature if request.temperature is not None else settings.temperature | |
| top_k = request.top_k if request.top_k is not None else settings.top_k | |
| top_p = request.top_p if request.top_p is not None else settings.top_p | |
| # Prepare stop tokens | |
| stop_tokens = [tokenizer.eos_token_id] | |
| if request.stop: | |
| if isinstance(request.stop, str): | |
| encoded = tokenizer.encode(request.stop, add_special_tokens=False) | |
| if encoded: | |
| stop_tokens.append(encoded[0]) | |
| elif isinstance(request.stop, list): | |
| for stop_str in request.stop: | |
| encoded = tokenizer.encode(stop_str, add_special_tokens=False) | |
| if encoded: | |
| stop_tokens.append(encoded[0]) | |
| def generate_tokens(): | |
| """ | |
| Background Thread: Pushes data directly into the async loop. | |
| """ | |
| try: | |
| # Use the optimized generate_stream from ONNXModelManager | |
| for token in model_manager.generate_stream( | |
| input_ids, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| stop_tokens=stop_tokens | |
| ): | |
| # CRITICAL: Schedule the 'put' on the main loop immediately | |
| # This wakes up the awaiter instantly—0ms latency overhead. | |
| loop.call_soon_threadsafe(async_queue.put_nowait, ("token", token)) | |
| except Exception as e: | |
| logger.error(f"Stream generation error: {e}") | |
| loop.call_soon_threadsafe(async_queue.put_nowait, ("error", str(e))) | |
| finally: | |
| loop.call_soon_threadsafe(async_queue.put_nowait, ("done", None)) | |
| # Start generation in background thread | |
| threading.Thread(target=generate_tokens, daemon=True).start() | |
| # Main Async Loop - No timeouts, no sleeps, pure event awaiting | |
| try: | |
| while True: | |
| # waits until data is pushed; yields control to other users while waiting | |
| msg_type, data = await async_queue.get() | |
| if msg_type == "token": | |
| text = tokenizer.decode([data], skip_special_tokens=True) | |
| if text: | |
| chunk = { | |
| "id": request_id, | |
| "object": "chat.completion.chunk", | |
| "created": created, | |
| "model": request.model, | |
| "choices": [{ | |
| "index": 0, | |
| "delta": {"content": text}, | |
| "finish_reason": None | |
| }] | |
| } | |
| # Yield in the format expected by EventSourceResponse | |
| yield {"data": json.dumps(chunk)} | |
| elif msg_type == "done": | |
| final = { | |
| "id": request_id, | |
| "object": "chat.completion.chunk", | |
| "created": created, | |
| "model": request.model, | |
| "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}] | |
| } | |
| yield {"data": json.dumps(final)} | |
| yield {"data": "[DONE]"} | |
| break | |
| elif msg_type == "error": | |
| logger.error(f"Stream error: {data}") | |
| yield {"data": json.dumps({"error": {"message": data}})} | |
| break | |
| except asyncio.CancelledError: | |
| logger.info(f"Stream cancelled for request {request_id[:8]}") | |
| raise | |
| except Exception as e: | |
| logger.error(f"Streaming error: {e}") | |
| yield {"data": json.dumps({"error": {"message": str(e)}})} | |
| # ============================================================================== | |
| # API Endpoints | |
| # ============================================================================== | |
| async def health_check(): | |
| """Health check with model status.""" | |
| return { | |
| "status": "ready" if model_manager.is_loaded else "loading", | |
| "model": { | |
| "id": settings.model_id, | |
| "variant": settings.model_variant, | |
| "loaded": model_manager.is_loaded, | |
| "backend": "ONNX Runtime" | |
| }, | |
| "server": { | |
| "name": settings.app_name, | |
| "version": settings.app_version, | |
| "port": settings.port | |
| } | |
| } | |
| async def health(): | |
| if not model_manager.is_loaded: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| return {"status": "healthy"} | |
| async def list_models(): | |
| return ModelListResponse( | |
| data=[ | |
| ModelInfo(id="lfm", created=int(time.time())), | |
| ModelInfo(id="lfm-2.5-1.2b-instruct-onnx", created=int(time.time())) | |
| ] | |
| ) | |
| async def chat_completions(request: ChatCompletionRequest): | |
| """OpenAI-compatible chat completion.""" | |
| if not model_manager.is_loaded: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| if request.stream: | |
| return EventSourceResponse( | |
| stream_chat_completion(request), | |
| media_type="text/event-stream", | |
| ping=30000, # 30 second keep-alive | |
| ping_message_factory=lambda: '{"type": "ping"}' | |
| ) | |
| try: | |
| tokenizer = model_manager.tokenizer | |
| messages = [{"role": m.role, "content": m.content} for m in request.messages] | |
| prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| input_ids = np.array([tokenizer.encode(prompt, add_special_tokens=False)], dtype=np.int64) | |
| max_tokens = request.max_tokens or settings.max_tokens | |
| temperature = request.temperature if request.temperature is not None else settings.temperature | |
| top_k = request.top_k if request.top_k is not None else settings.top_k | |
| top_p = request.top_p if request.top_p is not None else settings.top_p | |
| start_time = time.time() | |
| loop = asyncio.get_event_loop() | |
| tokens = await loop.run_in_executor( | |
| None, | |
| lambda: model_manager.generate( | |
| input_ids, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p | |
| ) | |
| ) | |
| response_text = tokenizer.decode(tokens, skip_special_tokens=True) | |
| gen_time = time.time() - start_time | |
| logger.debug(f"Generated {len(tokens)} tokens in {gen_time:.2f}s") | |
| return ChatCompletionResponse( | |
| id=generate_id(), | |
| created=int(time.time()), | |
| model=request.model, | |
| choices=[ | |
| ChatCompletionChoice( | |
| index=0, | |
| message=ChatMessage(role="assistant", content=response_text), | |
| finish_reason="stop" | |
| ) | |
| ], | |
| usage={ | |
| "prompt_tokens": input_ids.shape[1], | |
| "completion_tokens": len(tokens), | |
| "total_tokens": input_ids.shape[1] + len(tokens) | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"Chat completion error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def completions(request: CompletionRequest): | |
| """OpenAI-compatible text completion.""" | |
| if not model_manager.is_loaded: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| tokenizer = model_manager.tokenizer | |
| input_ids = np.array([tokenizer.encode(request.prompt)], dtype=np.int64) | |
| max_tokens = request.max_tokens or settings.max_tokens | |
| temperature = request.temperature if request.temperature is not None else settings.temperature | |
| top_k = request.top_k if request.top_k is not None else settings.top_k | |
| top_p = request.top_p if request.top_p is not None else settings.top_p | |
| loop = asyncio.get_event_loop() | |
| tokens = await loop.run_in_executor( | |
| None, | |
| lambda: model_manager.generate( | |
| input_ids, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p | |
| ) | |
| ) | |
| response_text = tokenizer.decode(tokens, skip_special_tokens=True) | |
| return CompletionResponse( | |
| id=generate_id(), | |
| created=int(time.time()), | |
| model=request.model, | |
| choices=[ | |
| CompletionChoice(index=0, text=response_text, finish_reason="stop") | |
| ], | |
| usage={ | |
| "prompt_tokens": input_ids.shape[1], | |
| "completion_tokens": len(tokens), | |
| "total_tokens": input_ids.shape[1] + len(tokens) | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"Completion error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ============================================================================== | |
| # WebSocket Autocomplete Endpoint | |
| # ============================================================================== | |
| async def ws_autocomplete(websocket: WebSocket): | |
| """ | |
| Persistent WebSocket endpoint for inline text predictions. | |
| Protocol: | |
| Client sends: { "context": "last ~300 chars before cursor" } | |
| Server sends: { "suggestion": "predicted next words" } | |
| Client sends: { "type": "ping" } → Server sends: { "type": "pong" } | |
| Design decisions: | |
| - Persistent connection: avoids reconnect overhead per prediction | |
| - Low temperature (0.3): more deterministic for inline suggestions | |
| - Max 20 tokens: keeps predictions short and fast (~800ms) | |
| - Stop on sentence boundaries (., !, ?, newline): natural break points | |
| - Uses "raw completion" prompt (no chat template): faster, less overhead | |
| """ | |
| await websocket.accept() | |
| logger.info("[ws/autocomplete] Client connected") | |
| try: | |
| while True: | |
| # Wait for a prediction request from the client | |
| raw = await websocket.receive_text() | |
| try: | |
| data = json.loads(raw) | |
| except json.JSONDecodeError: | |
| await websocket.send_text(json.dumps({"error": "Invalid JSON"})) | |
| continue | |
| # Heartbeat: respond to pings immediately | |
| if data.get("type") == "ping": | |
| await websocket.send_text(json.dumps({"type": "pong"})) | |
| continue | |
| context = data.get("context", "").strip() | |
| if not context: | |
| await websocket.send_text(json.dumps({"suggestion": ""})) | |
| continue | |
| if not model_manager.is_loaded: | |
| await websocket.send_text(json.dumps({"suggestion": ""})) | |
| continue | |
| # Generate prediction using the model | |
| try: | |
| tokenizer = model_manager.tokenizer | |
| max_tokens = min(data.get("max_tokens", 20), 30) # Cap at 30 | |
| # Use the chat template since this is an Instruct model. | |
| # Without it, the model repeats or hallucinates — it needs | |
| # the instruction format to understand it should CONTINUE text. | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "You are a writing assistant. The user will give you text from a document. " | |
| "Your job is to predict the next few words or sentence that naturally continues the text. " | |
| "ONLY output the continuation — do NOT repeat any of the given text. " | |
| "Keep it concise (1-2 short sentences max). " | |
| "Match the tone, style, and language of the existing text." | |
| ) | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"Continue this text:\n\n{context}" | |
| } | |
| ] | |
| prompt = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| input_ids = np.array( | |
| [tokenizer.encode(prompt, add_special_tokens=False)], | |
| dtype=np.int64 | |
| ) | |
| # Truncate input to last 512 tokens to keep inference fast | |
| if input_ids.shape[1] > 512: | |
| input_ids = input_ids[:, -512:] | |
| # Generate in a background thread to keep the event loop free | |
| loop = asyncio.get_running_loop() | |
| tokens = await loop.run_in_executor( | |
| None, | |
| lambda: model_manager.generate( | |
| input_ids, | |
| max_tokens=max_tokens, | |
| temperature=0.4, # Slightly creative but still focused | |
| top_k=40, | |
| top_p=0.9, | |
| stop_tokens=[ | |
| tokenizer.eos_token_id, | |
| # Stop at paragraph boundary | |
| *tokenizer.encode("\n", add_special_tokens=False), | |
| ] | |
| ) | |
| ) | |
| suggestion = tokenizer.decode(tokens, skip_special_tokens=True).strip() | |
| # Clean up: remove any accidental repetition of the context | |
| # (sometimes the model echoes the last few words) | |
| if suggestion and context: | |
| # If suggestion starts with the end of context, trim the overlap | |
| for overlap_len in range(min(len(suggestion), 30), 0, -1): | |
| if context.endswith(suggestion[:overlap_len]): | |
| suggestion = suggestion[overlap_len:].strip() | |
| break | |
| await websocket.send_text(json.dumps({"suggestion": suggestion})) | |
| except Exception as e: | |
| logger.error(f"[ws/autocomplete] Prediction error: {e}") | |
| await websocket.send_text(json.dumps({"suggestion": ""})) | |
| except WebSocketDisconnect: | |
| logger.info("[ws/autocomplete] Client disconnected") | |
| except Exception as e: | |
| logger.error(f"[ws/autocomplete] Connection error: {e}") | |
| try: | |
| await websocket.close(code=1011, reason="Internal error") | |
| except Exception: | |
| pass | |
| async def global_exception_handler(request: Request, exc: Exception): | |
| logger.error(f"Unhandled exception: {exc}", exc_info=True) | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": {"message": "Internal server error", "type": "server_error"}} | |
| ) | |
| # ============================================================================== | |
| # Main Entry Point | |
| # ============================================================================== | |
| if __name__ == "__main__": | |
| import uvicorn | |
| print(f""" | |
| ╔═══════════════════════════════════════════════════════════════╗ | |
| ║ LFM2.5 FastAPI Backend (ONNX Runtime) ║ | |
| ╠═══════════════════════════════════════════════════════════════╣ | |
| ║ Model: LiquidAI/LFM2.5-1.2B-Instruct-ONNX ║ | |
| ║ Variant: Q8 (~95% accuracy, fast CPU inference) ║ | |
| ║ Host: {settings.host}:{settings.port} ║ | |
| ║ Docs: http://{settings.host}:{settings.port}/docs ║ | |
| ╚═══════════════════════════════════════════════════════════════╝ | |
| """) | |
| uvicorn.run( | |
| "app:app", | |
| host=settings.host, | |
| port=settings.port, | |
| log_level=settings.log_level, | |
| workers=1, | |
| ) | |