Pablo Dejuan commited on
Commit
179dfc2
·
1 Parent(s): 6c8d2bc

Inference and Hub UX: shared predict_topk, atomic checkpoints, upload .env

Browse files

- Add _ThreeHeadPredictMixin (predict_topk / predict_topk_from_path) on both
ResNet heads; Gradio uses it with predict_format top-k JSON and config
ImageNet normalize; device selection includes MPS.
- Add spot_check_excluded_post_impressionism.py (train transforms, --top-k)
and README spot-check table with GAP/Lin+He refs and BiLSTM wording.
- train_cnn: atomic torch.save via temp+replace; weights_only=False on loads;
clear error when last.pt is corrupt (suggest cp from best.pt).
- upload_model_to_hf: load .env with override=True; required --checkpoint;
default export to data/label_maps; drop --token; python-dotenv in
requirements; Dockerfile omits redundant --export-labels-dir.
- Tests for predict_topk, spot_check helpers, atomic save, dotenv override.

Made-with: Cursor

Dockerfile CHANGED
@@ -39,7 +39,7 @@ else \
39
  echo \"[train] No last.pt found; training from scratch for ${EPOCHS} epochs\"; \
40
  python scripts/train_cnn_safe.py --arch \"$ARCH\" --epochs \"$EPOCHS\" --batch-size-primary \"$BATCH_SIZE_PRIMARY\" --batch-size-fallback \"$BATCH_SIZE_FALLBACK\"; \
41
  fi; \
42
- python scripts/upload_model_to_hf.py --repo-id \"$MODEL_REPO_ID\" --checkpoint \"checkpoints/$ARCH/best.pt\" --export-labels-dir data/label_maps; \
43
  kill $SERVER_PID >/dev/null 2>&1 || true \
44
  "]
45
 
 
39
  echo \"[train] No last.pt found; training from scratch for ${EPOCHS} epochs\"; \
40
  python scripts/train_cnn_safe.py --arch \"$ARCH\" --epochs \"$EPOCHS\" --batch-size-primary \"$BATCH_SIZE_PRIMARY\" --batch-size-fallback \"$BATCH_SIZE_FALLBACK\"; \
41
  fi; \
42
+ python scripts/upload_model_to_hf.py --repo-id \"$MODEL_REPO_ID\" --checkpoint \"checkpoints/$ARCH/best.pt\"; \
43
  kill $SERVER_PID >/dev/null 2>&1 || true \
44
  "]
45
 
README.md CHANGED
@@ -12,14 +12,62 @@ pinned: false
12
 
13
  # Arty
14
 
15
- WikiArt **genre / style / artist** classifiers (CNN baseline and CNN-RNN). This Hugging Face **Space** runs the **Gradio** app under [`gradio/app.py`](gradio/app.py); weights load from Hub model repos and architecture from [`src/model.py`](src/model.py).
16
 
17
  **Training** (Docker GPU job) lives in [`Dockerfile`](Dockerfile) — use a **separate** Space pointed at the same repo if you only want training, or run locally. Do not set `sdk: docker` on this Space if you want the Gradio UI.
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  ## Env (optional)
20
 
21
  - `BASELINE_MODEL_REPO_ID` — default `pdjota/cnn-baseline`
22
  - `CNNRNN_MODEL_REPO_ID` — default `pdjota/arty-cnn-rnn`
23
  - `HF_TOKEN` — if model repos are gated
24
 
25
- More detail: [`gradio/README.md`](gradio/README.md), [`docs/monorepo_gradio_space.md`](docs/monorepo_gradio_space.md).
 
 
 
 
 
 
 
 
12
 
13
  # Arty
14
 
