WikiArtClassifySD / README.md
Harish-JHR's picture
Add model card
8b5a1f5 verified
metadata
license: mit
tags:
  - art-classification
  - stable-diffusion
  - wikiart

ArtExtract Task 1 — Checkpoints

Trained classifiers for painting style, artist, and genre classification using Stable Diffusion U-Net activations as features (WikiArt dataset).

Models

Three model types, each trained for three tasks (style / artist / genre):

ConvLSTM — spatial Conv-LSTM over the 16×16 activation grid from SD down_blocks.2. Treats each spatial position as a token in a sequence; attention pooling selects which regions matter for the prediction.

MLP probe — simple MLP on global-avg-pooled activations (3840-d). Fast linear probe used to verify the quality of the SD features.

ResNet50 — ImageNet-pretrained ResNet50 fine-tuned on raw paintings. Baseline that doesn't use diffusion features.

Checkpoint format

Each best.pt is a standard PyTorch checkpoint dict:

{
    "epoch":       int,
    "model_state": OrderedDict,   # load with model.load_state_dict()
    "opt_state":   OrderedDict,
    "val_acc":     float,
}

Loading example

import torch
from src.models import build_model

ckpt = torch.load("convlstm_style/best.pt", map_location="cpu")
model = build_model("convlstm", num_classes=27)
model.load_state_dict(ckpt["model_state"])
model.eval()
print(f"Loaded from epoch {ckpt['epoch']}, val_acc={ckpt['val_acc']:.4f}")

Training

SD features extracted from valhalla/sd-wikiart-v2 at timestep t=200, hooking down_blocks.2, mid_block, up_blocks.1.