adelelsayed1991's picture
Upload folder using huggingface_hub
abd02e7 verified
from data.dataset import CheXpertDataset
from loss.assymetric import AsymmetricLoss
from models.mae import *
from models.densenet import *
from models.classifier import *
from torch.utils.data import DataLoader
import json
import os
import io
import sys
from sklearn.metrics import roc_auc_score,confusion_matrix
class TeeFile:
"""
File-like object that writes to multiple streams (e.g., stdout and a file)
Automatically handles string paths by opening them as files.
Usage:
# This now works with both file objects and paths
tee = TeeFile(sys.stdout, "/path/to/log.txt")
print("Hello", file=tee) # Writes to both stdout and the file
"""
def __init__(self, *file_objects_or_paths):
"""
Args:
*file_objects_or_paths: Mix of file objects (like sys.stdout)
or string paths to log files
"""
self.files = []
self.opened_files = [] # Track files we opened so we can close them later
for item in file_objects_or_paths:
if isinstance(item, str):
# It's a path string - open it as a file
f = open(item, 'a', buffering=1) # Append mode, line buffered
self.files.append(f)
self.opened_files.append(f)
else:
# It's already a file-like object (e.g., sys.stdout)
self.files.append(item)
def write(self, data):
"""Write data to all streams"""
for f in self.files:
try:
f.write(data)
f.flush()
except Exception as e:
# Handle closed file gracefully
print(f"Warning: Could not write to {f}: {e}", file=sys.stderr)
def flush(self):
"""Flush all streams"""
for f in self.files:
try:
f.flush()
except:
pass
def isatty(self):
"""Check if any stream is a terminal (for tqdm compatibility)"""
return any(getattr(f, "isatty", lambda: False)() for f in self.files)
def fileno(self):
"""Get file descriptor from any real file-like stream"""
for f in self.files:
if hasattr(f, "fileno"):
try:
return f.fileno()
except Exception:
pass
raise io.UnsupportedOperation("No fileno available")
def close(self):
"""Close any files we opened"""
for f in self.opened_files:
try:
f.close()
except:
pass
self.opened_files.clear()
def __del__(self):
"""Cleanup on deletion"""
self.close()
def __enter__(self):
"""Context manager support"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager cleanup"""
self.close()
return False
class MAETrainer:
def __init__(self,configs={}):
self.configs=configs
os.makedirs(configs["logdir"],exist_ok=True)
log_path_train = os.path.join(configs["logdir"], "training_log.txt")
log_path_val = os.path.join(configs["logdir"], "val_log.txt")
log_path_test = os.path.join(configs["logdir"], "test_log.txt")
#self.log_file = open(log_path, 'w', buffering=1)
self.traintee = TeeFile(sys.stdout, log_path_train)
self.valtee = TeeFile(sys.stdout, log_path_val)
self.testtee = TeeFile(sys.stdout, log_path_test)
for dir in self.configs["dirsToMake"]: os.makedirs(dir,exist_ok=True)
self.model=MaskedAutoEncoder(
c=configs["channels"],
mask_ratio=configs["mask_ratio"],
dropout=configs["dropout"],
img_size=configs["img_size"],
encoder_dim=configs["encoder_dim"],
mlp_dim=configs["mlp_dim"],
decoder_dim=configs["decoder_dim"],
encoder_depth=configs["encoder_depth"],
encoder_head=configs["encoder_head"],
decoder_depth=configs["decoder_depth"],
decoder_head=configs["decoder_head"],
patch_size=configs["patch_size"]
).to(configs["device"])
self.criterion=mae_loss
self.optimizer=torch.optim.AdamW(self.model.parameters(),configs["lr"], weight_decay=configs["weight_decay"])
self.schedular1=torch.optim.lr_scheduler.LinearLR(self.optimizer,start_factor=0.1,end_factor=1.0,total_iters=configs["warmup"])
self.schedular2=torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer,T_max=configs["num_epochs"]-configs["warmup"])
self.schedular=torch.optim.lr_scheduler.SequentialLR (self.optimizer,schedulers=[self.schedular1,self.schedular2],milestones=[configs["warmup"]])
self.scaler=torch.amp.GradScaler()
self.train_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["train_csv"],root_dir=configs["datadir"],augment=True,use_frontal_only=True)
self.val_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["val_csv"],root_dir=configs["datadir"],augment=False,use_frontal_only=True )
self.class_Weights=self.train_dataset.get_class_weights().to(self.configs["device"])
self.sample_Weights=self.train_dataset.get_sample_weights()
self.sampler=torch.utils.data.WeightedRandomSampler(self.sample_Weights,num_samples=len(self.sample_Weights))
self.trainloader=DataLoader(self.train_dataset,batch_size=configs["batch_size"],sampler=self.sampler,num_workers=8,pin_memory=True,persistent_workers=True)
self.valloader=DataLoader(self.val_dataset,batch_size=configs["batch_size"],shuffle=False,num_workers=8,pin_memory=True,persistent_workers=True)
self.history={"train_loss":[],"val_loss":[]}
self.current_epoch=0
if os.path.exists(self.configs["resume"]):
loadedpickle=torch.load(self.configs["resume"],map_location=self.configs["device"])
self.model.load_state_dict(loadedpickle["model"],strict=False)
self.optimizer.load_state_dict(loadedpickle["optimizer"])
self.schedular.load_state_dict(loadedpickle["schedular"])
self.schedular1.load_state_dict(loadedpickle["schedular1"])
self.schedular2.load_state_dict(loadedpickle["schedular2"])
self.scaler.load_state_dict(loadedpickle["scaler"])
self.current_epoch=loadedpickle["epoch"]+1
self.test_dataset = None
self.testloader = None
if configs.get("test_csv"):
self.test_dataset = CheXpertDataset(
zip_path=configs["zip_path"],
csv_path=configs["test_csv"],
root_dir=configs["datadir"],
augment=False,
use_frontal_only=True
)
self.testloader = DataLoader(
self.test_dataset,
batch_size=configs["batch_size"],
shuffle=False,
num_workers=8,
pin_memory=True,
persistent_workers=True
)
print(f"Test loader ready – {len(self.test_dataset)} images")
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
# FIX: Set memory allocator settings
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
# FIX: Enable gradient checkpointing if model supports it
if hasattr(self.model, 'enable_gradient_checkpointing'):
self.model.enable_gradient_checkpointing()
@staticmethod
def plot_training_metrics(metrics, epoch,figs_path):
import matplotlib.pyplot as plt
"""
Plot loss and AUC curves from training metrics.
Args:
metrics (dict): Dictionary containing lists for each metric key:
{
"train_loss": [...],
"val_loss": [...]
}
epoch (int): Current epoch number (used for title or axis scaling)
"""
epochs = list(range(1, epoch + 1))
#Compute the common length across all series
keys = ["train_loss","val_loss"]
lengths = [len(metrics[k]) for k in keys if k in metrics]
if not lengths:
return
n = min(lengths)
# Slice everything to the same length
m = {k: metrics[k][:n] for k in keys if k in metrics}
epochs = list(range(1, n + 1))
plt.figure(figsize=(14, 6))
# ---- Loss subplot ----
plt.subplot(1, 2, 1)
plt.plot(epochs, metrics["train_loss"], label="Train Loss", marker='o')
plt.plot(epochs, metrics["val_loss"], label="Val Loss", marker='s')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training & Validation Loss")
plt.legend()
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
os.makedirs(os.path.join(figs_path,str(epoch)),exist_ok=True)
plt.savefig(os.path.join(figs_path,str(epoch),"metrics.png"))
plt.show()
def train_epoch(self, epoch, looper):
self.model.train()
running_loss = 0.0
all_preds = []
all_targets = []
current_loss=0
total_batches = len(self.trainloader)
for batch_idx, data in looper:
image = data['image'].to(self.configs["device"], non_blocking=True)
target = data['labels'].to(self.configs["device"], non_blocking=True)
with torch.autocast(device_type=self.configs["device"].type, dtype=torch.float16):
img,preds,mask = self.model(image)
loss = self.criterion(img,preds,mask)
loss_back = loss / self.configs["accumulation"]
running_loss += loss.item()
if torch.isfinite(loss):
#loss_back.backward()
self.scaler.scale(loss_back).backward()
else:
self.optimizer.zero_grad(set_to_none=True)
continue
if (batch_idx + 1) % self.configs["accumulation"] == 0 or batch_idx == total_batches - 1:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.scaler.step(self.optimizer)
self.scaler.update()
#self.optimizer.step()
self.schedular.step()
self.optimizer.zero_grad(set_to_none=True)
# === LIVE METRICS (every batch) ===
current_loss = running_loss / (batch_idx + 1)
if (batch_idx + 1) % 10 == 0:
current_lr = self.optimizer.param_groups[0]['lr']
looper.set_postfix({
"lr": f"{current_lr:.2e}","batch":f"{batch_idx}/{total_batches}",
"epoch": f"{epoch}/{self.configs['num_epochs']}",
"loss": f"{current_loss:.3f}",
})
return current_loss
def validate(self, epoch, looper):
self.model.eval()
val_loss = 0.0
all_preds = []
all_targets = []
lenloader=len(self.valloader)
current_loss=0
with torch.no_grad():
for batch_idx, data in looper:
image = data["image"].to(self.configs["device"], non_blocking=True)
target = data["labels"].to(self.configs["device"], non_blocking=True)
with torch.autocast(device_type=self.configs["device"].type, dtype=torch.float16):
img,preds,mask = self.model(image)
loss = self.criterion(img,preds,mask)
val_loss += loss.item()
# === LIVE METRICS ===
current_loss = val_loss / (batch_idx + 1)
if (batch_idx + 1) % 10 == 0 :
looper.set_postfix({
"epoch": f"{epoch}/{self.configs['num_epochs']}",
"batch":f"{batch_idx}/{lenloader}",
"loss": f"{current_loss:.3f}",
})
return current_loss
def train(self):
for epoch in range(self.current_epoch,self.configs["num_epochs"]):
trainlooper=tqdm(enumerate(self.trainloader),desc="training: ", leave=False,file=self.traintee)
vallooper=tqdm(enumerate(self.valloader),desc="validating: ",leave=False,file=self.valtee)
self.model.train()
self.optimizer.zero_grad(set_to_none=True)
running_loss=self.train_epoch(epoch,trainlooper)
torch.cuda.synchronize()
torch.cuda.empty_cache()
val_loss=self.validate(epoch,vallooper)
torch.cuda.synchronize()
torch.cuda.empty_cache()
gc.collect()
if (self.history["val_loss"] and (val_loss<min(self.history["val_loss"]))) :
checkpoint={"model":self.model.state_dict(),"optimizer":self.optimizer.state_dict(),"schedular":self.schedular.state_dict(),"schedular1":self.schedular1.state_dict(),"schedular2":self.schedular2.state_dict(),"scaler":self.scaler.state_dict(),"epoch":epoch}
torch.save(checkpoint, self.configs["resume"])
print(f"train loss {running_loss} val loss {val_loss}")
self.history["train_loss"].append(float(running_loss))
self.history["val_loss"].append(float(val_loss))
if epoch%10==0:
historyfile=os.path.join(self.configs["logdir"],"history.json")
if os.path.exists(historyfile):
with open(historyfile,"r") as f:
history=json.load(f)
history["train_loss"]+=self.history["train_loss"]
history["val_loss"]+=self.history["val_loss"]
with open(historyfile,"w") as f:
json.dump(self.history,f)
f.close()
MAETrainer.plot_training_metrics(self.history,epoch+1,self.configs["logdir"])
self.current_epoch=epoch
class Trainer:
def __init__(self,configs={}):
self.configs=configs
os.makedirs(configs["logdir"],exist_ok=True)
log_path_train = os.path.join(configs["logdir"], "training_log.txt")
log_path_val = os.path.join(configs["logdir"], "val_log.txt")
log_path_test = os.path.join(configs["logdir"], "test_log.txt")
#self.log_file = open(log_path, 'w', buffering=1)
self.traintee = TeeFile(sys.stdout, log_path_train)
self.valtee = TeeFile(sys.stdout, log_path_val)
self.testtee = TeeFile(sys.stdout, log_path_test)
for dir in self.configs["dirsToMake"]: os.makedirs(dir,exist_ok=True)
self.model=XRAYClassifier(
c=configs["channels"],
num_classes=configs["num_classes"],
mask_ratio=configs["mask_ratio"],
dropout=configs["dropout"],
img_size=configs["img_size"],
encoder_dim=configs["encoder_dim"],
mlp_dim=configs["mlp_dim"],
decoder_dim=configs["decoder_dim"],
encoder_depth=configs["encoder_depth"],
encoder_head=configs["encoder_head"],
decoder_depth=configs["decoder_depth"],
decoder_head=configs["decoder_head"],
patch_size=configs["patch_size"]
).to(configs["device"])
self.optimizer=torch.optim.AdamW(self.model.parameters(),configs["lr"], weight_decay=configs["weight_decay"])
self.schedular1=torch.optim.lr_scheduler.LinearLR(self.optimizer,start_factor=0.1,end_factor=1.0,total_iters=configs["warmup"])
self.schedular2=torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer,T_max=configs["num_epochs"]-configs["warmup"])
self.schedular=torch.optim.lr_scheduler.SequentialLR (self.optimizer,schedulers=[self.schedular1,self.schedular2],milestones=[configs["warmup"]])
self.scaler=torch.amp.GradScaler()
self.train_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["train_csv"],root_dir=configs["datadir"],augment=True,use_frontal_only=True,mask_dir=configs["maskdir"])
self.val_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["val_csv"],root_dir=configs["datadir"],augment=False,use_frontal_only=True,mask_dir=configs["maskdir"] )
self.class_Weights=self.train_dataset.get_class_weights().to(self.configs["device"])
self.sample_Weights=self.train_dataset.get_sample_weights()
self.sampler=torch.utils.data.WeightedRandomSampler(self.sample_Weights,num_samples=len(self.sample_Weights))
self.trainloader=DataLoader(self.train_dataset,batch_size=configs["batch_size"],sampler=self.sampler,num_workers=0,pin_memory=True,persistent_workers=False)
self.valloader=DataLoader(self.val_dataset,batch_size=configs["batch_size"],shuffle=False,num_workers=0,pin_memory=True,persistent_workers=False)
self.criterion=AsymmetricLoss(class_weights=self.class_Weights).to(self.configs["device"])
self.history={"train_loss":[],"val_loss":[],"train_macro_auc":[],"val_macro_auc":[],"train_micro_auc":[],"val_micro_auc":[]}
if os.path.exists(os.path.join(self.configs["logdir"],"history.json")):
with open(os.path.join(self.configs["logdir"],"history.json"),'r') as hf:
self.history=json.load(hf)
hf.close()
self.current_epoch=0
self.optimal_thresholds =[0.5]*14
if os.path.exists(self.configs["resume"]):
ckpt = torch.load(self.configs["resume"], map_location=self.configs["device"],weights_only=False)
self.model.load_state_dict(ckpt["model"], strict=False)
self.optimizer.load_state_dict(ckpt["optimizer"])
self.schedular.load_state_dict(ckpt["schedular"])
self.schedular1.load_state_dict(ckpt["schedular1"])
self.schedular2.load_state_dict(ckpt["schedular2"])
self.scaler.load_state_dict(ckpt["scaler"])
self.current_epoch = ckpt.get("epoch", -1) + 1
self.optimal_thresholds =ckpt.get("thresholds")
else:
# Load MAE backbone only (pretrained)
bb = torch.load(self.configs["backbone"], map_location=self.configs["device"],weights_only=False)
# Optional: strip 'module.' if present
state = bb["model"]
if any(k.startswith("module.") for k in state.keys()):
from collections import OrderedDict
state = OrderedDict((k.replace("module.", "", 1), v) for k, v in state.items())
missing, unexpected = self.model.mae.load_state_dict(state, strict=False)
print("loaded backbone")
if missing: print(f"Missing keys: {len(missing)} (showing first 5): {missing[:5]}")
if unexpected: print(f"Unexpected keys: {len(unexpected)} (first 5): {unexpected[:5]}")
# (Optional) freeze backbone for warmup
for p in self.model.mae.parameters():
p.requires_grad = False
if os.path.exists(self.configs["densebackbone"]):
densebb=torch.load(self.configs["densebackbone"], map_location=self.configs["device"])
densestate = densebb["model"]
if any(k.startswith("module.") for k in state.keys()):
from collections import OrderedDict
state = OrderedDict((k.replace("module.", "", 1), v) for k, v in densestate.items())
densemissing, denseunexpected = self.model.dense.load_state_dict(densestate, strict=False)
print("loaded dense backbone")
if densemissing: print(f"Missing keys: {len(densemissing)} (showing first 5): {densemissing[:5]}")
if denseunexpected: print(f"Unexpected keys: {len(denseunexpected)} (first 5): {denseunexpected[:5]}")
self.test_dataset = None
self.testloader = None
if configs.get("test_csv"):
self.test_dataset = CheXpertDataset(
zip_path=configs["zip_path"],
csv_path=configs["test_csv"],
root_dir=configs["datadir"],
augment=False,
use_frontal_only=True
)
self.testloader = DataLoader(
self.test_dataset,
batch_size=configs["batch_size"],
shuffle=False,
num_workers=0,
pin_memory=True,
persistent_workers=False
)
print(f"Test loader ready – {len(self.test_dataset)} images")
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
# FIX: Set memory allocator settings
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
# FIX: Enable gradient checkpointing if model supports it
if hasattr(self.model, 'enable_gradient_checkpointing'):
self.model.enable_gradient_checkpointing()
@staticmethod
def plot_training_metrics(metrics, epoch,figs_path):
import matplotlib.pyplot as plt
"""
Plot loss and AUC curves from training metrics.
Args:
metrics (dict): Dictionary containing lists for each metric key:
{
"train_loss": [...],
"val_loss": [...],
"train_macro_auc": [...],
"val_macro_auc": [...],
"train_micro_auc": [...],
"val_micro_auc": [...]
}
epoch (int): Current epoch number (used for title or axis scaling)
"""
epochs = list(range(1, epoch + 1))
#Compute the common length across all series
keys = ["train_loss","val_loss","train_macro_auc","val_macro_auc","train_micro_auc","val_micro_auc"]
lengths = [len(metrics[k]) for k in keys if k in metrics]
if not lengths:
return
n = min(lengths)
# Slice everything to the same length
m = {k: metrics[k][:n] for k in keys if k in metrics}
epochs = list(range(1, n + 1))
plt.figure(figsize=(14, 6))
# ---- Loss subplot ----
plt.subplot(1, 2, 1)
plt.plot(epochs, metrics["train_loss"], label="Train Loss", marker='o')
plt.plot(epochs, metrics["val_loss"], label="Val Loss", marker='s')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training & Validation Loss")
plt.legend()
plt.grid(True, linestyle='--', alpha=0.6)
# ---- AUC subplot ----
plt.subplot(1, 2, 2)
plt.plot(epochs, metrics["train_macro_auc"], label="Train Macro AUC", marker='o')
plt.plot(epochs, metrics["val_macro_auc"], label="Val Macro AUC", marker='s')
plt.plot(epochs, metrics["train_micro_auc"], label="Train Micro AUC", marker='^')
plt.plot(epochs, metrics["val_micro_auc"], label="Val Micro AUC", marker='v')
plt.xlabel("Epoch")
plt.ylabel("AUC")
plt.title("Training & Validation AUC (Macro/Micro)")
plt.legend()
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
os.makedirs(os.path.join(figs_path,str(epoch)),exist_ok=True)
plt.savefig(os.path.join(figs_path,str(epoch),"metrics.png"))
plt.show()
def train_epoch(self, epoch, looper):
self.model.train()
running_loss = 0.0
all_preds = []
all_targets = []
total_batches = len(self.trainloader)
for batch_idx, data in looper:
image = data['image'].to(self.configs["device"], non_blocking=True)
target = data['labels'].to(self.configs["device"], non_blocking=True)
#with torch.autocast(device_type=self.configs["device"].type, dtype=torch.float16):
logits = self.model(image)
#with torch.autocast(device_type=self.configs["device"].type, enabled=False):
preds = torch.sigmoid(logits.float())
loss = self.criterion(preds, target)
loss_back = loss / self.configs["accumulation"]
running_loss += loss.item()
if torch.isfinite(loss):
loss_back.backward()
#self.scaler.scale(loss_back).backward()
else:
self.optimizer.zero_grad(set_to_none=True)
continue
if (batch_idx + 1) % self.configs["accumulation"] == 0 or batch_idx == total_batches - 1:
#self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
#self.scaler.step(self.optimizer)
#self.scaler.update()
self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)
# Store for AUC
all_preds.append(preds.detach().cpu())
all_targets.append(target.detach().cpu())
# === LIVE METRICS (every batch) ===
current_loss = running_loss / (batch_idx + 1)
if (batch_idx + 1) % 500 == 0 and len(all_preds) > 0:
preds_np = torch.cat(all_preds).numpy()
targets_np = torch.cat(all_targets).numpy()
macro_auc = roc_auc_score(targets_np, preds_np, average='macro')
micro_auc = roc_auc_score(targets_np, preds_np, average='micro')
current_lr = self.optimizer.param_groups[0]['lr']
looper.set_postfix({
"lr": f"{current_lr:.2e}","batch":f"{batch_idx}/{total_batches}",
"epoch": f"{epoch}/{self.configs['num_epochs']}",
"loss": f"{current_loss:.3f}",
"macro": f"{macro_auc:.3f}",
"micro": f"{micro_auc:.3f}"
})
# === FINAL FULL EPOCH METRICS ===
preds_full = torch.cat(all_preds).numpy()
targets_full = torch.cat(all_targets).numpy()
final_loss = running_loss / total_batches
final_macro_auc = roc_auc_score(targets_full, preds_full, average='macro')
final_micro_auc = roc_auc_score(targets_full, preds_full, average='micro')
del all_preds, all_targets, preds_full, targets_full
return final_loss, final_macro_auc, final_micro_auc
def validate(self, epoch, looper):
self.model.eval()
val_loss = 0.0
all_preds = []
all_targets = []
lenloader=len(self.valloader)
with torch.no_grad():
for batch_idx, data in looper:
image = data["image"].to(self.configs["device"], non_blocking=True)
target = data["labels"].to(self.configs["device"], non_blocking=True)
logits = self.model(image)
preds = torch.sigmoid(logits.float())
loss = self.criterion(preds, target)
val_loss += loss.item()
all_preds.append(preds.detach().cpu())
all_targets.append(target.detach().cpu())
# === LIVE METRICS ===
current_loss = val_loss / (batch_idx + 1)
if (batch_idx + 1) % 200 == 0 and len(all_preds) > 0:
preds_np = torch.cat(all_preds).numpy()
targets_np = torch.cat(all_targets).numpy()
macro_auc = roc_auc_score(targets_np, preds_np, average='macro')
micro_auc = roc_auc_score(targets_np, preds_np, average='micro')
looper.set_postfix({
"epoch": f"{epoch}/{self.configs['num_epochs']}",
"batch":f"{batch_idx}/{lenloader}",
"loss": f"{current_loss:.3f}",
"macro": f"{macro_auc:.3f}",
"micro": f"{micro_auc:.3f}"
})
# === FINAL FULL VALIDATION METRICS ===
preds_full = torch.cat(all_preds).numpy()
targets_full = torch.cat(all_targets).numpy()
num_classes = 14
new_thresholds = [0.5] * num_classes # default
for class_idx in range(num_classes):
if targets_full[:, class_idx].sum() == 0:
# no positive samples, keep default 0.5
continue
thresholds = np.arange(0.1, 0.9, 0.02)
best_score = -1
best_threshold = 0.5
for threshold in thresholds:
preds_bin = (preds_full[:, class_idx] >= threshold).astype(int)
tn, fp, fn, tp = confusion_matrix(
targets_full[:, class_idx].astype(int),
preds_bin
).ravel()
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
score = sensitivity + specificity - 1
if score > best_score:
best_score = score
best_threshold = threshold
new_thresholds[class_idx] = best_threshold
# after loop:
self.optimal_thresholds = new_thresholds
final_loss = val_loss / lenloader
final_macro_auc = roc_auc_score(targets_full, preds_full, average='macro')
final_micro_auc = roc_auc_score(targets_full, preds_full, average='micro')
del all_preds, all_targets, preds_full, targets_full
return final_loss, final_macro_auc, final_micro_auc
def train(self):
for epoch in range(self.current_epoch,self.configs["num_epochs"]):
trainlooper=tqdm(enumerate(self.trainloader),desc="training: ", leave=True,file=self.traintee)
vallooper=tqdm(enumerate(self.valloader),desc="validating: ",leave=True,file=self.valtee)
self.model.train()
self.schedular.step()
self.optimizer.zero_grad(set_to_none=True)
running_loss,macro_auc,micro_auc=self.train_epoch(epoch,trainlooper)
torch.cuda.synchronize()
torch.cuda.empty_cache()
val_loss,val_macro_auc,val_micro_auc=self.validate(epoch,vallooper)
torch.cuda.synchronize()
torch.cuda.empty_cache()
gc.collect()
if (self.history["val_macro_auc"] and (val_macro_auc>max(self.history["val_macro_auc"]))) or (self.history["val_micro_auc"] and val_micro_auc>max(self.history["val_micro_auc"])):
checkpoint={"model":self.model.state_dict(),"optimizer":self.optimizer.state_dict(),"schedular":self.schedular.state_dict(),
"schedular1":self.schedular1.state_dict(),"schedular2":self.schedular2.state_dict(),"scaler":self.scaler.state_dict(),"epoch":epoch
,"thresholds":self.optimal_thresholds }
torch.save(checkpoint, self.configs["resume"])
print(f"epoch {epoch} train loss {running_loss} val loss {val_loss} val_macro_auc {val_macro_auc} val_micro_auc {val_micro_auc} train_macro_auc {macro_auc} train_micro_auc {micro_auc}")
self.history["train_loss"].append(float(running_loss))
self.history["val_loss"].append(float(val_loss))
self.history["train_macro_auc"].append(float(macro_auc))
self.history["val_macro_auc"].append(float(val_macro_auc))
self.history["train_micro_auc"].append(float(micro_auc))
self.history["val_micro_auc"].append(float(val_micro_auc))
historyfile=os.path.join(self.configs["logdir"],"history.json")
if os.path.exists(historyfile):
with open(historyfile,"r") as f:
history=json.load(f)
history["train_loss"]+=self.history["train_loss"]
history["val_loss"]+=self.history["val_loss"]
history["train_macro_auc"]+=self.history["train_macro_auc"]
history["val_macro_auc"]+=self.history["val_macro_auc"]
with open(historyfile,"w") as f:
json.dump(self.history,f)
f.close()
if epoch%10==0:Trainer.plot_training_metrics(self.history,epoch+1,self.configs["logdir"])
self.current_epoch=epoch
def test(self, model_path=None, return_preds=False):
"""
Run a complete test evaluation.
If `model_path` is given, load that checkpoint first.
Returns (macro_auc, micro_auc, per_class_auc_dict) or predictions if requested.
"""
if model_path:
ckpt = torch.load(model_path, map_location=self.configs["device"])
self.model.load_state_dict(ckpt["model"])
print(f"Loaded checkpoint {model_path}")
if self.testloader is None:
raise RuntimeError("No test loader – provide `test_csv` in config")
self.model.eval()
all_preds, all_targets = [], []
test_loss = 0.0
looper = tqdm(enumerate(self.testloader), total=len(self.testloader),
desc="Testing ",file=self.testtee)
with torch.inference_mode():
for batch_idx, data in looper:
img = data['image'].to(self.configs["device"], non_blocking=True)
tgt = data['labels'].to(self.configs["device"], non_blocking=True)
#image_1ch=data['image_1ch'].to(self.configs["device"], non_blocking=True)
logits = self.model(img)
if self.optimal_thresholds:
# class-wise thresholds in probability-space, e.g. list/array length C
# self.optimal_thresholds[c] = tau_c
taus = torch.tensor(self.optimal_thresholds, device=logits.device).view(1, -1)
# convert thresholds from prob to logit
margins = torch.log(taus / (1 - taus)) # shape [1, C]
# shift logits by the margin
# now BCEWithLogitsLoss thinks the decision boundary is at logits == margins
# equivalently: decision boundary in original logits is at 'margins'
logits = logits - margins
probs = torch.sigmoid(logits)
loss = self.criterion(probs, tgt)
test_loss += loss.item()
all_preds.append(probs.cpu())
all_targets.append(tgt.cpu())
# live stats
cur_loss = test_loss / (batch_idx + 1)
if all_preds:
p = torch.cat(all_preds).numpy()
t = torch.cat(all_targets).numpy()
macro = roc_auc_score(t, p, average='macro')
micro = roc_auc_score(t, p, average='micro')
else:
macro = micro = 0.0
looper.set_postfix(loss=f"{cur_loss:.4f}",
macro=f"{macro:.4f}",
micro=f"{micro:.4f}")
# ---- final metrics ----
preds = torch.cat(all_preds).numpy()
targets = torch.cat(all_targets).numpy()
final_loss = test_loss / len(self.testloader)
macro_auc = roc_auc_score(targets, preds, average='macro')
micro_auc = roc_auc_score(targets, preds, average='micro')
# per-class AUC
per_class = {}
for i, name in enumerate(self.train_dataset.get_label_names()):
if targets[:, i].sum() > 0: # avoid division-by-zero
per_class[name] = roc_auc_score(targets[:, i], preds[:, i])
else:
per_class[name] = float('nan')
# ---- pretty table ----
print("\n" + "="*80)
print(f"TEST RESULTS (loss={final_loss:.4f})")
print("="*80)
print(f"{'Pathology':<30} {'AUC':>8}")
print("-"*40)
for name, auc in per_class.items():
print(f"{name:<30} {auc:>8.4f}" if not np.isnan(auc) else f"{name:<30} {'N/A':>8}")
print("-"*40)
print(f"{'Macro AUC':<30} {macro_auc:>8.4f}")
print(f"{'Micro AUC':<30} {micro_auc:>8.4f}")
print("="*80)
if return_preds:
return macro_auc, micro_auc, per_class, (preds, targets)
return macro_auc, micro_auc, per_class