| from __future__ import annotations
|
|
|
| import json
|
| import os
|
| from typing import Dict
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.optim as optim
|
| from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
|
|
| from flexibrain.config import RunConfig
|
| from flexibrain.data import build_downstream_dataloaders
|
| from flexibrain.data.classification import prepare_batch_data
|
| from flexibrain.models import build_downstream_model
|
| from flexibrain.utils.logging import setup_logger
|
| from flexibrain.utils.seed import set_seed
|
|
|
|
|
| class DownstreamTrainer:
|
| def __init__(self, cfg: RunConfig):
|
| self.cfg = cfg
|
| self.rank = cfg.training.local_rank
|
| self.device = torch.device(f"cuda:{self.rank}" if torch.cuda.is_available() else "cpu")
|
| self.logger = setup_logger("downstream", cfg.logging.log_dir, rank=self.rank)
|
|
|
| def build(self):
|
| set_seed(self.cfg.training.seed)
|
| self.model = build_downstream_model(
|
| self.cfg.model,
|
| self.device,
|
| logger=self.logger,
|
| checkpoint_path=self.cfg.pretrain_checkpoint,
|
| from_scratch=self.cfg.from_scratch,
|
| use_checkpoint_config=self.cfg.use_checkpoint_config,
|
| )
|
| self.train_loader, self.val_loader, self.test_loader = build_downstream_dataloaders(self.cfg.data, self.cfg.training, rank=self.rank, world_size=self.cfg.training.world_size)
|
| if self.cfg.training.lr_backbone is not None or self.cfg.training.lr_head is not None:
|
| backbone_params = list(self.model.backbone.parameters())
|
| head_params = [p for n, p in self.model.named_parameters() if not n.startswith("backbone.")]
|
| self.optimizer = optim.AdamW([
|
| {"params": backbone_params, "lr": self.cfg.training.lr_backbone or self.cfg.training.lr},
|
| {"params": head_params, "lr": self.cfg.training.lr_head or self.cfg.training.lr},
|
| ], weight_decay=self.cfg.training.weight_decay)
|
| else:
|
| self.optimizer = optim.AdamW(self.model.parameters(), lr=self.cfg.training.lr, weight_decay=self.cfg.training.weight_decay)
|
| self.use_amp = bool(self.cfg.training.use_amp and self.device.type == "cuda")
|
| self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp)
|
| total_steps = max(1, len(self.train_loader) * self.cfg.training.epochs)
|
| warmup_steps = max(1, len(self.train_loader) * self.cfg.training.warmup_epochs)
|
|
|
| def lr_lambda(step):
|
| if step < warmup_steps:
|
| return step / warmup_steps
|
| return max(0.0, (total_steps - step) / max(1, total_steps - warmup_steps))
|
|
|
| self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
|
| return self
|
|
|
| def _optimizer_step(self) -> None:
|
| if self.cfg.training.grad_clip > 0:
|
| self.scaler.unscale_(self.optimizer)
|
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.training.grad_clip)
|
| self.scaler.step(self.optimizer)
|
| self.scaler.update()
|
| self.scheduler.step()
|
| self.optimizer.zero_grad(set_to_none=True)
|
|
|
| def train_one_epoch(self, epoch: int) -> float:
|
| self.model.train()
|
| criterion = nn.CrossEntropyLoss()
|
| total_loss = 0.0
|
| num_batches = 0
|
| accum = self.cfg.training.grad_accumulation_steps
|
| self.optimizer.zero_grad(set_to_none=True)
|
| for batch_idx, batch in enumerate(self.train_loader):
|
| x, meta, orig_Ts, labels, affines = prepare_batch_data(batch, self.device)
|
| with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=self.use_amp):
|
| logits = self.model(x, meta=meta, orig_Ts=orig_Ts, affines=affines)
|
| loss = criterion(logits, labels)
|
| self.scaler.scale(loss / accum).backward()
|
| if (batch_idx + 1) % accum == 0:
|
| self._optimizer_step()
|
| total_loss += float(loss.item())
|
| num_batches += 1
|
| if self.rank == 0 and (batch_idx + 1) % self.cfg.logging.log_interval == 0:
|
| self.logger.info("Epoch %d [%d/%d] loss=%.6f", epoch + 1, batch_idx + 1, len(self.train_loader), loss.item())
|
| if num_batches % accum != 0:
|
| self._optimizer_step()
|
| return total_loss / max(1, num_batches)
|
|
|
| @torch.no_grad()
|
| def evaluate(self, loader, split_name: str) -> Dict[str, float]:
|
| self.model.eval()
|
| criterion = nn.CrossEntropyLoss()
|
| preds, labels_all = [], []
|
| total_loss = 0.0
|
| num_batches = 0
|
| for batch in loader:
|
| x, meta, orig_Ts, labels, affines = prepare_batch_data(batch, self.device)
|
| with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=self.use_amp):
|
| logits = self.model(x, meta=meta, orig_Ts=orig_Ts, affines=affines)
|
| loss = criterion(logits, labels)
|
| total_loss += float(loss.item())
|
| num_batches += 1
|
| preds.extend(torch.argmax(logits, dim=1).cpu().numpy())
|
| labels_all.extend(labels.cpu().numpy())
|
| metrics = {
|
| "loss": total_loss / max(1, num_batches),
|
| "accuracy": accuracy_score(labels_all, preds),
|
| "precision_macro": precision_score(labels_all, preds, average="macro", zero_division=0),
|
| "recall_macro": recall_score(labels_all, preds, average="macro", zero_division=0),
|
| "f1_macro": f1_score(labels_all, preds, average="macro", zero_division=0),
|
| "f1_weighted": f1_score(labels_all, preds, average="weighted", zero_division=0),
|
| }
|
| if self.rank == 0:
|
| self.logger.info("%s metrics: %s", split_name, metrics)
|
| return metrics
|
|
|
| def save(self, epoch: int, metrics: dict, is_best: bool):
|
| if self.rank != 0:
|
| return
|
| os.makedirs(self.cfg.logging.checkpoint_dir, exist_ok=True)
|
| payload = {
|
| "epoch": epoch,
|
| "model": self.model.state_dict(),
|
| "optimizer": self.optimizer.state_dict(),
|
| "scheduler": self.scheduler.state_dict(),
|
| "metrics": metrics,
|
| "config": vars(self.cfg.model),
|
| }
|
| torch.save(payload, os.path.join(self.cfg.logging.checkpoint_dir, "downstream_latest.pt"))
|
| if is_best:
|
| torch.save(payload, os.path.join(self.cfg.logging.checkpoint_dir, "downstream_best.pt"))
|
|
|
| def _load_best_for_test(self) -> None:
|
| best_path = os.path.join(self.cfg.logging.checkpoint_dir, "downstream_best.pt")
|
| if not os.path.exists(best_path):
|
| return
|
| checkpoint = torch.load(best_path, map_location=self.device)
|
| self.model.load_state_dict(checkpoint["model"])
|
|
|
| def _save_test_metrics(self, metrics: Dict[str, float]) -> None:
|
| if self.rank != 0:
|
| return
|
| os.makedirs(self.cfg.logging.checkpoint_dir, exist_ok=True)
|
| path = os.path.join(self.cfg.logging.checkpoint_dir, "test_metrics.json")
|
| with open(path, "w", encoding="utf-8") as f:
|
| json.dump(metrics, f, indent=2)
|
|
|
| def fit(self):
|
| self.build()
|
| if self.rank == 0:
|
| self.logger.info("Starting downstream on %s", self.device)
|
| self.logger.info("Train size=%d Val size=%d", len(self.train_loader.dataset), len(self.val_loader.dataset))
|
| best_f1 = -1.0
|
| for epoch in range(self.cfg.training.epochs):
|
| train_loss = self.train_one_epoch(epoch)
|
| val_metrics = self.evaluate(self.val_loader, "Validation")
|
| metrics = {"val": val_metrics, "train_loss": train_loss}
|
| is_best = val_metrics["f1_macro"] > best_f1
|
| if is_best:
|
| best_f1 = val_metrics["f1_macro"]
|
| self.save(epoch, metrics, is_best=is_best)
|
| if self.rank == 0:
|
| self.logger.info("Epoch %d done train=%.6f best_f1=%.6f", epoch + 1, train_loss, best_f1)
|
| if self.test_loader is not None:
|
| self._load_best_for_test()
|
| test_metrics = self.evaluate(self.test_loader, "Test")
|
| self._save_test_metrics(test_metrics)
|
|
|