api_light_hf / src /clients /llm_client.py
Renecto's picture
deploy api_light_hf (2026-03-12 12:47:03)
cf7f643
"""
Unified LLM Client for Hugging Face Inference API.
Supports:
- Hugging Face Inference API (text-generation, image-text-to-text) via huggingface_hub
- All calls enforce JSON response format and validate with Pydantic schemas.
"""
import base64
import json
import logging
import os
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, TypeVar
import yaml
from pydantic import BaseModel
logger = logging.getLogger("llm_client")
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s [%(name)s] %(levelname)s: %(message)s"))
logger.addHandler(handler)
logger.setLevel(logging.INFO)
T = TypeVar("T", bound=BaseModel)
# =============================================================================
# Configuration Loader
# =============================================================================
def load_model_config() -> Dict[str, Any]:
config_path = Path(__file__).parent.parent / "config" / "models.yaml"
if not config_path.exists():
raise FileNotFoundError(f"Model config not found: {config_path}")
with open(config_path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
_MODEL_CONFIG: Optional[Dict[str, Any]] = None
def get_model_config() -> Dict[str, Any]:
global _MODEL_CONFIG
if _MODEL_CONFIG is None:
_MODEL_CONFIG = load_model_config()
return _MODEL_CONFIG
def resolve_model_alias(model_key: str) -> str:
config = get_model_config()
aliases = config.get("aliases", {})
return aliases.get(model_key, model_key)
def get_model_info(model_key: str) -> Dict[str, Any]:
config = get_model_config()
models = config.get("models", {})
if model_key in models:
return models[model_key]
resolved = resolve_model_alias(model_key)
if resolved in models:
return models[resolved]
# Return a sensible default for unknown models rather than raising
logger.warning(f"Unknown model '{model_key}', using default huggingface text-generation config")
return {"provider": "huggingface", "task": "text-generation", "supports_images": False}
def get_default_model() -> str:
env_model = os.environ.get("HF_MODEL")
if env_model:
return env_model
return get_model_config().get("default_model", "meta-llama/Llama-3.3-70B-Instruct")
def list_available_models() -> List[str]:
return list(get_model_config().get("models", {}).keys())
def get_recommended_model(use_case: str = "default") -> str:
config = get_model_config()
models = config.get("models", {})
for model_key, model_info in models.items():
if use_case in model_info.get("recommended_for", []):
return model_key
return config.get("default_model", "meta-llama/Llama-3.3-70B-Instruct")
# =============================================================================
# JSON Extraction Helpers
# =============================================================================
def _extract_json_from_text(text: str) -> str:
"""Extract JSON object or array from a text response (strips markdown fences etc.)."""
text = text.strip()
# Strip ```json ... ``` or ``` ... ``` fences
text = re.sub(r"^```(?:json)?\s*", "", text)
text = re.sub(r"\s*```$", "", text)
text = text.strip()
# If the text already starts with { or [ it is likely clean JSON
if text.startswith("{") or text.startswith("["):
return text
# Try to find first JSON object/array in the text
match = re.search(r"(\{[\s\S]*\}|\[[\s\S]*\])", text)
if match:
return match.group(1)
return text
# =============================================================================
# LLM Client
# =============================================================================
class LLMClient:
"""
Unified LLM client targeting the Hugging Face Inference API.
Usage:
from src.clients import LLMClient
from pydantic import BaseModel
class MyResponse(BaseModel):
answer: str
score: int
client = LLMClient()
result = client.call(
prompt="Analyze this text...",
schema=MyResponse,
)
print(result.answer, result.score)
"""
def __init__(
self,
hf_token: Optional[str] = None,
default_model: Optional[str] = None,
# Legacy compat args (ignored but accepted so old api code keeps working)
openai_key: Optional[str] = None,
google_api_key: Optional[str] = None,
):
self._hf_token = hf_token or os.environ.get("HF_TOKEN")
self._default_model = default_model or get_default_model()
def _get_client(self):
from huggingface_hub import InferenceClient
token = self._hf_token
if not token:
raise ValueError(
"HF_TOKEN is required. Set the HF_TOKEN environment variable "
"or pass hf_token= to LLMClient()."
)
return InferenceClient(token=token)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def call(
self,
prompt: str,
schema: Type[T],
model: Optional[str] = None,
system_prompt: Optional[str] = None,
images: Optional[List[str]] = None,
temperature: float = 0.7,
max_tokens: int = 8192,
) -> T:
"""
Call LLM and return a validated Pydantic model instance.
Args:
prompt: User prompt text.
schema: Pydantic model class used for response validation.
model: Model key or alias (uses default if not specified).
system_prompt: Optional system prompt.
images: Optional list of base64-encoded images (for vision models).
temperature: Sampling temperature.
max_tokens: Maximum tokens to generate.
Returns:
Validated Pydantic model instance.
"""
resolved_model = resolve_model_alias(model or self._default_model)
model_info = get_model_info(resolved_model)
supports_images = model_info.get("supports_images", False)
# Build JSON-aware prompt
json_schema = schema.model_json_schema()
schema_str = json.dumps(json_schema, ensure_ascii=False, indent=2)
json_instruction = (
f"\n\nRespond ONLY with a valid JSON object matching this schema "
f"(do not include any text outside the JSON):\n{schema_str}"
)
full_prompt = prompt + json_instruction
if images and not supports_images:
logger.warning(
f"Model '{resolved_model}' does not support images. Images will be ignored."
)
images = None
text = self._call_hf(
prompt=full_prompt,
model=resolved_model,
system_prompt=system_prompt,
images=images,
temperature=temperature,
max_tokens=max_tokens,
)
json_text = _extract_json_from_text(text)
try:
data = json.loads(json_text)
except json.JSONDecodeError as exc:
raise ValueError(
f"Model '{resolved_model}' returned non-JSON response: {text[:300]}"
) from exc
return schema.model_validate(data)
def call_raw(
self,
prompt: str,
model: Optional[str] = None,
system_prompt: Optional[str] = None,
images: Optional[List[str]] = None,
temperature: float = 0.7,
max_tokens: int = 8192,
) -> str:
"""Call LLM and return the raw text response."""
resolved_model = resolve_model_alias(model or self._default_model)
model_info = get_model_info(resolved_model)
supports_images = model_info.get("supports_images", False)
if images and not supports_images:
logger.warning(
f"Model '{resolved_model}' does not support images. Images will be ignored."
)
images = None
return self._call_hf(
prompt=prompt,
model=resolved_model,
system_prompt=system_prompt,
images=images,
temperature=temperature,
max_tokens=max_tokens,
)
# ------------------------------------------------------------------
# Internal HF caller
# ------------------------------------------------------------------
def _call_hf(
self,
prompt: str,
model: str,
system_prompt: Optional[str] = None,
images: Optional[List[str]] = None,
temperature: float = 0.7,
max_tokens: int = 8192,
) -> str:
"""Send a chat-completions request to the HF Inference API."""
client = self._get_client()
messages: List[Dict[str, Any]] = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
# Build user message content
if images:
content: List[Dict[str, Any]] = []
for img_b64 in images:
if not img_b64:
continue
# Strip data-URI prefix if present
if img_b64.startswith("data:"):
img_b64 = img_b64.split(",", 1)[1]
content.append({
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{img_b64}"},
})
content.append({"type": "text", "text": prompt})
messages.append({"role": "user", "content": content})
else:
messages.append({"role": "user", "content": prompt})
logger.info(f"[{model}] Calling HF Inference API (messages={len(messages)})")
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
content_str = response.choices[0].message.content or ""
usage = getattr(response, "usage", None)
if usage:
logger.info(
f"[{model}] tokens={getattr(usage, 'total_tokens', '?')} "
f"(prompt={getattr(usage, 'prompt_tokens', '?')}, "
f"completion={getattr(usage, 'completion_tokens', '?')})"
)
return content_str
# =============================================================================
# Convenience helpers
# =============================================================================
_DEFAULT_CLIENT: Optional[LLMClient] = None
def get_llm_client(
hf_token: Optional[str] = None,
# Legacy compat
openai_key: Optional[str] = None,
google_api_key: Optional[str] = None,
) -> LLMClient:
"""Get or create a shared LLMClient instance."""
global _DEFAULT_CLIENT
if _DEFAULT_CLIENT is None or hf_token is not None:
_DEFAULT_CLIENT = LLMClient(hf_token=hf_token)
return _DEFAULT_CLIENT