| |
| """ |
| Experiment A: Missing-modality robustness for scene recognition (T1). |
| |
| Train a late-fusion Transformer on all 5 modalities with random per-sample |
| modality dropout. At test time, systematically evaluate every modality subset |
| (single modalities, leave-one-out, and full set) by zeroing out the |
| slices of the concatenated input tensor that correspond to the dropped |
| modalities. |
| |
| Reuses: experiments.dataset.get_dataloaders, experiments.models.build_model, |
| and the pretrained-backbone-transfer helper from train_exp1.py. |
| """ |
|
|
| import os |
| import sys |
| import json |
| import time |
| import random |
| import argparse |
| import itertools |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from sklearn.metrics import accuracy_score, f1_score, confusion_matrix |
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| from data.dataset import get_dataloaders, NUM_CLASSES |
| from nets.models import build_model |
| from tasks.train_exp1 import ( |
| set_seed, apply_augmentation, _load_and_freeze_backbone, |
| ) |
|
|
|
|
| def modality_slices(modality_dims): |
| """Return {mod_name: (start, end)} byte-offsets into the concatenated feature dim.""" |
| slices = {} |
| off = 0 |
| for name, dim in modality_dims.items(): |
| slices[name] = (off, off + dim) |
| off += dim |
| return slices |
|
|
|
|
| def mask_modalities(x, slices, active_mods): |
| """Zero out the slices of x corresponding to modalities NOT in active_mods. |
| |
| x: (B, T, F_total) |
| Returns a new tensor; does not mutate x in place. |
| """ |
| if set(active_mods) == set(slices.keys()): |
| return x |
| x2 = x.clone() |
| for name, (s, e) in slices.items(): |
| if name not in active_mods: |
| x2[..., s:e] = 0.0 |
| return x2 |
|
|
|
|
| def train_one_epoch_with_dropout(model, loader, criterion, optimizer, device, |
| slices, mod_dropout_p=0.0, |
| augment=False, noise_std=0.1, time_mask_ratio=0.1): |
| """Train one epoch. With probability mod_dropout_p, for each training sample |
| independently drop a random non-empty subset of modalities. |
| |
| Strategy: for each sample, flip an independent Bernoulli(p) per modality; |
| if ALL modalities would be dropped, keep one at random. |
| """ |
| model.train() |
| mods = list(slices.keys()) |
| total_loss = 0.0 |
| all_preds, all_labels = [], [] |
|
|
| for x, y, mask, _ in loader: |
| x, y, mask = x.to(device), y.to(device), mask.to(device) |
| if augment: |
| x = apply_augmentation(x, mask, noise_std, time_mask_ratio) |
|
|
| if mod_dropout_p > 0: |
| B = x.size(0) |
| for i in range(B): |
| dropped = [m for m in mods if random.random() < mod_dropout_p] |
| |
| if len(dropped) == len(mods): |
| dropped = random.sample(dropped, len(dropped) - 1) |
| for m in dropped: |
| s, e = slices[m] |
| x[i, :, s:e] = 0.0 |
|
|
| optimizer.zero_grad() |
| logits = model(x, mask) |
| loss = criterion(logits, y) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_( |
| [p for p in model.parameters() if p.requires_grad], 1.0 |
| ) |
| optimizer.step() |
|
|
| total_loss += loss.item() * y.size(0) |
| all_preds.extend(logits.argmax(dim=1).cpu().numpy()) |
| all_labels.extend(y.cpu().numpy()) |
|
|
| n = len(all_labels) |
| return total_loss / n, accuracy_score(all_labels, all_preds) |
|
|
|
|
| @torch.no_grad() |
| def evaluate_with_mask(model, loader, criterion, device, slices, active_mods): |
| model.eval() |
| total_loss = 0.0 |
| all_preds, all_labels = [], [] |
| for x, y, mask, _ in loader: |
| x, y, mask = x.to(device), y.to(device), mask.to(device) |
| x = mask_modalities(x, slices, set(active_mods)) |
| logits = model(x, mask) |
| loss = criterion(logits, y) |
| total_loss += loss.item() * y.size(0) |
| all_preds.extend(logits.argmax(dim=1).cpu().numpy()) |
| all_labels.extend(y.cpu().numpy()) |
| n = len(all_labels) |
| if n == 0: |
| return 0.0, 0.0, 0.0, np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=int) |
| acc = accuracy_score(all_labels, all_preds) |
| f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0) |
| cm = confusion_matrix(all_labels, all_preds, labels=list(range(NUM_CLASSES))) |
| return total_loss / n, acc, f1, cm |
|
|
|
|
| def run_experiment(args): |
| set_seed(args.seed) |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
|
|
| modalities = args.modalities.split(',') |
| print(f"Model: {args.model} | Fusion: {args.fusion} | Modalities: {modalities}") |
| print(f"Training dropout p={args.mod_dropout_p}") |
|
|
| train_loader, val_loader, test_loader, info = get_dataloaders( |
| modalities, batch_size=args.batch_size, downsample=args.downsample |
| ) |
| if info['val_size'] == 0: |
| val_loader = test_loader |
| print(f"Train: {info['train_size']}, Test: {info['test_size']}") |
| print(f"Feature dim: {info['feat_dim']}, Modality dims: {info['modality_dims']}") |
|
|
| slices = modality_slices(info['modality_dims']) |
| print(f"Modality slices: {slices}") |
|
|
| model = build_model( |
| args.model, args.fusion, info['feat_dim'], |
| info['modality_dims'], info['num_classes'], |
| hidden_dim=args.hidden_dim, proj_dim=args.proj_dim, |
| late_agg=args.late_agg, |
| ).to(device) |
|
|
| |
| if args.pretrained_dir: |
| for i, mod in enumerate(modalities): |
| pt_path = os.path.join(args.pretrained_dir, |
| f"transformer_{mod}_early", "model_best.pt") |
| if os.path.exists(pt_path): |
| _load_and_freeze_backbone(model, pt_path, i, args.fusion) |
| else: |
| print(f" WARN: no pretrained ckpt for {mod} at {pt_path}") |
|
|
| total = sum(p.numel() for p in model.parameters()) |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"Params: {trainable:,}/{total:,}") |
|
|
| class_weights = info['class_weights'].to(device) |
| criterion = nn.CrossEntropyLoss(weight=class_weights, |
| label_smoothing=args.label_smoothing) |
|
|
| optimizer = torch.optim.Adam( |
| filter(lambda p: p.requires_grad, model.parameters()), |
| lr=args.lr, weight_decay=args.weight_decay, |
| ) |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
| optimizer, mode='min', factor=0.5, patience=7, min_lr=1e-6, |
| ) |
|
|
| mod_str = '-'.join(modalities) |
| exp_name = f"{args.model}_{mod_str}_{args.fusion}_drop{args.mod_dropout_p}_seed{args.seed}" |
| if args.tag: |
| exp_name += f"_{args.tag}" |
| out_dir = os.path.join(args.output_dir, exp_name) |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| best_val_loss = float('inf') |
| best_epoch = 0 |
| patience_counter = 0 |
|
|
| for epoch in range(1, args.epochs + 1): |
| t0 = time.time() |
| train_loss, train_acc = train_one_epoch_with_dropout( |
| model, train_loader, criterion, optimizer, device, |
| slices=slices, mod_dropout_p=args.mod_dropout_p, |
| augment=args.augment, |
| ) |
| |
| val_loss, val_acc, val_f1, _ = evaluate_with_mask( |
| model, val_loader, criterion, device, slices, modalities, |
| ) |
| scheduler.step(val_loss) |
| print(f" E{epoch:3d} | tr_loss {train_loss:.4f} tr_acc {train_acc:.4f} | " |
| f"va_loss {val_loss:.4f} va_acc {val_acc:.4f} va_f1 {val_f1:.4f} | " |
| f"{time.time()-t0:.1f}s") |
| if val_loss < best_val_loss: |
| best_val_loss = val_loss |
| best_epoch = epoch |
| patience_counter = 0 |
| torch.save(model.state_dict(), os.path.join(out_dir, 'model_best.pt')) |
| else: |
| patience_counter += 1 |
| if patience_counter >= args.patience: |
| print(f" Early stop at epoch {epoch} (best {best_epoch})") |
| break |
|
|
| |
| model.load_state_dict(torch.load(os.path.join(out_dir, 'model_best.pt'), |
| weights_only=True)) |
|
|
| |
| print("\n=== Robustness Evaluation ===") |
| eval_configs = [] |
| eval_configs.append(('full', modalities)) |
| for m in modalities: |
| remaining = [x for x in modalities if x != m] |
| eval_configs.append((f'drop_{m}', remaining)) |
| for m in modalities: |
| eval_configs.append((f'only_{m}', [m])) |
|
|
| results_matrix = {} |
| for name, active in eval_configs: |
| _, acc, f1, _ = evaluate_with_mask( |
| model, test_loader, criterion, device, slices, active, |
| ) |
| results_matrix[name] = {'active': active, 'acc': float(acc), 'f1': float(f1)} |
| print(f" {name:<15s} mods={active} | acc {acc:.4f} f1 {f1:.4f}") |
|
|
| results = { |
| 'experiment': exp_name, |
| 'training_dropout_p': args.mod_dropout_p, |
| 'seed': args.seed, |
| 'best_epoch': best_epoch, |
| 'eval_configs': results_matrix, |
| 'train_size': info['train_size'], |
| 'test_size': info['test_size'], |
| 'modality_dims': info['modality_dims'], |
| 'args': vars(args), |
| } |
| with open(os.path.join(out_dir, 'results.json'), 'w') as f: |
| json.dump(results, f, indent=2, ensure_ascii=False) |
| print(f"Saved: {out_dir}/results.json") |
| return results |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument('--model', type=str, default='transformer') |
| p.add_argument('--modalities', type=str, default='mocap,emg,eyetrack,imu,pressure') |
| p.add_argument('--fusion', type=str, default='late') |
| p.add_argument('--late_agg', type=str, default='mean') |
| p.add_argument('--mod_dropout_p', type=float, default=0.3, |
| help='Per-modality independent dropout prob at training time') |
| p.add_argument('--pretrained_dir', type=str, default='', |
| help='Directory with pretrained single-modality ckpts') |
| p.add_argument('--epochs', type=int, default=100) |
| p.add_argument('--batch_size', type=int, default=16) |
| p.add_argument('--lr', type=float, default=1e-3) |
| p.add_argument('--weight_decay', type=float, default=1e-4) |
| p.add_argument('--hidden_dim', type=int, default=128) |
| p.add_argument('--proj_dim', type=int, default=0) |
| p.add_argument('--downsample', type=int, default=5) |
| p.add_argument('--patience', type=int, default=15) |
| p.add_argument('--label_smoothing', type=float, default=0.1) |
| p.add_argument('--augment', action='store_true') |
| p.add_argument('--seed', type=int, default=42) |
| p.add_argument('--output_dir', type=str, required=True) |
| p.add_argument('--tag', type=str, default='') |
| args = p.parse_args() |
| run_experiment(args) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|