| import math |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange |
| from sklearn.metrics import accuracy_score, f1_score |
|
|
| from src.galileo import adjust_learning_rate |
|
|
| from .metrics import mean_iou |
|
|
| PROBING_LRs = { |
| "LP": [ |
| 1e-4, |
| 3e-4, |
| 5e-4, |
| 8e-4, |
| 1e-3, |
| 3e-3, |
| 5e-3, |
| 8e-3, |
| 1e-2, |
| 3e-2, |
| 5e-2, |
| 8e-2, |
| 1e-1, |
| 3e-1, |
| 5e-1, |
| 8e-1, |
| ], |
| } |
|
|
|
|
| def train_and_eval_probe_cls(lr, config, loaders, in_features, device): |
| probe = train_probe_cls( |
| data_loader=loaders["train"], |
| lr=lr, |
| epochs=50, |
| in_features=in_features, |
| num_classes=config["num_classes"], |
| is_multilabel=config["is_multilabel"], |
| device=device, |
| ) |
| val_acc = evaluate_probe_cls( |
| data_loader=loaders["valid"], |
| probe=probe, |
| is_multilabel=config["is_multilabel"], |
| device=device, |
| ) |
| test_acc = evaluate_probe_cls( |
| data_loader=loaders["test"], |
| probe=probe, |
| is_multilabel=config["is_multilabel"], |
| device=device, |
| ) |
| return val_acc, test_acc |
|
|
|
|
| def train_and_eval_probe_seg(lr, config, loaders, in_features, grid_size, device): |
| output_patch_size = math.ceil(config["segmentation_map_height_width"] / grid_size) |
| probe = train_probe_seg( |
| data_loader=loaders["train"], |
| lr=lr, |
| epochs=50, |
| in_features=in_features, |
| num_classes=config["num_classes"], |
| patch_size=output_patch_size, |
| device=device, |
| ) |
| val_miou = evaluate_probe_seg( |
| data_loader=loaders["valid"], |
| probe=probe, |
| num_classes=config["num_classes"], |
| patch_size=output_patch_size, |
| device=device, |
| ) |
| test_miou = evaluate_probe_seg( |
| data_loader=loaders["test"], |
| probe=probe, |
| num_classes=config["num_classes"], |
| patch_size=output_patch_size, |
| device=device, |
| ) |
| return val_miou, test_miou |
|
|
|
|
| def train_probe_cls( |
| data_loader, |
| lr, |
| epochs, |
| in_features, |
| num_classes, |
| is_multilabel, |
| device, |
| ): |
| probe = nn.Sequential(nn.BatchNorm1d(in_features), nn.Linear(in_features, num_classes)).to( |
| device |
| ) |
| opt = torch.optim.AdamW(probe.parameters(), lr=lr) |
|
|
| sched_config = { |
| "lr": lr, |
| "warmup_epochs": int(epochs * 0.1), |
| "min_lr": 1.0e-5, |
| "epochs": epochs, |
| } |
| probe = probe.train() |
|
|
| if is_multilabel: |
| loss_function = nn.MultiLabelSoftMarginLoss() |
| else: |
| loss_function = nn.CrossEntropyLoss() |
|
|
| for epoch in range(epochs): |
| for i, batch in enumerate(data_loader): |
| batch_emb, batch_labels = batch |
| batch_emb = batch_emb.to(device) |
|
|
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
| logits = probe(batch_emb) |
| loss = loss_function(logits, batch_labels.to(device)) |
|
|
| loss.backward() |
| adjust_learning_rate( |
| optimizer=opt, |
| epoch=epoch + (i / len(data_loader)), |
| total_epochs=sched_config["epochs"], |
| warmup_epochs=sched_config["warmup_epochs"], |
| max_lr=sched_config["lr"], |
| min_lr=sched_config["min_lr"], |
| ) |
|
|
| opt.step() |
| opt.zero_grad() |
|
|
| return probe |
|
|
|
|
| def evaluate_probe_cls(data_loader, probe, is_multilabel, device): |
| probe = probe.eval() |
|
|
| all_logits = [] |
| all_labels = [] |
| with torch.no_grad(): |
| for batch in data_loader: |
| batch_emb, batch_labels = batch |
| batch_emb = batch_emb.to(device) |
|
|
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
| batch_logits = probe(batch_emb) |
|
|
| all_logits.append(batch_logits.float().cpu()) |
| all_labels.append(batch_labels) |
|
|
| all_logits = torch.cat(all_logits, dim=0) |
| all_labels = torch.cat(all_labels, dim=0) |
|
|
| if is_multilabel: |
| all_preds = torch.sigmoid(all_logits) > 0.5 |
| return f1_score(all_labels, all_preds, average="micro") |
| else: |
| all_preds = torch.argmax(all_logits, dim=-1) |
| return accuracy_score(all_labels, all_preds) |
|
|
|
|
| def train_probe_seg( |
| data_loader, |
| lr, |
| epochs, |
| in_features, |
| num_classes, |
| patch_size, |
| probe_type, |
| device, |
| ): |
| logits_per_patch = int(num_classes * patch_size * patch_size) |
| assert probe_type in ["LP", "MLP"] |
| if probe_type == "LP": |
| probe = nn.Sequential(nn.Linear(in_features, logits_per_patch)).to(device) |
| else: |
| probe = nn.Sequential( |
| nn.Linear(in_features, 2048), nn.GELU(), nn.Linear(2048, logits_per_patch) |
| ).to(device) |
|
|
| opt = torch.optim.AdamW(probe.parameters(), lr=lr) |
|
|
| sched_config = { |
| "lr": lr, |
| "warmup_epochs": int(epochs * 0.1), |
| "min_lr": 1.0e-5, |
| "epochs": epochs, |
| } |
| probe = probe.train() |
| loss_function = nn.CrossEntropyLoss(ignore_index=-1) |
|
|
| for epoch in range(epochs): |
| for i, batch in enumerate(data_loader): |
| batch_emb, batch_labels = batch |
| spatial_patches_per_dim = int(batch_emb.shape[1] ** 0.5) |
| batch_emb = batch_emb.to(device) |
|
|
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
| logits = probe(batch_emb) |
|
|
| |
| if batch_labels.shape[1] == batch_labels.shape[2]: |
| logits = rearrange( |
| logits, |
| "b (h w) (c i j) -> b c (h i) (w j)", |
| h=spatial_patches_per_dim, |
| w=spatial_patches_per_dim, |
| c=num_classes, |
| i=patch_size, |
| j=patch_size, |
| ) |
| if logits.shape[-2] != batch_labels.shape[-2]: |
| logits = F.interpolate( |
| logits, |
| size=(batch_labels.shape[-2], batch_labels.shape[-1]), |
| mode="bilinear", |
| align_corners=True, |
| ) |
| else: |
| |
| logits = rearrange( |
| logits, |
| "b t (c i j) -> b c t (i j)", |
| c=num_classes, |
| i=patch_size, |
| j=patch_size, |
| ) |
| loss = loss_function(logits, batch_labels.to(device)) |
|
|
| loss.backward() |
| adjust_learning_rate( |
| optimizer=opt, |
| epoch=epoch + (i / len(data_loader)), |
| total_epochs=sched_config["epochs"], |
| warmup_epochs=sched_config["warmup_epochs"], |
| max_lr=sched_config["lr"], |
| min_lr=sched_config["min_lr"], |
| ) |
|
|
| opt.step() |
| opt.zero_grad() |
|
|
| return probe |
|
|
|
|
| def evaluate_probe_seg( |
| data_loader, |
| probe, |
| num_classes, |
| patch_size, |
| device, |
| ): |
| probe = probe.eval() |
|
|
| all_preds = [] |
| all_labels = [] |
| with torch.no_grad(): |
| for batch in data_loader: |
| batch_emb, batch_labels = batch |
| spatial_patches_per_dim = int(batch_emb.shape[1] ** 0.5) |
| batch_emb = batch_emb.to(device) |
|
|
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
| logits = probe(batch_emb) |
| logits = rearrange( |
| logits, |
| "b (h w) (c i j) -> b c (h i) (w j)", |
| h=spatial_patches_per_dim, |
| w=spatial_patches_per_dim, |
| c=num_classes, |
| i=patch_size, |
| j=patch_size, |
| ) |
| if logits.shape[-2] != batch_labels.shape[-2]: |
| logits = F.interpolate( |
| logits, |
| size=(batch_labels.shape[-2], batch_labels.shape[-1]), |
| mode="bilinear", |
| align_corners=True, |
| ) |
|
|
| preds = torch.argmax(logits, dim=1).cpu() |
| all_preds.append(preds) |
| all_labels.append(batch_labels) |
|
|
| all_preds = torch.cat(all_preds) |
| all_labels = torch.cat(all_labels) |
| miou = mean_iou(all_preds, all_labels, num_classes=num_classes, ignore_label=-1) |
| return miou |
|
|