artydemo / scripts /export_hf_model_card.py
Pablo Dejuan
generate cnn_rrn card
e99de5b
#!/usr/bin/env python3
"""
Write a Hugging Face model card README (YAML + markdown) with evaluation metrics.
Runs the same test-set evaluation as eval_cnn.py and fills `model-index` metrics.
Usage:
python scripts/export_hf_model_card.py --arch cnn --output docs/hf_model_card_cnn_baseline.md
See: https://huggingface.co/docs/hub/model-cards#evaluation-results
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT / "scripts"))
from eval_cnn import EvalMetrics, compute_test_metrics # noqa: E402
def render_model_card(
m: EvalMetrics,
*,
arch: str,
model_index_name: str,
dataset_id: str,
source_name: str,
source_url: str,
license_id: str,
extra_body: str,
) -> str:
"""Build full README.md content for a Hub model repo."""
g, s, a, a5 = m["genre_top1"], m["style_top1"], m["artist_top1"], m["artist_top5"]
epoch = m["epoch"]
epoch_str = repr(epoch) if epoch is not None else "unknown"
if arch == "cnnrnn":
extra_tags = "\n - lstm\n - bilstm"
intro = (
f"Multi-head **ResNet-50** backbone with **column pooling**, a **bidirectional LSTM** over spatial strips, "
f"and linear heads for **genre**, **style**, and **artist** on a WikiArt subset "
f"([{dataset_id}](https://huggingface.co/datasets/{dataset_id}))."
)
else:
extra_tags = ""
intro = (
f"Multi-head **ResNet-50** classifier: **genre**, **style**, and **artist** on a WikiArt subset "
f"([{dataset_id}](https://huggingface.co/datasets/{dataset_id}))."
)
yaml = f"""---
license: {license_id}
language: en
pipeline_tag: image-classification
library_name: pytorch
datasets:
- {dataset_id}
tags:
- pytorch
- torchvision
- resnet-50
- image-classification
- multi-task
- wikiart
- art{extra_tags}
model-index:
- name: {model_index_name}
results:
- task:
type: image-classification
dataset:
name: {dataset_id}
type: {dataset_id}
metrics:
- name: Genre accuracy (top-1)
type: accuracy
value: {g:.6f}
- name: Style accuracy (top-1)
type: accuracy
value: {s:.6f}
- name: Artist accuracy (top-1)
type: accuracy
value: {a:.6f}
- name: Artist accuracy (top-5)
type: accuracy
value: {a5:.6f}
source:
name: {source_name}
url: {source_url}
---
# {model_index_name}
{intro}
## Evaluation (test split)
| Metric | Value |
|--------|------:|
| Genre (top-1) | {g:.2%} |
| Style (top-1) | {s:.2%} |
| Artist (top-1) | {a:.2%} |
| Artist (top-5) | {a5:.2%} |
- **Checkpoint**: `{m["checkpoint_name"]}`
- **Arch**: `{m["arch"]}`
- **Epoch** (from checkpoint): {epoch_str}
- **Test images**: {m["test_n"]}
{extra_body}
"""
return yaml
def main() -> None:
p = argparse.ArgumentParser(description="Export HF model card README with eval metrics.")
p.add_argument("--arch", choices=["cnn", "cnnrnn"], default="cnn")
p.add_argument("--last", action="store_true", help="Use last.pt instead of best.pt")
p.add_argument(
"--output",
type=Path,
default=None,
help="Where to write README (default: docs/hf_model_card_cnn_baseline.md or ..._arty_cnn_rnn.md by --arch)",
)
p.add_argument("--model-name", default="Arty CNN baseline", help="model-index name + H1 title")
p.add_argument("--dataset", default="pdjota/artyset", help="Hub dataset id")
p.add_argument("--source-name", default="Local eval (scripts/eval_cnn.py)")
p.add_argument("--source-url", default="", help="Optional URL (e.g. GitHub repo)")
p.add_argument("--license", default="apache-2.0", dest="license_id")
p.add_argument(
"--extra-body",
default="## Limitations\n\nNot for production attribution or forensic ID; academic / demo use.",
help="Markdown appended after the eval table",
)
args = p.parse_args()
out = args.output
if out is None:
default_out = {
"cnn": ROOT / "docs" / "hf_model_card_cnn_baseline.md",
"cnnrnn": ROOT / "docs" / "hf_model_card_arty_cnn_rnn.md",
}
out = default_out[args.arch]
try:
m = compute_test_metrics(arch=args.arch, last=args.last)
except FileNotFoundError as e:
print(f"ERROR: {e}", file=sys.stderr)
sys.exit(1)
source_url = args.source_url.strip() or "https://huggingface.co/docs/hub/model-cards"
text = render_model_card(
m,
arch=args.arch,
model_index_name=args.model_name,
dataset_id=args.dataset,
source_name=args.source_name,
source_url=source_url,
license_id=args.license_id,
extra_body=args.extra_body,
)
out.parent.mkdir(parents=True, exist_ok=True)
out.write_text(text, encoding="utf-8")
print(f"Wrote {out} ({len(text)} bytes)")
if __name__ == "__main__":
main()