Inference and Hub UX: shared predict_topk, atomic checkpoints, upload .env
Browse files- Add _ThreeHeadPredictMixin (predict_topk / predict_topk_from_path) on both
ResNet heads; Gradio uses it with predict_format top-k JSON and config
ImageNet normalize; device selection includes MPS.
- Add spot_check_excluded_post_impressionism.py (train transforms, --top-k)
and README spot-check table with GAP/Lin+He refs and BiLSTM wording.
- train_cnn: atomic torch.save via temp+replace; weights_only=False on loads;
clear error when last.pt is corrupt (suggest cp from best.pt).
- upload_model_to_hf: load .env with override=True; required --checkpoint;
default export to data/label_maps; drop --token; python-dotenv in
requirements; Dockerfile omits redundant --export-labels-dir.
- Tests for predict_topk, spot_check helpers, atomic save, dotenv override.
Made-with: Cursor
- Dockerfile +1 -1
- README.md +50 -2
- gradio/app.py +19 -14
- requirements.txt +1 -0
- scripts/spot_check_excluded_post_impressionism.py +137 -0
- scripts/train_cnn.py +25 -5
- scripts/upload_model_to_hf.py +41 -16
- src/model.py +64 -2
- src/predict_format.py +9 -0
- tests/test_model_architectures.py +48 -0
- tests/test_predict_format.py +12 -0
- tests/test_spot_check_excluded_post_impressionism.py +47 -0
- tests/test_train_cnn_atomic_save.py +33 -0
- tests/test_upload_model_to_hf.py +19 -0
|
@@ -39,7 +39,7 @@ else \
|
|
| 39 |
echo \"[train] No last.pt found; training from scratch for ${EPOCHS} epochs\"; \
|
| 40 |
python scripts/train_cnn_safe.py --arch \"$ARCH\" --epochs \"$EPOCHS\" --batch-size-primary \"$BATCH_SIZE_PRIMARY\" --batch-size-fallback \"$BATCH_SIZE_FALLBACK\"; \
|
| 41 |
fi; \
|
| 42 |
-
python scripts/upload_model_to_hf.py --repo-id \"$MODEL_REPO_ID\" --checkpoint \"checkpoints/$ARCH/best.pt\"
|
| 43 |
kill $SERVER_PID >/dev/null 2>&1 || true \
|
| 44 |
"]
|
| 45 |
|
|
|
|
| 39 |
echo \"[train] No last.pt found; training from scratch for ${EPOCHS} epochs\"; \
|
| 40 |
python scripts/train_cnn_safe.py --arch \"$ARCH\" --epochs \"$EPOCHS\" --batch-size-primary \"$BATCH_SIZE_PRIMARY\" --batch-size-fallback \"$BATCH_SIZE_FALLBACK\"; \
|
| 41 |
fi; \
|
| 42 |
+
python scripts/upload_model_to_hf.py --repo-id \"$MODEL_REPO_ID\" --checkpoint \"checkpoints/$ARCH/best.pt\"; \
|
| 43 |
kill $SERVER_PID >/dev/null 2>&1 || true \
|
| 44 |
"]
|
| 45 |
|
|
@@ -12,14 +12,62 @@ pinned: false
|
|
| 12 |
|
| 13 |
# Arty
|
| 14 |
|
| 15 |
-
WikiArt **genre
|
| 16 |
|
| 17 |
**Training** (Docker GPU job) lives in [`Dockerfile`](Dockerfile) — use a **separate** Space pointed at the same repo if you only want training, or run locally. Do not set `sdk: docker` on this Space if you want the Gradio UI.
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
## Env (optional)
|
| 20 |
|
| 21 |
- `BASELINE_MODEL_REPO_ID` — default `pdjota/cnn-baseline`
|
| 22 |
- `CNNRNN_MODEL_REPO_ID` — default `pdjota/arty-cnn-rnn`
|
| 23 |
- `HF_TOKEN` — if model repos are gated
|
| 24 |
|
| 25 |
-
More detail: [`gradio/README.md`](gradio/README.md), [`docs/monorepo_gradio_space.md`](docs/monorepo_gradio_space.md).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
# Arty
|
| 14 |
|
| 15 |
+
**Arty** is a multi-task WikiArt classifier: **genre**, **style**, and **artist** in one model with two architectures — a **CNN baseline** (ResNet-50 + global pooling + three heads) and a **CNN–RNN** (same backbone, **bidirectional long short-term memory (BiLSTM)** over spatial features + three heads). This Hugging Face **Space** runs the **Gradio** app in [`gradio/app.py`](gradio/app.py); **weights** load from Hub model repos and **architecture** from [`src/model.py`](src/model.py). Each architecuture corresponds to a model: [pdjota/cnn-baseline](https://huggingface.co/pdjota/cnn-baseline) or [pdjota/arty-cnn-rnn](https://huggingface.co/pdjota/arty-cnn-rnn)
|
| 16 |
|
| 17 |
**Training** (Docker GPU job) lives in [`Dockerfile`](Dockerfile) — use a **separate** Space pointed at the same repo if you only want training, or run locally. Do not set `sdk: docker` on this Space if you want the Gradio UI.
|
| 18 |
|
| 19 |
+
## About this project
|
| 20 |
+
|
| 21 |
+
### Classification with three labels together
|
| 22 |
+
|
| 23 |
+
**Genre** and **style** are relatively **generic**: many paintings share the same movement or subject category, and the model learns broad visual patterns that match those labels. **Artist** is **more specific** — we usually think of *who* painted it **within** a stylistic movement. Distinguishing painters who share a movement like Sisley and Monet in the impressionist movement or identifying Picasso who create a different set of paitings in symbolism or cubism become challenging. The network must pick up **fine-grained** cues (palette, brushwork, recurring motifs) that sit on top of the same broad visual cues that support style and genre.
|
| 24 |
+
|
| 25 |
+
We use a **shared convolutional trunk** (one set of image features) and **three separate heads** (genre, style, artist). The trunk carries **generic** painting features; the heads split **coarse** (genre, style) vs **fine** (artist) decisions so the artist task can specialize without forcing a single output to encode everything at once.
|
| 26 |
+
|
| 27 |
+
### ResNet-50 and short fine-tuning
|
| 28 |
+
|
| 29 |
+
**ResNet-50** is a deep convolutional network built from **residual (skip) connections** so very deep stacks train without vanishing gradients ([He et al., 2016](https://arxiv.org/abs/1512.03385); trained on **ImageNet** in that work). **Transfer learning** is standard: features from early conv layers tend to transfer across related visual domains better than random initialization ([Yosinski et al., 2014](https://arxiv.org/abs/1411.1792)). For **paintings**, [Zhao et al. (2021)](https://doi.org/10.1371/journal.pone.0248414) compare the same model families **with and without** ImageNet-based transfer on WikiArt (genre / style / artist) and report strong results with pretraining — so we **fine-tune** the backbone and heads for a **limited number of epochs** instead of training from scratch at ImageNet-scale cost.
|
| 30 |
+
We use an **ArtGAN-aligned** WikiArt-style index: images catalogued with consistent **genre, style, and
|
| 31 |
+
artist** labels; a few broken paths are excluded. A curated dataset is on the Hub (e.g. [`pdjota/artyset`]
|
| 32 |
+
(https://huggingface.co/datasets/pdjota/artyset)) for reproducible training.
|
| 33 |
+
|
| 34 |
+
The resulting classification is already good for some examples. **Spot check (CNN baseline `best.pt` with [`scripts/spot_check_excluded_post_impressionism.py`](scripts/spot_check_excluded_post_impressionism.py)):** five **Post_Impressionism** images under `data/wikiart_excluded/Post_Impressionism/` (not in the training index) — style top-1 and a short note:
|
| 35 |
+
|
| 36 |
+
| Painting | Style (top-1) | Comment |
|
| 37 |
+
| -------- | ------------- | ------- |
|
| 38 |
+
| `henri-matisse_a-vase-with-oranges.jpg` | Post_Impressionism (~80%) | Still life; confident style match. |
|
| 39 |
+
| `henri-de-toulouse-lautrec_portrait-of-vincent-van-gogh-1887.jpg` | Impressionism (~99%) | Sketchy handling reads as Impressionist; Post_Impressionism far behind. |
|
| 40 |
+
| `pablo-picasso_seated-monkey-1905.jpg` | Post_Impressionism (~42%) | Close with Expressionism; **artist** top-1 Picasso (~93%). |
|
| 41 |
+
| `paul-gauguin_a-seashore-1887.jpg` | Impressionism (~75%) | Post_Impressionism second (~16%). |
|
| 42 |
+
| `a.y.-jackson_the-edge-of-the-maple-wood-1910.jpg` | Impressionism (~94%) | Landscape; artist head has no A.Y. Jackson class (23 ArtGAN artists). |
|
| 43 |
+
|
| 44 |
+
### Bidirectional long short-term memory (BiLSTM) on top of the CNN
|
| 45 |
+
|
| 46 |
+
Zhao et al. note that their setup uses **colour** information heavily and that **spatial** information could still improve classification. Standard **global average pooling (GAP)** after the last conv map **throws away layout**: it averages each channel over all spatial positions, so the classifier sees a **single vector per channel** with **no remaining (x, y) structure** ([Lin et al., 2014](https://arxiv.org/abs/1312.4400); ResNet-50 uses this pattern before its **fully connected (FC)** layer [He et al., 2016](https://arxiv.org/abs/1512.03385)). That answers “what is present” but not “how it is arranged.” We keep the same ResNet backbone, then **turn the spatial grid into a sequence** (e.g. column-wise strips), run a **bidirectional long short-term memory (BiLSTM)**, then classify. The **reasoning** is: composition, figure–ground balance, and brushstroke patterns often have **left–right (or strip-wise) structure**; a sequence model can integrate **context** along that axis **bidirectionally**, which GAP does not model. The CNN–RNN is a **minimal, comparable** upgrade: same heads and training loop, different pooling.
|
| 47 |
+
|
| 48 |
+
### Data: ArtGAN-aligned index and Hugging Face
|
| 49 |
+
|
| 50 |
+
We align with the **ArtGAN / WikiArt** lineage so labels are **catalogue-consistent** for genre, style, and artist. We **trim** the index (drop broken/missing files, validate paths) and publish a curated dataset on the Hub (e.g. [`pdjota/artyset`](https://huggingface.co/datasets/pdjota/artyset)) so **training and demos are reproducible**.
|
| 51 |
+
|
| 52 |
+
`scripts/train_cnn.py` uses **70% / 15% / 15%** train / val / test, **stratified by `artist_id`**. **Reasoning:** if we split randomly by image, we might put **almost all works of a rare artist** in one fold; **artist** is also the label that would most easily “leak” structurally (same brushwork in train vs test). Stratifying by artist keeps **each split’s artist mix** more representative, so validation/test **accuracy and loss** are comparable across runs and less dominated by **which artists** landed in which fold.
|
| 53 |
+
|
| 54 |
+
### Training artifacts and this Space
|
| 55 |
+
|
| 56 |
+
Runs save **PyTorch** checkpoints (`best.pt`, `last.pt`), **CSV** logs (`train_log.csv`, `results_summary.csv`), and we upload reference models with **`id2label` JSON** to Hub model repos. Training works on **CPU**, **Apple Silicon (MPS)**, or a **GPU Space**.
|
| 57 |
+
|
| 58 |
+
**Gradio (this Space):** upload a painting and compare **CNN baseline** vs **CNN–RNN** top-k predictions. Env vars below select which Hub checkpoints to load.
|
| 59 |
+
|
| 60 |
## Env (optional)
|
| 61 |
|
| 62 |
- `BASELINE_MODEL_REPO_ID` — default `pdjota/cnn-baseline`
|
| 63 |
- `CNNRNN_MODEL_REPO_ID` — default `pdjota/arty-cnn-rnn`
|
| 64 |
- `HF_TOKEN` — if model repos are gated
|
| 65 |
|
| 66 |
+
More detail: [`gradio/README.md`](gradio/README.md), [`docs/monorepo_gradio_space.md`](docs/monorepo_gradio_space.md), research plan [`plan.md`](plan.md).
|
| 67 |
+
|
| 68 |
+
### References (ResNet / transfer / WikiArt)
|
| 69 |
+
|
| 70 |
+
1. He, K., Zhang, X., Ren, S., & Sun, J. (2016). *Deep residual learning for image recognition.* CVPR. [arXiv:1512.03385](https://arxiv.org/abs/1512.03385)
|
| 71 |
+
2. Yosinski, J., Clune, J., Bengio, Y., & Lipson, H. (2014). *How transferable are features in deep neural networks?* NeurIPS. [arXiv:1411.1792](https://arxiv.org/abs/1411.1792)
|
| 72 |
+
3. Zhao, W., Zhou, D., Qiu, X., & Jiang, W. (2021). *Compare the performance of the models in art classification.* PLOS ONE 16(3): e0248414. [DOI:10.1371/journal.pone.0248414](https://doi.org/10.1371/journal.pone.0248414)
|
| 73 |
+
4. Lin, M., Chen, Q., & Yan, S. (2014). *Network in network.* ICLR. [arXiv:1312.4400](https://arxiv.org/abs/1312.4400) — global average pooling to aggregate conv feature maps before classification.
|
|
@@ -18,7 +18,6 @@ from typing import Any, Dict, List, Optional, Tuple
|
|
| 18 |
|
| 19 |
import gradio as gr
|
| 20 |
import torch
|
| 21 |
-
import torch.nn.functional as F
|
| 22 |
from huggingface_hub import hf_hub_download
|
| 23 |
from PIL import Image
|
| 24 |
from torchvision import transforms as T
|
|
@@ -32,12 +31,19 @@ if not _SRC.exists():
|
|
| 32 |
)
|
| 33 |
sys.path.insert(0, str(REPO_ROOT / "src"))
|
| 34 |
|
|
|
|
| 35 |
from model import ResNet50BiLSTMThreeHeads # type: ignore
|
| 36 |
from model import ResNet50ThreeHeads # type: ignore
|
|
|
|
| 37 |
|
| 38 |
BASELINE_REPO = os.environ.get("BASELINE_MODEL_REPO_ID", "pdjota/cnn-baseline")
|
| 39 |
CNNRNN_REPO = os.environ.get("CNNRNN_MODEL_REPO_ID", "pdjota/arty-cnn-rnn")
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 42 |
|
| 43 |
transform = T.Compose(
|
|
@@ -45,7 +51,7 @@ transform = T.Compose(
|
|
| 45 |
T.Resize(256),
|
| 46 |
T.CenterCrop(224),
|
| 47 |
T.ToTensor(),
|
| 48 |
-
T.Normalize(mean=
|
| 49 |
]
|
| 50 |
)
|
| 51 |
|
|
@@ -109,11 +115,6 @@ def _load(repo_id: str) -> Dict[str, Any]:
|
|
| 109 |
|
| 110 |
# --- prediction helpers ----------------------------------------------------
|
| 111 |
|
| 112 |
-
def _topk(logits: torch.Tensor, id2label: Dict[int, str], k: int = 3) -> List[Dict[str, Any]]:
|
| 113 |
-
probs = F.softmax(logits, dim=-1)[0]
|
| 114 |
-
vals, idxs = probs.topk(k)
|
| 115 |
-
return [{"label": id2label.get(int(i), str(int(i))), "prob": round(float(v), 4)} for v, i in zip(vals, idxs)]
|
| 116 |
-
|
| 117 |
|
| 118 |
def _bucket(pct: float) -> str:
|
| 119 |
if pct >= 80:
|
|
@@ -151,12 +152,16 @@ def predict(model_choice: str, image: Optional[Image.Image]) -> Tuple[str, str]:
|
|
| 151 |
model = assets["model"]
|
| 152 |
|
| 153 |
x = transform(image).unsqueeze(0).to(DEVICE)
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
summary = "\n".join([
|
| 162 |
f"**Genre**: {_summarize(g3)}",
|
|
|
|
| 18 |
|
| 19 |
import gradio as gr
|
| 20 |
import torch
|
|
|
|
| 21 |
from huggingface_hub import hf_hub_download
|
| 22 |
from PIL import Image
|
| 23 |
from torchvision import transforms as T
|
|
|
|
| 31 |
)
|
| 32 |
sys.path.insert(0, str(REPO_ROOT / "src"))
|
| 33 |
|
| 34 |
+
from config import IMAGENET_MEAN, IMAGENET_STD # type: ignore
|
| 35 |
from model import ResNet50BiLSTMThreeHeads # type: ignore
|
| 36 |
from model import ResNet50ThreeHeads # type: ignore
|
| 37 |
+
from predict_format import topk_tuples_to_ui_items # type: ignore
|
| 38 |
|
| 39 |
BASELINE_REPO = os.environ.get("BASELINE_MODEL_REPO_ID", "pdjota/cnn-baseline")
|
| 40 |
CNNRNN_REPO = os.environ.get("CNNRNN_MODEL_REPO_ID", "pdjota/arty-cnn-rnn")
|
| 41 |
+
if torch.cuda.is_available():
|
| 42 |
+
DEVICE = torch.device("cuda")
|
| 43 |
+
elif getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
|
| 44 |
+
DEVICE = torch.device("mps")
|
| 45 |
+
else:
|
| 46 |
+
DEVICE = torch.device("cpu")
|
| 47 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 48 |
|
| 49 |
transform = T.Compose(
|
|
|
|
| 51 |
T.Resize(256),
|
| 52 |
T.CenterCrop(224),
|
| 53 |
T.ToTensor(),
|
| 54 |
+
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
| 55 |
]
|
| 56 |
)
|
| 57 |
|
|
|
|
| 115 |
|
| 116 |
# --- prediction helpers ----------------------------------------------------
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
def _bucket(pct: float) -> str:
|
| 120 |
if pct >= 80:
|
|
|
|
| 152 |
model = assets["model"]
|
| 153 |
|
| 154 |
x = transform(image).unsqueeze(0).to(DEVICE)
|
| 155 |
+
g_t, s_t, a_t = model.predict_topk(
|
| 156 |
+
x,
|
| 157 |
+
genre_id2label=assets["genre"],
|
| 158 |
+
style_id2label=assets["style"],
|
| 159 |
+
artist_id2label=assets["artist"],
|
| 160 |
+
k=3,
|
| 161 |
+
)
|
| 162 |
+
g3 = topk_tuples_to_ui_items(g_t)
|
| 163 |
+
s3 = topk_tuples_to_ui_items(s_t)
|
| 164 |
+
a3 = topk_tuples_to_ui_items(a_t)
|
| 165 |
|
| 166 |
summary = "\n".join([
|
| 167 |
f"**Genre**: {_summarize(g3)}",
|
|
@@ -9,6 +9,7 @@ scikit-learn>=1.2
|
|
| 9 |
matplotlib>=3.7
|
| 10 |
tqdm>=4.65
|
| 11 |
huggingface_hub>=0.25.0,<1.0 # Gradio 5.x needs HfFolder; removed in hub 1.0
|
|
|
|
| 12 |
pytest>=7.0 # tests
|
| 13 |
pytest-cov>=4.0 # coverage
|
| 14 |
|
|
|
|
| 9 |
matplotlib>=3.7
|
| 10 |
tqdm>=4.65
|
| 11 |
huggingface_hub>=0.25.0,<1.0 # Gradio 5.x needs HfFolder; removed in hub 1.0
|
| 12 |
+
python-dotenv>=1.0 # optional: HF_TOKEN from repo .env for upload scripts
|
| 13 |
pytest>=7.0 # tests
|
| 14 |
pytest-cov>=4.0 # coverage
|
| 15 |
|
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Spot-check the CNN baseline on fixed excluded Post_Impressionism images (README table).
|
| 3 |
+
|
| 4 |
+
Images live under data/wikiart_excluded/Post_Impressionism/ (not in the training index).
|
| 5 |
+
Uses the same eval transforms as training (`train_cnn.get_transforms(train=False)`).
|
| 6 |
+
|
| 7 |
+
Usage (from repo root):
|
| 8 |
+
|
| 9 |
+
python scripts/spot_check_excluded_post_impressionism.py
|
| 10 |
+
python scripts/spot_check_excluded_post_impressionism.py --cpu
|
| 11 |
+
python scripts/spot_check_excluded_post_impressionism.py --top-k 5
|
| 12 |
+
"""
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import importlib.util
|
| 17 |
+
import json
|
| 18 |
+
import sys
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
ROOT = Path(__file__).resolve().parent.parent
|
| 24 |
+
sys.path.insert(0, str(ROOT / "src"))
|
| 25 |
+
|
| 26 |
+
from config import checkpoint_dir_for_arch
|
| 27 |
+
from model import ResNet50ThreeHeads
|
| 28 |
+
|
| 29 |
+
DEFAULT_REL_PATHS: tuple[str, ...] = (
|
| 30 |
+
"henri-matisse_a-vase-with-oranges.jpg",
|
| 31 |
+
"henri-de-toulouse-lautrec_portrait-of-vincent-van-gogh-1887.jpg",
|
| 32 |
+
"pablo-picasso_seated-monkey-1905.jpg",
|
| 33 |
+
"paul-gauguin_a-seashore-1887.jpg",
|
| 34 |
+
"a.y.-jackson_the-edge-of-the-maple-wood-1910.jpg",
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
LABEL_MAPS_DIR = ROOT / "data" / "label_maps"
|
| 38 |
+
EXCLUDED_STYLE_DIR = ROOT / "data" / "wikiart_excluded" / "Post_Impressionism"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _load_train_cnn():
|
| 42 |
+
spec = importlib.util.spec_from_file_location("train_cnn", ROOT / "scripts" / "train_cnn.py")
|
| 43 |
+
mod = importlib.util.module_from_spec(spec)
|
| 44 |
+
assert spec.loader is not None
|
| 45 |
+
spec.loader.exec_module(mod)
|
| 46 |
+
return mod
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def load_id2label(path: Path) -> dict[int, str]:
|
| 50 |
+
with open(path, encoding="utf-8") as f:
|
| 51 |
+
return {int(k): v for k, v in json.load(f).items()}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def load_label_maps() -> tuple[dict[int, str], dict[int, str], dict[int, str]]:
|
| 55 |
+
return (
|
| 56 |
+
load_id2label(LABEL_MAPS_DIR / "genre_id2label.json"),
|
| 57 |
+
load_id2label(LABEL_MAPS_DIR / "style_id2label.json"),
|
| 58 |
+
load_id2label(LABEL_MAPS_DIR / "artist_id2label.json"),
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def resolve_device(*, force_cpu: bool) -> torch.device:
|
| 63 |
+
if force_cpu:
|
| 64 |
+
return torch.device("cpu")
|
| 65 |
+
if torch.cuda.is_available():
|
| 66 |
+
return torch.device("cuda")
|
| 67 |
+
if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
|
| 68 |
+
return torch.device("mps")
|
| 69 |
+
return torch.device("cpu")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def main() -> None:
|
| 73 |
+
p = argparse.ArgumentParser(description="CNN spot-check on excluded Post_Impressionism examples.")
|
| 74 |
+
p.add_argument("--cpu", action="store_true", help="Force CPU")
|
| 75 |
+
p.add_argument(
|
| 76 |
+
"--checkpoint",
|
| 77 |
+
type=Path,
|
| 78 |
+
default=None,
|
| 79 |
+
help="Path to best.pt (default: checkpoints/<cnn>/best.pt from config)",
|
| 80 |
+
)
|
| 81 |
+
p.add_argument("--top-k", type=int, default=3, metavar="K", help="Top-k per head (default: 3)")
|
| 82 |
+
args = p.parse_args()
|
| 83 |
+
|
| 84 |
+
if args.top_k < 1:
|
| 85 |
+
print("ERROR: --top-k must be >= 1", file=sys.stderr)
|
| 86 |
+
sys.exit(1)
|
| 87 |
+
|
| 88 |
+
device = resolve_device(force_cpu=args.cpu)
|
| 89 |
+
ckpt_path = args.checkpoint if args.checkpoint is not None else checkpoint_dir_for_arch("cnn") / "best.pt"
|
| 90 |
+
if not ckpt_path.exists():
|
| 91 |
+
print(f"ERROR: checkpoint not found: {ckpt_path}", file=sys.stderr)
|
| 92 |
+
sys.exit(1)
|
| 93 |
+
|
| 94 |
+
genre_map, style_map, artist_map = load_label_maps()
|
| 95 |
+
|
| 96 |
+
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
|
| 97 |
+
n_genre = ckpt["n_genre"]
|
| 98 |
+
n_style = ckpt["n_style"]
|
| 99 |
+
n_artist = ckpt["n_artist"]
|
| 100 |
+
model = ResNet50ThreeHeads(n_genre=n_genre, n_style=n_style, n_artist=n_artist, weights=None)
|
| 101 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 102 |
+
model.to(device)
|
| 103 |
+
|
| 104 |
+
train_cnn = _load_train_cnn()
|
| 105 |
+
transform = train_cnn.get_transforms(train=False)
|
| 106 |
+
|
| 107 |
+
paths = [EXCLUDED_STYLE_DIR / name for name in DEFAULT_REL_PATHS]
|
| 108 |
+
|
| 109 |
+
print(f"Checkpoint: {ckpt_path}")
|
| 110 |
+
print(f"Device: {device}")
|
| 111 |
+
print(f"Top-k: {args.top_k}")
|
| 112 |
+
post_ids = [k for k, v in style_map.items() if v == "Post_Impressionism"]
|
| 113 |
+
print(f"Post_Impressionism style id(s): {post_ids}")
|
| 114 |
+
print()
|
| 115 |
+
|
| 116 |
+
for path in paths:
|
| 117 |
+
if not path.exists():
|
| 118 |
+
print(f"MISSING: {path}")
|
| 119 |
+
continue
|
| 120 |
+
g, s, a = model.predict_topk_from_path(
|
| 121 |
+
path,
|
| 122 |
+
transform,
|
| 123 |
+
device,
|
| 124 |
+
genre_id2label=genre_map,
|
| 125 |
+
style_id2label=style_map,
|
| 126 |
+
artist_id2label=artist_map,
|
| 127 |
+
k=args.top_k,
|
| 128 |
+
)
|
| 129 |
+
print("=" * 72)
|
| 130 |
+
print(path.name)
|
| 131 |
+
print(" genre (top-%d):" % args.top_k, g)
|
| 132 |
+
print(" style (top-%d):" % args.top_k, s)
|
| 133 |
+
print(" artist (top-%d):" % args.top_k, a)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
if __name__ == "__main__":
|
| 137 |
+
main()
|
|
@@ -26,6 +26,15 @@ from torchvision import transforms as T
|
|
| 26 |
from sklearn.model_selection import train_test_split
|
| 27 |
|
| 28 |
ROOT = Path(__file__).resolve().parent.parent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
sys.path.insert(0, str(ROOT / "src"))
|
| 30 |
|
| 31 |
from config import (
|
|
@@ -134,7 +143,18 @@ def main() -> None:
|
|
| 134 |
resume_path = ckpt_dir / "last.pt"
|
| 135 |
|
| 136 |
if args.resume and resume_path.exists():
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
start_epoch = ckpt["epoch"] + 1
|
| 139 |
best_val_loss = ckpt.get("val_loss", float("inf"))
|
| 140 |
extra = args.epochs if args.epochs is not None else 10
|
|
@@ -275,9 +295,9 @@ def main() -> None:
|
|
| 275 |
"n_style": N_STYLE,
|
| 276 |
"n_artist": N_ARTIST,
|
| 277 |
}
|
| 278 |
-
|
| 279 |
if is_best:
|
| 280 |
-
|
| 281 |
|
| 282 |
log_row = {
|
| 283 |
"epoch": epoch,
|
|
@@ -315,7 +335,7 @@ def main() -> None:
|
|
| 315 |
"batch_in_epoch": current_batch_in_epoch,
|
| 316 |
"num_batches_in_epoch": current_num_batches_in_epoch,
|
| 317 |
}
|
| 318 |
-
|
| 319 |
print(
|
| 320 |
"\n"
|
| 321 |
f"[{now_ts()}] Stopped by user (Ctrl+C). Saved resumable checkpoint to "
|
|
@@ -327,7 +347,7 @@ def main() -> None:
|
|
| 327 |
# Save best-val results summary
|
| 328 |
best_ckpt_path = ckpt_dir / "best.pt"
|
| 329 |
if best_ckpt_path.exists():
|
| 330 |
-
best_ckpt = torch.load(best_ckpt_path, map_location="cpu")
|
| 331 |
summary = {
|
| 332 |
"best_epoch": best_ckpt.get("epoch"),
|
| 333 |
"val_loss": best_ckpt.get("val_loss"),
|
|
|
|
| 26 |
from sklearn.model_selection import train_test_split
|
| 27 |
|
| 28 |
ROOT = Path(__file__).resolve().parent.parent
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _atomic_torch_save(obj: object, path: Path) -> None:
|
| 32 |
+
"""Write `path` via a temp file + `os.replace` so a kill mid-write does not truncate `last.pt` / `best.pt`."""
|
| 33 |
+
path = Path(path)
|
| 34 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 35 |
+
tmp = path.with_suffix(path.suffix + ".tmp")
|
| 36 |
+
torch.save(obj, tmp)
|
| 37 |
+
os.replace(tmp, path)
|
| 38 |
sys.path.insert(0, str(ROOT / "src"))
|
| 39 |
|
| 40 |
from config import (
|
|
|
|
| 143 |
resume_path = ckpt_dir / "last.pt"
|
| 144 |
|
| 145 |
if args.resume and resume_path.exists():
|
| 146 |
+
try:
|
| 147 |
+
ckpt = torch.load(resume_path, map_location=device, weights_only=False)
|
| 148 |
+
except Exception as e:
|
| 149 |
+
print(
|
| 150 |
+
f"[{now_ts()}] ERROR: Cannot load {resume_path} (often a truncated file if the process was killed "
|
| 151 |
+
f"during `torch.save`).\n"
|
| 152 |
+
f" {e}\n"
|
| 153 |
+
f" Fix: if best.pt is intact, copy it over last.pt and resume, e.g.\n"
|
| 154 |
+
f" cp {ckpt_dir / 'best.pt'} {resume_path}",
|
| 155 |
+
file=sys.stderr,
|
| 156 |
+
)
|
| 157 |
+
sys.exit(1)
|
| 158 |
start_epoch = ckpt["epoch"] + 1
|
| 159 |
best_val_loss = ckpt.get("val_loss", float("inf"))
|
| 160 |
extra = args.epochs if args.epochs is not None else 10
|
|
|
|
| 295 |
"n_style": N_STYLE,
|
| 296 |
"n_artist": N_ARTIST,
|
| 297 |
}
|
| 298 |
+
_atomic_torch_save(ckpt, ckpt_dir / "last.pt")
|
| 299 |
if is_best:
|
| 300 |
+
_atomic_torch_save(ckpt, ckpt_dir / "best.pt")
|
| 301 |
|
| 302 |
log_row = {
|
| 303 |
"epoch": epoch,
|
|
|
|
| 335 |
"batch_in_epoch": current_batch_in_epoch,
|
| 336 |
"num_batches_in_epoch": current_num_batches_in_epoch,
|
| 337 |
}
|
| 338 |
+
_atomic_torch_save(interrupted_ckpt, ckpt_dir / "last.pt")
|
| 339 |
print(
|
| 340 |
"\n"
|
| 341 |
f"[{now_ts()}] Stopped by user (Ctrl+C). Saved resumable checkpoint to "
|
|
|
|
| 347 |
# Save best-val results summary
|
| 348 |
best_ckpt_path = ckpt_dir / "best.pt"
|
| 349 |
if best_ckpt_path.exists():
|
| 350 |
+
best_ckpt = torch.load(best_ckpt_path, map_location="cpu", weights_only=False)
|
| 351 |
summary = {
|
| 352 |
"best_epoch": best_ckpt.get("epoch"),
|
| 353 |
"val_loss": best_ckpt.get("val_loss"),
|
|
@@ -1,28 +1,41 @@
|
|
| 1 |
"""
|
| 2 |
-
Upload a trained checkpoint to
|
| 3 |
-
export id→label JSON files locally.
|
| 4 |
|
| 5 |
-
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
| 10 |
"""
|
| 11 |
|
| 12 |
import argparse
|
| 13 |
import json
|
|
|
|
| 14 |
import sys
|
| 15 |
from pathlib import Path
|
| 16 |
|
| 17 |
import torch
|
| 18 |
|
| 19 |
ROOT = Path(__file__).resolve().parent.parent
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
DATA_DIR = ROOT / "data"
|
| 24 |
INDEX_SELECTED = DATA_DIR / "wikiart_index_selected.csv"
|
| 25 |
-
|
| 26 |
|
| 27 |
|
| 28 |
def build_id2label_from_selected_index(index_path: Path) -> tuple[dict[str, str], dict[str, str], dict[str, str]]:
|
|
@@ -101,22 +114,34 @@ def upload_checkpoint_and_labels(
|
|
| 101 |
|
| 102 |
|
| 103 |
def main() -> None:
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
p.add_argument("--repo-id", required=True, help="Model repo id, e.g. username/arty-cnn-baseline")
|
| 106 |
p.add_argument(
|
| 107 |
"--checkpoint",
|
| 108 |
type=Path,
|
| 109 |
-
|
| 110 |
-
help=
|
| 111 |
)
|
| 112 |
-
p.add_argument("--token", default=None, help="HF token (default: HF_TOKEN env)")
|
| 113 |
p.add_argument("--index", type=Path, default=INDEX_SELECTED, help="Selected index CSV (default: data/wikiart_index_selected.csv)")
|
| 114 |
-
p.add_argument(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
args = p.parse_args()
|
| 116 |
|
| 117 |
-
token =
|
| 118 |
if not token:
|
| 119 |
-
print(
|
|
|
|
|
|
|
|
|
|
| 120 |
sys.exit(1)
|
| 121 |
|
| 122 |
# quick sanity load of checkpoint format
|
|
|
|
| 1 |
"""
|
| 2 |
+
Upload a trained checkpoint and id→label JSONs to a Hugging Face model repo (for Spaces / demos).
|
|
|
|
| 3 |
|
| 4 |
+
Usage (repo root):
|
| 5 |
|
| 6 |
+
python scripts/upload_model_to_hf.py --repo-id USER/reponame --checkpoint PATH/TO/best.pt
|
| 7 |
+
|
| 8 |
+
`HF_TOKEN`: repo `.env` (python-dotenv) wins over an existing shell `HF_TOKEN` when the key appears in `.env`
|
| 9 |
+
(`load_dotenv(..., override=True)`). Use a token with **write** access to the model repo. Local labels: `data/label_maps/`.
|
| 10 |
"""
|
| 11 |
|
| 12 |
import argparse
|
| 13 |
import json
|
| 14 |
+
import os
|
| 15 |
import sys
|
| 16 |
from pathlib import Path
|
| 17 |
|
| 18 |
import torch
|
| 19 |
|
| 20 |
ROOT = Path(__file__).resolve().parent.parent
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _load_dotenv_from_repo() -> None:
|
| 24 |
+
"""Load repo `.env` into os.environ. Keys in `.env` override the same keys already in the environment
|
| 25 |
+
(fixes stale HF_TOKEN from the shell or IDE masking a valid token in `.env`)."""
|
| 26 |
+
env_path = ROOT / ".env"
|
| 27 |
+
if not env_path.is_file():
|
| 28 |
+
return
|
| 29 |
+
try:
|
| 30 |
+
from dotenv import load_dotenv
|
| 31 |
+
except ImportError:
|
| 32 |
+
return
|
| 33 |
+
load_dotenv(env_path, override=True)
|
| 34 |
+
|
| 35 |
|
| 36 |
DATA_DIR = ROOT / "data"
|
| 37 |
INDEX_SELECTED = DATA_DIR / "wikiart_index_selected.csv"
|
| 38 |
+
LABEL_EXPORT_DEFAULT = DATA_DIR / "label_maps"
|
| 39 |
|
| 40 |
|
| 41 |
def build_id2label_from_selected_index(index_path: Path) -> tuple[dict[str, str], dict[str, str], dict[str, str]]:
|
|
|
|
| 114 |
|
| 115 |
|
| 116 |
def main() -> None:
|
| 117 |
+
_load_dotenv_from_repo()
|
| 118 |
+
|
| 119 |
+
p = argparse.ArgumentParser(
|
| 120 |
+
description="Upload model checkpoint + id2label JSONs to Hugging Face Hub. "
|
| 121 |
+
"Loads repo-root .env by default (HF_TOKEN) when python-dotenv is installed."
|
| 122 |
+
)
|
| 123 |
p.add_argument("--repo-id", required=True, help="Model repo id, e.g. username/arty-cnn-baseline")
|
| 124 |
p.add_argument(
|
| 125 |
"--checkpoint",
|
| 126 |
type=Path,
|
| 127 |
+
required=True,
|
| 128 |
+
help="Checkpoint file to upload (e.g. checkpoints/cnn_baseline/best.pt or checkpoints/cnnrnn/best.pt)",
|
| 129 |
)
|
|
|
|
| 130 |
p.add_argument("--index", type=Path, default=INDEX_SELECTED, help="Selected index CSV (default: data/wikiart_index_selected.csv)")
|
| 131 |
+
p.add_argument(
|
| 132 |
+
"--export-labels-dir",
|
| 133 |
+
type=Path,
|
| 134 |
+
default=LABEL_EXPORT_DEFAULT,
|
| 135 |
+
help=f"Write *_id2label.json here (default: {LABEL_EXPORT_DEFAULT})",
|
| 136 |
+
)
|
| 137 |
args = p.parse_args()
|
| 138 |
|
| 139 |
+
token = os.environ.get("HF_TOKEN")
|
| 140 |
if not token:
|
| 141 |
+
print(
|
| 142 |
+
"Missing token: add HF_TOKEN to repo-root .env",
|
| 143 |
+
file=sys.stderr,
|
| 144 |
+
)
|
| 145 |
sys.exit(1)
|
| 146 |
|
| 147 |
# quick sanity load of checkpoint format
|
|
@@ -1,10 +1,72 @@
|
|
| 1 |
"""ResNet-50 backbone variants for multi-task classification."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
|
|
|
| 4 |
from torchvision.models import ResNet50_Weights, resnet50
|
| 5 |
|
| 6 |
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
"""ResNet-50 (ImageNet pretrained), GAP, then three linear heads: genre, style, artist."""
|
| 9 |
|
| 10 |
def __init__(
|
|
@@ -58,7 +120,7 @@ class ResNet50ThreeHeads(nn.Module):
|
|
| 58 |
)
|
| 59 |
|
| 60 |
|
| 61 |
-
class ResNet50BiLSTMThreeHeads(nn.Module):
|
| 62 |
"""
|
| 63 |
ResNet-50 (ImageNet pretrained) feature map -> column pooling -> BiLSTM -> mean pool -> three heads.
|
| 64 |
|
|
|
|
| 1 |
"""ResNet-50 backbone variants for multi-task classification."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
from torchvision.models import ResNet50_Weights, resnet50
|
| 10 |
|
| 11 |
|
| 12 |
+
def _topk_from_logits(
|
| 13 |
+
logits: torch.Tensor, id2label: dict[int, str], k: int
|
| 14 |
+
) -> list[tuple[str, float]]:
|
| 15 |
+
"""Map top-k softmax probabilities to label names (batch size 1)."""
|
| 16 |
+
n = logits.size(-1)
|
| 17 |
+
k = min(k, n)
|
| 18 |
+
probs = F.softmax(logits, dim=-1)[0]
|
| 19 |
+
top = probs.topk(k)
|
| 20 |
+
return [(id2label[int(i)], float(p)) for i, p in zip(top.indices.tolist(), top.values.tolist())]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class _ThreeHeadPredictMixin:
|
| 24 |
+
"""Top-k decoding for genre / style / artist heads (shared by CNN and CNN–RNN)."""
|
| 25 |
+
|
| 26 |
+
def predict_topk(
|
| 27 |
+
self,
|
| 28 |
+
x: torch.Tensor,
|
| 29 |
+
*,
|
| 30 |
+
genre_id2label: dict[int, str],
|
| 31 |
+
style_id2label: dict[int, str],
|
| 32 |
+
artist_id2label: dict[int, str],
|
| 33 |
+
k: int = 3,
|
| 34 |
+
) -> tuple[list[tuple[str, float]], list[tuple[str, float]], list[tuple[str, float]]]:
|
| 35 |
+
self.eval()
|
| 36 |
+
with torch.no_grad():
|
| 37 |
+
lg, ls, la = self(x)
|
| 38 |
+
return (
|
| 39 |
+
_topk_from_logits(lg, genre_id2label, k),
|
| 40 |
+
_topk_from_logits(ls, style_id2label, k),
|
| 41 |
+
_topk_from_logits(la, artist_id2label, k),
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def predict_topk_from_path(
|
| 45 |
+
self,
|
| 46 |
+
path: Path | str,
|
| 47 |
+
transform: torch.nn.Module,
|
| 48 |
+
device: torch.device,
|
| 49 |
+
*,
|
| 50 |
+
genre_id2label: dict[int, str],
|
| 51 |
+
style_id2label: dict[int, str],
|
| 52 |
+
artist_id2label: dict[int, str],
|
| 53 |
+
k: int = 3,
|
| 54 |
+
) -> tuple[list[tuple[str, float]], list[tuple[str, float]], list[tuple[str, float]]]:
|
| 55 |
+
from PIL import Image
|
| 56 |
+
|
| 57 |
+
p = Path(path)
|
| 58 |
+
img = Image.open(p).convert("RGB")
|
| 59 |
+
x = transform(img).unsqueeze(0).to(device)
|
| 60 |
+
return self.predict_topk(
|
| 61 |
+
x,
|
| 62 |
+
genre_id2label=genre_id2label,
|
| 63 |
+
style_id2label=style_id2label,
|
| 64 |
+
artist_id2label=artist_id2label,
|
| 65 |
+
k=k,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class ResNet50ThreeHeads(_ThreeHeadPredictMixin, nn.Module):
|
| 70 |
"""ResNet-50 (ImageNet pretrained), GAP, then three linear heads: genre, style, artist."""
|
| 71 |
|
| 72 |
def __init__(
|
|
|
|
| 120 |
)
|
| 121 |
|
| 122 |
|
| 123 |
+
class ResNet50BiLSTMThreeHeads(_ThreeHeadPredictMixin, nn.Module):
|
| 124 |
"""
|
| 125 |
ResNet-50 (ImageNet pretrained) feature map -> column pooling -> BiLSTM -> mean pool -> three heads.
|
| 126 |
|
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Format `model.predict_topk` tuples for Gradio summaries and JSON."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def topk_tuples_to_ui_items(rows: list[tuple[str, float]]) -> list[dict[str, Any]]:
|
| 8 |
+
"""Match legacy Gradio `_topk` shape: `label` + `prob` rounded to 4 decimals."""
|
| 9 |
+
return [{"label": label, "prob": round(float(prob), 4)} for label, prob in rows]
|
|
@@ -2,6 +2,8 @@ import sys
|
|
| 2 |
from pathlib import Path
|
| 3 |
|
| 4 |
import torch
|
|
|
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
def test_resnet50_three_heads_forward_shapes_no_weights() -> None:
|
|
@@ -17,6 +19,52 @@ def test_resnet50_three_heads_forward_shapes_no_weights() -> None:
|
|
| 17 |
assert a.shape == (2, 23)
|
| 18 |
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
def test_resnet50_bilstm_three_heads_forward_shapes_no_weights() -> None:
|
| 21 |
root = Path(__file__).resolve().parent.parent
|
| 22 |
sys.path.insert(0, str(root / "src"))
|
|
|
|
| 2 |
from pathlib import Path
|
| 3 |
|
| 4 |
import torch
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from torchvision import transforms as T
|
| 7 |
|
| 8 |
|
| 9 |
def test_resnet50_three_heads_forward_shapes_no_weights() -> None:
|
|
|
|
| 19 |
assert a.shape == (2, 23)
|
| 20 |
|
| 21 |
|
| 22 |
+
def test_resnet50_three_heads_predict_topk() -> None:
|
| 23 |
+
root = Path(__file__).resolve().parent.parent
|
| 24 |
+
sys.path.insert(0, str(root / "src"))
|
| 25 |
+
from model import ResNet50ThreeHeads
|
| 26 |
+
|
| 27 |
+
model = ResNet50ThreeHeads(n_genre=10, n_style=27, n_artist=23, weights=None)
|
| 28 |
+
x = torch.randn(1, 3, 224, 224)
|
| 29 |
+
gmap = {i: f"g{i}" for i in range(10)}
|
| 30 |
+
smap = {i: f"s{i}" for i in range(27)}
|
| 31 |
+
amap = {i: f"a{i}" for i in range(23)}
|
| 32 |
+
g, s, a = model.predict_topk(
|
| 33 |
+
x,
|
| 34 |
+
genre_id2label=gmap,
|
| 35 |
+
style_id2label=smap,
|
| 36 |
+
artist_id2label=amap,
|
| 37 |
+
k=3,
|
| 38 |
+
)
|
| 39 |
+
assert len(g) == len(s) == len(a) == 3
|
| 40 |
+
assert all(isinstance(name, str) and 0.0 <= p <= 1.0 for name, p in g + s + a)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def test_resnet50_three_heads_predict_topk_from_path(tmp_path: Path) -> None:
|
| 44 |
+
root = Path(__file__).resolve().parent.parent
|
| 45 |
+
sys.path.insert(0, str(root / "src"))
|
| 46 |
+
from model import ResNet50ThreeHeads
|
| 47 |
+
|
| 48 |
+
img_path = tmp_path / "x.jpg"
|
| 49 |
+
Image.new("RGB", (256, 256), color=(120, 80, 40)).save(img_path, format="JPEG")
|
| 50 |
+
|
| 51 |
+
model = ResNet50ThreeHeads(n_genre=10, n_style=27, n_artist=23, weights=None)
|
| 52 |
+
device = torch.device("cpu")
|
| 53 |
+
transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])
|
| 54 |
+
gmap = {i: f"g{i}" for i in range(10)}
|
| 55 |
+
smap = {i: f"s{i}" for i in range(27)}
|
| 56 |
+
amap = {i: f"a{i}" for i in range(23)}
|
| 57 |
+
g, s, a = model.predict_topk_from_path(
|
| 58 |
+
img_path,
|
| 59 |
+
transform,
|
| 60 |
+
device,
|
| 61 |
+
genre_id2label=gmap,
|
| 62 |
+
style_id2label=smap,
|
| 63 |
+
artist_id2label=amap,
|
| 64 |
+
)
|
| 65 |
+
assert len(g) == len(s) == len(a) == 3
|
| 66 |
+
|
| 67 |
+
|
| 68 |
def test_resnet50_bilstm_three_heads_forward_shapes_no_weights() -> None:
|
| 69 |
root = Path(__file__).resolve().parent.parent
|
| 70 |
sys.path.insert(0, str(root / "src"))
|
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
ROOT = Path(__file__).resolve().parent.parent
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def test_topk_tuples_to_ui_items_rounding() -> None:
|
| 8 |
+
sys.path.insert(0, str(ROOT / "src"))
|
| 9 |
+
from predict_format import topk_tuples_to_ui_items
|
| 10 |
+
|
| 11 |
+
out = topk_tuples_to_ui_items([("a", 0.123456789), ("b", 0.5)])
|
| 12 |
+
assert out == [{"label": "a", "prob": 0.1235}, {"label": "b", "prob": 0.5}]
|
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for scripts/spot_check_excluded_post_impressionism.py helpers."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import importlib.util
|
| 5 |
+
import json
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
ROOT = Path(__file__).resolve().parent.parent
|
| 12 |
+
SCRIPT = ROOT / "scripts" / "spot_check_excluded_post_impressionism.py"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _load_module():
|
| 16 |
+
spec = importlib.util.spec_from_file_location("spot_check_excluded_post_impressionism", SCRIPT)
|
| 17 |
+
mod = importlib.util.module_from_spec(spec)
|
| 18 |
+
assert spec.loader is not None
|
| 19 |
+
sys.modules["spot_check_excluded_post_impressionism"] = mod
|
| 20 |
+
spec.loader.exec_module(mod)
|
| 21 |
+
return mod
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def test_load_id2label_roundtrip(tmp_path: Path) -> None:
|
| 25 |
+
mod = _load_module()
|
| 26 |
+
p = tmp_path / "m.json"
|
| 27 |
+
p.write_text(json.dumps({"0": "foo", "1": "bar"}), encoding="utf-8")
|
| 28 |
+
assert mod.load_id2label(p) == {0: "foo", 1: "bar"}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def test_default_rel_paths_count() -> None:
|
| 32 |
+
mod = _load_module()
|
| 33 |
+
assert len(mod.DEFAULT_REL_PATHS) == 5
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def test_resolve_device_force_cpu() -> None:
|
| 37 |
+
mod = _load_module()
|
| 38 |
+
assert mod.resolve_device(force_cpu=True).type == "cpu"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def test_load_label_maps_reads_repo_json() -> None:
|
| 42 |
+
mod = _load_module()
|
| 43 |
+
if not (mod.LABEL_MAPS_DIR / "genre_id2label.json").exists():
|
| 44 |
+
return
|
| 45 |
+
g, s, a = mod.load_label_maps()
|
| 46 |
+
assert "Post_Impressionism" in s.values()
|
| 47 |
+
assert len(g) >= 1 and len(a) >= 1
|
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for scripts/train_cnn.py checkpoint atomic save helper."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import importlib.util
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
ROOT = Path(__file__).resolve().parent.parent
|
| 12 |
+
SCRIPT = ROOT / "scripts" / "train_cnn.py"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _load_train_cnn():
|
| 16 |
+
spec = importlib.util.spec_from_file_location("train_cnn", SCRIPT)
|
| 17 |
+
mod = importlib.util.module_from_spec(spec)
|
| 18 |
+
assert spec.loader is not None
|
| 19 |
+
sys.modules["train_cnn"] = mod
|
| 20 |
+
spec.loader.exec_module(mod)
|
| 21 |
+
return mod
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def test_atomic_torch_save_roundtrip(tmp_path: Path) -> None:
|
| 25 |
+
mod = _load_train_cnn()
|
| 26 |
+
path = tmp_path / "ckpt.pt"
|
| 27 |
+
mod._atomic_torch_save({"k": 42, "t": torch.zeros(2)}, path)
|
| 28 |
+
assert path.is_file()
|
| 29 |
+
assert not path.with_suffix(".pt.tmp").exists()
|
| 30 |
+
|
| 31 |
+
data = torch.load(path, map_location="cpu", weights_only=False)
|
| 32 |
+
assert data["k"] == 42
|
| 33 |
+
assert list(data["t"].shape) == [2]
|
|
@@ -1,6 +1,7 @@
|
|
| 1 |
"""Tests for scripts/upload_model_to_hf.py"""
|
| 2 |
|
| 3 |
import importlib.util
|
|
|
|
| 4 |
import sys
|
| 5 |
from pathlib import Path
|
| 6 |
from unittest.mock import MagicMock, patch
|
|
@@ -16,6 +17,24 @@ sys.modules["upload_model_to_hf"] = mod
|
|
| 16 |
spec.loader.exec_module(mod)
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
def test_build_id2label_from_selected_index(tmp_path: Path) -> None:
|
| 20 |
index = tmp_path / "index.csv"
|
| 21 |
pd.DataFrame(
|
|
|
|
| 1 |
"""Tests for scripts/upload_model_to_hf.py"""
|
| 2 |
|
| 3 |
import importlib.util
|
| 4 |
+
import os
|
| 5 |
import sys
|
| 6 |
from pathlib import Path
|
| 7 |
from unittest.mock import MagicMock, patch
|
|
|
|
| 17 |
spec.loader.exec_module(mod)
|
| 18 |
|
| 19 |
|
| 20 |
+
def test_load_dotenv_from_repo_sets_hf_token(tmp_path: Path, monkeypatch) -> None:
|
| 21 |
+
monkeypatch.delenv("HF_TOKEN", raising=False)
|
| 22 |
+
(tmp_path / ".env").write_text("HF_TOKEN=fake_from_dotenv\n", encoding="utf-8")
|
| 23 |
+
with patch.object(mod, "ROOT", tmp_path):
|
| 24 |
+
mod._load_dotenv_from_repo()
|
| 25 |
+
|
| 26 |
+
assert os.environ.get("HF_TOKEN") == "fake_from_dotenv"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def test_load_dotenv_from_repo_overrides_stale_hf_token(tmp_path: Path, monkeypatch) -> None:
|
| 30 |
+
monkeypatch.setenv("HF_TOKEN", "stale_wrong_token")
|
| 31 |
+
(tmp_path / ".env").write_text("HF_TOKEN=good_from_dotenv\n", encoding="utf-8")
|
| 32 |
+
with patch.object(mod, "ROOT", tmp_path):
|
| 33 |
+
mod._load_dotenv_from_repo()
|
| 34 |
+
|
| 35 |
+
assert os.environ.get("HF_TOKEN") == "good_from_dotenv"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
def test_build_id2label_from_selected_index(tmp_path: Path) -> None:
|
| 39 |
index = tmp_path / "index.csv"
|
| 40 |
pd.DataFrame(
|