|
|
from data.dataset import CheXpertDataset
|
|
|
from loss.assymetric import AsymmetricLoss
|
|
|
from models.mae import *
|
|
|
from models.densenet import *
|
|
|
from models.classifier import *
|
|
|
from torch.utils.data import DataLoader
|
|
|
import json
|
|
|
import os
|
|
|
import io
|
|
|
import sys
|
|
|
from sklearn.metrics import roc_auc_score,confusion_matrix
|
|
|
|
|
|
class TeeFile:
|
|
|
"""
|
|
|
File-like object that writes to multiple streams (e.g., stdout and a file)
|
|
|
Automatically handles string paths by opening them as files.
|
|
|
|
|
|
Usage:
|
|
|
# This now works with both file objects and paths
|
|
|
tee = TeeFile(sys.stdout, "/path/to/log.txt")
|
|
|
print("Hello", file=tee) # Writes to both stdout and the file
|
|
|
"""
|
|
|
def __init__(self, *file_objects_or_paths):
|
|
|
"""
|
|
|
Args:
|
|
|
*file_objects_or_paths: Mix of file objects (like sys.stdout)
|
|
|
or string paths to log files
|
|
|
"""
|
|
|
self.files = []
|
|
|
self.opened_files = []
|
|
|
|
|
|
for item in file_objects_or_paths:
|
|
|
if isinstance(item, str):
|
|
|
|
|
|
f = open(item, 'a', buffering=1)
|
|
|
self.files.append(f)
|
|
|
self.opened_files.append(f)
|
|
|
else:
|
|
|
|
|
|
self.files.append(item)
|
|
|
|
|
|
def write(self, data):
|
|
|
"""Write data to all streams"""
|
|
|
for f in self.files:
|
|
|
try:
|
|
|
f.write(data)
|
|
|
f.flush()
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"Warning: Could not write to {f}: {e}", file=sys.stderr)
|
|
|
|
|
|
def flush(self):
|
|
|
"""Flush all streams"""
|
|
|
for f in self.files:
|
|
|
try:
|
|
|
f.flush()
|
|
|
except:
|
|
|
pass
|
|
|
|
|
|
def isatty(self):
|
|
|
"""Check if any stream is a terminal (for tqdm compatibility)"""
|
|
|
return any(getattr(f, "isatty", lambda: False)() for f in self.files)
|
|
|
|
|
|
def fileno(self):
|
|
|
"""Get file descriptor from any real file-like stream"""
|
|
|
for f in self.files:
|
|
|
if hasattr(f, "fileno"):
|
|
|
try:
|
|
|
return f.fileno()
|
|
|
except Exception:
|
|
|
pass
|
|
|
raise io.UnsupportedOperation("No fileno available")
|
|
|
|
|
|
def close(self):
|
|
|
"""Close any files we opened"""
|
|
|
for f in self.opened_files:
|
|
|
try:
|
|
|
f.close()
|
|
|
except:
|
|
|
pass
|
|
|
self.opened_files.clear()
|
|
|
|
|
|
def __del__(self):
|
|
|
"""Cleanup on deletion"""
|
|
|
self.close()
|
|
|
|
|
|
def __enter__(self):
|
|
|
"""Context manager support"""
|
|
|
return self
|
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
|
"""Context manager cleanup"""
|
|
|
self.close()
|
|
|
return False
|
|
|
|
|
|
class MAETrainer:
|
|
|
def __init__(self,configs={}):
|
|
|
|
|
|
self.configs=configs
|
|
|
os.makedirs(configs["logdir"],exist_ok=True)
|
|
|
log_path_train = os.path.join(configs["logdir"], "training_log.txt")
|
|
|
log_path_val = os.path.join(configs["logdir"], "val_log.txt")
|
|
|
log_path_test = os.path.join(configs["logdir"], "test_log.txt")
|
|
|
|
|
|
self.traintee = TeeFile(sys.stdout, log_path_train)
|
|
|
self.valtee = TeeFile(sys.stdout, log_path_val)
|
|
|
self.testtee = TeeFile(sys.stdout, log_path_test)
|
|
|
|
|
|
for dir in self.configs["dirsToMake"]: os.makedirs(dir,exist_ok=True)
|
|
|
|
|
|
self.model=MaskedAutoEncoder(
|
|
|
c=configs["channels"],
|
|
|
mask_ratio=configs["mask_ratio"],
|
|
|
dropout=configs["dropout"],
|
|
|
img_size=configs["img_size"],
|
|
|
encoder_dim=configs["encoder_dim"],
|
|
|
mlp_dim=configs["mlp_dim"],
|
|
|
decoder_dim=configs["decoder_dim"],
|
|
|
encoder_depth=configs["encoder_depth"],
|
|
|
encoder_head=configs["encoder_head"],
|
|
|
decoder_depth=configs["decoder_depth"],
|
|
|
decoder_head=configs["decoder_head"],
|
|
|
patch_size=configs["patch_size"]
|
|
|
).to(configs["device"])
|
|
|
|
|
|
self.criterion=mae_loss
|
|
|
|
|
|
self.optimizer=torch.optim.AdamW(self.model.parameters(),configs["lr"], weight_decay=configs["weight_decay"])
|
|
|
self.schedular1=torch.optim.lr_scheduler.LinearLR(self.optimizer,start_factor=0.1,end_factor=1.0,total_iters=configs["warmup"])
|
|
|
self.schedular2=torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer,T_max=configs["num_epochs"]-configs["warmup"])
|
|
|
self.schedular=torch.optim.lr_scheduler.SequentialLR (self.optimizer,schedulers=[self.schedular1,self.schedular2],milestones=[configs["warmup"]])
|
|
|
self.scaler=torch.amp.GradScaler()
|
|
|
|
|
|
self.train_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["train_csv"],root_dir=configs["datadir"],augment=True,use_frontal_only=True)
|
|
|
self.val_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["val_csv"],root_dir=configs["datadir"],augment=False,use_frontal_only=True )
|
|
|
self.class_Weights=self.train_dataset.get_class_weights().to(self.configs["device"])
|
|
|
self.sample_Weights=self.train_dataset.get_sample_weights()
|
|
|
self.sampler=torch.utils.data.WeightedRandomSampler(self.sample_Weights,num_samples=len(self.sample_Weights))
|
|
|
self.trainloader=DataLoader(self.train_dataset,batch_size=configs["batch_size"],sampler=self.sampler,num_workers=8,pin_memory=True,persistent_workers=True)
|
|
|
self.valloader=DataLoader(self.val_dataset,batch_size=configs["batch_size"],shuffle=False,num_workers=8,pin_memory=True,persistent_workers=True)
|
|
|
self.history={"train_loss":[],"val_loss":[]}
|
|
|
|
|
|
self.current_epoch=0
|
|
|
|
|
|
if os.path.exists(self.configs["resume"]):
|
|
|
loadedpickle=torch.load(self.configs["resume"],map_location=self.configs["device"])
|
|
|
self.model.load_state_dict(loadedpickle["model"],strict=False)
|
|
|
self.optimizer.load_state_dict(loadedpickle["optimizer"])
|
|
|
self.schedular.load_state_dict(loadedpickle["schedular"])
|
|
|
self.schedular1.load_state_dict(loadedpickle["schedular1"])
|
|
|
self.schedular2.load_state_dict(loadedpickle["schedular2"])
|
|
|
self.scaler.load_state_dict(loadedpickle["scaler"])
|
|
|
self.current_epoch=loadedpickle["epoch"]+1
|
|
|
|
|
|
|
|
|
|
|
|
self.test_dataset = None
|
|
|
self.testloader = None
|
|
|
if configs.get("test_csv"):
|
|
|
self.test_dataset = CheXpertDataset(
|
|
|
zip_path=configs["zip_path"],
|
|
|
csv_path=configs["test_csv"],
|
|
|
root_dir=configs["datadir"],
|
|
|
augment=False,
|
|
|
use_frontal_only=True
|
|
|
)
|
|
|
self.testloader = DataLoader(
|
|
|
self.test_dataset,
|
|
|
batch_size=configs["batch_size"],
|
|
|
shuffle=False,
|
|
|
num_workers=8,
|
|
|
pin_memory=True,
|
|
|
persistent_workers=True
|
|
|
)
|
|
|
print(f"Test loader ready – {len(self.test_dataset)} images")
|
|
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
torch.backends.cudnn.enabled = True
|
|
|
|
|
|
|
|
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
|
|
|
|
|
|
|
|
if hasattr(self.model, 'enable_gradient_checkpointing'):
|
|
|
self.model.enable_gradient_checkpointing()
|
|
|
@staticmethod
|
|
|
def plot_training_metrics(metrics, epoch,figs_path):
|
|
|
import matplotlib.pyplot as plt
|
|
|
"""
|
|
|
Plot loss and AUC curves from training metrics.
|
|
|
|
|
|
Args:
|
|
|
metrics (dict): Dictionary containing lists for each metric key:
|
|
|
{
|
|
|
"train_loss": [...],
|
|
|
"val_loss": [...]
|
|
|
}
|
|
|
epoch (int): Current epoch number (used for title or axis scaling)
|
|
|
"""
|
|
|
epochs = list(range(1, epoch + 1))
|
|
|
|
|
|
|
|
|
keys = ["train_loss","val_loss"]
|
|
|
lengths = [len(metrics[k]) for k in keys if k in metrics]
|
|
|
if not lengths:
|
|
|
return
|
|
|
n = min(lengths)
|
|
|
|
|
|
|
|
|
m = {k: metrics[k][:n] for k in keys if k in metrics}
|
|
|
epochs = list(range(1, n + 1))
|
|
|
|
|
|
plt.figure(figsize=(14, 6))
|
|
|
|
|
|
|
|
|
|
|
|
plt.subplot(1, 2, 1)
|
|
|
plt.plot(epochs, metrics["train_loss"], label="Train Loss", marker='o')
|
|
|
plt.plot(epochs, metrics["val_loss"], label="Val Loss", marker='s')
|
|
|
plt.xlabel("Epoch")
|
|
|
plt.ylabel("Loss")
|
|
|
plt.title("Training & Validation Loss")
|
|
|
plt.legend()
|
|
|
plt.grid(True, linestyle='--', alpha=0.6)
|
|
|
|
|
|
|
|
|
plt.tight_layout()
|
|
|
os.makedirs(os.path.join(figs_path,str(epoch)),exist_ok=True)
|
|
|
plt.savefig(os.path.join(figs_path,str(epoch),"metrics.png"))
|
|
|
plt.show()
|
|
|
|
|
|
def train_epoch(self, epoch, looper):
|
|
|
self.model.train()
|
|
|
running_loss = 0.0
|
|
|
all_preds = []
|
|
|
all_targets = []
|
|
|
current_loss=0
|
|
|
total_batches = len(self.trainloader)
|
|
|
|
|
|
for batch_idx, data in looper:
|
|
|
image = data['image'].to(self.configs["device"], non_blocking=True)
|
|
|
target = data['labels'].to(self.configs["device"], non_blocking=True)
|
|
|
|
|
|
with torch.autocast(device_type=self.configs["device"].type, dtype=torch.float16):
|
|
|
img,preds,mask = self.model(image)
|
|
|
loss = self.criterion(img,preds,mask)
|
|
|
|
|
|
loss_back = loss / self.configs["accumulation"]
|
|
|
running_loss += loss.item()
|
|
|
|
|
|
if torch.isfinite(loss):
|
|
|
|
|
|
self.scaler.scale(loss_back).backward()
|
|
|
else:
|
|
|
self.optimizer.zero_grad(set_to_none=True)
|
|
|
continue
|
|
|
|
|
|
if (batch_idx + 1) % self.configs["accumulation"] == 0 or batch_idx == total_batches - 1:
|
|
|
self.scaler.unscale_(self.optimizer)
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
|
self.scaler.step(self.optimizer)
|
|
|
self.scaler.update()
|
|
|
|
|
|
self.schedular.step()
|
|
|
self.optimizer.zero_grad(set_to_none=True)
|
|
|
|
|
|
|
|
|
|
|
|
current_loss = running_loss / (batch_idx + 1)
|
|
|
if (batch_idx + 1) % 10 == 0:
|
|
|
current_lr = self.optimizer.param_groups[0]['lr']
|
|
|
looper.set_postfix({
|
|
|
"lr": f"{current_lr:.2e}","batch":f"{batch_idx}/{total_batches}",
|
|
|
"epoch": f"{epoch}/{self.configs['num_epochs']}",
|
|
|
"loss": f"{current_loss:.3f}",
|
|
|
})
|
|
|
|
|
|
return current_loss
|
|
|
def validate(self, epoch, looper):
|
|
|
self.model.eval()
|
|
|
val_loss = 0.0
|
|
|
all_preds = []
|
|
|
all_targets = []
|
|
|
lenloader=len(self.valloader)
|
|
|
current_loss=0
|
|
|
with torch.no_grad():
|
|
|
for batch_idx, data in looper:
|
|
|
image = data["image"].to(self.configs["device"], non_blocking=True)
|
|
|
target = data["labels"].to(self.configs["device"], non_blocking=True)
|
|
|
|
|
|
with torch.autocast(device_type=self.configs["device"].type, dtype=torch.float16):
|
|
|
img,preds,mask = self.model(image)
|
|
|
loss = self.criterion(img,preds,mask)
|
|
|
|
|
|
val_loss += loss.item()
|
|
|
|
|
|
|
|
|
current_loss = val_loss / (batch_idx + 1)
|
|
|
if (batch_idx + 1) % 10 == 0 :
|
|
|
|
|
|
looper.set_postfix({
|
|
|
"epoch": f"{epoch}/{self.configs['num_epochs']}",
|
|
|
"batch":f"{batch_idx}/{lenloader}",
|
|
|
"loss": f"{current_loss:.3f}",
|
|
|
})
|
|
|
|
|
|
return current_loss
|
|
|
def train(self):
|
|
|
|
|
|
for epoch in range(self.current_epoch,self.configs["num_epochs"]):
|
|
|
trainlooper=tqdm(enumerate(self.trainloader),desc="training: ", leave=False,file=self.traintee)
|
|
|
vallooper=tqdm(enumerate(self.valloader),desc="validating: ",leave=False,file=self.valtee)
|
|
|
|
|
|
|
|
|
self.model.train()
|
|
|
self.optimizer.zero_grad(set_to_none=True)
|
|
|
|
|
|
running_loss=self.train_epoch(epoch,trainlooper)
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
val_loss=self.validate(epoch,vallooper)
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
gc.collect()
|
|
|
|
|
|
if (self.history["val_loss"] and (val_loss<min(self.history["val_loss"]))) :
|
|
|
checkpoint={"model":self.model.state_dict(),"optimizer":self.optimizer.state_dict(),"schedular":self.schedular.state_dict(),"schedular1":self.schedular1.state_dict(),"schedular2":self.schedular2.state_dict(),"scaler":self.scaler.state_dict(),"epoch":epoch}
|
|
|
torch.save(checkpoint, self.configs["resume"])
|
|
|
|
|
|
print(f"train loss {running_loss} val loss {val_loss}")
|
|
|
|
|
|
self.history["train_loss"].append(float(running_loss))
|
|
|
self.history["val_loss"].append(float(val_loss))
|
|
|
|
|
|
if epoch%10==0:
|
|
|
historyfile=os.path.join(self.configs["logdir"],"history.json")
|
|
|
if os.path.exists(historyfile):
|
|
|
with open(historyfile,"r") as f:
|
|
|
history=json.load(f)
|
|
|
history["train_loss"]+=self.history["train_loss"]
|
|
|
history["val_loss"]+=self.history["val_loss"]
|
|
|
with open(historyfile,"w") as f:
|
|
|
json.dump(self.history,f)
|
|
|
f.close()
|
|
|
MAETrainer.plot_training_metrics(self.history,epoch+1,self.configs["logdir"])
|
|
|
|
|
|
self.current_epoch=epoch
|
|
|
|
|
|
class Trainer:
|
|
|
def __init__(self,configs={}):
|
|
|
|
|
|
self.configs=configs
|
|
|
os.makedirs(configs["logdir"],exist_ok=True)
|
|
|
log_path_train = os.path.join(configs["logdir"], "training_log.txt")
|
|
|
log_path_val = os.path.join(configs["logdir"], "val_log.txt")
|
|
|
log_path_test = os.path.join(configs["logdir"], "test_log.txt")
|
|
|
|
|
|
self.traintee = TeeFile(sys.stdout, log_path_train)
|
|
|
self.valtee = TeeFile(sys.stdout, log_path_val)
|
|
|
self.testtee = TeeFile(sys.stdout, log_path_test)
|
|
|
|
|
|
for dir in self.configs["dirsToMake"]: os.makedirs(dir,exist_ok=True)
|
|
|
|
|
|
self.model=XRAYClassifier(
|
|
|
c=configs["channels"],
|
|
|
num_classes=configs["num_classes"],
|
|
|
mask_ratio=configs["mask_ratio"],
|
|
|
dropout=configs["dropout"],
|
|
|
img_size=configs["img_size"],
|
|
|
encoder_dim=configs["encoder_dim"],
|
|
|
mlp_dim=configs["mlp_dim"],
|
|
|
decoder_dim=configs["decoder_dim"],
|
|
|
encoder_depth=configs["encoder_depth"],
|
|
|
encoder_head=configs["encoder_head"],
|
|
|
decoder_depth=configs["decoder_depth"],
|
|
|
decoder_head=configs["decoder_head"],
|
|
|
patch_size=configs["patch_size"]
|
|
|
).to(configs["device"])
|
|
|
|
|
|
|
|
|
|
|
|
self.optimizer=torch.optim.AdamW(self.model.parameters(),configs["lr"], weight_decay=configs["weight_decay"])
|
|
|
self.schedular1=torch.optim.lr_scheduler.LinearLR(self.optimizer,start_factor=0.1,end_factor=1.0,total_iters=configs["warmup"])
|
|
|
self.schedular2=torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer,T_max=configs["num_epochs"]-configs["warmup"])
|
|
|
self.schedular=torch.optim.lr_scheduler.SequentialLR (self.optimizer,schedulers=[self.schedular1,self.schedular2],milestones=[configs["warmup"]])
|
|
|
self.scaler=torch.amp.GradScaler()
|
|
|
|
|
|
self.train_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["train_csv"],root_dir=configs["datadir"],augment=True,use_frontal_only=True,mask_dir=configs["maskdir"])
|
|
|
self.val_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["val_csv"],root_dir=configs["datadir"],augment=False,use_frontal_only=True,mask_dir=configs["maskdir"] )
|
|
|
|
|
|
self.class_Weights=self.train_dataset.get_class_weights().to(self.configs["device"])
|
|
|
self.sample_Weights=self.train_dataset.get_sample_weights()
|
|
|
self.sampler=torch.utils.data.WeightedRandomSampler(self.sample_Weights,num_samples=len(self.sample_Weights))
|
|
|
self.trainloader=DataLoader(self.train_dataset,batch_size=configs["batch_size"],sampler=self.sampler,num_workers=0,pin_memory=True,persistent_workers=False)
|
|
|
self.valloader=DataLoader(self.val_dataset,batch_size=configs["batch_size"],shuffle=False,num_workers=0,pin_memory=True,persistent_workers=False)
|
|
|
self.criterion=AsymmetricLoss(class_weights=self.class_Weights).to(self.configs["device"])
|
|
|
self.history={"train_loss":[],"val_loss":[],"train_macro_auc":[],"val_macro_auc":[],"train_micro_auc":[],"val_micro_auc":[]}
|
|
|
if os.path.exists(os.path.join(self.configs["logdir"],"history.json")):
|
|
|
with open(os.path.join(self.configs["logdir"],"history.json"),'r') as hf:
|
|
|
self.history=json.load(hf)
|
|
|
hf.close()
|
|
|
self.current_epoch=0
|
|
|
|
|
|
self.optimal_thresholds =[0.5]*14
|
|
|
|
|
|
if os.path.exists(self.configs["resume"]):
|
|
|
ckpt = torch.load(self.configs["resume"], map_location=self.configs["device"],weights_only=False)
|
|
|
self.model.load_state_dict(ckpt["model"], strict=False)
|
|
|
self.optimizer.load_state_dict(ckpt["optimizer"])
|
|
|
self.schedular.load_state_dict(ckpt["schedular"])
|
|
|
self.schedular1.load_state_dict(ckpt["schedular1"])
|
|
|
self.schedular2.load_state_dict(ckpt["schedular2"])
|
|
|
self.scaler.load_state_dict(ckpt["scaler"])
|
|
|
self.current_epoch = ckpt.get("epoch", -1) + 1
|
|
|
self.optimal_thresholds =ckpt.get("thresholds")
|
|
|
else:
|
|
|
|
|
|
bb = torch.load(self.configs["backbone"], map_location=self.configs["device"],weights_only=False)
|
|
|
|
|
|
|
|
|
state = bb["model"]
|
|
|
if any(k.startswith("module.") for k in state.keys()):
|
|
|
from collections import OrderedDict
|
|
|
state = OrderedDict((k.replace("module.", "", 1), v) for k, v in state.items())
|
|
|
|
|
|
missing, unexpected = self.model.mae.load_state_dict(state, strict=False)
|
|
|
print("loaded backbone")
|
|
|
if missing: print(f"Missing keys: {len(missing)} (showing first 5): {missing[:5]}")
|
|
|
if unexpected: print(f"Unexpected keys: {len(unexpected)} (first 5): {unexpected[:5]}")
|
|
|
|
|
|
|
|
|
for p in self.model.mae.parameters():
|
|
|
p.requires_grad = False
|
|
|
if os.path.exists(self.configs["densebackbone"]):
|
|
|
densebb=torch.load(self.configs["densebackbone"], map_location=self.configs["device"])
|
|
|
densestate = densebb["model"]
|
|
|
if any(k.startswith("module.") for k in state.keys()):
|
|
|
from collections import OrderedDict
|
|
|
state = OrderedDict((k.replace("module.", "", 1), v) for k, v in densestate.items())
|
|
|
densemissing, denseunexpected = self.model.dense.load_state_dict(densestate, strict=False)
|
|
|
print("loaded dense backbone")
|
|
|
if densemissing: print(f"Missing keys: {len(densemissing)} (showing first 5): {densemissing[:5]}")
|
|
|
if denseunexpected: print(f"Unexpected keys: {len(denseunexpected)} (first 5): {denseunexpected[:5]}")
|
|
|
|
|
|
self.test_dataset = None
|
|
|
self.testloader = None
|
|
|
if configs.get("test_csv"):
|
|
|
self.test_dataset = CheXpertDataset(
|
|
|
zip_path=configs["zip_path"],
|
|
|
csv_path=configs["test_csv"],
|
|
|
root_dir=configs["datadir"],
|
|
|
augment=False,
|
|
|
use_frontal_only=True
|
|
|
)
|
|
|
self.testloader = DataLoader(
|
|
|
self.test_dataset,
|
|
|
batch_size=configs["batch_size"],
|
|
|
shuffle=False,
|
|
|
num_workers=0,
|
|
|
pin_memory=True,
|
|
|
persistent_workers=False
|
|
|
)
|
|
|
print(f"Test loader ready – {len(self.test_dataset)} images")
|
|
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
torch.backends.cudnn.enabled = True
|
|
|
|
|
|
|
|
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
|
|
|
|
|
|
|
|
if hasattr(self.model, 'enable_gradient_checkpointing'):
|
|
|
self.model.enable_gradient_checkpointing()
|
|
|
@staticmethod
|
|
|
def plot_training_metrics(metrics, epoch,figs_path):
|
|
|
import matplotlib.pyplot as plt
|
|
|
"""
|
|
|
Plot loss and AUC curves from training metrics.
|
|
|
|
|
|
Args:
|
|
|
metrics (dict): Dictionary containing lists for each metric key:
|
|
|
{
|
|
|
"train_loss": [...],
|
|
|
"val_loss": [...],
|
|
|
"train_macro_auc": [...],
|
|
|
"val_macro_auc": [...],
|
|
|
"train_micro_auc": [...],
|
|
|
"val_micro_auc": [...]
|
|
|
}
|
|
|
epoch (int): Current epoch number (used for title or axis scaling)
|
|
|
"""
|
|
|
epochs = list(range(1, epoch + 1))
|
|
|
|
|
|
|
|
|
keys = ["train_loss","val_loss","train_macro_auc","val_macro_auc","train_micro_auc","val_micro_auc"]
|
|
|
lengths = [len(metrics[k]) for k in keys if k in metrics]
|
|
|
if not lengths:
|
|
|
return
|
|
|
n = min(lengths)
|
|
|
|
|
|
|
|
|
m = {k: metrics[k][:n] for k in keys if k in metrics}
|
|
|
epochs = list(range(1, n + 1))
|
|
|
|
|
|
plt.figure(figsize=(14, 6))
|
|
|
|
|
|
|
|
|
|
|
|
plt.subplot(1, 2, 1)
|
|
|
plt.plot(epochs, metrics["train_loss"], label="Train Loss", marker='o')
|
|
|
plt.plot(epochs, metrics["val_loss"], label="Val Loss", marker='s')
|
|
|
plt.xlabel("Epoch")
|
|
|
plt.ylabel("Loss")
|
|
|
plt.title("Training & Validation Loss")
|
|
|
plt.legend()
|
|
|
plt.grid(True, linestyle='--', alpha=0.6)
|
|
|
|
|
|
|
|
|
plt.subplot(1, 2, 2)
|
|
|
plt.plot(epochs, metrics["train_macro_auc"], label="Train Macro AUC", marker='o')
|
|
|
plt.plot(epochs, metrics["val_macro_auc"], label="Val Macro AUC", marker='s')
|
|
|
plt.plot(epochs, metrics["train_micro_auc"], label="Train Micro AUC", marker='^')
|
|
|
plt.plot(epochs, metrics["val_micro_auc"], label="Val Micro AUC", marker='v')
|
|
|
plt.xlabel("Epoch")
|
|
|
plt.ylabel("AUC")
|
|
|
plt.title("Training & Validation AUC (Macro/Micro)")
|
|
|
plt.legend()
|
|
|
plt.grid(True, linestyle='--', alpha=0.6)
|
|
|
|
|
|
plt.tight_layout()
|
|
|
os.makedirs(os.path.join(figs_path,str(epoch)),exist_ok=True)
|
|
|
plt.savefig(os.path.join(figs_path,str(epoch),"metrics.png"))
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
def train_epoch(self, epoch, looper):
|
|
|
self.model.train()
|
|
|
running_loss = 0.0
|
|
|
all_preds = []
|
|
|
all_targets = []
|
|
|
|
|
|
total_batches = len(self.trainloader)
|
|
|
|
|
|
for batch_idx, data in looper:
|
|
|
image = data['image'].to(self.configs["device"], non_blocking=True)
|
|
|
target = data['labels'].to(self.configs["device"], non_blocking=True)
|
|
|
|
|
|
|
|
|
logits = self.model(image)
|
|
|
|
|
|
|
|
|
preds = torch.sigmoid(logits.float())
|
|
|
loss = self.criterion(preds, target)
|
|
|
|
|
|
loss_back = loss / self.configs["accumulation"]
|
|
|
running_loss += loss.item()
|
|
|
|
|
|
if torch.isfinite(loss):
|
|
|
loss_back.backward()
|
|
|
|
|
|
else:
|
|
|
self.optimizer.zero_grad(set_to_none=True)
|
|
|
continue
|
|
|
|
|
|
if (batch_idx + 1) % self.configs["accumulation"] == 0 or batch_idx == total_batches - 1:
|
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
|
|
|
|
|
|
|
self.optimizer.step()
|
|
|
|
|
|
self.optimizer.zero_grad(set_to_none=True)
|
|
|
|
|
|
|
|
|
all_preds.append(preds.detach().cpu())
|
|
|
all_targets.append(target.detach().cpu())
|
|
|
|
|
|
|
|
|
current_loss = running_loss / (batch_idx + 1)
|
|
|
if (batch_idx + 1) % 500 == 0 and len(all_preds) > 0:
|
|
|
preds_np = torch.cat(all_preds).numpy()
|
|
|
targets_np = torch.cat(all_targets).numpy()
|
|
|
macro_auc = roc_auc_score(targets_np, preds_np, average='macro')
|
|
|
micro_auc = roc_auc_score(targets_np, preds_np, average='micro')
|
|
|
|
|
|
current_lr = self.optimizer.param_groups[0]['lr']
|
|
|
looper.set_postfix({
|
|
|
"lr": f"{current_lr:.2e}","batch":f"{batch_idx}/{total_batches}",
|
|
|
"epoch": f"{epoch}/{self.configs['num_epochs']}",
|
|
|
"loss": f"{current_loss:.3f}",
|
|
|
"macro": f"{macro_auc:.3f}",
|
|
|
"micro": f"{micro_auc:.3f}"
|
|
|
})
|
|
|
|
|
|
|
|
|
preds_full = torch.cat(all_preds).numpy()
|
|
|
targets_full = torch.cat(all_targets).numpy()
|
|
|
final_loss = running_loss / total_batches
|
|
|
final_macro_auc = roc_auc_score(targets_full, preds_full, average='macro')
|
|
|
final_micro_auc = roc_auc_score(targets_full, preds_full, average='micro')
|
|
|
|
|
|
del all_preds, all_targets, preds_full, targets_full
|
|
|
|
|
|
return final_loss, final_macro_auc, final_micro_auc
|
|
|
|
|
|
def validate(self, epoch, looper):
|
|
|
self.model.eval()
|
|
|
val_loss = 0.0
|
|
|
all_preds = []
|
|
|
all_targets = []
|
|
|
lenloader=len(self.valloader)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for batch_idx, data in looper:
|
|
|
image = data["image"].to(self.configs["device"], non_blocking=True)
|
|
|
target = data["labels"].to(self.configs["device"], non_blocking=True)
|
|
|
|
|
|
logits = self.model(image)
|
|
|
|
|
|
preds = torch.sigmoid(logits.float())
|
|
|
loss = self.criterion(preds, target)
|
|
|
|
|
|
val_loss += loss.item()
|
|
|
|
|
|
all_preds.append(preds.detach().cpu())
|
|
|
all_targets.append(target.detach().cpu())
|
|
|
|
|
|
|
|
|
|
|
|
current_loss = val_loss / (batch_idx + 1)
|
|
|
if (batch_idx + 1) % 200 == 0 and len(all_preds) > 0:
|
|
|
preds_np = torch.cat(all_preds).numpy()
|
|
|
targets_np = torch.cat(all_targets).numpy()
|
|
|
macro_auc = roc_auc_score(targets_np, preds_np, average='macro')
|
|
|
micro_auc = roc_auc_score(targets_np, preds_np, average='micro')
|
|
|
looper.set_postfix({
|
|
|
"epoch": f"{epoch}/{self.configs['num_epochs']}",
|
|
|
"batch":f"{batch_idx}/{lenloader}",
|
|
|
"loss": f"{current_loss:.3f}",
|
|
|
"macro": f"{macro_auc:.3f}",
|
|
|
"micro": f"{micro_auc:.3f}"
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
preds_full = torch.cat(all_preds).numpy()
|
|
|
targets_full = torch.cat(all_targets).numpy()
|
|
|
num_classes = 14
|
|
|
new_thresholds = [0.5] * num_classes
|
|
|
|
|
|
for class_idx in range(num_classes):
|
|
|
if targets_full[:, class_idx].sum() == 0:
|
|
|
|
|
|
continue
|
|
|
|
|
|
thresholds = np.arange(0.1, 0.9, 0.02)
|
|
|
best_score = -1
|
|
|
best_threshold = 0.5
|
|
|
|
|
|
for threshold in thresholds:
|
|
|
preds_bin = (preds_full[:, class_idx] >= threshold).astype(int)
|
|
|
tn, fp, fn, tp = confusion_matrix(
|
|
|
targets_full[:, class_idx].astype(int),
|
|
|
preds_bin
|
|
|
).ravel()
|
|
|
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
|
|
|
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
|
|
|
score = sensitivity + specificity - 1
|
|
|
|
|
|
if score > best_score:
|
|
|
best_score = score
|
|
|
best_threshold = threshold
|
|
|
|
|
|
new_thresholds[class_idx] = best_threshold
|
|
|
|
|
|
|
|
|
self.optimal_thresholds = new_thresholds
|
|
|
|
|
|
|
|
|
final_loss = val_loss / lenloader
|
|
|
final_macro_auc = roc_auc_score(targets_full, preds_full, average='macro')
|
|
|
final_micro_auc = roc_auc_score(targets_full, preds_full, average='micro')
|
|
|
|
|
|
del all_preds, all_targets, preds_full, targets_full
|
|
|
|
|
|
return final_loss, final_macro_auc, final_micro_auc
|
|
|
def train(self):
|
|
|
|
|
|
for epoch in range(self.current_epoch,self.configs["num_epochs"]):
|
|
|
trainlooper=tqdm(enumerate(self.trainloader),desc="training: ", leave=True,file=self.traintee)
|
|
|
vallooper=tqdm(enumerate(self.valloader),desc="validating: ",leave=True,file=self.valtee)
|
|
|
|
|
|
|
|
|
self.model.train()
|
|
|
self.schedular.step()
|
|
|
self.optimizer.zero_grad(set_to_none=True)
|
|
|
|
|
|
running_loss,macro_auc,micro_auc=self.train_epoch(epoch,trainlooper)
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
val_loss,val_macro_auc,val_micro_auc=self.validate(epoch,vallooper)
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
gc.collect()
|
|
|
|
|
|
if (self.history["val_macro_auc"] and (val_macro_auc>max(self.history["val_macro_auc"]))) or (self.history["val_micro_auc"] and val_micro_auc>max(self.history["val_micro_auc"])):
|
|
|
checkpoint={"model":self.model.state_dict(),"optimizer":self.optimizer.state_dict(),"schedular":self.schedular.state_dict(),
|
|
|
"schedular1":self.schedular1.state_dict(),"schedular2":self.schedular2.state_dict(),"scaler":self.scaler.state_dict(),"epoch":epoch
|
|
|
,"thresholds":self.optimal_thresholds }
|
|
|
torch.save(checkpoint, self.configs["resume"])
|
|
|
|
|
|
print(f"epoch {epoch} train loss {running_loss} val loss {val_loss} val_macro_auc {val_macro_auc} val_micro_auc {val_micro_auc} train_macro_auc {macro_auc} train_micro_auc {micro_auc}")
|
|
|
|
|
|
self.history["train_loss"].append(float(running_loss))
|
|
|
self.history["val_loss"].append(float(val_loss))
|
|
|
self.history["train_macro_auc"].append(float(macro_auc))
|
|
|
self.history["val_macro_auc"].append(float(val_macro_auc))
|
|
|
self.history["train_micro_auc"].append(float(micro_auc))
|
|
|
self.history["val_micro_auc"].append(float(val_micro_auc))
|
|
|
|
|
|
|
|
|
historyfile=os.path.join(self.configs["logdir"],"history.json")
|
|
|
if os.path.exists(historyfile):
|
|
|
with open(historyfile,"r") as f:
|
|
|
history=json.load(f)
|
|
|
history["train_loss"]+=self.history["train_loss"]
|
|
|
history["val_loss"]+=self.history["val_loss"]
|
|
|
history["train_macro_auc"]+=self.history["train_macro_auc"]
|
|
|
history["val_macro_auc"]+=self.history["val_macro_auc"]
|
|
|
with open(historyfile,"w") as f:
|
|
|
json.dump(self.history,f)
|
|
|
f.close()
|
|
|
|
|
|
if epoch%10==0:Trainer.plot_training_metrics(self.history,epoch+1,self.configs["logdir"])
|
|
|
|
|
|
self.current_epoch=epoch
|
|
|
def test(self, model_path=None, return_preds=False):
|
|
|
"""
|
|
|
Run a complete test evaluation.
|
|
|
If `model_path` is given, load that checkpoint first.
|
|
|
Returns (macro_auc, micro_auc, per_class_auc_dict) or predictions if requested.
|
|
|
"""
|
|
|
if model_path:
|
|
|
ckpt = torch.load(model_path, map_location=self.configs["device"])
|
|
|
self.model.load_state_dict(ckpt["model"])
|
|
|
print(f"Loaded checkpoint {model_path}")
|
|
|
|
|
|
if self.testloader is None:
|
|
|
raise RuntimeError("No test loader – provide `test_csv` in config")
|
|
|
|
|
|
self.model.eval()
|
|
|
all_preds, all_targets = [], []
|
|
|
|
|
|
test_loss = 0.0
|
|
|
looper = tqdm(enumerate(self.testloader), total=len(self.testloader),
|
|
|
desc="Testing ",file=self.testtee)
|
|
|
|
|
|
with torch.inference_mode():
|
|
|
for batch_idx, data in looper:
|
|
|
img = data['image'].to(self.configs["device"], non_blocking=True)
|
|
|
tgt = data['labels'].to(self.configs["device"], non_blocking=True)
|
|
|
|
|
|
|
|
|
logits = self.model(img)
|
|
|
if self.optimal_thresholds:
|
|
|
|
|
|
|
|
|
taus = torch.tensor(self.optimal_thresholds, device=logits.device).view(1, -1)
|
|
|
|
|
|
|
|
|
margins = torch.log(taus / (1 - taus))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logits = logits - margins
|
|
|
probs = torch.sigmoid(logits)
|
|
|
loss = self.criterion(probs, tgt)
|
|
|
test_loss += loss.item()
|
|
|
|
|
|
all_preds.append(probs.cpu())
|
|
|
all_targets.append(tgt.cpu())
|
|
|
|
|
|
|
|
|
cur_loss = test_loss / (batch_idx + 1)
|
|
|
if all_preds:
|
|
|
p = torch.cat(all_preds).numpy()
|
|
|
t = torch.cat(all_targets).numpy()
|
|
|
macro = roc_auc_score(t, p, average='macro')
|
|
|
micro = roc_auc_score(t, p, average='micro')
|
|
|
else:
|
|
|
macro = micro = 0.0
|
|
|
looper.set_postfix(loss=f"{cur_loss:.4f}",
|
|
|
macro=f"{macro:.4f}",
|
|
|
micro=f"{micro:.4f}")
|
|
|
|
|
|
|
|
|
preds = torch.cat(all_preds).numpy()
|
|
|
targets = torch.cat(all_targets).numpy()
|
|
|
final_loss = test_loss / len(self.testloader)
|
|
|
macro_auc = roc_auc_score(targets, preds, average='macro')
|
|
|
micro_auc = roc_auc_score(targets, preds, average='micro')
|
|
|
|
|
|
|
|
|
per_class = {}
|
|
|
for i, name in enumerate(self.train_dataset.get_label_names()):
|
|
|
if targets[:, i].sum() > 0:
|
|
|
per_class[name] = roc_auc_score(targets[:, i], preds[:, i])
|
|
|
else:
|
|
|
per_class[name] = float('nan')
|
|
|
|
|
|
|
|
|
print("\n" + "="*80)
|
|
|
print(f"TEST RESULTS (loss={final_loss:.4f})")
|
|
|
print("="*80)
|
|
|
print(f"{'Pathology':<30} {'AUC':>8}")
|
|
|
print("-"*40)
|
|
|
for name, auc in per_class.items():
|
|
|
print(f"{name:<30} {auc:>8.4f}" if not np.isnan(auc) else f"{name:<30} {'N/A':>8}")
|
|
|
print("-"*40)
|
|
|
print(f"{'Macro AUC':<30} {macro_auc:>8.4f}")
|
|
|
print(f"{'Micro AUC':<30} {micro_auc:>8.4f}")
|
|
|
print("="*80)
|
|
|
|
|
|
if return_preds:
|
|
|
return macro_auc, micro_auc, per_class, (preds, targets)
|
|
|
return macro_auc, micro_auc, per_class |