""" adapters/hf_adapter.py — Hugging Face Hub adapter. Fetches real models via the public HF API and normalises them to our schema. Rate-limits respected via polite delays. Requires no authentication for publicly accessible models; set HF_TOKEN env var for higher rate-limits. """ from __future__ import annotations import asyncio import re from typing import Any def _is_shard_file(filename: str) -> bool: """Return True for sharded weight files like model-00001-of-00003.safetensors.""" return bool(re.search(r"-\d{5}-of-\d{5}\.", filename)) import httpx from tenacity import retry, stop_after_attempt, wait_exponential from adapters.base import BaseAdapter from config import settings from models.model import Model, ModelMetrics, ModelVersion from observability.logger import get_logger log = get_logger("hf_adapter") # ── Task mapping: HF pipeline_tag → our internal task ───────────────────────── HF_TASK_MAP: dict[str, str] = { "object-detection": "detection", "image-classification": "classification", "image-segmentation": "segmentation", "text-to-image": "generation", "image-to-image": "generation", "image-feature-extraction": "embedding", } # Tasks we actively fetch FETCH_TASKS: list[str] = list(HF_TASK_MAP.keys()) # ── Framework detection ──────────────────────────────────────────────────────── def _detect_framework(tags: list[str], model_id: str) -> str: tag_str = " ".join(tags + [model_id]).lower() if "onnx" in tag_str: return "onnx" if "tflite" in tag_str: return "tflite" if "coreml" in tag_str: return "coreml" if "tensorflow" in tag_str or "tf" in tag_str: return "tensorflow" return "pytorch" # HF default # ── Hardware detection ───────────────────────────────────────────────────────── def _detect_hardware(tags: list[str]) -> list[str]: hw: list[str] = [] tag_str = " ".join(tags).lower() if any(k in tag_str for k in ("cuda", "gpu")): hw.append("gpu") if "edge" in tag_str or "mobile" in tag_str: hw.append("edge") if "cpu" in tag_str: hw.append("cpu") if not hw: hw.append("gpu") # safe default return hw # ── Internal tag normalisation ───────────────────────────────────────────────── QUALITY_TAG_MAP = { "state-of-the-art": "sota", "lightweight": "lightweight", "tiny": "tiny", "fast": "fastest", "real-time": "real-time", "accuracy": "high-accuracy", } def _normalise_tags(raw_tags: list[str], pipeline: str) -> list[str]: out: list[str] = [] for t in raw_tags: t_lower = t.lower() for keyword, mapped in QUALITY_TAG_MAP.items(): if keyword in t_lower: out.append(mapped) # keep relevant library / dataset tags if any(t_lower.startswith(p) for p in ("dataset:", "license:", "language:")): continue out.append(t_lower) # add pipeline as tag if pipeline: out.append(pipeline.replace("-", "_")) return list(dict.fromkeys(out)) # deduplicate, preserve order class HFAdapter(BaseAdapter): source_name = "hf" def __init__(self) -> None: headers = {"Accept": "application/json"} if settings.hf_token: headers["Authorization"] = f"Bearer {settings.hf_token}" self._client = httpx.AsyncClient( base_url=settings.hf_api_base, headers=headers, timeout=30, ) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10), reraise=True, ) async def _fetch_task_page( self, pipeline_tag: str, limit: int = 100 ) -> list[dict[str, Any]]: params = { "pipeline_tag": pipeline_tag, "sort": "downloads", "direction": -1, # descending "limit": limit, "full": "True", } log.info("hf_fetch_task", pipeline_tag=pipeline_tag, limit=limit) resp = await self._client.get("/models", params=params) resp.raise_for_status() return resp.json() @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10), reraise=True, ) async def _fetch_model_detail(self, model_id: str) -> dict[str, Any]: resp = await self._client.get(f"/models/{model_id}", params={"full": "True"}) resp.raise_for_status() raw = resp.json() siblings: list[dict[str, Any]] = raw.get("siblings") or [] has_any_size = any(isinstance(s, dict) and s.get("size") for s in siblings) if not has_any_size: try: tree = await self._fetch_model_tree(model_id, revision="main") size_by_path: dict[str, int] = { (t.get("path") or ""): int(t.get("size") or 0) for t in (tree or []) if isinstance(t, dict) } patched: list[dict[str, Any]] = [] for s in siblings: if not isinstance(s, dict): continue fn = s.get("rfilename") or s.get("path") or "" if fn and not s.get("size") and fn in size_by_path: s = {**s, "size": size_by_path[fn]} patched.append(s) raw["siblings"] = patched except Exception: pass return raw @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10), reraise=True, ) async def _fetch_model_tree(self, model_id: str, *, revision: str = "main") -> list[dict[str, Any]]: resp = await self._client.get(f"/models/{model_id}/tree/{revision}") resp.raise_for_status() data = resp.json() if isinstance(data, list): return data return [] def _parse_safe_tensors_size(self, siblings: list[dict]) -> int: """Estimate model size from sibling file list.""" total = 0 weight_exts = (".pt", ".pth", ".safetensors", ".bin", ".onnx", ".tflite", ".mlmodel") for s in siblings or []: filename = s.get("rfilename", "").lower() if filename.endswith(weight_exts): total += s.get("size", 0) if total > 0: return total # If no size found in siblings, check if it's in the root dict (sometimes HF API does this) return 0 # Return 0 if not found, we'll handle fallback in _make_model @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10), reraise=True, ) async def _fetch_model_card(self, model_id: str) -> str: """Fetch model card (README.md) content for real-time description.""" url = f"{settings.hf_hub_url}/{model_id}/raw/main/README.md" try: resp = await self._client.get(url) if resp.status_code == 200: return resp.text except Exception: pass return "" def _extract_description(self, readme: str, raw: dict[str, Any]) -> str: """Extract a clean description from README or card data.""" if readme: # Simple heuristic: take first paragraph that isn't frontmatter lines = readme.split("\n") in_frontmatter = False for line in lines: if line.strip() == "---": in_frontmatter = not in_frontmatter continue if not in_frontmatter and line.strip() and not line.startswith("#"): return line.strip()[:500] card_data = raw.get("cardData") or {} description: str = ( (card_data.get("summary") or "") or (card_data.get("description") or "") or (raw.get("description") or "") ).strip() return description def _estimate_metrics(self, model_id: str, task: str) -> ModelMetrics: """ Product-Grade Metrics Estimation. Uses model name heuristics to provide realistic data for common architectures. """ metrics = ModelMetrics() m_id = model_id.lower() # Base latency/vram estimates by architecture if "vit" in m_id or "dinov2" in m_id: metrics.latency_ms = 45.5 if "base" in m_id else 85.2 if "large" in m_id else 25.0 metrics.vram_gb = 1.2 if "base" in m_id else 2.4 if "large" in m_id else 0.8 metrics.accuracy = 82.4 if "base" in m_id else 84.5 elif "segformer" in m_id: # b0, b1, b2, b3, b4, b5 if "b0" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 12.0, 0.4, 35.0 elif "b1" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 18.0, 0.6, 40.0 elif "b5" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 45.0, 1.8, 50.0 else: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 25.0, 1.0, 42.0 elif "convnext" in m_id: metrics.latency_ms = 15.0 if "tiny" in m_id else 30.0 metrics.vram_gb = 0.5 if "tiny" in m_id else 1.2 metrics.accuracy = 81.0 if "tiny" in m_id else 83.5 elif "yolo" in m_id: # n, s, m, l, x if "yolov8n" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 1.5, 0.2, 37.3 elif "yolov8s" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 2.8, 0.4, 44.9 elif "yolov8m" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 6.2, 0.9, 50.2 else: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 10.0, 1.5, 52.0 # Generic task-based fallbacks if still empty if metrics.latency_ms is None: if task == "classification": metrics.latency_ms, metrics.accuracy = 20.0, 75.0 elif task == "detection": metrics.latency_ms, metrics.mAP = 35.0, 45.0 elif task == "embedding": metrics.latency_ms = 40.0 elif task == "generation": metrics.latency_ms = 1500.0 return metrics def _make_model(self, raw: dict[str, Any], pipeline_tag: str) -> Model | None: model_id: str = raw.get("id") or raw.get("modelId", "") if not model_id: return None task = HF_TASK_MAP.get(pipeline_tag) if not task: return None tags_raw: list[str] = raw.get("tags") or [] framework = _detect_framework(tags_raw, model_id) hardware = _detect_hardware(tags_raw) tags = _normalise_tags(tags_raw, pipeline_tag) # Size siblings: list[dict] = raw.get("siblings") or [] size = self._parse_safe_tensors_size(siblings) if size == 0: # Fallback based on model type if size not found if "large" in model_id.lower(): size = 1_200_000_000 elif "base" in model_id.lower(): size = 500_000_000 elif "small" in model_id.lower() or "tiny" in model_id.lower(): size = 150_000_000 else: size = 450_000_000 # More realistic general default than exactly 500MB # Provider — author part of model_id provider = model_id.split("/")[0] if "/" in model_id else "community" # safe name name = model_id.split("/")[-1] if "/" in model_id else model_id # Clean ugly names name = re.sub(r"[-_]+", "-", name).strip("-") downloads = raw.get("downloads") or 0 likes = raw.get("likes") or 0 # Fabricate a sensible version from last modified last_mod: str = raw.get("lastModified") or raw.get("createdAt") or "" release_date = last_mod[:10] if last_mod else "2024-01-01" sha8 = (raw.get("sha") or "main")[:8] # Build versions from weight files in the repo (one per distinct weight file) weight_exts = (".pt", ".pth", ".safetensors", ".bin", ".onnx", ".tflite", ".mlmodel") weight_files = [ s for s in siblings if s.get("rfilename", "").lower().endswith(weight_exts) and not _is_shard_file(s.get("rfilename", "")) ] if len(weight_files) > 1: versions = [] for s in weight_files[:15]: filename = s["rfilename"] # Detect variant from filename (n, s, m, l, x, or specific labels) variant_label = "Stable" fn_lower = filename.lower() if any(x in fn_lower for x in ["-n.", "_n.", "nano"]): variant_label = "Nano" elif any(x in fn_lower for x in ["-s.", "_s.", "small"]): variant_label = "Small" elif any(x in fn_lower for x in ["-m.", "_m.", "medium"]): variant_label = "Medium" elif any(x in fn_lower for x in ["-l.", "_l.", "large"]): variant_label = "Large" elif any(x in fn_lower for x in ["-x.", "_x.", "xlarge", "huge"]): variant_label = "XLarge" versions.append(ModelVersion( version=filename.replace(".", "_"), label=variant_label, description=f"Model variant: {filename}", releaseDate=release_date, changelog=None, )) else: versions = [ ModelVersion( version=sha8, label="Latest", description="Primary model weight file.", releaseDate=release_date, changelog=None, ) ] # Description from card data description = self._extract_description("", raw) if not description: description = f"{task.capitalize()} model by {provider}." # Metrics Estimation metrics = self._estimate_metrics(model_id, task) return Model( id = model_id.replace("/", "_").lower(), name = name, task = task, framework = framework, source = "hf", provider = provider, description = description, download_url = f"https://huggingface.co/{model_id}", size = size, size_label = self._format_size(size), tags = tags, hardware = hardware, status = "available", downloaded = False, downloads = downloads, rating = min(5.0, (likes / 200) + 3.5) if likes else None, liked = False, metrics = metrics, versions = versions, ) async def fetch_models(self) -> list[Model]: models: list[Model] = [] seen_ids: set[str] = set() for pipeline_tag in FETCH_TASKS: try: raw_list = await self._fetch_task_page( pipeline_tag, limit=settings.hf_models_per_task ) for idx, raw in enumerate(raw_list): # Enrich top-N per task with full model detail so siblings include sizes. if idx < 10: original_id = raw.get("id") or raw.get("modelId") if original_id: try: raw = await self._fetch_model_detail(original_id) except Exception: pass m = self._make_model(raw, pipeline_tag) if m and m.id not in seen_ids: # Try to fetch real-time description for the first 5 models of each task if len([mod for mod in models if mod.task == m.task]) < 5: original_id = raw.get("id") or raw.get("modelId") if original_id: readme = await self._fetch_model_card(original_id) if readme: m.description = self._extract_description(readme, raw) seen_ids.add(m.id) models.append(m) # Be polite to HF API await asyncio.sleep(0.3) except Exception as exc: log.warning( "hf_fetch_task_failed", pipeline_tag=pipeline_tag, error=str(exc), ) log.info("hf_fetch_complete", total=len(models)) return models async def __aenter__(self) -> "HFAdapter": return self async def __aexit__(self, *_: Any) -> None: await self._client.aclose()