| --- |
| license: mit |
| tags: |
| - art-classification |
| - stable-diffusion |
| - wikiart |
| --- |
| |
| # ArtExtract Task 1 — Checkpoints |
|
|
| Trained classifiers for painting style, artist, and genre classification using |
| Stable Diffusion U-Net activations as features (WikiArt dataset). |
|
|
| ## Models |
|
|
| Three model types, each trained for three tasks (style / artist / genre): |
|
|
| **ConvLSTM** — spatial Conv-LSTM over the 16×16 activation grid from SD down_blocks.2. |
| Treats each spatial position as a token in a sequence; attention pooling selects |
| which regions matter for the prediction. |
| |
| **MLP probe** — simple MLP on global-avg-pooled activations (3840-d). |
| Fast linear probe used to verify the quality of the SD features. |
| |
| **ResNet50** — ImageNet-pretrained ResNet50 fine-tuned on raw paintings. |
| Baseline that doesn't use diffusion features. |
| |
| ## Checkpoint format |
| |
| Each `best.pt` is a standard PyTorch checkpoint dict: |
| |
| ```python |
| { |
| "epoch": int, |
| "model_state": OrderedDict, # load with model.load_state_dict() |
| "opt_state": OrderedDict, |
| "val_acc": float, |
| } |
| ``` |
| |
| ## Loading example |
|
|
| ```python |
| import torch |
| from src.models import build_model |
| |
| ckpt = torch.load("convlstm_style/best.pt", map_location="cpu") |
| model = build_model("convlstm", num_classes=27) |
| model.load_state_dict(ckpt["model_state"]) |
| model.eval() |
| print(f"Loaded from epoch {ckpt['epoch']}, val_acc={ckpt['val_acc']:.4f}") |
| ``` |
|
|
| ## Training |
|
|
| SD features extracted from `valhalla/sd-wikiart-v2` at timestep t=200, |
| hooking `down_blocks.2`, `mid_block`, `up_blocks.1`. |
|
|