#!/usr/bin/env python3 """ Experiment A: Missing-modality robustness for scene recognition (T1). Train a late-fusion Transformer on all 5 modalities with random per-sample modality dropout. At test time, systematically evaluate every modality subset (single modalities, leave-one-out, and full set) by zeroing out the slices of the concatenated input tensor that correspond to the dropped modalities. Reuses: experiments.dataset.get_dataloaders, experiments.models.build_model, and the pretrained-backbone-transfer helper from train_exp1.py. """ import os import sys import json import time import random import argparse import itertools import numpy as np import torch import torch.nn as nn from sklearn.metrics import accuracy_score, f1_score, confusion_matrix sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data.dataset import get_dataloaders, NUM_CLASSES from nets.models import build_model from tasks.train_exp1 import ( set_seed, apply_augmentation, _load_and_freeze_backbone, ) def modality_slices(modality_dims): """Return {mod_name: (start, end)} byte-offsets into the concatenated feature dim.""" slices = {} off = 0 for name, dim in modality_dims.items(): slices[name] = (off, off + dim) off += dim return slices def mask_modalities(x, slices, active_mods): """Zero out the slices of x corresponding to modalities NOT in active_mods. x: (B, T, F_total) Returns a new tensor; does not mutate x in place. """ if set(active_mods) == set(slices.keys()): return x x2 = x.clone() for name, (s, e) in slices.items(): if name not in active_mods: x2[..., s:e] = 0.0 return x2 def train_one_epoch_with_dropout(model, loader, criterion, optimizer, device, slices, mod_dropout_p=0.0, augment=False, noise_std=0.1, time_mask_ratio=0.1): """Train one epoch. With probability mod_dropout_p, for each training sample independently drop a random non-empty subset of modalities. Strategy: for each sample, flip an independent Bernoulli(p) per modality; if ALL modalities would be dropped, keep one at random. """ model.train() mods = list(slices.keys()) total_loss = 0.0 all_preds, all_labels = [], [] for x, y, mask, _ in loader: x, y, mask = x.to(device), y.to(device), mask.to(device) if augment: x = apply_augmentation(x, mask, noise_std, time_mask_ratio) if mod_dropout_p > 0: B = x.size(0) for i in range(B): dropped = [m for m in mods if random.random() < mod_dropout_p] # ensure at least one modality survives if len(dropped) == len(mods): dropped = random.sample(dropped, len(dropped) - 1) for m in dropped: s, e = slices[m] x[i, :, s:e] = 0.0 optimizer.zero_grad() logits = model(x, mask) loss = criterion(logits, y) loss.backward() torch.nn.utils.clip_grad_norm_( [p for p in model.parameters() if p.requires_grad], 1.0 ) optimizer.step() total_loss += loss.item() * y.size(0) all_preds.extend(logits.argmax(dim=1).cpu().numpy()) all_labels.extend(y.cpu().numpy()) n = len(all_labels) return total_loss / n, accuracy_score(all_labels, all_preds) @torch.no_grad() def evaluate_with_mask(model, loader, criterion, device, slices, active_mods): model.eval() total_loss = 0.0 all_preds, all_labels = [], [] for x, y, mask, _ in loader: x, y, mask = x.to(device), y.to(device), mask.to(device) x = mask_modalities(x, slices, set(active_mods)) logits = model(x, mask) loss = criterion(logits, y) total_loss += loss.item() * y.size(0) all_preds.extend(logits.argmax(dim=1).cpu().numpy()) all_labels.extend(y.cpu().numpy()) n = len(all_labels) if n == 0: return 0.0, 0.0, 0.0, np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=int) acc = accuracy_score(all_labels, all_preds) f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0) cm = confusion_matrix(all_labels, all_preds, labels=list(range(NUM_CLASSES))) return total_loss / n, acc, f1, cm def run_experiment(args): set_seed(args.seed) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") modalities = args.modalities.split(',') print(f"Model: {args.model} | Fusion: {args.fusion} | Modalities: {modalities}") print(f"Training dropout p={args.mod_dropout_p}") train_loader, val_loader, test_loader, info = get_dataloaders( modalities, batch_size=args.batch_size, downsample=args.downsample ) if info['val_size'] == 0: val_loader = test_loader print(f"Train: {info['train_size']}, Test: {info['test_size']}") print(f"Feature dim: {info['feat_dim']}, Modality dims: {info['modality_dims']}") slices = modality_slices(info['modality_dims']) print(f"Modality slices: {slices}") model = build_model( args.model, args.fusion, info['feat_dim'], info['modality_dims'], info['num_classes'], hidden_dim=args.hidden_dim, proj_dim=args.proj_dim, late_agg=args.late_agg, ).to(device) # Optional pretrained backbone loading (per-modality) if args.pretrained_dir: for i, mod in enumerate(modalities): pt_path = os.path.join(args.pretrained_dir, f"transformer_{mod}_early", "model_best.pt") if os.path.exists(pt_path): _load_and_freeze_backbone(model, pt_path, i, args.fusion) else: print(f" WARN: no pretrained ckpt for {mod} at {pt_path}") total = sum(p.numel() for p in model.parameters()) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Params: {trainable:,}/{total:,}") class_weights = info['class_weights'].to(device) criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=args.label_smoothing) optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay, ) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=0.5, patience=7, min_lr=1e-6, ) mod_str = '-'.join(modalities) exp_name = f"{args.model}_{mod_str}_{args.fusion}_drop{args.mod_dropout_p}_seed{args.seed}" if args.tag: exp_name += f"_{args.tag}" out_dir = os.path.join(args.output_dir, exp_name) os.makedirs(out_dir, exist_ok=True) best_val_loss = float('inf') best_epoch = 0 patience_counter = 0 for epoch in range(1, args.epochs + 1): t0 = time.time() train_loss, train_acc = train_one_epoch_with_dropout( model, train_loader, criterion, optimizer, device, slices=slices, mod_dropout_p=args.mod_dropout_p, augment=args.augment, ) # Validate on FULL modalities (baseline performance) val_loss, val_acc, val_f1, _ = evaluate_with_mask( model, val_loader, criterion, device, slices, modalities, ) scheduler.step(val_loss) print(f" E{epoch:3d} | tr_loss {train_loss:.4f} tr_acc {train_acc:.4f} | " f"va_loss {val_loss:.4f} va_acc {val_acc:.4f} va_f1 {val_f1:.4f} | " f"{time.time()-t0:.1f}s") if val_loss < best_val_loss: best_val_loss = val_loss best_epoch = epoch patience_counter = 0 torch.save(model.state_dict(), os.path.join(out_dir, 'model_best.pt')) else: patience_counter += 1 if patience_counter >= args.patience: print(f" Early stop at epoch {epoch} (best {best_epoch})") break # Restore best model model.load_state_dict(torch.load(os.path.join(out_dir, 'model_best.pt'), weights_only=True)) # Systematic evaluation: full, leave-one-out, and all singletons print("\n=== Robustness Evaluation ===") eval_configs = [] eval_configs.append(('full', modalities)) for m in modalities: remaining = [x for x in modalities if x != m] eval_configs.append((f'drop_{m}', remaining)) for m in modalities: eval_configs.append((f'only_{m}', [m])) results_matrix = {} for name, active in eval_configs: _, acc, f1, _ = evaluate_with_mask( model, test_loader, criterion, device, slices, active, ) results_matrix[name] = {'active': active, 'acc': float(acc), 'f1': float(f1)} print(f" {name:<15s} mods={active} | acc {acc:.4f} f1 {f1:.4f}") results = { 'experiment': exp_name, 'training_dropout_p': args.mod_dropout_p, 'seed': args.seed, 'best_epoch': best_epoch, 'eval_configs': results_matrix, 'train_size': info['train_size'], 'test_size': info['test_size'], 'modality_dims': info['modality_dims'], 'args': vars(args), } with open(os.path.join(out_dir, 'results.json'), 'w') as f: json.dump(results, f, indent=2, ensure_ascii=False) print(f"Saved: {out_dir}/results.json") return results def main(): p = argparse.ArgumentParser() p.add_argument('--model', type=str, default='transformer') p.add_argument('--modalities', type=str, default='mocap,emg,eyetrack,imu,pressure') p.add_argument('--fusion', type=str, default='late') p.add_argument('--late_agg', type=str, default='mean') p.add_argument('--mod_dropout_p', type=float, default=0.3, help='Per-modality independent dropout prob at training time') p.add_argument('--pretrained_dir', type=str, default='', help='Directory with pretrained single-modality ckpts') p.add_argument('--epochs', type=int, default=100) p.add_argument('--batch_size', type=int, default=16) p.add_argument('--lr', type=float, default=1e-3) p.add_argument('--weight_decay', type=float, default=1e-4) p.add_argument('--hidden_dim', type=int, default=128) p.add_argument('--proj_dim', type=int, default=0) p.add_argument('--downsample', type=int, default=5) p.add_argument('--patience', type=int, default=15) p.add_argument('--label_smoothing', type=float, default=0.1) p.add_argument('--augment', action='store_true') p.add_argument('--seed', type=int, default=42) p.add_argument('--output_dir', type=str, required=True) p.add_argument('--tag', type=str, default='') args = p.parse_args() run_experiment(args) if __name__ == '__main__': main()