""" Unified LLM Client for Hugging Face Inference API. Supports: - Hugging Face Inference API (text-generation, image-text-to-text) via huggingface_hub - All calls enforce JSON response format and validate with Pydantic schemas. """ import base64 import json import logging import os import re from pathlib import Path from typing import Any, Dict, List, Optional, Type, TypeVar import yaml from pydantic import BaseModel logger = logging.getLogger("llm_client") if not logger.handlers: handler = logging.StreamHandler() handler.setFormatter(logging.Formatter("%(asctime)s [%(name)s] %(levelname)s: %(message)s")) logger.addHandler(handler) logger.setLevel(logging.INFO) T = TypeVar("T", bound=BaseModel) # ============================================================================= # Configuration Loader # ============================================================================= def load_model_config() -> Dict[str, Any]: config_path = Path(__file__).parent.parent / "config" / "models.yaml" if not config_path.exists(): raise FileNotFoundError(f"Model config not found: {config_path}") with open(config_path, "r", encoding="utf-8") as f: return yaml.safe_load(f) _MODEL_CONFIG: Optional[Dict[str, Any]] = None def get_model_config() -> Dict[str, Any]: global _MODEL_CONFIG if _MODEL_CONFIG is None: _MODEL_CONFIG = load_model_config() return _MODEL_CONFIG def resolve_model_alias(model_key: str) -> str: config = get_model_config() aliases = config.get("aliases", {}) return aliases.get(model_key, model_key) def get_model_info(model_key: str) -> Dict[str, Any]: config = get_model_config() models = config.get("models", {}) if model_key in models: return models[model_key] resolved = resolve_model_alias(model_key) if resolved in models: return models[resolved] # Return a sensible default for unknown models rather than raising logger.warning(f"Unknown model '{model_key}', using default huggingface text-generation config") return {"provider": "huggingface", "task": "text-generation", "supports_images": False} def get_default_model() -> str: env_model = os.environ.get("HF_MODEL") if env_model: return env_model return get_model_config().get("default_model", "meta-llama/Llama-3.3-70B-Instruct") def list_available_models() -> List[str]: return list(get_model_config().get("models", {}).keys()) def get_recommended_model(use_case: str = "default") -> str: config = get_model_config() models = config.get("models", {}) for model_key, model_info in models.items(): if use_case in model_info.get("recommended_for", []): return model_key return config.get("default_model", "meta-llama/Llama-3.3-70B-Instruct") # ============================================================================= # JSON Extraction Helpers # ============================================================================= def _extract_json_from_text(text: str) -> str: """Extract JSON object or array from a text response (strips markdown fences etc.).""" text = text.strip() # Strip ```json ... ``` or ``` ... ``` fences text = re.sub(r"^```(?:json)?\s*", "", text) text = re.sub(r"\s*```$", "", text) text = text.strip() # If the text already starts with { or [ it is likely clean JSON if text.startswith("{") or text.startswith("["): return text # Try to find first JSON object/array in the text match = re.search(r"(\{[\s\S]*\}|\[[\s\S]*\])", text) if match: return match.group(1) return text # ============================================================================= # LLM Client # ============================================================================= class LLMClient: """ Unified LLM client targeting the Hugging Face Inference API. Usage: from src.clients import LLMClient from pydantic import BaseModel class MyResponse(BaseModel): answer: str score: int client = LLMClient() result = client.call( prompt="Analyze this text...", schema=MyResponse, ) print(result.answer, result.score) """ def __init__( self, hf_token: Optional[str] = None, default_model: Optional[str] = None, # Legacy compat args (ignored but accepted so old api code keeps working) openai_key: Optional[str] = None, google_api_key: Optional[str] = None, ): self._hf_token = hf_token or os.environ.get("HF_TOKEN") self._default_model = default_model or get_default_model() def _get_client(self): from huggingface_hub import InferenceClient token = self._hf_token if not token: raise ValueError( "HF_TOKEN is required. Set the HF_TOKEN environment variable " "or pass hf_token= to LLMClient()." ) return InferenceClient(token=token) # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def call( self, prompt: str, schema: Type[T], model: Optional[str] = None, system_prompt: Optional[str] = None, images: Optional[List[str]] = None, temperature: float = 0.7, max_tokens: int = 8192, ) -> T: """ Call LLM and return a validated Pydantic model instance. Args: prompt: User prompt text. schema: Pydantic model class used for response validation. model: Model key or alias (uses default if not specified). system_prompt: Optional system prompt. images: Optional list of base64-encoded images (for vision models). temperature: Sampling temperature. max_tokens: Maximum tokens to generate. Returns: Validated Pydantic model instance. """ resolved_model = resolve_model_alias(model or self._default_model) model_info = get_model_info(resolved_model) supports_images = model_info.get("supports_images", False) # Build JSON-aware prompt json_schema = schema.model_json_schema() schema_str = json.dumps(json_schema, ensure_ascii=False, indent=2) json_instruction = ( f"\n\nRespond ONLY with a valid JSON object matching this schema " f"(do not include any text outside the JSON):\n{schema_str}" ) full_prompt = prompt + json_instruction if images and not supports_images: logger.warning( f"Model '{resolved_model}' does not support images. Images will be ignored." ) images = None text = self._call_hf( prompt=full_prompt, model=resolved_model, system_prompt=system_prompt, images=images, temperature=temperature, max_tokens=max_tokens, ) json_text = _extract_json_from_text(text) try: data = json.loads(json_text) except json.JSONDecodeError as exc: raise ValueError( f"Model '{resolved_model}' returned non-JSON response: {text[:300]}" ) from exc return schema.model_validate(data) def call_raw( self, prompt: str, model: Optional[str] = None, system_prompt: Optional[str] = None, images: Optional[List[str]] = None, temperature: float = 0.7, max_tokens: int = 8192, ) -> str: """Call LLM and return the raw text response.""" resolved_model = resolve_model_alias(model or self._default_model) model_info = get_model_info(resolved_model) supports_images = model_info.get("supports_images", False) if images and not supports_images: logger.warning( f"Model '{resolved_model}' does not support images. Images will be ignored." ) images = None return self._call_hf( prompt=prompt, model=resolved_model, system_prompt=system_prompt, images=images, temperature=temperature, max_tokens=max_tokens, ) # ------------------------------------------------------------------ # Internal HF caller # ------------------------------------------------------------------ def _call_hf( self, prompt: str, model: str, system_prompt: Optional[str] = None, images: Optional[List[str]] = None, temperature: float = 0.7, max_tokens: int = 8192, ) -> str: """Send a chat-completions request to the HF Inference API.""" client = self._get_client() messages: List[Dict[str, Any]] = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) # Build user message content if images: content: List[Dict[str, Any]] = [] for img_b64 in images: if not img_b64: continue # Strip data-URI prefix if present if img_b64.startswith("data:"): img_b64 = img_b64.split(",", 1)[1] content.append({ "type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}, }) content.append({"type": "text", "text": prompt}) messages.append({"role": "user", "content": content}) else: messages.append({"role": "user", "content": prompt}) logger.info(f"[{model}] Calling HF Inference API (messages={len(messages)})") response = client.chat.completions.create( model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, ) content_str = response.choices[0].message.content or "" usage = getattr(response, "usage", None) if usage: logger.info( f"[{model}] tokens={getattr(usage, 'total_tokens', '?')} " f"(prompt={getattr(usage, 'prompt_tokens', '?')}, " f"completion={getattr(usage, 'completion_tokens', '?')})" ) return content_str # ============================================================================= # Convenience helpers # ============================================================================= _DEFAULT_CLIENT: Optional[LLMClient] = None def get_llm_client( hf_token: Optional[str] = None, # Legacy compat openai_key: Optional[str] = None, google_api_key: Optional[str] = None, ) -> LLMClient: """Get or create a shared LLMClient instance.""" global _DEFAULT_CLIENT if _DEFAULT_CLIENT is None or hf_token is not None: _DEFAULT_CLIENT = LLMClient(hf_token=hf_token) return _DEFAULT_CLIENT