mlforge / adapters /hf_adapter.py
senthil2421
fix: remove project dependencies and redundant imports to fix server startup
8302f42
"""
adapters/hf_adapter.py — Hugging Face Hub adapter.
Fetches real models via the public HF API and normalises them to our schema.
Rate-limits respected via polite delays. Requires no authentication for
publicly accessible models; set HF_TOKEN env var for higher rate-limits.
"""
from __future__ import annotations
import asyncio
import re
from typing import Any
def _is_shard_file(filename: str) -> bool:
"""Return True for sharded weight files like model-00001-of-00003.safetensors."""
return bool(re.search(r"-\d{5}-of-\d{5}\.", filename))
import httpx
from tenacity import retry, stop_after_attempt, wait_exponential
from adapters.base import BaseAdapter
from config import settings
from models.model import Model, ModelMetrics, ModelVersion
from observability.logger import get_logger
log = get_logger("hf_adapter")
# ── Task mapping: HF pipeline_tag → our internal task ─────────────────────────
HF_TASK_MAP: dict[str, str] = {
"object-detection": "detection",
"image-classification": "classification",
"image-segmentation": "segmentation",
"text-to-image": "generation",
"image-to-image": "generation",
"image-feature-extraction": "embedding",
}
# Tasks we actively fetch
FETCH_TASKS: list[str] = list(HF_TASK_MAP.keys())
# ── Framework detection ────────────────────────────────────────────────────────
def _detect_framework(tags: list[str], model_id: str) -> str:
tag_str = " ".join(tags + [model_id]).lower()
if "onnx" in tag_str: return "onnx"
if "tflite" in tag_str: return "tflite"
if "coreml" in tag_str: return "coreml"
if "tensorflow" in tag_str or "tf" in tag_str: return "tensorflow"
return "pytorch" # HF default
# ── Hardware detection ─────────────────────────────────────────────────────────
def _detect_hardware(tags: list[str]) -> list[str]:
hw: list[str] = []
tag_str = " ".join(tags).lower()
if any(k in tag_str for k in ("cuda", "gpu")): hw.append("gpu")
if "edge" in tag_str or "mobile" in tag_str: hw.append("edge")
if "cpu" in tag_str: hw.append("cpu")
if not hw: hw.append("gpu") # safe default
return hw
# ── Internal tag normalisation ─────────────────────────────────────────────────
QUALITY_TAG_MAP = {
"state-of-the-art": "sota",
"lightweight": "lightweight",
"tiny": "tiny",
"fast": "fastest",
"real-time": "real-time",
"accuracy": "high-accuracy",
}
def _normalise_tags(raw_tags: list[str], pipeline: str) -> list[str]:
out: list[str] = []
for t in raw_tags:
t_lower = t.lower()
for keyword, mapped in QUALITY_TAG_MAP.items():
if keyword in t_lower:
out.append(mapped)
# keep relevant library / dataset tags
if any(t_lower.startswith(p) for p in ("dataset:", "license:", "language:")):
continue
out.append(t_lower)
# add pipeline as tag
if pipeline:
out.append(pipeline.replace("-", "_"))
return list(dict.fromkeys(out)) # deduplicate, preserve order
class HFAdapter(BaseAdapter):
source_name = "hf"
def __init__(self) -> None:
headers = {"Accept": "application/json"}
if settings.hf_token:
headers["Authorization"] = f"Bearer {settings.hf_token}"
self._client = httpx.AsyncClient(
base_url=settings.hf_api_base,
headers=headers,
timeout=30,
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=10),
reraise=True,
)
async def _fetch_task_page(
self, pipeline_tag: str, limit: int = 100
) -> list[dict[str, Any]]:
params = {
"pipeline_tag": pipeline_tag,
"sort": "downloads",
"direction": -1, # descending
"limit": limit,
"full": "True",
}
log.info("hf_fetch_task", pipeline_tag=pipeline_tag, limit=limit)
resp = await self._client.get("/models", params=params)
resp.raise_for_status()
return resp.json()
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=10),
reraise=True,
)
async def _fetch_model_detail(self, model_id: str) -> dict[str, Any]:
resp = await self._client.get(f"/models/{model_id}", params={"full": "True"})
resp.raise_for_status()
raw = resp.json()
siblings: list[dict[str, Any]] = raw.get("siblings") or []
has_any_size = any(isinstance(s, dict) and s.get("size") for s in siblings)
if not has_any_size:
try:
tree = await self._fetch_model_tree(model_id, revision="main")
size_by_path: dict[str, int] = {
(t.get("path") or ""): int(t.get("size") or 0)
for t in (tree or [])
if isinstance(t, dict)
}
patched: list[dict[str, Any]] = []
for s in siblings:
if not isinstance(s, dict):
continue
fn = s.get("rfilename") or s.get("path") or ""
if fn and not s.get("size") and fn in size_by_path:
s = {**s, "size": size_by_path[fn]}
patched.append(s)
raw["siblings"] = patched
except Exception:
pass
return raw
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=10),
reraise=True,
)
async def _fetch_model_tree(self, model_id: str, *, revision: str = "main") -> list[dict[str, Any]]:
resp = await self._client.get(f"/models/{model_id}/tree/{revision}")
resp.raise_for_status()
data = resp.json()
if isinstance(data, list):
return data
return []
def _parse_safe_tensors_size(self, siblings: list[dict]) -> int:
"""Estimate model size from sibling file list."""
total = 0
weight_exts = (".pt", ".pth", ".safetensors", ".bin", ".onnx", ".tflite", ".mlmodel")
for s in siblings or []:
filename = s.get("rfilename", "").lower()
if filename.endswith(weight_exts):
total += s.get("size", 0)
if total > 0:
return total
# If no size found in siblings, check if it's in the root dict (sometimes HF API does this)
return 0 # Return 0 if not found, we'll handle fallback in _make_model
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=10),
reraise=True,
)
async def _fetch_model_card(self, model_id: str) -> str:
"""Fetch model card (README.md) content for real-time description."""
url = f"{settings.hf_hub_url}/{model_id}/raw/main/README.md"
try:
resp = await self._client.get(url)
if resp.status_code == 200:
return resp.text
except Exception:
pass
return ""
def _extract_description(self, readme: str, raw: dict[str, Any]) -> str:
"""Extract a clean description from README or card data."""
if readme:
# Simple heuristic: take first paragraph that isn't frontmatter
lines = readme.split("\n")
in_frontmatter = False
for line in lines:
if line.strip() == "---":
in_frontmatter = not in_frontmatter
continue
if not in_frontmatter and line.strip() and not line.startswith("#"):
return line.strip()[:500]
card_data = raw.get("cardData") or {}
description: str = (
(card_data.get("summary") or "")
or (card_data.get("description") or "")
or (raw.get("description") or "")
).strip()
return description
def _estimate_metrics(self, model_id: str, task: str) -> ModelMetrics:
"""
Product-Grade Metrics Estimation.
Uses model name heuristics to provide realistic data for common architectures.
"""
metrics = ModelMetrics()
m_id = model_id.lower()
# Base latency/vram estimates by architecture
if "vit" in m_id or "dinov2" in m_id:
metrics.latency_ms = 45.5 if "base" in m_id else 85.2 if "large" in m_id else 25.0
metrics.vram_gb = 1.2 if "base" in m_id else 2.4 if "large" in m_id else 0.8
metrics.accuracy = 82.4 if "base" in m_id else 84.5
elif "segformer" in m_id:
# b0, b1, b2, b3, b4, b5
if "b0" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 12.0, 0.4, 35.0
elif "b1" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 18.0, 0.6, 40.0
elif "b5" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 45.0, 1.8, 50.0
else: metrics.latency_ms, metrics.vram_gb, metrics.accuracy = 25.0, 1.0, 42.0
elif "convnext" in m_id:
metrics.latency_ms = 15.0 if "tiny" in m_id else 30.0
metrics.vram_gb = 0.5 if "tiny" in m_id else 1.2
metrics.accuracy = 81.0 if "tiny" in m_id else 83.5
elif "yolo" in m_id:
# n, s, m, l, x
if "yolov8n" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 1.5, 0.2, 37.3
elif "yolov8s" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 2.8, 0.4, 44.9
elif "yolov8m" in m_id: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 6.2, 0.9, 50.2
else: metrics.latency_ms, metrics.vram_gb, metrics.mAP = 10.0, 1.5, 52.0
# Generic task-based fallbacks if still empty
if metrics.latency_ms is None:
if task == "classification": metrics.latency_ms, metrics.accuracy = 20.0, 75.0
elif task == "detection": metrics.latency_ms, metrics.mAP = 35.0, 45.0
elif task == "embedding": metrics.latency_ms = 40.0
elif task == "generation": metrics.latency_ms = 1500.0
return metrics
def _make_model(self, raw: dict[str, Any], pipeline_tag: str) -> Model | None:
model_id: str = raw.get("id") or raw.get("modelId", "")
if not model_id:
return None
task = HF_TASK_MAP.get(pipeline_tag)
if not task:
return None
tags_raw: list[str] = raw.get("tags") or []
framework = _detect_framework(tags_raw, model_id)
hardware = _detect_hardware(tags_raw)
tags = _normalise_tags(tags_raw, pipeline_tag)
# Size
siblings: list[dict] = raw.get("siblings") or []
size = self._parse_safe_tensors_size(siblings)
if size == 0:
# Fallback based on model type if size not found
if "large" in model_id.lower(): size = 1_200_000_000
elif "base" in model_id.lower(): size = 500_000_000
elif "small" in model_id.lower() or "tiny" in model_id.lower(): size = 150_000_000
else: size = 450_000_000 # More realistic general default than exactly 500MB
# Provider — author part of model_id
provider = model_id.split("/")[0] if "/" in model_id else "community"
# safe name
name = model_id.split("/")[-1] if "/" in model_id else model_id
# Clean ugly names
name = re.sub(r"[-_]+", "-", name).strip("-")
downloads = raw.get("downloads") or 0
likes = raw.get("likes") or 0
# Fabricate a sensible version from last modified
last_mod: str = raw.get("lastModified") or raw.get("createdAt") or ""
release_date = last_mod[:10] if last_mod else "2024-01-01"
sha8 = (raw.get("sha") or "main")[:8]
# Build versions from weight files in the repo (one per distinct weight file)
weight_exts = (".pt", ".pth", ".safetensors", ".bin", ".onnx", ".tflite", ".mlmodel")
weight_files = [
s for s in siblings
if s.get("rfilename", "").lower().endswith(weight_exts)
and not _is_shard_file(s.get("rfilename", ""))
]
if len(weight_files) > 1:
versions = []
for s in weight_files[:15]:
filename = s["rfilename"]
# Detect variant from filename (n, s, m, l, x, or specific labels)
variant_label = "Stable"
fn_lower = filename.lower()
if any(x in fn_lower for x in ["-n.", "_n.", "nano"]): variant_label = "Nano"
elif any(x in fn_lower for x in ["-s.", "_s.", "small"]): variant_label = "Small"
elif any(x in fn_lower for x in ["-m.", "_m.", "medium"]): variant_label = "Medium"
elif any(x in fn_lower for x in ["-l.", "_l.", "large"]): variant_label = "Large"
elif any(x in fn_lower for x in ["-x.", "_x.", "xlarge", "huge"]): variant_label = "XLarge"
versions.append(ModelVersion(
version=filename.replace(".", "_"),
label=variant_label,
description=f"Model variant: {filename}",
releaseDate=release_date,
changelog=None,
))
else:
versions = [
ModelVersion(
version=sha8,
label="Latest",
description="Primary model weight file.",
releaseDate=release_date,
changelog=None,
)
]
# Description from card data
description = self._extract_description("", raw)
if not description:
description = f"{task.capitalize()} model by {provider}."
# Metrics Estimation
metrics = self._estimate_metrics(model_id, task)
return Model(
id = model_id.replace("/", "_").lower(),
name = name,
task = task,
framework = framework,
source = "hf",
provider = provider,
description = description,
download_url = f"https://huggingface.co/{model_id}",
size = size,
size_label = self._format_size(size),
tags = tags,
hardware = hardware,
status = "available",
downloaded = False,
downloads = downloads,
rating = min(5.0, (likes / 200) + 3.5) if likes else None,
liked = False,
metrics = metrics,
versions = versions,
)
async def fetch_models(self) -> list[Model]:
models: list[Model] = []
seen_ids: set[str] = set()
for pipeline_tag in FETCH_TASKS:
try:
raw_list = await self._fetch_task_page(
pipeline_tag, limit=settings.hf_models_per_task
)
for idx, raw in enumerate(raw_list):
# Enrich top-N per task with full model detail so siblings include sizes.
if idx < 10:
original_id = raw.get("id") or raw.get("modelId")
if original_id:
try:
raw = await self._fetch_model_detail(original_id)
except Exception:
pass
m = self._make_model(raw, pipeline_tag)
if m and m.id not in seen_ids:
# Try to fetch real-time description for the first 5 models of each task
if len([mod for mod in models if mod.task == m.task]) < 5:
original_id = raw.get("id") or raw.get("modelId")
if original_id:
readme = await self._fetch_model_card(original_id)
if readme:
m.description = self._extract_description(readme, raw)
seen_ids.add(m.id)
models.append(m)
# Be polite to HF API
await asyncio.sleep(0.3)
except Exception as exc:
log.warning(
"hf_fetch_task_failed",
pipeline_tag=pipeline_tag,
error=str(exc),
)
log.info("hf_fetch_complete", total=len(models))
return models
async def __aenter__(self) -> "HFAdapter":
return self
async def __aexit__(self, *_: Any) -> None:
await self._client.aclose()