15
+ **Arty** is a multi-task WikiArt classifier: **genre**, **style**, and **artist** in one model with two architectures — a **CNN baseline** (ResNet-50 + global pooling + three heads) and a **CNNRNN** (same backbone, **bidirectional long short-term memory (BiLSTM)** over spatial features + three heads). This Hugging Face **Space** runs the **Gradio** app in [`gradio/app.py`](gradio/app.py); **weights** load from Hub model repos and **architecture** from [`src/model.py`](src/model.py). Each architecuture corresponds to a model: [pdjota/cnn-baseline](https://huggingface.co/pdjota/cnn-baseline) or [pdjota/arty-cnn-rnn](https://huggingface.co/pdjota/arty-cnn-rnn)
16
 
17
  **Training** (Docker GPU job) lives in [`Dockerfile`](Dockerfile) — use a **separate** Space pointed at the same repo if you only want training, or run locally. Do not set `sdk: docker` on this Space if you want the Gradio UI.
18
 
19
+ ## About this project
20
+
21
+ ### Classification with three labels together
22
+
23
+ **Genre** and **style** are relatively **generic**: many paintings share the same movement or subject category, and the model learns broad visual patterns that match those labels. **Artist** is **more specific** — we usually think of *who* painted it **within** a stylistic movement. Distinguishing painters who share a movement like Sisley and Monet in the impressionist movement or identifying Picasso who create a different set of paitings in symbolism or cubism become challenging. The network must pick up **fine-grained** cues (palette, brushwork, recurring motifs) that sit on top of the same broad visual cues that support style and genre.
24
+
25
+ We use a **shared convolutional trunk** (one set of image features) and **three separate heads** (genre, style, artist). The trunk carries **generic** painting features; the heads split **coarse** (genre, style) vs **fine** (artist) decisions so the artist task can specialize without forcing a single output to encode everything at once.
26
+
27
+ ### ResNet-50 and short fine-tuning
28
+
29
+ **ResNet-50** is a deep convolutional network built from **residual (skip) connections** so very deep stacks train without vanishing gradients ([He et al., 2016](https://arxiv.org/abs/1512.03385); trained on **ImageNet** in that work). **Transfer learning** is standard: features from early conv layers tend to transfer across related visual domains better than random initialization ([Yosinski et al., 2014](https://arxiv.org/abs/1411.1792)). For **paintings**, [Zhao et al. (2021)](https://doi.org/10.1371/journal.pone.0248414) compare the same model families **with and without** ImageNet-based transfer on WikiArt (genre / style / artist) and report strong results with pretraining — so we **fine-tune** the backbone and heads for a **limited number of epochs** instead of training from scratch at ImageNet-scale cost.
30
+ We use an **ArtGAN-aligned** WikiArt-style index: images catalogued with consistent **genre, style, and
31
+ artist** labels; a few broken paths are excluded. A curated dataset is on the Hub (e.g. [`pdjota/artyset`]
32
+ (https://huggingface.co/datasets/pdjota/artyset)) for reproducible training.
33
+
34
+ The resulting classification is already good for some examples. **Spot check (CNN baseline `best.pt` with [`scripts/spot_check_excluded_post_impressionism.py`](scripts/spot_check_excluded_post_impressionism.py)):** five **Post_Impressionism** images under `data/wikiart_excluded/Post_Impressionism/` (not in the training index) — style top-1 and a short note:
35
+
36
+ | Painting | Style (top-1) | Comment |
37
+ | -------- | ------------- | ------- |
38
+ | `henri-matisse_a-vase-with-oranges.jpg` | Post_Impressionism (~80%) | Still life; confident style match. |
39
+ | `henri-de-toulouse-lautrec_portrait-of-vincent-van-gogh-1887.jpg` | Impressionism (~99%) | Sketchy handling reads as Impressionist; Post_Impressionism far behind. |
40
+ | `pablo-picasso_seated-monkey-1905.jpg` | Post_Impressionism (~42%) | Close with Expressionism; **artist** top-1 Picasso (~93%). |
41
+ | `paul-gauguin_a-seashore-1887.jpg` | Impressionism (~75%) | Post_Impressionism second (~16%). |
42
+ | `a.y.-jackson_the-edge-of-the-maple-wood-1910.jpg` | Impressionism (~94%) | Landscape; artist head has no A.Y. Jackson class (23 ArtGAN artists). |
43
+
44
+ ### Bidirectional long short-term memory (BiLSTM) on top of the CNN
45
+
46
+ Zhao et al. note that their setup uses **colour** information heavily and that **spatial** information could still improve classification. Standard **global average pooling (GAP)** after the last conv map **throws away layout**: it averages each channel over all spatial positions, so the classifier sees a **single vector per channel** with **no remaining (x, y) structure** ([Lin et al., 2014](https://arxiv.org/abs/1312.4400); ResNet-50 uses this pattern before its **fully connected (FC)** layer [He et al., 2016](https://arxiv.org/abs/1512.03385)). That answers “what is present” but not “how it is arranged.” We keep the same ResNet backbone, then **turn the spatial grid into a sequence** (e.g. column-wise strips), run a **bidirectional long short-term memory (BiLSTM)**, then classify. The **reasoning** is: composition, figure–ground balance, and brushstroke patterns often have **left–right (or strip-wise) structure**; a sequence model can integrate **context** along that axis **bidirectionally**, which GAP does not model. The CNN–RNN is a **minimal, comparable** upgrade: same heads and training loop, different pooling.
47
+
48
+ ### Data: ArtGAN-aligned index and Hugging Face
49
+
50
+ We align with the **ArtGAN / WikiArt** lineage so labels are **catalogue-consistent** for genre, style, and artist. We **trim** the index (drop broken/missing files, validate paths) and publish a curated dataset on the Hub (e.g. [`pdjota/artyset`](https://huggingface.co/datasets/pdjota/artyset)) so **training and demos are reproducible**.
51
+
52
+ `scripts/train_cnn.py` uses **70% / 15% / 15%** train / val / test, **stratified by `artist_id`**. **Reasoning:** if we split randomly by image, we might put **almost all works of a rare artist** in one fold; **artist** is also the label that would most easily “leak” structurally (same brushwork in train vs test). Stratifying by artist keeps **each split’s artist mix** more representative, so validation/test **accuracy and loss** are comparable across runs and less dominated by **which artists** landed in which fold.
53
+
54
+ ### Training artifacts and this Space
55
+
56
+ Runs save **PyTorch** checkpoints (`best.pt`, `last.pt`), **CSV** logs (`train_log.csv`, `results_summary.csv`), and we upload reference models with **`id2label` JSON** to Hub model repos. Training works on **CPU**, **Apple Silicon (MPS)**, or a **GPU Space**.
57
+
58
+ **Gradio (this Space):** upload a painting and compare **CNN baseline** vs **CNN–RNN** top-k predictions. Env vars below select which Hub checkpoints to load.
59
+
60
  ## Env (optional)
61
 
62
  - `BASELINE_MODEL_REPO_ID` — default `pdjota/cnn-baseline`
63
  - `CNNRNN_MODEL_REPO_ID` — default `pdjota/arty-cnn-rnn`
64
  - `HF_TOKEN` — if model repos are gated
65
 
66
+ More detail: [`gradio/README.md`](gradio/README.md), [`docs/monorepo_gradio_space.md`](docs/monorepo_gradio_space.md), research plan [`plan.md`](plan.md).
67
+
68
+ ### References (ResNet / transfer / WikiArt)
69
+
70
+ 1. He, K., Zhang, X., Ren, S., & Sun, J. (2016). *Deep residual learning for image recognition.* CVPR. [arXiv:1512.03385](https://arxiv.org/abs/1512.03385)
71
+ 2. Yosinski, J., Clune, J., Bengio, Y., & Lipson, H. (2014). *How transferable are features in deep neural networks?* NeurIPS. [arXiv:1411.1792](https://arxiv.org/abs/1411.1792)
72
+ 3. Zhao, W., Zhou, D., Qiu, X., & Jiang, W. (2021). *Compare the performance of the models in art classification.* PLOS ONE 16(3): e0248414. [DOI:10.1371/journal.pone.0248414](https://doi.org/10.1371/journal.pone.0248414)
73
+ 4. Lin, M., Chen, Q., & Yan, S. (2014). *Network in network.* ICLR. [arXiv:1312.4400](https://arxiv.org/abs/1312.4400) — global average pooling to aggregate conv feature maps before classification.
gradio/app.py CHANGED
@@ -18,7 +18,6 @@ from typing import Any, Dict, List, Optional, Tuple
18
 
19
  import gradio as gr
20
  import torch
21
- import torch.nn.functional as F
22
  from huggingface_hub import hf_hub_download
23
  from PIL import Image
24
  from torchvision import transforms as T
@@ -32,12 +31,19 @@ if not _SRC.exists():
32
  )
33
  sys.path.insert(0, str(REPO_ROOT / "src"))
34
 
 
35
  from model import ResNet50BiLSTMThreeHeads # type: ignore
36
  from model import ResNet50ThreeHeads # type: ignore
 
37
 
38
  BASELINE_REPO = os.environ.get("BASELINE_MODEL_REPO_ID", "pdjota/cnn-baseline")
39
  CNNRNN_REPO = os.environ.get("CNNRNN_MODEL_REPO_ID", "pdjota/arty-cnn-rnn")
40
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
41
  HF_TOKEN = os.environ.get("HF_TOKEN")
42
 
43
  transform = T.Compose(
@@ -45,7 +51,7 @@ transform = T.Compose(
45
  T.Resize(256),
46
  T.CenterCrop(224),
47
  T.ToTensor(),
48
- T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
49
  ]
50
  )
51
 
@@ -109,11 +115,6 @@ def _load(repo_id: str) -> Dict[str, Any]:
109
 
110
  # --- prediction helpers ----------------------------------------------------
111
 
112
- def _topk(logits: torch.Tensor, id2label: Dict[int, str], k: int = 3) -> List[Dict[str, Any]]:
113
- probs = F.softmax(logits, dim=-1)[0]
114
- vals, idxs = probs.topk(k)
115
- return [{"label": id2label.get(int(i), str(int(i))), "prob": round(float(v), 4)} for v, i in zip(vals, idxs)]
116
-
117
 
118
  def _bucket(pct: float) -> str:
119
  if pct >= 80:
@@ -151,12 +152,16 @@ def predict(model_choice: str, image: Optional[Image.Image]) -> Tuple[str, str]:
151
  model = assets["model"]
152
 
153
  x = transform(image).unsqueeze(0).to(DEVICE)
154
- with torch.no_grad():
155
- lg, ls, la = model(x)
156
-
157
- g3 = _topk(lg, assets["genre"])
158
- s3 = _topk(ls, assets["style"])
159
- a3 = _topk(la, assets["artist"])
 
 
 
 
160
 
161
  summary = "\n".join([
162
  f"**Genre**: {_summarize(g3)}",
 
18
 
19
  import gradio as gr
20
  import torch
 
21
  from huggingface_hub import hf_hub_download
22
  from PIL import Image
23
  from torchvision import transforms as T
 
31
  )
32
  sys.path.insert(0, str(REPO_ROOT / "src"))
33
 
34
+ from config import IMAGENET_MEAN, IMAGENET_STD # type: ignore
35
  from model import ResNet50BiLSTMThreeHeads # type: ignore
36
  from model import ResNet50ThreeHeads # type: ignore
37
+ from predict_format import topk_tuples_to_ui_items # type: ignore
38
 
39
  BASELINE_REPO = os.environ.get("BASELINE_MODEL_REPO_ID", "pdjota/cnn-baseline")
40
  CNNRNN_REPO = os.environ.get("CNNRNN_MODEL_REPO_ID", "pdjota/arty-cnn-rnn")
41
+ if torch.cuda.is_available():
42
+ DEVICE = torch.device("cuda")
43
+ elif getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
44
+ DEVICE = torch.device("mps")
45
+ else:
46
+ DEVICE = torch.device("cpu")
47
  HF_TOKEN = os.environ.get("HF_TOKEN")
48
 
49
  transform = T.Compose(
 
51
  T.Resize(256),
52
  T.CenterCrop(224),
53
  T.ToTensor(),
54
+ T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
55
  ]
56
  )
57
 
 
115
 
116
  # --- prediction helpers ----------------------------------------------------
117
 
 
 
 
 
 
118
 
119
  def _bucket(pct: float) -> str:
120
  if pct >= 80:
 
152
  model = assets["model"]
153
 
154
  x = transform(image).unsqueeze(0).to(DEVICE)
155
+ g_t, s_t, a_t = model.predict_topk(
156
+ x,
157
+ genre_id2label=assets["genre"],
158
+ style_id2label=assets["style"],
159
+ artist_id2label=assets["artist"],
160
+ k=3,
161
+ )
162
+ g3 = topk_tuples_to_ui_items(g_t)
163
+ s3 = topk_tuples_to_ui_items(s_t)
164
+ a3 = topk_tuples_to_ui_items(a_t)
165
 
166
  summary = "\n".join([
167
  f"**Genre**: {_summarize(g3)}",
requirements.txt CHANGED
@@ -9,6 +9,7 @@ scikit-learn>=1.2
9
  matplotlib>=3.7
10
  tqdm>=4.65
11
  huggingface_hub>=0.25.0,<1.0 # Gradio 5.x needs HfFolder; removed in hub 1.0
 
12
  pytest>=7.0 # tests
13
  pytest-cov>=4.0 # coverage
14
 
 
9
  matplotlib>=3.7
10
  tqdm>=4.65
11
  huggingface_hub>=0.25.0,<1.0 # Gradio 5.x needs HfFolder; removed in hub 1.0
12
+ python-dotenv>=1.0 # optional: HF_TOKEN from repo .env for upload scripts
13
  pytest>=7.0 # tests
14
  pytest-cov>=4.0 # coverage
15
 
scripts/spot_check_excluded_post_impressionism.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Spot-check the CNN baseline on fixed excluded Post_Impressionism images (README table).
3
+
4
+ Images live under data/wikiart_excluded/Post_Impressionism/ (not in the training index).
5
+ Uses the same eval transforms as training (`train_cnn.get_transforms(train=False)`).
6
+
7
+ Usage (from repo root):
8
+
9
+ python scripts/spot_check_excluded_post_impressionism.py
10
+ python scripts/spot_check_excluded_post_impressionism.py --cpu
11
+ python scripts/spot_check_excluded_post_impressionism.py --top-k 5
12
+ """
13
+ from __future__ import annotations
14
+
15
+ import argparse
16
+ import importlib.util
17
+ import json
18
+ import sys
19
+ from pathlib import Path
20
+
21
+ import torch
22
+
23
+ ROOT = Path(__file__).resolve().parent.parent
24
+ sys.path.insert(0, str(ROOT / "src"))
25
+
26
+ from config import checkpoint_dir_for_arch
27
+ from model import ResNet50ThreeHeads
28
+
29
+ DEFAULT_REL_PATHS: tuple[str, ...] = (
30
+ "henri-matisse_a-vase-with-oranges.jpg",
31
+ "henri-de-toulouse-lautrec_portrait-of-vincent-van-gogh-1887.jpg",
32
+ "pablo-picasso_seated-monkey-1905.jpg",
33
+ "paul-gauguin_a-seashore-1887.jpg",
34
+ "a.y.-jackson_the-edge-of-the-maple-wood-1910.jpg",
35
+ )
36
+
37
+ LABEL_MAPS_DIR = ROOT / "data" / "label_maps"
38
+ EXCLUDED_STYLE_DIR = ROOT / "data" / "wikiart_excluded" / "Post_Impressionism"
39
+
40
+
41
+ def _load_train_cnn():
42
+ spec = importlib.util.spec_from_file_location("train_cnn", ROOT / "scripts" / "train_cnn.py")
43
+ mod = importlib.util.module_from_spec(spec)
44
+ assert spec.loader is not None
45
+ spec.loader.exec_module(mod)
46
+ return mod
47
+
48
+
49
+ def load_id2label(path: Path) -> dict[int, str]:
50
+ with open(path, encoding="utf-8") as f:
51
+ return {int(k): v for k, v in json.load(f).items()}
52
+
53
+
54
+ def load_label_maps() -> tuple[dict[int, str], dict[int, str], dict[int, str]]:
55
+ return (
56
+ load_id2label(LABEL_MAPS_DIR / "genre_id2label.json"),
57
+ load_id2label(LABEL_MAPS_DIR / "style_id2label.json"),
58
+ load_id2label(LABEL_MAPS_DIR / "artist_id2label.json"),
59
+ )
60
+
61
+
62
+ def resolve_device(*, force_cpu: bool) -> torch.device:
63
+ if force_cpu:
64
+ return torch.device("cpu")
65
+ if torch.cuda.is_available():
66
+ return torch.device("cuda")
67
+ if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
68
+ return torch.device("mps")
69
+ return torch.device("cpu")
70
+
71
+
72
+ def main() -> None:
73
+ p = argparse.ArgumentParser(description="CNN spot-check on excluded Post_Impressionism examples.")
74
+ p.add_argument("--cpu", action="store_true", help="Force CPU")
75
+ p.add_argument(
76
+ "--checkpoint",
77
+ type=Path,
78
+ default=None,
79
+ help="Path to best.pt (default: checkpoints/<cnn>/best.pt from config)",
80
+ )
81
+ p.add_argument("--top-k", type=int, default=3, metavar="K", help="Top-k per head (default: 3)")
82
+ args = p.parse_args()
83
+
84
+ if args.top_k < 1:
85
+ print("ERROR: --top-k must be >= 1", file=sys.stderr)
86
+ sys.exit(1)
87
+
88
+ device = resolve_device(force_cpu=args.cpu)
89
+ ckpt_path = args.checkpoint if args.checkpoint is not None else checkpoint_dir_for_arch("cnn") / "best.pt"
90
+ if not ckpt_path.exists():
91
+ print(f"ERROR: checkpoint not found: {ckpt_path}", file=sys.stderr)
92
+ sys.exit(1)
93
+
94
+ genre_map, style_map, artist_map = load_label_maps()
95
+
96
+ ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
97
+ n_genre = ckpt["n_genre"]
98
+ n_style = ckpt["n_style"]
99
+ n_artist = ckpt["n_artist"]
100
+ model = ResNet50ThreeHeads(n_genre=n_genre, n_style=n_style, n_artist=n_artist, weights=None)
101
+ model.load_state_dict(ckpt["model_state_dict"])
102
+ model.to(device)
103
+
104
+ train_cnn = _load_train_cnn()
105
+ transform = train_cnn.get_transforms(train=False)
106
+
107
+ paths = [EXCLUDED_STYLE_DIR / name for name in DEFAULT_REL_PATHS]
108
+
109
+ print(f"Checkpoint: {ckpt_path}")
110
+ print(f"Device: {device}")
111
+ print(f"Top-k: {args.top_k}")
112
+ post_ids = [k for k, v in style_map.items() if v == "Post_Impressionism"]
113
+ print(f"Post_Impressionism style id(s): {post_ids}")
114
+ print()
115
+
116
+ for path in paths:
117
+ if not path.exists():
118
+ print(f"MISSING: {path}")
119
+ continue
120
+ g, s, a = model.predict_topk_from_path(
121
+ path,
122
+ transform,
123
+ device,
124
+ genre_id2label=genre_map,
125
+ style_id2label=style_map,
126
+ artist_id2label=artist_map,
127
+ k=args.top_k,
128
+ )
129
+ print("=" * 72)
130
+ print(path.name)
131
+ print(" genre (top-%d):" % args.top_k, g)
132
+ print(" style (top-%d):" % args.top_k, s)
133
+ print(" artist (top-%d):" % args.top_k, a)
134
+
135
+
136
+ if __name__ == "__main__":
137
+ main()
scripts/train_cnn.py CHANGED
@@ -26,6 +26,15 @@ from torchvision import transforms as T
26
  from sklearn.model_selection import train_test_split
27
 
28
  ROOT = Path(__file__).resolve().parent.parent
 
 
 
 
 
 
 
 
 
29
  sys.path.insert(0, str(ROOT / "src"))
30
 
31
  from config import (
@@ -134,7 +143,18 @@ def main() -> None:
134
  resume_path = ckpt_dir / "last.pt"
135
 
136
  if args.resume and resume_path.exists():
137
- ckpt = torch.load(resume_path, map_location=device)
 
 
 
 
 
 
 
 
 
 
 
138
  start_epoch = ckpt["epoch"] + 1
139
  best_val_loss = ckpt.get("val_loss", float("inf"))
140
  extra = args.epochs if args.epochs is not None else 10
@@ -275,9 +295,9 @@ def main() -> None:
275
  "n_style": N_STYLE,
276
  "n_artist": N_ARTIST,
277
  }
278
- torch.save(ckpt, ckpt_dir / "last.pt")
279
  if is_best:
280
- torch.save(ckpt, ckpt_dir / "best.pt")
281
 
282
  log_row = {
283
  "epoch": epoch,
@@ -315,7 +335,7 @@ def main() -> None:
315
  "batch_in_epoch": current_batch_in_epoch,
316
  "num_batches_in_epoch": current_num_batches_in_epoch,
317
  }
318
- torch.save(interrupted_ckpt, ckpt_dir / "last.pt")
319
  print(
320
  "\n"
321
  f"[{now_ts()}] Stopped by user (Ctrl+C). Saved resumable checkpoint to "
@@ -327,7 +347,7 @@ def main() -> None:
327
  # Save best-val results summary
328
  best_ckpt_path = ckpt_dir / "best.pt"
329
  if best_ckpt_path.exists():
330
- best_ckpt = torch.load(best_ckpt_path, map_location="cpu")
331
  summary = {
332
  "best_epoch": best_ckpt.get("epoch"),
333
  "val_loss": best_ckpt.get("val_loss"),
 
26
  from sklearn.model_selection import train_test_split
27
 
28
  ROOT = Path(__file__).resolve().parent.parent
29
+
30
+
31
+ def _atomic_torch_save(obj: object, path: Path) -> None:
32
+ """Write `path` via a temp file + `os.replace` so a kill mid-write does not truncate `last.pt` / `best.pt`."""
33
+ path = Path(path)
34
+ path.parent.mkdir(parents=True, exist_ok=True)
35
+ tmp = path.with_suffix(path.suffix + ".tmp")
36
+ torch.save(obj, tmp)
37
+ os.replace(tmp, path)
38
  sys.path.insert(0, str(ROOT / "src"))
39
 
40
  from config import (
 
143
  resume_path = ckpt_dir / "last.pt"
144
 
145
  if args.resume and resume_path.exists():
146
+ try:
147
+ ckpt = torch.load(resume_path, map_location=device, weights_only=False)
148
+ except Exception as e:
149
+ print(
150
+ f"[{now_ts()}] ERROR: Cannot load {resume_path} (often a truncated file if the process was killed "
151
+ f"during `torch.save`).\n"
152
+ f" {e}\n"
153
+ f" Fix: if best.pt is intact, copy it over last.pt and resume, e.g.\n"
154
+ f" cp {ckpt_dir / 'best.pt'} {resume_path}",
155
+ file=sys.stderr,
156
+ )
157
+ sys.exit(1)
158
  start_epoch = ckpt["epoch"] + 1
159
  best_val_loss = ckpt.get("val_loss", float("inf"))
160
  extra = args.epochs if args.epochs is not None else 10
 
295
  "n_style": N_STYLE,
296
  "n_artist": N_ARTIST,
297
  }
298
+ _atomic_torch_save(ckpt, ckpt_dir / "last.pt")
299
  if is_best:
300
+ _atomic_torch_save(ckpt, ckpt_dir / "best.pt")
301
 
302
  log_row = {
303
  "epoch": epoch,
 
335
  "batch_in_epoch": current_batch_in_epoch,
336
  "num_batches_in_epoch": current_num_batches_in_epoch,
337
  }
338
+ _atomic_torch_save(interrupted_ckpt, ckpt_dir / "last.pt")
339
  print(
340
  "\n"
341
  f"[{now_ts()}] Stopped by user (Ctrl+C). Saved resumable checkpoint to "
 
347
  # Save best-val results summary
348
  best_ckpt_path = ckpt_dir / "best.pt"
349
  if best_ckpt_path.exists():
350
+ best_ckpt = torch.load(best_ckpt_path, map_location="cpu", weights_only=False)
351
  summary = {
352
  "best_epoch": best_ckpt.get("epoch"),
353
  "val_loss": best_ckpt.get("val_loss"),
scripts/upload_model_to_hf.py CHANGED
@@ -1,28 +1,41 @@
1
  """
2
- Upload a trained checkpoint to the Hugging Face Hub (model repo) and (optionally)
3
- export id→label JSON files locally.
4
 
5
- The Space/demo code uses these JSONs to turn predicted class indices into readable labels.
6
 
7
- Usage:
8
- HF_TOKEN=... python scripts/upload_model_to_hf.py --repo-id USERNAME/arty-cnn-baseline
9
- HF_TOKEN=... python scripts/upload_model_to_hf.py --repo-id USERNAME/arty-cnn-baseline --export-labels-dir data/label_maps
 
10
  """
11
 
12
  import argparse
13
  import json
 
14
  import sys
15
  from pathlib import Path
16
 
17
  import torch
18
 
19
  ROOT = Path(__file__).resolve().parent.parent
20
- sys.path.insert(0, str(ROOT / "src"))
21
- from config import checkpoint_dir_for_arch # noqa: E402
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  DATA_DIR = ROOT / "data"
24
  INDEX_SELECTED = DATA_DIR / "wikiart_index_selected.csv"
25
- CHECKPOINT_DEFAULT = checkpoint_dir_for_arch("cnn") / "best.pt"
26
 
27
 
28
  def build_id2label_from_selected_index(index_path: Path) -> tuple[dict[str, str], dict[str, str], dict[str, str]]:
@@ -101,22 +114,34 @@ def upload_checkpoint_and_labels(
101
 
102
 
103
  def main() -> None:
104
- p = argparse.ArgumentParser(description="Upload model checkpoint + id2label JSONs to Hugging Face Hub")
 
 
 
 
 
105
  p.add_argument("--repo-id", required=True, help="Model repo id, e.g. username/arty-cnn-baseline")
106
  p.add_argument(
107
  "--checkpoint",
108
  type=Path,
109
- default=CHECKPOINT_DEFAULT,
110
- help=f"Checkpoint path (default: {CHECKPOINT_DEFAULT})",
111
  )
112
- p.add_argument("--token", default=None, help="HF token (default: HF_TOKEN env)")
113
  p.add_argument("--index", type=Path, default=INDEX_SELECTED, help="Selected index CSV (default: data/wikiart_index_selected.csv)")
114
- p.add_argument("--export-labels-dir", type=Path, default=None, help="Optional dir to write *_id2label.json locally")
 
 
 
 
 
115
  args = p.parse_args()
116
 
117
- token = args.token or __import__("os").environ.get("HF_TOKEN")
118
  if not token:
119
- print("Set HF_TOKEN or pass --token", file=sys.stderr)
 
 
 
120
  sys.exit(1)
121
 
122
  # quick sanity load of checkpoint format
 
1
  """
2
+ Upload a trained checkpoint and id→label JSONs to a Hugging Face model repo (for Spaces / demos).
 
3
 
4
+ Usage (repo root):
5
 
6
+ python scripts/upload_model_to_hf.py --repo-id USER/reponame --checkpoint PATH/TO/best.pt
7
+
8
+ `HF_TOKEN`: repo `.env` (python-dotenv) wins over an existing shell `HF_TOKEN` when the key appears in `.env`
9
+ (`load_dotenv(..., override=True)`). Use a token with **write** access to the model repo. Local labels: `data/label_maps/`.
10
  """
11
 
12
  import argparse
13
  import json
14
+ import os
15
  import sys
16
  from pathlib import Path
17
 
18
  import torch
19
 
20
  ROOT = Path(__file__).resolve().parent.parent
21
+
22
+
23
+ def _load_dotenv_from_repo() -> None:
24
+ """Load repo `.env` into os.environ. Keys in `.env` override the same keys already in the environment
25
+ (fixes stale HF_TOKEN from the shell or IDE masking a valid token in `.env`)."""
26
+ env_path = ROOT / ".env"
27
+ if not env_path.is_file():
28
+ return
29
+ try:
30
+ from dotenv import load_dotenv
31
+ except ImportError:
32
+ return
33
+ load_dotenv(env_path, override=True)
34
+
35
 
36
  DATA_DIR = ROOT / "data"
37
  INDEX_SELECTED = DATA_DIR / "wikiart_index_selected.csv"
38
+ LABEL_EXPORT_DEFAULT = DATA_DIR / "label_maps"
39
 
40
 
41
  def build_id2label_from_selected_index(index_path: Path) -> tuple[dict[str, str], dict[str, str], dict[str, str]]:
 
114
 
115
 
116
  def main() -> None:
117
+ _load_dotenv_from_repo()
118
+
119
+ p = argparse.ArgumentParser(
120
+ description="Upload model checkpoint + id2label JSONs to Hugging Face Hub. "
121
+ "Loads repo-root .env by default (HF_TOKEN) when python-dotenv is installed."
122
+ )
123
  p.add_argument("--repo-id", required=True, help="Model repo id, e.g. username/arty-cnn-baseline")
124
  p.add_argument(
125
  "--checkpoint",
126
  type=Path,
127
+ required=True,
128
+ help="Checkpoint file to upload (e.g. checkpoints/cnn_baseline/best.pt or checkpoints/cnnrnn/best.pt)",
129
  )
 
130
  p.add_argument("--index", type=Path, default=INDEX_SELECTED, help="Selected index CSV (default: data/wikiart_index_selected.csv)")
131
+ p.add_argument(
132
+ "--export-labels-dir",
133
+ type=Path,
134
+ default=LABEL_EXPORT_DEFAULT,
135
+ help=f"Write *_id2label.json here (default: {LABEL_EXPORT_DEFAULT})",
136
+ )
137
  args = p.parse_args()
138
 
139
+ token = os.environ.get("HF_TOKEN")
140
  if not token:
141
+ print(
142
+ "Missing token: add HF_TOKEN to repo-root .env",
143
+ file=sys.stderr,
144
+ )
145
  sys.exit(1)
146
 
147
  # quick sanity load of checkpoint format
src/model.py CHANGED
@@ -1,10 +1,72 @@
1
  """ResNet-50 backbone variants for multi-task classification."""
 
 
 
 
2
  import torch
3
  import torch.nn as nn
 
4
  from torchvision.models import ResNet50_Weights, resnet50
5
 
6
 
7
- class ResNet50ThreeHeads(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  """ResNet-50 (ImageNet pretrained), GAP, then three linear heads: genre, style, artist."""
9
 
10
  def __init__(
@@ -58,7 +120,7 @@ class ResNet50ThreeHeads(nn.Module):
58
  )
59
 
60
 
61
- class ResNet50BiLSTMThreeHeads(nn.Module):
62
  """
63
  ResNet-50 (ImageNet pretrained) feature map -> column pooling -> BiLSTM -> mean pool -> three heads.
64
 
 
1
  """ResNet-50 backbone variants for multi-task classification."""
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+
6
  import torch
7
  import torch.nn as nn
8
+ import torch.nn.functional as F
9
  from torchvision.models import ResNet50_Weights, resnet50
10
 
11
 
12
+ def _topk_from_logits(
13
+ logits: torch.Tensor, id2label: dict[int, str], k: int
14
+ ) -> list[tuple[str, float]]:
15
+ """Map top-k softmax probabilities to label names (batch size 1)."""
16
+ n = logits.size(-1)
17
+ k = min(k, n)
18
+ probs = F.softmax(logits, dim=-1)[0]
19
+ top = probs.topk(k)
20
+ return [(id2label[int(i)], float(p)) for i, p in zip(top.indices.tolist(), top.values.tolist())]
21
+
22
+
23
+ class _ThreeHeadPredictMixin:
24
+ """Top-k decoding for genre / style / artist heads (shared by CNN and CNN–RNN)."""
25
+
26
+ def predict_topk(
27
+ self,
28
+ x: torch.Tensor,
29
+ *,
30
+ genre_id2label: dict[int, str],
31
+ style_id2label: dict[int, str],
32
+ artist_id2label: dict[int, str],
33
+ k: int = 3,
34
+ ) -> tuple[list[tuple[str, float]], list[tuple[str, float]], list[tuple[str, float]]]:
35
+ self.eval()
36
+ with torch.no_grad():
37
+ lg, ls, la = self(x)
38
+ return (
39
+ _topk_from_logits(lg, genre_id2label, k),
40
+ _topk_from_logits(ls, style_id2label, k),
41
+ _topk_from_logits(la, artist_id2label, k),
42
+ )
43
+
44
+ def predict_topk_from_path(
45
+ self,
46
+ path: Path | str,
47
+ transform: torch.nn.Module,
48
+ device: torch.device,
49
+ *,
50
+ genre_id2label: dict[int, str],
51
+ style_id2label: dict[int, str],
52
+ artist_id2label: dict[int, str],
53
+ k: int = 3,
54
+ ) -> tuple[list[tuple[str, float]], list[tuple[str, float]], list[tuple[str, float]]]:
55
+ from PIL import Image
56
+
57
+ p = Path(path)
58
+ img = Image.open(p).convert("RGB")
59
+ x = transform(img).unsqueeze(0).to(device)
60
+ return self.predict_topk(
61
+ x,
62
+ genre_id2label=genre_id2label,
63
+ style_id2label=style_id2label,
64
+ artist_id2label=artist_id2label,
65
+ k=k,
66
+ )
67
+
68
+
69
+ class ResNet50ThreeHeads(_ThreeHeadPredictMixin, nn.Module):
70
  """ResNet-50 (ImageNet pretrained), GAP, then three linear heads: genre, style, artist."""
71
 
72
  def __init__(
 
120
  )
121
 
122
 
123
+ class ResNet50BiLSTMThreeHeads(_ThreeHeadPredictMixin, nn.Module):
124
  """
125
  ResNet-50 (ImageNet pretrained) feature map -> column pooling -> BiLSTM -> mean pool -> three heads.
126
 
src/predict_format.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """Format `model.predict_topk` tuples for Gradio summaries and JSON."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Any
5
+
6
+
7
+ def topk_tuples_to_ui_items(rows: list[tuple[str, float]]) -> list[dict[str, Any]]:
8
+ """Match legacy Gradio `_topk` shape: `label` + `prob` rounded to 4 decimals."""
9
+ return [{"label": label, "prob": round(float(prob), 4)} for label, prob in rows]
tests/test_model_architectures.py CHANGED
@@ -2,6 +2,8 @@ import sys
2
  from pathlib import Path
3
 
4
  import torch
 
 
5
 
6
 
7
  def test_resnet50_three_heads_forward_shapes_no_weights() -> None:
@@ -17,6 +19,52 @@ def test_resnet50_three_heads_forward_shapes_no_weights() -> None:
17
  assert a.shape == (2, 23)
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def test_resnet50_bilstm_three_heads_forward_shapes_no_weights() -> None:
21
  root = Path(__file__).resolve().parent.parent
22
  sys.path.insert(0, str(root / "src"))
 
2
  from pathlib import Path
3
 
4
  import torch
5
+ from PIL import Image
6
+ from torchvision import transforms as T
7
 
8
 
9
  def test_resnet50_three_heads_forward_shapes_no_weights() -> None:
 
19
  assert a.shape == (2, 23)
20
 
21
 
22
+ def test_resnet50_three_heads_predict_topk() -> None:
23
+ root = Path(__file__).resolve().parent.parent
24
+ sys.path.insert(0, str(root / "src"))
25
+ from model import ResNet50ThreeHeads
26
+
27
+ model = ResNet50ThreeHeads(n_genre=10, n_style=27, n_artist=23, weights=None)
28
+ x = torch.randn(1, 3, 224, 224)
29
+ gmap = {i: f"g{i}" for i in range(10)}
30
+ smap = {i: f"s{i}" for i in range(27)}
31
+ amap = {i: f"a{i}" for i in range(23)}
32
+ g, s, a = model.predict_topk(
33
+ x,
34
+ genre_id2label=gmap,
35
+ style_id2label=smap,
36
+ artist_id2label=amap,
37
+ k=3,
38
+ )
39
+ assert len(g) == len(s) == len(a) == 3
40
+ assert all(isinstance(name, str) and 0.0 <= p <= 1.0 for name, p in g + s + a)
41
+
42
+
43
+ def test_resnet50_three_heads_predict_topk_from_path(tmp_path: Path) -> None:
44
+ root = Path(__file__).resolve().parent.parent
45
+ sys.path.insert(0, str(root / "src"))
46
+ from model import ResNet50ThreeHeads
47
+
48
+ img_path = tmp_path / "x.jpg"
49
+ Image.new("RGB", (256, 256), color=(120, 80, 40)).save(img_path, format="JPEG")
50
+
51
+ model = ResNet50ThreeHeads(n_genre=10, n_style=27, n_artist=23, weights=None)
52
+ device = torch.device("cpu")
53
+ transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])
54
+ gmap = {i: f"g{i}" for i in range(10)}
55
+ smap = {i: f"s{i}" for i in range(27)}
56
+ amap = {i: f"a{i}" for i in range(23)}
57
+ g, s, a = model.predict_topk_from_path(
58
+ img_path,
59
+ transform,
60
+ device,
61
+ genre_id2label=gmap,
62
+ style_id2label=smap,
63
+ artist_id2label=amap,
64
+ )
65
+ assert len(g) == len(s) == len(a) == 3
66
+
67
+
68
  def test_resnet50_bilstm_three_heads_forward_shapes_no_weights() -> None:
69
  root = Path(__file__).resolve().parent.parent
70
  sys.path.insert(0, str(root / "src"))
tests/test_predict_format.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import sys
3
+
4
+ ROOT = Path(__file__).resolve().parent.parent
5
+
6
+
7
+ def test_topk_tuples_to_ui_items_rounding() -> None:
8
+ sys.path.insert(0, str(ROOT / "src"))
9
+ from predict_format import topk_tuples_to_ui_items
10
+
11
+ out = topk_tuples_to_ui_items([("a", 0.123456789), ("b", 0.5)])
12
+ assert out == [{"label": "a", "prob": 0.1235}, {"label": "b", "prob": 0.5}]
tests/test_spot_check_excluded_post_impressionism.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for scripts/spot_check_excluded_post_impressionism.py helpers."""
2
+ from __future__ import annotations
3
+
4
+ import importlib.util
5
+ import json
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ import torch
10
+
11
+ ROOT = Path(__file__).resolve().parent.parent
12
+ SCRIPT = ROOT / "scripts" / "spot_check_excluded_post_impressionism.py"
13
+
14
+
15
+ def _load_module():
16
+ spec = importlib.util.spec_from_file_location("spot_check_excluded_post_impressionism", SCRIPT)
17
+ mod = importlib.util.module_from_spec(spec)
18
+ assert spec.loader is not None
19
+ sys.modules["spot_check_excluded_post_impressionism"] = mod
20
+ spec.loader.exec_module(mod)
21
+ return mod
22
+
23
+
24
+ def test_load_id2label_roundtrip(tmp_path: Path) -> None:
25
+ mod = _load_module()
26
+ p = tmp_path / "m.json"
27
+ p.write_text(json.dumps({"0": "foo", "1": "bar"}), encoding="utf-8")
28
+ assert mod.load_id2label(p) == {0: "foo", 1: "bar"}
29
+
30
+
31
+ def test_default_rel_paths_count() -> None:
32
+ mod = _load_module()
33
+ assert len(mod.DEFAULT_REL_PATHS) == 5
34
+
35
+
36
+ def test_resolve_device_force_cpu() -> None:
37
+ mod = _load_module()
38
+ assert mod.resolve_device(force_cpu=True).type == "cpu"
39
+
40
+
41
+ def test_load_label_maps_reads_repo_json() -> None:
42
+ mod = _load_module()
43
+ if not (mod.LABEL_MAPS_DIR / "genre_id2label.json").exists():
44
+ return
45
+ g, s, a = mod.load_label_maps()
46
+ assert "Post_Impressionism" in s.values()
47
+ assert len(g) >= 1 and len(a) >= 1
tests/test_train_cnn_atomic_save.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for scripts/train_cnn.py checkpoint atomic save helper."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import importlib.util
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ import torch
10
+
11
+ ROOT = Path(__file__).resolve().parent.parent
12
+ SCRIPT = ROOT / "scripts" / "train_cnn.py"
13
+
14
+
15
+ def _load_train_cnn():
16
+ spec = importlib.util.spec_from_file_location("train_cnn", SCRIPT)
17
+ mod = importlib.util.module_from_spec(spec)
18
+ assert spec.loader is not None
19
+ sys.modules["train_cnn"] = mod
20
+ spec.loader.exec_module(mod)
21
+ return mod
22
+
23
+
24
+ def test_atomic_torch_save_roundtrip(tmp_path: Path) -> None:
25
+ mod = _load_train_cnn()
26
+ path = tmp_path / "ckpt.pt"
27
+ mod._atomic_torch_save({"k": 42, "t": torch.zeros(2)}, path)
28
+ assert path.is_file()
29
+ assert not path.with_suffix(".pt.tmp").exists()
30
+
31
+ data = torch.load(path, map_location="cpu", weights_only=False)
32
+ assert data["k"] == 42
33
+ assert list(data["t"].shape) == [2]
tests/test_upload_model_to_hf.py CHANGED
@@ -1,6 +1,7 @@
1
  """Tests for scripts/upload_model_to_hf.py"""
2
 
3
  import importlib.util
 
4
  import sys
5
  from pathlib import Path
6
  from unittest.mock import MagicMock, patch
@@ -16,6 +17,24 @@ sys.modules["upload_model_to_hf"] = mod
16
  spec.loader.exec_module(mod)
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def test_build_id2label_from_selected_index(tmp_path: Path) -> None:
20
  index = tmp_path / "index.csv"
21
  pd.DataFrame(
 
1
  """Tests for scripts/upload_model_to_hf.py"""
2
 
3
  import importlib.util
4
+ import os
5
  import sys
6
  from pathlib import Path
7
  from unittest.mock import MagicMock, patch
 
17
  spec.loader.exec_module(mod)
18
 
19
 
20
+ def test_load_dotenv_from_repo_sets_hf_token(tmp_path: Path, monkeypatch) -> None:
21
+ monkeypatch.delenv("HF_TOKEN", raising=False)
22
+ (tmp_path / ".env").write_text("HF_TOKEN=fake_from_dotenv\n", encoding="utf-8")
23
+ with patch.object(mod, "ROOT", tmp_path):
24
+ mod._load_dotenv_from_repo()
25
+
26
+ assert os.environ.get("HF_TOKEN") == "fake_from_dotenv"
27
+
28
+
29
+ def test_load_dotenv_from_repo_overrides_stale_hf_token(tmp_path: Path, monkeypatch) -> None:
30
+ monkeypatch.setenv("HF_TOKEN", "stale_wrong_token")
31
+ (tmp_path / ".env").write_text("HF_TOKEN=good_from_dotenv\n", encoding="utf-8")
32
+ with patch.object(mod, "ROOT", tmp_path):
33
+ mod._load_dotenv_from_repo()
34
+
35
+ assert os.environ.get("HF_TOKEN") == "good_from_dotenv"
36
+
37
+
38
  def test_build_id2label_from_selected_index(tmp_path: Path) -> None:
39
  index = tmp_path / "index.csv"
40
  pd.DataFrame(