WikiArtClassifySD / README.md
Harish-JHR's picture
Add model card
8b5a1f5 verified
---
license: mit
tags:
- art-classification
- stable-diffusion
- wikiart
---
# ArtExtract Task 1 — Checkpoints
Trained classifiers for painting style, artist, and genre classification using
Stable Diffusion U-Net activations as features (WikiArt dataset).
## Models
Three model types, each trained for three tasks (style / artist / genre):
**ConvLSTM** — spatial Conv-LSTM over the 16×16 activation grid from SD down_blocks.2.
Treats each spatial position as a token in a sequence; attention pooling selects
which regions matter for the prediction.
**MLP probe** — simple MLP on global-avg-pooled activations (3840-d).
Fast linear probe used to verify the quality of the SD features.
**ResNet50** — ImageNet-pretrained ResNet50 fine-tuned on raw paintings.
Baseline that doesn't use diffusion features.
## Checkpoint format
Each `best.pt` is a standard PyTorch checkpoint dict:
```python
{
"epoch": int,
"model_state": OrderedDict, # load with model.load_state_dict()
"opt_state": OrderedDict,
"val_acc": float,
}
```
## Loading example
```python
import torch
from src.models import build_model
ckpt = torch.load("convlstm_style/best.pt", map_location="cpu")
model = build_model("convlstm", num_classes=27)
model.load_state_dict(ckpt["model_state"])
model.eval()
print(f"Loaded from epoch {ckpt['epoch']}, val_acc={ckpt['val_acc']:.4f}")
```
## Training
SD features extracted from `valhalla/sd-wikiart-v2` at timestep t=200,
hooking `down_blocks.2`, `mid_block`, `up_blocks.1`.