|
|
import os |
|
|
import sys |
|
|
import argparse |
|
|
import yaml |
|
|
import datetime |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from sklearn.metrics import f1_score |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.distributed as dist |
|
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
from torch.utils.data import DataLoader, DistributedSampler |
|
|
from torch.cuda.amp import GradScaler, autocast |
|
|
|
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'hiera')) |
|
|
|
|
|
from hiera.hiera_mae import HieraClassifier |
|
|
from data.downstream_dataset import fMRITaskDataset, fMRITaskDataset1, EmoFMRIDataset, HCPtaskDataset |
|
|
from data.adni_dataset import ADNIDataset |
|
|
|
|
|
from utils.utils import MetricLogger, load_config, log_to_file, count_parameters, save_checkpoint, load_checkpoint, LabelScaler |
|
|
from utils.optim import create_optimizer, create_lr_scheduler |
|
|
from utils.ddp import setup_distributed, set_seed, cleanup_distributed |
|
|
|
|
|
|
|
|
def create_model(config): |
|
|
"""Create Hiera Classifier model from config""" |
|
|
task_config = config['task'] |
|
|
exp_config = config['experiment'] |
|
|
|
|
|
model_config = config['model'] |
|
|
pretrained_checkpoint_path = exp_config.get('pretrained_checkpoint', None) |
|
|
|
|
|
if pretrained_checkpoint_path: |
|
|
pretrain_config_path = Path(pretrained_checkpoint_path).parent.parent / 'config.yaml' |
|
|
if os.path.exists(pretrain_config_path): |
|
|
print(f"Loading model architecture from pretrained config: {pretrain_config_path}") |
|
|
pretrain_config = load_config(pretrain_config_path) |
|
|
model_config = pretrain_config['model'] |
|
|
else: |
|
|
print(f"Warning: Pretrained config not found at {pretrain_config_path}. Using finetune config for model architecture.") |
|
|
|
|
|
model = HieraClassifier( |
|
|
num_classes=task_config['num_classes'], |
|
|
task_type=task_config['task_type'], |
|
|
input_size=tuple(model_config['input_size']), |
|
|
in_chans=model_config['in_chans'], |
|
|
patch_kernel=tuple(model_config['patch_kernel']), |
|
|
patch_stride=tuple(model_config['patch_stride']), |
|
|
patch_padding=tuple(model_config['patch_padding']), |
|
|
q_stride=tuple(model_config['q_stride']), |
|
|
mask_unit_size=tuple(model_config['mask_unit_size']), |
|
|
embed_dim=model_config['embed_dim'], |
|
|
num_heads=model_config['num_heads'], |
|
|
stages=tuple(model_config['stages']), |
|
|
q_pool=model_config['q_pool'], |
|
|
mlp_ratio=model_config['mlp_ratio'], |
|
|
) |
|
|
|
|
|
|
|
|
if pretrained_checkpoint_path: |
|
|
if os.path.exists(pretrained_checkpoint_path): |
|
|
model.load_pretrained_mae(pretrained_checkpoint_path) |
|
|
else: |
|
|
print(f"Warning: Pretrained checkpoint not found at {pretrained_checkpoint_path}. Model is randomly initialized.") |
|
|
else: |
|
|
print("Warning: No pretrained checkpoint specified. Model is randomly initialized.") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
def create_dataloaders(config, is_distributed, rank, world_size): |
|
|
"""Create train, validation, and test dataloaders""" |
|
|
data_config = config['data'] |
|
|
task_config = config['task'] |
|
|
|
|
|
train_dataset = fMRITaskDataset( |
|
|
data_root=data_config['data_root'], |
|
|
datasets=data_config['datasets'], |
|
|
split_suffixes=data_config['train_split_suffixes'], |
|
|
crop_length=data_config['input_seq_len'], |
|
|
label_csv_path=task_config['csv'], |
|
|
task_type=task_config['task_type'] |
|
|
) |
|
|
|
|
|
val_dataset = fMRITaskDataset( |
|
|
data_root=data_config['data_root'], |
|
|
datasets=data_config['datasets'], |
|
|
split_suffixes=data_config['val_split_suffixes'], |
|
|
crop_length=data_config['input_seq_len'], |
|
|
label_csv_path=task_config['csv'], |
|
|
task_type=task_config['task_type'] |
|
|
) |
|
|
|
|
|
|
|
|
test_dataset = fMRITaskDataset( |
|
|
data_root=data_config['data_root'], |
|
|
datasets=data_config['datasets'], |
|
|
split_suffixes=data_config.get('test_split_suffixes', ['test']), |
|
|
crop_length=data_config['input_seq_len'], |
|
|
label_csv_path=task_config['csv'], |
|
|
task_type=task_config['task_type'] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if is_distributed: |
|
|
train_sampler = DistributedSampler( |
|
|
train_dataset, |
|
|
num_replicas=world_size, |
|
|
rank=rank, |
|
|
shuffle=True, |
|
|
seed=config['experiment']['seed'] |
|
|
) |
|
|
val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False) |
|
|
test_sampler = DistributedSampler(test_dataset, num_replicas=world_size, rank=rank, shuffle=False) |
|
|
else: |
|
|
train_sampler = None |
|
|
val_sampler = None |
|
|
test_sampler = None |
|
|
|
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=data_config['batch_size'], |
|
|
sampler=train_sampler, |
|
|
shuffle=(train_sampler is None), |
|
|
num_workers=data_config['num_workers'], |
|
|
pin_memory=data_config['pin_memory'], |
|
|
prefetch_factor=data_config.get('prefetch_factor', 2), |
|
|
drop_last=True |
|
|
) |
|
|
|
|
|
val_loader = DataLoader( |
|
|
val_dataset, |
|
|
batch_size=data_config['batch_size'], |
|
|
sampler=val_sampler, |
|
|
shuffle=False, |
|
|
num_workers=data_config['num_workers'], |
|
|
pin_memory=data_config['pin_memory'], |
|
|
prefetch_factor=data_config.get('prefetch_factor', 2), |
|
|
drop_last=False |
|
|
) |
|
|
|
|
|
test_loader = DataLoader( |
|
|
test_dataset, |
|
|
batch_size=data_config['batch_size'], |
|
|
sampler=test_sampler, |
|
|
shuffle=False, |
|
|
num_workers=data_config['num_workers'], |
|
|
pin_memory=data_config['pin_memory'], |
|
|
prefetch_factor=data_config.get('prefetch_factor', 2), |
|
|
drop_last=False |
|
|
) |
|
|
|
|
|
return train_loader, val_loader, test_loader, train_sampler |
|
|
|
|
|
|
|
|
def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, scaler, epoch, config, |
|
|
rank, world_size, label_scaler=None,log_file=None): |
|
|
"""Train for one epoch""" |
|
|
model.train() |
|
|
|
|
|
metric_logger = MetricLogger(delimiter=" ") |
|
|
header = f'Epoch: [{epoch}]' |
|
|
|
|
|
train_config = config['training'] |
|
|
log_config = config['logging'] |
|
|
task_config = config['task'] |
|
|
|
|
|
accum_iter = train_config['accum_iter'] |
|
|
use_amp = train_config['use_amp'] |
|
|
clip_grad = train_config.get('clip_grad', None) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
for data_iter_step, (samples, labels) in enumerate(metric_logger.log_every(train_loader, log_config['print_freq'], header)): |
|
|
|
|
|
if data_iter_step % accum_iter == 0: |
|
|
scheduler.step() |
|
|
|
|
|
|
|
|
samples = samples.cuda(rank, non_blocking=True) |
|
|
labels = labels.cuda(rank, non_blocking=True) |
|
|
|
|
|
|
|
|
|
|
|
with autocast(enabled=use_amp): |
|
|
outputs = model(samples) |
|
|
|
|
|
|
|
|
if task_config['task_type'] == 'classification': |
|
|
if labels.dim() > 1: |
|
|
labels = labels.squeeze() |
|
|
|
|
|
loss = criterion(outputs, labels) |
|
|
|
|
|
_, predicted = outputs.max(1) |
|
|
correct = predicted.eq(labels).sum().item() |
|
|
accuracy = correct / labels.size(0) |
|
|
else: |
|
|
if label_scaler is not None: |
|
|
target_for_loss = label_scaler.transform(labels) |
|
|
else: |
|
|
target_for_loss = labels |
|
|
loss = criterion(outputs.squeeze(), target_for_loss.squeeze()) |
|
|
accuracy = 0.0 |
|
|
|
|
|
loss = loss / accum_iter |
|
|
|
|
|
|
|
|
if use_amp: |
|
|
scaler.scale(loss).backward() |
|
|
|
|
|
if (data_iter_step + 1) % accum_iter == 0: |
|
|
if clip_grad is not None: |
|
|
scaler.unscale_(optimizer) |
|
|
nn.utils.clip_grad_norm_(model.parameters(), clip_grad) |
|
|
scaler.step(optimizer) |
|
|
scaler.update() |
|
|
optimizer.zero_grad() |
|
|
else: |
|
|
loss.backward() |
|
|
|
|
|
if (data_iter_step + 1) % accum_iter == 0: |
|
|
if clip_grad is not None: |
|
|
nn.utils.clip_grad_norm_(model.parameters(), clip_grad) |
|
|
optimizer.step() |
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
loss_value = loss.item() * accum_iter |
|
|
if not np.isfinite(loss_value): |
|
|
print(f"Loss is {loss_value}, stopping training") |
|
|
sys.exit(1) |
|
|
|
|
|
metric_logger.update(loss=loss_value) |
|
|
metric_logger.update(lr=optimizer.param_groups[0]["lr"]) |
|
|
if task_config['task_type'] == 'classification': |
|
|
metric_logger.update(acc=accuracy) |
|
|
|
|
|
|
|
|
metric_logger.synchronize_between_processes() |
|
|
print(f"Averaged stats: {metric_logger}") |
|
|
|
|
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def evaluate(model, data_loader, criterion, config, rank, epoch=None, label_scaler=None, mode='val'): |
|
|
|
|
|
model.eval() |
|
|
metric_logger = MetricLogger(delimiter=" ") |
|
|
header = f'{mode.capitalize()} Epoch: [{epoch}]' if epoch is not None else f'{mode.capitalize()}:' |
|
|
|
|
|
task_type = config['task']['task_type'] |
|
|
|
|
|
all_preds, all_targets = [], [] |
|
|
|
|
|
for samples, labels in metric_logger.log_every(data_loader, 50, header): |
|
|
samples = samples.cuda(rank, non_blocking=True) |
|
|
labels = labels.cuda(rank, non_blocking=True) |
|
|
|
|
|
outputs = model(samples) |
|
|
|
|
|
if task_type == 'classification': |
|
|
labels = labels.squeeze().long() if labels.dim() > 1 else labels.long() |
|
|
loss = criterion(outputs, labels) |
|
|
|
|
|
preds = outputs.argmax(1) |
|
|
acc = (preds == labels).float().mean().item() |
|
|
metric_logger.update(loss=loss.item(), acc=acc) |
|
|
|
|
|
all_preds.append(preds.cpu()) |
|
|
all_targets.append(labels.cpu()) |
|
|
|
|
|
else: |
|
|
if label_scaler is not None: |
|
|
target_norm = label_scaler.transform(labels) |
|
|
loss = criterion(outputs.view(-1), target_norm.view(-1)) |
|
|
|
|
|
metric_logger.update(loss=loss.item()) |
|
|
all_preds.append(outputs.detach().cpu().view(-1)) |
|
|
all_targets.append(target_norm.detach().cpu().view(-1)) |
|
|
|
|
|
if len(all_preds) > 0: |
|
|
all_preds = torch.cat(all_preds) |
|
|
all_targets = torch.cat(all_targets) |
|
|
|
|
|
if task_type == 'classification': |
|
|
f1 = f1_score(all_targets.numpy(), all_preds.numpy(), average='weighted') |
|
|
metric_logger.update(f1=f1) |
|
|
else: |
|
|
mse = torch.mean((all_preds - all_targets) ** 2).item() |
|
|
mae = torch.mean(torch.abs(all_preds - all_targets)).item() |
|
|
|
|
|
ss_res = torch.sum((all_targets - all_preds) ** 2) |
|
|
ss_tot = torch.sum((all_targets - all_targets.mean()) ** 2) |
|
|
r2 = (1 - ss_res / (ss_tot + 1e-8)).item() |
|
|
|
|
|
vx = all_preds - all_preds.mean() |
|
|
vy = all_targets - all_targets.mean() |
|
|
corr = (torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx**2)) * torch.sqrt(torch.sum(vy**2)) + 1e-8)).item() |
|
|
|
|
|
metric_logger.update(mse=mse, mae=mae, r2=r2, corr=corr) |
|
|
|
|
|
metric_logger.synchronize_between_processes() |
|
|
|
|
|
if rank == 0: |
|
|
print(f"[{mode.upper()}] Global stats: {metric_logger}") |
|
|
|
|
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main fine-tuning function""" |
|
|
|
|
|
parser = argparse.ArgumentParser(description='Hiera MAE 4D fMRI Downstream Fine-tuning') |
|
|
parser.add_argument('--config', type=str, default='configs/finetune_config.yaml', |
|
|
help='Path to config file') |
|
|
parser.add_argument('--resume', type=str, default=None, |
|
|
help='Path to checkpoint to resume from') |
|
|
parser.add_argument('--output_dir', type=str, default=None, |
|
|
help='Output directory (overrides config)') |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
config = load_config(args.config) |
|
|
|
|
|
|
|
|
if args.resume is not None: |
|
|
config['experiment']['resume'] = args.resume |
|
|
if args.output_dir is not None: |
|
|
config['experiment']['output_dir'] = args.output_dir |
|
|
|
|
|
|
|
|
is_distributed, rank, world_size, gpu = setup_distributed() |
|
|
|
|
|
|
|
|
set_seed(config['experiment']['seed'], rank) |
|
|
|
|
|
|
|
|
if rank == 0: |
|
|
output_dir = Path(config['experiment']['output_dir']) |
|
|
checkpoint_dir = output_dir / 'checkpoints' |
|
|
log_dir = output_dir / 'logs' |
|
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
checkpoint_dir.mkdir(parents=True, exist_ok=True) |
|
|
log_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
with open(output_dir / 'config.yaml', 'w') as f: |
|
|
yaml.dump(config, f, default_flow_style=False) |
|
|
|
|
|
|
|
|
log_file = output_dir / 'training_log.txt' |
|
|
with open(log_file, 'w') as f: |
|
|
f.write(f"Fine-tuning started at {datetime.datetime.now()}\n") |
|
|
f.write("="*80 + "\n") |
|
|
f.write(f"Config: {args.config}\n") |
|
|
f.write(f"Output directory: {config['experiment']['output_dir']}\n") |
|
|
f.write(f"Task type: {config['task']['task_type']}\n") |
|
|
f.write("="*80 + "\n\n") |
|
|
else: |
|
|
checkpoint_dir = None |
|
|
log_file = None |
|
|
|
|
|
if is_distributed: |
|
|
dist.barrier() |
|
|
|
|
|
model = create_model(config) |
|
|
model = model.cuda(gpu) |
|
|
|
|
|
if rank == 0: |
|
|
print("\nAnalyzing model architecture...") |
|
|
count_parameters(model, verbose=True) |
|
|
|
|
|
if is_distributed: |
|
|
model = DDP(model, device_ids=[gpu], find_unused_parameters=True) |
|
|
|
|
|
model_without_ddp = model.module if is_distributed else model |
|
|
|
|
|
if rank == 0: |
|
|
print("Creating dataloaders...") |
|
|
train_loader, val_loader, test_loader, train_sampler = create_dataloaders( |
|
|
config, is_distributed, rank, world_size |
|
|
) |
|
|
|
|
|
label_scaler = None |
|
|
if config['task']['task_type'] == 'regression': |
|
|
if rank == 0: |
|
|
mean_val = config['task']['mean'] |
|
|
scale_val = config['task']['std'] |
|
|
print(f"StandardScaler fit complete. Mean: {mean_val:.4f}, Std: {scale_val:.4f}") |
|
|
|
|
|
norm_mean = torch.tensor(mean_val, device=gpu, dtype=torch.float32) |
|
|
norm_std = torch.tensor(scale_val, device=gpu, dtype=torch.float32) |
|
|
|
|
|
if is_distributed: |
|
|
dist.broadcast(norm_mean, src=0) |
|
|
dist.broadcast(norm_std, src=0) |
|
|
|
|
|
label_scaler = LabelScaler(norm_mean, norm_std) |
|
|
|
|
|
if rank == 0: |
|
|
print(f"Training samples: {len(train_loader.dataset)}") |
|
|
print(f"Validation samples: {len(val_loader.dataset)}") |
|
|
print(f"Test samples: {len(test_loader.dataset)}") |
|
|
print(f"Batches per epoch: {len(train_loader)}") |
|
|
|
|
|
|
|
|
task_config = config['task'] |
|
|
if task_config['task_type'] == 'classification': |
|
|
criterion = nn.CrossEntropyLoss(label_smoothing=0.0) |
|
|
else: |
|
|
criterion = nn.MSELoss() |
|
|
|
|
|
|
|
|
if config['training'].get('freeze_encoder', False): |
|
|
if rank == 0: |
|
|
print("Freezing encoder weights. Only the head will be trained.") |
|
|
for name, param in model_without_ddp.named_parameters(): |
|
|
if 'head' not in name: |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
if rank == 0: |
|
|
print("Trainable parameters:") |
|
|
for name, param in model_without_ddp.named_parameters(): |
|
|
if param.requires_grad: |
|
|
print(name) |
|
|
|
|
|
|
|
|
optimizer = create_optimizer(model_without_ddp, config) |
|
|
scheduler = create_lr_scheduler(optimizer, config, len(train_loader)) |
|
|
|
|
|
|
|
|
scaler = GradScaler() if config['training']['use_amp'] else None |
|
|
|
|
|
|
|
|
start_epoch = 0 |
|
|
best_metric = 0.0 |
|
|
best_loss = float('inf') |
|
|
|
|
|
if config['experiment'].get('resume', None) is not None: |
|
|
start_epoch, best_metric, best_loss = load_checkpoint( |
|
|
config['experiment']['resume'], |
|
|
model_without_ddp, |
|
|
optimizer, |
|
|
scheduler, |
|
|
scaler |
|
|
) |
|
|
print(f"Resumed from epoch {start_epoch}. Best metric: {best_metric:.4f}, Best loss: {best_loss:.4f}") |
|
|
else: |
|
|
|
|
|
if config['task']['task_type'] == 'classification': |
|
|
best_metric = 0.0 |
|
|
else: |
|
|
best_metric = float('inf') |
|
|
|
|
|
|
|
|
if rank == 0: |
|
|
print("Starting fine-tuning...") |
|
|
print(f"Training from epoch {start_epoch} to {config['training']['epochs']}") |
|
|
|
|
|
for epoch in range(start_epoch, config['training']['epochs']): |
|
|
if is_distributed and train_sampler is not None: |
|
|
train_sampler.set_epoch(epoch) |
|
|
|
|
|
|
|
|
train_stats = train_one_epoch( |
|
|
model, train_loader, criterion, optimizer, scheduler, scaler, |
|
|
epoch, config, rank, world_size, label_scaler, log_file |
|
|
) |
|
|
|
|
|
|
|
|
if rank == 0: |
|
|
log_msg = f"Epoch {epoch} Training - " |
|
|
log_msg += " | ".join([f"{k}: {v:.4f}" for k, v in train_stats.items()]) |
|
|
print(log_msg) |
|
|
log_to_file(log_file, log_msg) |
|
|
|
|
|
|
|
|
if epoch % config['validation']['val_freq'] == 0 or epoch == config['training']['epochs'] - 1: |
|
|
print(f"DEBUG: label_scaler type is {type(label_scaler)}, value is {label_scaler}") |
|
|
val_stats = evaluate( |
|
|
model, val_loader, criterion, config, rank, epoch, label_scaler, 'val' |
|
|
) |
|
|
test_stats = evaluate(model, test_loader, criterion, config, rank, epoch, label_scaler, 'test' ) |
|
|
|
|
|
|
|
|
if rank == 0: |
|
|
log_msg = f"Epoch {epoch} Validation - " |
|
|
log_msg += " | ".join([f"{k}: {v:.4f}" for k, v in val_stats.items()]) |
|
|
print(log_msg) |
|
|
log_to_file(log_file, log_msg) |
|
|
|
|
|
log_msg = f"Epoch {epoch} Test - " |
|
|
log_msg += " | ".join([f"{k}: {v:.4f}" for k, v in test_stats.items()]) |
|
|
print(log_msg) |
|
|
log_to_file(log_file, log_msg) |
|
|
|
|
|
|
|
|
if rank == 0: |
|
|
if task_config['task_type'] == 'classification': |
|
|
|
|
|
current_metric = val_stats.get('acc', 0.0) |
|
|
is_best = current_metric > best_metric |
|
|
if is_best: |
|
|
best_metric = current_metric |
|
|
best_loss = val_stats['loss'] |
|
|
else: |
|
|
|
|
|
is_best = val_stats['loss'] < best_loss |
|
|
if is_best: |
|
|
best_loss = val_stats['loss'] |
|
|
best_metric = -best_loss |
|
|
|
|
|
checkpoint_state = { |
|
|
'epoch': epoch + 1, |
|
|
'model_state_dict': model_without_ddp.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'scheduler_state_dict': scheduler.state_dict(), |
|
|
'best_metric': best_metric, |
|
|
'best_loss': best_loss, |
|
|
'config': config, |
|
|
'train_stats': train_stats, |
|
|
'val_stats': val_stats, |
|
|
} |
|
|
|
|
|
if scaler is not None: |
|
|
checkpoint_state['scaler_state_dict'] = scaler.state_dict() |
|
|
|
|
|
save_checkpoint( |
|
|
checkpoint_state, |
|
|
is_best, |
|
|
checkpoint_dir, |
|
|
filename=f'checkpoint_epoch_{epoch}.pth' |
|
|
) |
|
|
|
|
|
checkpoint_msg = f"Checkpoint saved at epoch {epoch}" |
|
|
print(checkpoint_msg) |
|
|
log_to_file(log_file, checkpoint_msg) |
|
|
|
|
|
if is_best: |
|
|
if task_config['task_type'] == 'classification': |
|
|
best_msg = f"New best validation accuracy: {best_metric:.4f}" |
|
|
else: |
|
|
best_msg = f"New best validation loss: {best_loss:.4f}" |
|
|
print(best_msg) |
|
|
log_to_file(log_file, best_msg) |
|
|
|
|
|
|
|
|
if rank == 0 and (epoch + 1) % config['logging']['save_freq'] == 0: |
|
|
checkpoint_state = { |
|
|
'epoch': epoch + 1, |
|
|
'model_state_dict': model_without_ddp.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'scheduler_state_dict': scheduler.state_dict(), |
|
|
'best_metric': best_metric, |
|
|
'best_loss': best_loss, |
|
|
'config': config, |
|
|
} |
|
|
|
|
|
if scaler is not None: |
|
|
checkpoint_state['scaler_state_dict'] = scaler.state_dict() |
|
|
|
|
|
save_checkpoint( |
|
|
checkpoint_state, |
|
|
False, |
|
|
checkpoint_dir, |
|
|
filename=f'checkpoint_epoch_{epoch}.pth' |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
cleanup_distributed() |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|