Spaces:
Sleeping
Sleeping
| """ | |
| Unified LLM Client for Hugging Face Inference API. | |
| Supports: | |
| - Hugging Face Inference API (text-generation, image-text-to-text) via huggingface_hub | |
| - All calls enforce JSON response format and validate with Pydantic schemas. | |
| """ | |
| import base64 | |
| import json | |
| import logging | |
| import os | |
| import re | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Type, TypeVar | |
| import yaml | |
| from pydantic import BaseModel | |
| logger = logging.getLogger("llm_client") | |
| if not logger.handlers: | |
| handler = logging.StreamHandler() | |
| handler.setFormatter(logging.Formatter("%(asctime)s [%(name)s] %(levelname)s: %(message)s")) | |
| logger.addHandler(handler) | |
| logger.setLevel(logging.INFO) | |
| T = TypeVar("T", bound=BaseModel) | |
| # ============================================================================= | |
| # Configuration Loader | |
| # ============================================================================= | |
| def load_model_config() -> Dict[str, Any]: | |
| config_path = Path(__file__).parent.parent / "config" / "models.yaml" | |
| if not config_path.exists(): | |
| raise FileNotFoundError(f"Model config not found: {config_path}") | |
| with open(config_path, "r", encoding="utf-8") as f: | |
| return yaml.safe_load(f) | |
| _MODEL_CONFIG: Optional[Dict[str, Any]] = None | |
| def get_model_config() -> Dict[str, Any]: | |
| global _MODEL_CONFIG | |
| if _MODEL_CONFIG is None: | |
| _MODEL_CONFIG = load_model_config() | |
| return _MODEL_CONFIG | |
| def resolve_model_alias(model_key: str) -> str: | |
| config = get_model_config() | |
| aliases = config.get("aliases", {}) | |
| return aliases.get(model_key, model_key) | |
| def get_model_info(model_key: str) -> Dict[str, Any]: | |
| config = get_model_config() | |
| models = config.get("models", {}) | |
| if model_key in models: | |
| return models[model_key] | |
| resolved = resolve_model_alias(model_key) | |
| if resolved in models: | |
| return models[resolved] | |
| # Return a sensible default for unknown models rather than raising | |
| logger.warning(f"Unknown model '{model_key}', using default huggingface text-generation config") | |
| return {"provider": "huggingface", "task": "text-generation", "supports_images": False} | |
| def get_default_model() -> str: | |
| env_model = os.environ.get("HF_MODEL") | |
| if env_model: | |
| return env_model | |
| return get_model_config().get("default_model", "meta-llama/Llama-3.3-70B-Instruct") | |
| def list_available_models() -> List[str]: | |
| return list(get_model_config().get("models", {}).keys()) | |
| def get_recommended_model(use_case: str = "default") -> str: | |
| config = get_model_config() | |
| models = config.get("models", {}) | |
| for model_key, model_info in models.items(): | |
| if use_case in model_info.get("recommended_for", []): | |
| return model_key | |
| return config.get("default_model", "meta-llama/Llama-3.3-70B-Instruct") | |
| # ============================================================================= | |
| # JSON Extraction Helpers | |
| # ============================================================================= | |
| def _extract_json_from_text(text: str) -> str: | |
| """Extract JSON object or array from a text response (strips markdown fences etc.).""" | |
| text = text.strip() | |
| # Strip ```json ... ``` or ``` ... ``` fences | |
| text = re.sub(r"^```(?:json)?\s*", "", text) | |
| text = re.sub(r"\s*```$", "", text) | |
| text = text.strip() | |
| # If the text already starts with { or [ it is likely clean JSON | |
| if text.startswith("{") or text.startswith("["): | |
| return text | |
| # Try to find first JSON object/array in the text | |
| match = re.search(r"(\{[\s\S]*\}|\[[\s\S]*\])", text) | |
| if match: | |
| return match.group(1) | |
| return text | |
| # ============================================================================= | |
| # LLM Client | |
| # ============================================================================= | |
| class LLMClient: | |
| """ | |
| Unified LLM client targeting the Hugging Face Inference API. | |
| Usage: | |
| from src.clients import LLMClient | |
| from pydantic import BaseModel | |
| class MyResponse(BaseModel): | |
| answer: str | |
| score: int | |
| client = LLMClient() | |
| result = client.call( | |
| prompt="Analyze this text...", | |
| schema=MyResponse, | |
| ) | |
| print(result.answer, result.score) | |
| """ | |
| def __init__( | |
| self, | |
| hf_token: Optional[str] = None, | |
| default_model: Optional[str] = None, | |
| # Legacy compat args (ignored but accepted so old api code keeps working) | |
| openai_key: Optional[str] = None, | |
| google_api_key: Optional[str] = None, | |
| ): | |
| self._hf_token = hf_token or os.environ.get("HF_TOKEN") | |
| self._default_model = default_model or get_default_model() | |
| def _get_client(self): | |
| from huggingface_hub import InferenceClient | |
| token = self._hf_token | |
| if not token: | |
| raise ValueError( | |
| "HF_TOKEN is required. Set the HF_TOKEN environment variable " | |
| "or pass hf_token= to LLMClient()." | |
| ) | |
| return InferenceClient(token=token) | |
| # ------------------------------------------------------------------ | |
| # Public API | |
| # ------------------------------------------------------------------ | |
| def call( | |
| self, | |
| prompt: str, | |
| schema: Type[T], | |
| model: Optional[str] = None, | |
| system_prompt: Optional[str] = None, | |
| images: Optional[List[str]] = None, | |
| temperature: float = 0.7, | |
| max_tokens: int = 8192, | |
| ) -> T: | |
| """ | |
| Call LLM and return a validated Pydantic model instance. | |
| Args: | |
| prompt: User prompt text. | |
| schema: Pydantic model class used for response validation. | |
| model: Model key or alias (uses default if not specified). | |
| system_prompt: Optional system prompt. | |
| images: Optional list of base64-encoded images (for vision models). | |
| temperature: Sampling temperature. | |
| max_tokens: Maximum tokens to generate. | |
| Returns: | |
| Validated Pydantic model instance. | |
| """ | |
| resolved_model = resolve_model_alias(model or self._default_model) | |
| model_info = get_model_info(resolved_model) | |
| supports_images = model_info.get("supports_images", False) | |
| # Build JSON-aware prompt | |
| json_schema = schema.model_json_schema() | |
| schema_str = json.dumps(json_schema, ensure_ascii=False, indent=2) | |
| json_instruction = ( | |
| f"\n\nRespond ONLY with a valid JSON object matching this schema " | |
| f"(do not include any text outside the JSON):\n{schema_str}" | |
| ) | |
| full_prompt = prompt + json_instruction | |
| if images and not supports_images: | |
| logger.warning( | |
| f"Model '{resolved_model}' does not support images. Images will be ignored." | |
| ) | |
| images = None | |
| text = self._call_hf( | |
| prompt=full_prompt, | |
| model=resolved_model, | |
| system_prompt=system_prompt, | |
| images=images, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| json_text = _extract_json_from_text(text) | |
| try: | |
| data = json.loads(json_text) | |
| except json.JSONDecodeError as exc: | |
| raise ValueError( | |
| f"Model '{resolved_model}' returned non-JSON response: {text[:300]}" | |
| ) from exc | |
| return schema.model_validate(data) | |
| def call_raw( | |
| self, | |
| prompt: str, | |
| model: Optional[str] = None, | |
| system_prompt: Optional[str] = None, | |
| images: Optional[List[str]] = None, | |
| temperature: float = 0.7, | |
| max_tokens: int = 8192, | |
| ) -> str: | |
| """Call LLM and return the raw text response.""" | |
| resolved_model = resolve_model_alias(model or self._default_model) | |
| model_info = get_model_info(resolved_model) | |
| supports_images = model_info.get("supports_images", False) | |
| if images and not supports_images: | |
| logger.warning( | |
| f"Model '{resolved_model}' does not support images. Images will be ignored." | |
| ) | |
| images = None | |
| return self._call_hf( | |
| prompt=prompt, | |
| model=resolved_model, | |
| system_prompt=system_prompt, | |
| images=images, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Internal HF caller | |
| # ------------------------------------------------------------------ | |
| def _call_hf( | |
| self, | |
| prompt: str, | |
| model: str, | |
| system_prompt: Optional[str] = None, | |
| images: Optional[List[str]] = None, | |
| temperature: float = 0.7, | |
| max_tokens: int = 8192, | |
| ) -> str: | |
| """Send a chat-completions request to the HF Inference API.""" | |
| client = self._get_client() | |
| messages: List[Dict[str, Any]] = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| # Build user message content | |
| if images: | |
| content: List[Dict[str, Any]] = [] | |
| for img_b64 in images: | |
| if not img_b64: | |
| continue | |
| # Strip data-URI prefix if present | |
| if img_b64.startswith("data:"): | |
| img_b64 = img_b64.split(",", 1)[1] | |
| content.append({ | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/png;base64,{img_b64}"}, | |
| }) | |
| content.append({"type": "text", "text": prompt}) | |
| messages.append({"role": "user", "content": content}) | |
| else: | |
| messages.append({"role": "user", "content": prompt}) | |
| logger.info(f"[{model}] Calling HF Inference API (messages={len(messages)})") | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| content_str = response.choices[0].message.content or "" | |
| usage = getattr(response, "usage", None) | |
| if usage: | |
| logger.info( | |
| f"[{model}] tokens={getattr(usage, 'total_tokens', '?')} " | |
| f"(prompt={getattr(usage, 'prompt_tokens', '?')}, " | |
| f"completion={getattr(usage, 'completion_tokens', '?')})" | |
| ) | |
| return content_str | |
| # ============================================================================= | |
| # Convenience helpers | |
| # ============================================================================= | |
| _DEFAULT_CLIENT: Optional[LLMClient] = None | |
| def get_llm_client( | |
| hf_token: Optional[str] = None, | |
| # Legacy compat | |
| openai_key: Optional[str] = None, | |
| google_api_key: Optional[str] = None, | |
| ) -> LLMClient: | |
| """Get or create a shared LLMClient instance.""" | |
| global _DEFAULT_CLIENT | |
| if _DEFAULT_CLIENT is None or hf_token is not None: | |
| _DEFAULT_CLIENT = LLMClient(hf_token=hf_token) | |
| return _DEFAULT_CLIENT | |