Spaces:
Running on Zero
Running on Zero
| #!/usr/bin/env python3 | |
| """Build a LoRA catalog for a Hugging Face user. | |
| Usage: | |
| python scripts/update_loras_catalog.py --author artificialguybr --output loras.json | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import re | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any | |
| import requests | |
| HF_API_MODELS = "https://huggingface.co/api/models" | |
| IMAGE_EXTENSIONS = (".png", ".jpg", ".jpeg", ".webp") | |
| class LoraEntry: | |
| title: str | |
| repo: str | |
| trigger_word: str | |
| family: str | |
| base_model: str | |
| image: str | |
| weight_name: str | |
| def as_dict(self) -> dict[str, Any]: | |
| return { | |
| "title": self.title, | |
| "repo": self.repo, | |
| "trigger_word": self.trigger_word, | |
| "family": self.family, | |
| "base_model": self.base_model, | |
| "image": self.image, | |
| "weight_name": self.weight_name, | |
| } | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--author", required=True, help="HF username/org") | |
| parser.add_argument("--output", default="loras.json", help="Output JSON path") | |
| return parser.parse_args() | |
| def load_existing_triggers(path: Path) -> dict[str, str]: | |
| if not path.exists(): | |
| return {} | |
| try: | |
| content = json.loads(path.read_text(encoding="utf-8")) | |
| except Exception: | |
| return {} | |
| triggers: dict[str, str] = {} | |
| for item in content: | |
| repo = str(item.get("repo", "")).strip() | |
| trigger = str(item.get("trigger_word", "")).strip() | |
| if repo and trigger: | |
| triggers[repo] = trigger | |
| return triggers | |
| def paginated_models(author: str) -> list[dict[str, Any]]: | |
| models: list[dict[str, Any]] = [] | |
| url = HF_API_MODELS | |
| params: dict[str, Any] | None = {"author": author, "full": "true", "limit": 100} | |
| while True: | |
| response = requests.get(url, params=params, timeout=60) | |
| response.raise_for_status() | |
| chunk = response.json() | |
| models.extend(chunk) | |
| link_header = response.headers.get("Link", "") | |
| if 'rel="next"' not in link_header: | |
| break | |
| next_url = link_header.split(";")[0].strip("<>") | |
| url = next_url | |
| params = None | |
| return models | |
| def extract_base_model(tags: list[str]) -> str: | |
| for tag in tags: | |
| if tag.startswith("base_model:adapter:"): | |
| return tag.replace("base_model:adapter:", "", 1) | |
| for tag in tags: | |
| if tag.startswith("base_model:"): | |
| return tag.replace("base_model:", "", 1) | |
| return "" | |
| def detect_family(base_model: str, repo_id: str, tags: list[str]) -> str: | |
| base = base_model.lower() | |
| if "stable-diffusion-xl" in base or "sdxl" in base: | |
| return "sdxl" | |
| if "stable-diffusion-v1-5" in base or "sd 1.5" in base or "sd1.5" in base or "sd-1-5" in base: | |
| return "sd15" | |
| if "qwen-image" in base or "qwen image" in base: | |
| return "qwen-image" | |
| if "z-image" in base or "zimage" in base: | |
| return "z-image" | |
| if "flux" in base: | |
| return "flux" | |
| text = " ".join([repo_id.lower(), *[t.lower() for t in tags]]) | |
| if "stable-diffusion-xl" in text or "sdxl" in text: | |
| return "sdxl" | |
| if "stable-diffusion-v1-5" in text or "sd 1.5" in text or "sd1.5" in text or "sd-1-5" in text: | |
| return "sd15" | |
| if "qwen-image" in text or "qwen image" in text: | |
| return "qwen-image" | |
| if "z-image" in text or "zimage" in text: | |
| return "z-image" | |
| if "flux" in text: | |
| return "flux" | |
| return "other" | |
| def is_t2i_lora(model: dict[str, Any]) -> bool: | |
| if model.get("pipeline_tag") != "text-to-image": | |
| return False | |
| tags = [str(tag).lower() for tag in model.get("tags", [])] | |
| if any("lora" in tag for tag in tags): | |
| return True | |
| return "base_model:adapter:" in " ".join(tags) | |
| def infer_title(repo_id: str) -> str: | |
| name = repo_id.split("/", 1)[-1] | |
| cleaned = name.replace("_", " ").replace("-", " ").strip() | |
| return " ".join(part.capitalize() for part in cleaned.split()) | |
| def pick_cover_image(repo_id: str, siblings: list[dict[str, Any]]) -> str: | |
| for item in siblings: | |
| filename = str(item.get("rfilename", "")) | |
| lower = filename.lower() | |
| if lower.endswith(IMAGE_EXTENSIONS) and not lower.startswith("."): | |
| return f"https://huggingface.co/{repo_id}/resolve/main/{filename}" | |
| return "" | |
| def pick_weight_name(siblings: list[dict[str, Any]]) -> str: | |
| preferred = [] | |
| fallback = [] | |
| for item in siblings: | |
| filename = str(item.get("rfilename", "")) | |
| lower = filename.lower() | |
| if not lower.endswith(".safetensors"): | |
| continue | |
| if "comfyui/" in lower: | |
| continue | |
| if lower.startswith("adapter_model"): | |
| preferred.append(filename) | |
| continue | |
| if "/" not in filename: | |
| preferred.append(filename) | |
| continue | |
| fallback.append(filename) | |
| if preferred: | |
| return sorted(preferred)[0] | |
| if fallback: | |
| return sorted(fallback)[0] | |
| return "" | |
| def normalize_trigger(text: str) -> str: | |
| cleaned = text.strip().strip("\"'").strip() | |
| cleaned = re.sub(r"\s+", " ", cleaned) | |
| cleaned = cleaned.strip(" ,;.") | |
| if cleaned in {"-", "none", "n/a"}: | |
| return "" | |
| return cleaned | |
| def extract_trigger_from_readme(readme: str) -> str: | |
| frontmatter = readme | |
| if readme.startswith("---"): | |
| parts = readme.split("---", 2) | |
| if len(parts) >= 3: | |
| frontmatter = parts[1] | |
| patterns = [ | |
| r"(?im)^\s*instance_prompt\s*:\s*(.+?)\s*$", | |
| r"(?im)^\s*trigger_word\s*:\s*(.+?)\s*$", | |
| r"(?im)^\s*activation[_ ]token\s*:\s*(.+?)\s*$", | |
| r"(?im)^\s*trigger[_ ]phrase\s*:\s*(.+?)\s*$", | |
| r"(?im)^\s*token\s*:\s*(.+?)\s*$", | |
| ] | |
| for pattern in patterns: | |
| match = re.search(pattern, frontmatter) | |
| if match: | |
| trigger = normalize_trigger(match.group(1)) | |
| if trigger: | |
| return trigger | |
| body_patterns = [ | |
| r"(?im)trigger word\s*[:\-]\s*`?([^`\n]+)`?", | |
| r"(?im)activation token\s*[:\-]\s*`?([^`\n]+)`?", | |
| r"(?im)use\s+`([^`]+)`\s+in your prompt", | |
| r"(?im)you can use\s+([^.\n]+)", | |
| ] | |
| for pattern in body_patterns: | |
| match = re.search(pattern, readme) | |
| if match: | |
| trigger = normalize_trigger(match.group(1)) | |
| if trigger: | |
| return trigger | |
| return "" | |
| def fetch_trigger_word(repo_id: str, session: requests.Session) -> str: | |
| readme_url = f"https://huggingface.co/{repo_id}/raw/main/README.md" | |
| try: | |
| response = session.get(readme_url, timeout=30) | |
| if response.status_code != 200: | |
| return "" | |
| return extract_trigger_from_readme(response.text) | |
| except Exception: | |
| return "" | |
| def build_catalog( | |
| models: list[dict[str, Any]], existing_triggers: dict[str, str] | |
| ) -> list[dict[str, Any]]: | |
| entries: list[LoraEntry] = [] | |
| session = requests.Session() | |
| for model in models: | |
| if not is_t2i_lora(model): | |
| continue | |
| repo_id = model["id"] | |
| tags = [str(tag) for tag in model.get("tags", [])] | |
| base_model = extract_base_model(tags) | |
| family = detect_family(base_model, repo_id, tags) | |
| siblings = model.get("siblings") or [] | |
| trigger_word = fetch_trigger_word(repo_id, session) or existing_triggers.get(repo_id, "") | |
| entries.append( | |
| LoraEntry( | |
| title=infer_title(repo_id), | |
| repo=repo_id, | |
| trigger_word=trigger_word, | |
| family=family, | |
| base_model=base_model, | |
| image=pick_cover_image(repo_id, siblings), | |
| weight_name=pick_weight_name(siblings), | |
| ) | |
| ) | |
| entries.sort(key=lambda x: (x.family, x.title.lower())) | |
| return [entry.as_dict() for entry in entries] | |
| def main() -> None: | |
| args = parse_args() | |
| output_path = Path(args.output) | |
| existing_triggers = load_existing_triggers(output_path) | |
| models = paginated_models(args.author) | |
| catalog = build_catalog(models, existing_triggers) | |
| output_path.write_text(json.dumps(catalog, indent=2, ensure_ascii=False) + "\n", encoding="utf-8") | |
| by_family: dict[str, int] = {} | |
| for row in catalog: | |
| fam = row["family"] | |
| by_family[fam] = by_family.get(fam, 0) + 1 | |
| with_trigger = sum(1 for row in catalog if row.get("trigger_word")) | |
| print(f"Saved {len(catalog)} LoRAs to {output_path}") | |
| print(f"Trigger words filled: {with_trigger}") | |
| print("Family counts:") | |
| for fam in sorted(by_family): | |
| print(f" - {fam}: {by_family[fam]}") | |
| if __name__ == "__main__": | |
| main() | |