--- license: mit tags: - art-classification - stable-diffusion - wikiart --- # ArtExtract Task 1 — Checkpoints Trained classifiers for painting style, artist, and genre classification using Stable Diffusion U-Net activations as features (WikiArt dataset). ## Models Three model types, each trained for three tasks (style / artist / genre): **ConvLSTM** — spatial Conv-LSTM over the 16×16 activation grid from SD down_blocks.2. Treats each spatial position as a token in a sequence; attention pooling selects which regions matter for the prediction. **MLP probe** — simple MLP on global-avg-pooled activations (3840-d). Fast linear probe used to verify the quality of the SD features. **ResNet50** — ImageNet-pretrained ResNet50 fine-tuned on raw paintings. Baseline that doesn't use diffusion features. ## Checkpoint format Each `best.pt` is a standard PyTorch checkpoint dict: ```python { "epoch": int, "model_state": OrderedDict, # load with model.load_state_dict() "opt_state": OrderedDict, "val_acc": float, } ``` ## Loading example ```python import torch from src.models import build_model ckpt = torch.load("convlstm_style/best.pt", map_location="cpu") model = build_model("convlstm", num_classes=27) model.load_state_dict(ckpt["model_state"]) model.eval() print(f"Loaded from epoch {ckpt['epoch']}, val_acc={ckpt['val_acc']:.4f}") ``` ## Training SD features extracted from `valhalla/sd-wikiart-v2` at timestep t=200, hooking `down_blocks.2`, `mid_block`, `up_blocks.1`.