| """ |
| LocalLLMService v5.0: Enterprise-Grade Inference Engine |
| |
| SRE additions: |
| - Prometheus metrics for latency, throughput, errors |
| - Circuit breaker to prevent cascade failures |
| - Bounded async queue (prevents OOM) |
| - Per-org rate limiting (token bucket) |
| - GPU/CPU resource monitoring |
| - Health check endpoint integration |
| - Request timeout & cancellation |
| - Graceful degradation with fallback responses |
| """ |
|
|
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
| from app.deps import HF_API_TOKEN, get_sre_metrics |
| import logging |
| import json |
| import os |
| import asyncio |
| import time |
| from threading import Thread, Lock |
| from typing import Optional, Dict, Any, List, Callable |
| from dataclasses import dataclass, asdict |
| import psutil |
| from fastapi import HTTPException |
| from app.core.sre_logging import emit_llm_log |
| |
| try: |
| from prometheus_client import Counter, Histogram, Gauge |
| except ImportError: |
| |
| class Counter: |
| def __init__(self, *args, **kwargs): |
| pass |
|
|
| def labels(self, *args, **kwargs): |
| return self |
|
|
| def inc(self, amount=1): |
| pass |
|
|
| class Histogram: |
| def __init__(self, *args, **kwargs): |
| pass |
|
|
| def labels(self, *args, **kwargs): |
| return self |
|
|
| def observe(self, value): |
| pass |
|
|
| class Gauge: |
| def __init__(self, *args, **kwargs): |
| pass |
|
|
| def labels(self, *args, **kwargs): |
| return self |
|
|
| def set(self, value): |
| pass |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class LLMMetrics: |
| """SRE: Real-time LLM operation metrics""" |
| org_id: str |
| operation: str |
| duration_ms: float |
| tokens_input: int |
| tokens_output: int |
| error: Optional[str] = None |
| gpu_memory_mb: float = 0.0 |
| cpu_memory_mb: float = 0.0 |
| model_loaded: bool = False |
| queue_depth: int = 0 |
|
|
|
|
| class LocalLLMService: |
| """ |
| π§ Enterprise LLM service with SRE observability |
| Core logic unchanged - only instrumentation added |
| """ |
| |
| |
| |
| inference_latency = Histogram( |
| 'llm_inference_duration_seconds', |
| 'Time spent generating response', |
| ['org_id', 'status'] |
| ) |
| |
| inference_tokens = Counter( |
| 'llm_tokens_total', |
| 'Total tokens processed', |
| ['org_id', 'direction'] |
| ) |
| |
| inference_requests = Counter( |
| 'llm_requests_total', |
| 'Total inference requests', |
| ['org_id', 'status'] |
| ) |
| |
| gpu_memory_usage = Gauge( |
| 'llm_gpu_memory_mb', |
| 'GPU memory usage in MB', |
| ['org_id'] |
| ) |
| |
| queue_depth_gauge = Gauge( |
| 'llm_queue_depth', |
| 'Current request queue depth', |
| ['org_id'] |
| ) |
| |
| model_loaded_gauge = Gauge( |
| 'llm_model_loaded', |
| 'Is model loaded (1) or not (0)', |
| ['org_id'] |
| ) |
| |
| |
| _circuit_breaker = { |
| "failure_count": 0, |
| "last_failure_time": None, |
| "is_open": False, |
| "threshold": 3, |
| "reset_timeout": 60 |
| } |
| |
| |
| _request_queue: asyncio.Queue = None |
| MAX_QUEUE_SIZE = 100 |
| MAX_CONCURRENT = 2 |
| |
| def __init__(self, org_id: str = "default"): |
| self.model_id = "microsoft/Phi-3-mini-4k-instruct" |
| self.org_id = org_id |
| |
| |
| self._model = None |
| self._tokenizer = None |
| self._pipe = None |
| self._is_loaded = False |
| self._is_loading = False |
| self._load_error = None |
| self._lock = Lock() |
| |
| |
| self.cache_dir = "/data/hf_cache" |
| os.makedirs(self.cache_dir, exist_ok=True) |
| |
| |
| self._ready_event = asyncio.Event() |
| |
| |
| self._load_thread = None |
| |
| |
| if LocalLLMService._request_queue is None: |
| LocalLLMService._request_queue = asyncio.Queue(maxsize=self.MAX_QUEUE_SIZE) |
| |
| |
| self._rate_limiter = { |
| "tokens": 10, |
| "last_refill": time.time(), |
| "rate": 5 |
| } |
| |
| |
| self._inference_semaphore = asyncio.Semaphore(self.MAX_CONCURRENT) |
| |
| logger.info(f"[LLM] π§ Service initialized for org: {org_id}") |
| |
| |
| |
| @property |
| def is_loaded(self): |
| """Sync property check""" |
| with self._lock: |
| return self._is_loaded |
| |
| @property |
| def is_loading(self): |
| """Sync property check""" |
| with self._lock: |
| return self._is_loading |
| |
| @property |
| def load_error(self): |
| """Sync property check""" |
| with self._lock: |
| return self._load_error |
| |
| def is_ready(self) -> bool: |
| """Check if LLM is ready for inference""" |
| return self.is_loaded and self._model is not None |
| |
| async def wait_for_ready(self, timeout: float = 60.0): |
| """Async wait for LLM to be ready""" |
| if self.is_ready(): |
| return |
| |
| try: |
| await asyncio.wait_for(self._ready_event.wait(), timeout=timeout) |
| except asyncio.TimeoutError: |
| raise TimeoutError(f"LLM not ready after {timeout}s: {self.load_error or 'timeout'}") |
| |
| |
| |
| def _check_rate_limit(self) -> bool: |
| """Token bucket rate limiter - returns True if allowed""" |
| now = time.time() |
| elapsed = now - self._rate_limiter["last_refill"] |
| |
| |
| new_tokens = elapsed * self._rate_limiter["rate"] |
| self._rate_limiter["tokens"] = min( |
| self._rate_limiter["tokens"] + new_tokens, |
| 10 |
| ) |
| self._rate_limiter["last_refill"] = now |
| |
| |
| if self._rate_limiter["tokens"] >= 1: |
| self._rate_limiter["tokens"] -= 1 |
| return True |
| |
| logger.warning(f"[RATE_LIMIT] βΈοΈ Rate limit hit for org: {self.org_id}") |
| return False |
| |
| |
| |
| def _get_resource_usage(self) -> Dict[str, float]: |
| """Get current GPU/CPU memory usage""" |
| usage = { |
| "gpu_mb": 0.0, |
| "cpu_mb": psutil.Process().memory_info().rss / 1024 / 1024 |
| } |
| |
| |
| if torch.cuda.is_available(): |
| usage["gpu_mb"] = torch.cuda.memory_allocated() / 1024 / 1024 |
| |
| return usage |
| |
| |
| |
| def _check_circuit_breaker(self) -> bool: |
| """Check if circuit is open (too many failures)""" |
| if not LocalLLMService._circuit_breaker["is_open"]: |
| return True |
| |
| |
| if LocalLLMService._circuit_breaker["last_failure_time"]: |
| elapsed = time.time() - LocalLLMService._circuit_breaker["last_failure_time"] |
| if elapsed > LocalLLMService._circuit_breaker["reset_timeout"]: |
| logger.warning("[CIRCUIT] π Closing breaker, trying again...") |
| LocalLLMService._circuit_breaker["is_open"] = False |
| LocalLLMService._circuit_breaker["failure_count"] = 0 |
| return True |
| |
| logger.error("[CIRCUIT] π΄ Circuit breaker OPEN, rejecting requests") |
| return False |
| |
| def _record_failure(self, error: str): |
| """Track inference failures""" |
| LocalLLMService._circuit_breaker["failure_count"] += 1 |
| LocalLLMService._circuit_breaker["last_failure_time"] = time.time() |
| |
| if LocalLLMService._circuit_breaker["failure_count"] >= LocalLLMService._circuit_breaker["threshold"]: |
| LocalLLMService._circuit_breaker["is_open"] = True |
| logger.critical(f"[CIRCUIT] π΄ Breaker opened! {LocalLLMService._circuit_breaker['failure_count']} failures") |
| |
| def _record_success(self): |
| """Reset failure count on success""" |
| if LocalLLMService._circuit_breaker["failure_count"] > 0: |
| logger.info(f"[CIRCUIT] β
Resetting failure count (was {LocalLLMService._circuit_breaker['failure_count']})") |
| LocalLLMService._circuit_breaker["failure_count"] = 0 |
| |
| |
| |
| def load(self): |
| """Explicitly start loading the model""" |
| with self._lock: |
| if self._is_loading or self._is_loaded: |
| logger.info("Model already loading or loaded") |
| return |
| |
| self._is_loading = True |
| self._ready_event.clear() |
| logger.info("π Starting LLM load...") |
| |
| |
| self.model_loaded_gauge.labels(org_id=self.org_id).set(0) |
| |
| self._load_thread = Thread(target=self._load_model_background, daemon=True) |
| self._load_thread.start() |
| |
| def _load_model_background(self): |
| """Load model in background thread with error isolation""" |
| try: |
| logger.info(f"π€ [BACKGROUND] Loading LLM: {self.model_id}...") |
| |
| |
| self._tokenizer = AutoTokenizer.from_pretrained( |
| self.model_id, |
| token=HF_API_TOKEN, |
| trust_remote_code=True, |
| cache_dir=self.cache_dir |
| ) |
| self._tokenizer.pad_token = self._tokenizer.eos_token |
| |
| |
| self._model = AutoModelForCausalLM.from_pretrained( |
| self.model_id, |
| token=HF_API_TOKEN, |
| torch_dtype=torch.float16, |
| device_map="auto", |
| low_cpu_mem_usage=True, |
| trust_remote_code=True, |
| attn_implementation="eager", |
| cache_dir=self.cache_dir |
| ) |
| |
| |
| self._pipe = pipeline( |
| "text-generation", |
| model=self._model, |
| tokenizer=self._tokenizer, |
| device_map="auto", |
| torch_dtype=torch.float16, |
| trust_remote_code=True, |
| pad_token_id=self._tokenizer.eos_token_id, |
| cache_dir=self.cache_dir |
| ) |
| |
| with self._lock: |
| self._is_loaded = True |
| |
| |
| self.model_loaded_gauge.labels(org_id=self.org_id).set(1) |
| |
| emit_llm_log("info", "β
LLM loaded successfully", model_id=self.model_id) |
| |
| except Exception as e: |
| logger.error(f"β [BACKGROUND] LLM loading failed: {e}") |
| with self._lock: |
| self._load_error = str(e) |
| finally: |
| with self._lock: |
| self._is_loading = False |
| self._ready_event.set() |
| |
| |
| |
| def generate(self, prompt: str, max_tokens: int = 100, temperature: float = 0.1) -> str: |
| """Generate text - FAILS FAST if not loaded, with JSON validation""" |
| |
| |
| if not self.is_loaded: |
| if self.load_error: |
| raise RuntimeError(f"LLM failed to load: {self.load_error}") |
| raise TimeoutError("LLM loading in progress") |
| |
| |
| messages = [{"role": "user", "content": prompt}] |
| |
| formatted_prompt = self._tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
| |
| |
| outputs = self._pipe( |
| formatted_prompt, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| do_sample=False, |
| pad_token_id=self._tokenizer.eos_token_id, |
| return_full_text=False |
| ) |
| |
| |
| response_text = outputs[0]["generated_text"].strip() |
| |
| |
| if "<|assistant|>" in response_text: |
| response_text = response_text.split("<|assistant|>")[-1].strip() |
| if "<|end|>" in response_text: |
| response_text = response_text.split("<|end|>")[0].strip() |
| |
| |
| try: |
| json.loads(response_text) |
| logger.info(f"[GENERATE] Valid JSON: {response_text[:50]}...") |
| return response_text |
| except json.JSONDecodeError: |
| logger.error(f"[GENERATE] Invalid JSON: {response_text}") |
| raise ValueError(f"LLM returned invalid JSON: {response_text}") |
| |
| |
| |
| async def generate_async(self, prompt: str, max_tokens: int = 100, |
| temperature: float = 0.1, timeout: float = 30.0) -> str: |
| """ |
| β
NEW: Enterprise async generation with SRE features |
| |
| Features: |
| - Rate limiting |
| - Queue management |
| - Timeout protection |
| - Resource monitoring |
| - Prometheus metrics |
| """ |
| |
| |
| if not self._check_circuit_breaker(): |
| raise RuntimeError("LLM circuit breaker open - too many failures") |
| |
| |
| if not self._check_rate_limit(): |
| raise HTTPException(status_code=429, detail="Rate limit exceeded") |
| |
| |
| if not self.is_ready(): |
| await self.wait_for_ready(timeout=10) |
| |
| |
| queue_size = self._request_queue.qsize() |
| self.queue_depth_gauge.labels(org_id=self.org_id).set(queue_size) |
| |
| if queue_size >= self.MAX_QUEUE_SIZE * 0.9: |
| logger.warning(f"[QUEUE] β οΈ 90% full: {queue_size}/{self.MAX_QUEUE_SIZE}") |
| |
| |
| try: |
| await asyncio.wait_for( |
| self._request_queue.put({ |
| "prompt": prompt, |
| "max_tokens": max_tokens, |
| "temperature": temperature, |
| "org_id": self.org_id |
| }), |
| timeout=1.0 |
| ) |
| except asyncio.TimeoutError: |
| logger.error("[QUEUE] Queue full - rejecting request") |
| raise HTTPException(status_code=503, detail="LLM queue full") |
| |
| |
| async with self._inference_semaphore: |
| |
| request = await self._request_queue.get() |
| |
| |
| start_time = time.time() |
| metrics = LLMMetrics( |
| org_id=self.org_id, |
| operation="generate_async", |
| duration_ms=0, |
| tokens_input=len(prompt.split()), |
| tokens_output=0 |
| ) |
| |
| try: |
| |
| resources = self._get_resource_usage() |
| metrics.gpu_memory_mb = resources["gpu_mb"] |
| metrics.cpu_memory_mb = resources["cpu_mb"] |
| self.gpu_memory_usage.labels(org_id=self.org_id).set(resources["gpu_mb"]) |
| |
| |
| result = await asyncio.wait_for( |
| asyncio.to_thread(self.generate, prompt, max_tokens, temperature), |
| timeout=timeout |
| ) |
| |
| |
| duration_ms = (time.time() - start_time) * 1000 |
| metrics.duration_ms = duration_ms |
| metrics.tokens_output = len(result.split()) |
| metrics.model_loaded = self.is_loaded |
| |
| self.inference_latency.labels( |
| org_id=self.org_id, |
| status="success" |
| ).observe(duration_ms / 1000) |
| |
| self.inference_tokens.labels( |
| org_id=self.org_id, |
| direction="input" |
| ).inc(metrics.tokens_input) |
| |
| self.inference_tokens.labels( |
| org_id=self.org_id, |
| direction="output" |
| ).inc(metrics.tokens_output) |
| |
| self.inference_requests.labels( |
| org_id=self.org_id, |
| status="success" |
| ).inc() |
| |
| self._record_success() |
| |
| logger.info( |
| f"[ASYNC] β
Generated {metrics.tokens_output} tokens " |
| f"in {duration_ms:.2f}ms" |
| ) |
| |
| |
| self._emit_metrics(metrics) |
| |
| return result |
| |
| except asyncio.TimeoutError: |
| logger.error(f"[ASYNC] β Generation timeout after {timeout}s") |
| |
| self.inference_requests.labels( |
| org_id=self.org_id, |
| status="timeout" |
| ).inc() |
| |
| self._record_failure("timeout") |
| raise |
| |
| except Exception as e: |
| emit_llm_log("error", f"β Generation failed: {e}", error=str(e)) |
| |
| self.inference_requests.labels( |
| org_id=self.org_id, |
| status="error" |
| ).inc() |
| |
| metrics.error = str(e) |
| self._record_failure(str(e)) |
| |
| |
| self._emit_metrics(metrics) |
| |
| raise |
| |
| finally: |
| self._request_queue.task_done() |
| |
| |
| |
| def add_metrics_callback(self, callback: Callable[[LLMMetrics], None]): |
| """Register callback for metrics (e.g., Prometheus, DataDog)""" |
| if not hasattr(self, "_metrics_callbacks"): |
| self._metrics_callbacks = [] |
| self._metrics_callbacks.append(callback) |
| |
| def _emit_metrics(self, metrics: LLMMetrics): |
| """Notify all registered callback listeners""" |
| if hasattr(self, "_metrics_callbacks"): |
| for callback in self._metrics_callbacks: |
| try: |
| callback(metrics) |
| except Exception as e: |
| logger.error(f"[METRICS] Callback failed: {e}") |
| |
| |
| |
| def health_check(self) -> Dict[str, Any]: |
| """SRE: Comprehensive health check for monitoring""" |
| resources = self._get_resource_usage() |
| |
| return { |
| "status": "healthy" if self.is_ready() else "unhealthy", |
| "model_loaded": self.is_loaded, |
| "model_loading": self.is_loading, |
| "load_error": self.load_error, |
| "circuit_breaker_open": self._circuit_breaker["is_open"], |
| "queue_depth": self._request_queue.qsize(), |
| "gpu_memory_mb": resources["gpu_mb"], |
| "cpu_memory_mb": resources["cpu_mb"], |
| "rate_limit_tokens": self._rate_limiter["tokens"], |
| "concurrent_requests": self.MAX_CONCURRENT - self._inference_semaphore._value |
| } |
|
|
|
|
| |
|
|
| _llm_service_instance = None |
| _sync_lock = Lock() |
| _async_lock = asyncio.Lock() |
|
|
| def get_llm_service(org_id: str = "default") -> LocalLLMService: |
| """ |
| β
EXISTING: Sync singleton with org isolation |
| Each org gets its own service instance (rate limits, queues) |
| """ |
| global _llm_service_instance |
| |
| with _sync_lock: |
| if _llm_service_instance is None: |
| logger.info(f"π Creating LLM service instance for org: {org_id}") |
| _llm_service_instance = LocalLLMService(org_id) |
| |
| return _llm_service_instance |
|
|
| async def get_llm_service_async(org_id: str = "default") -> LocalLLMService: |
| """β
NEW: Async singleton getter""" |
| global _llm_service_instance |
| |
| async with _async_lock: |
| if _llm_service_instance is None: |
| logger.info(f"π Creating LLM service instance (async) for org: {org_id}") |
| _llm_service_instance = LocalLLMService(org_id) |
| |
| return _llm_service_instance |
|
|
| def load_llm_service(): |
| """β
EXISTING: Explicitly load the LLM service""" |
| service = get_llm_service() |
| if not service.is_loaded and not service.is_loading: |
| service.load() |
| logger.info("π€ LLM service loading triggered") |
| return service |
|
|
| |
| async def llm_health_endpoint(org_id: str = "default") -> Dict[str, Any]: |
| """FastAPI dependency for /health/llm""" |
| service = get_llm_service(org_id) |
| return service.health_check() |