mae / trainer /utils.py
adelelsayed1991's picture
Upload folder using huggingface_hub
5ffe2e2 verified
from data.dataset import CheXpertDataset
from loss.mae_loss import mae_loss
from models.mae import *
from torch.utils.data import DataLoader
import json
import os
import io
import sys
class TeeFile:
"""
File-like object that writes to multiple streams (e.g., stdout and a file)
Automatically handles string paths by opening them as files.
Usage:
# This now works with both file objects and paths
tee = TeeFile(sys.stdout, "/path/to/log.txt")
print("Hello", file=tee) # Writes to both stdout and the file
"""
def __init__(self, *file_objects_or_paths):
"""
Args:
*file_objects_or_paths: Mix of file objects (like sys.stdout)
or string paths to log files
"""
self.files = []
self.opened_files = [] # Track files we opened so we can close them later
for item in file_objects_or_paths:
if isinstance(item, str):
# It's a path string - open it as a file
f = open(item, 'a', buffering=1) # Append mode, line buffered
self.files.append(f)
self.opened_files.append(f)
else:
# It's already a file-like object (e.g., sys.stdout)
self.files.append(item)
def write(self, data):
"""Write data to all streams"""
for f in self.files:
try:
f.write(data)
f.flush()
except Exception as e:
# Handle closed file gracefully
print(f"Warning: Could not write to {f}: {e}", file=sys.stderr)
def flush(self):
"""Flush all streams"""
for f in self.files:
try:
f.flush()
except:
pass
def isatty(self):
"""Check if any stream is a terminal (for tqdm compatibility)"""
return any(getattr(f, "isatty", lambda: False)() for f in self.files)
def fileno(self):
"""Get file descriptor from any real file-like stream"""
for f in self.files:
if hasattr(f, "fileno"):
try:
return f.fileno()
except Exception:
pass
raise io.UnsupportedOperation("No fileno available")
def close(self):
"""Close any files we opened"""
for f in self.opened_files:
try:
f.close()
except:
pass
self.opened_files.clear()
def __del__(self):
"""Cleanup on deletion"""
self.close()
def __enter__(self):
"""Context manager support"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager cleanup"""
self.close()
return False
class MAETrainer:
def __init__(self,configs={}):
self.configs=configs
os.makedirs(configs["logdir"],exist_ok=True)
log_path_train = os.path.join(configs["logdir"], "training_log.txt")
log_path_val = os.path.join(configs["logdir"], "val_log.txt")
log_path_test = os.path.join(configs["logdir"], "test_log.txt")
#self.log_file = open(log_path, 'w', buffering=1)
self.traintee = TeeFile(sys.stdout, log_path_train)
self.valtee = TeeFile(sys.stdout, log_path_val)
self.testtee = TeeFile(sys.stdout, log_path_test)
for dir in self.configs["dirsToMake"]: os.makedirs(dir,exist_ok=True)
self.model=MaskedAutoEncoder(
c=configs["channels"],
mask_ratio=configs["mask_ratio"],
dropout=configs["dropout"],
img_size=configs["img_size"],
encoder_dim=configs["encoder_dim"],
mlp_dim=configs["mlp_dim"],
decoder_dim=configs["decoder_dim"],
encoder_depth=configs["encoder_depth"],
encoder_head=configs["encoder_head"],
decoder_depth=configs["decoder_depth"],
decoder_head=configs["decoder_head"],
patch_size=configs["patch_size"]
).to(configs["device"])
self.criterion=mae_loss
self.optimizer=torch.optim.AdamW(self.model.parameters(),configs["lr"], weight_decay=configs["weight_decay"])
self.schedular1=torch.optim.lr_scheduler.LinearLR(self.optimizer,start_factor=0.1,end_factor=1.0,total_iters=configs["warmup"])
self.schedular2=torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer,T_max=configs["num_epochs"]-configs["warmup"])
self.schedular=torch.optim.lr_scheduler.SequentialLR (self.optimizer,schedulers=[self.schedular1,self.schedular2],milestones=[configs["warmup"]])
self.scaler=torch.amp.GradScaler()
self.train_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["train_csv"],root_dir=configs["datadir"],augment=True,use_frontal_only=True)
self.val_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["val_csv"],root_dir=configs["datadir"],augment=False,use_frontal_only=True )
self.class_Weights=self.train_dataset.get_class_weights().to(self.configs["device"])
self.sample_Weights=self.train_dataset.get_sample_weights()
self.sampler=torch.utils.data.WeightedRandomSampler(self.sample_Weights,num_samples=len(self.sample_Weights))
self.trainloader=DataLoader(self.train_dataset,batch_size=configs["batch_size"],sampler=self.sampler,num_workers=8,pin_memory=True,persistent_workers=True)
self.valloader=DataLoader(self.val_dataset,batch_size=configs["batch_size"],shuffle=False,num_workers=8,pin_memory=True,persistent_workers=True)
self.history={"train_loss":[],"val_loss":[]}
self.current_epoch=0
if os.path.exists(self.configs["resume"]):
loadedpickle=torch.load(self.configs["resume"],map_location=self.configs["device"])
self.model.load_state_dict(loadedpickle["model"],strict=False)
self.optimizer.load_state_dict(loadedpickle["optimizer"])
self.schedular.load_state_dict(loadedpickle["schedular"])
self.schedular1.load_state_dict(loadedpickle["schedular1"])
self.schedular2.load_state_dict(loadedpickle["schedular2"])
self.scaler.load_state_dict(loadedpickle["scaler"])
self.current_epoch=loadedpickle["epoch"]+1
self.test_dataset = None
self.testloader = None
if configs.get("test_csv"):
self.test_dataset = CheXpertDataset(
zip_path=configs["zip_path"],
csv_path=configs["test_csv"],
root_dir=configs["datadir"],
augment=False,
use_frontal_only=True
)
self.testloader = DataLoader(
self.test_dataset,
batch_size=configs["batch_size"],
shuffle=False,
num_workers=8,
pin_memory=True,
persistent_workers=True
)
print(f"Test loader ready – {len(self.test_dataset)} images")
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
# FIX: Set memory allocator settings
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
# FIX: Enable gradient checkpointing if model supports it
if hasattr(self.model, 'enable_gradient_checkpointing'):
self.model.enable_gradient_checkpointing()
@staticmethod
def plot_training_metrics(metrics, epoch,figs_path):
import matplotlib.pyplot as plt
"""
Plot loss and AUC curves from training metrics.
Args:
metrics (dict): Dictionary containing lists for each metric key:
{
"train_loss": [...],
"val_loss": [...]
}
epoch (int): Current epoch number (used for title or axis scaling)
"""
epochs = list(range(1, epoch + 1))
#Compute the common length across all series
keys = ["train_loss","val_loss"]
lengths = [len(metrics[k]) for k in keys if k in metrics]
if not lengths:
return
n = min(lengths)
# Slice everything to the same length
m = {k: metrics[k][:n] for k in keys if k in metrics}
epochs = list(range(1, n + 1))
plt.figure(figsize=(14, 6))
# ---- Loss subplot ----
plt.subplot(1, 2, 1)
plt.plot(epochs, metrics["train_loss"], label="Train Loss", marker='o')
plt.plot(epochs, metrics["val_loss"], label="Val Loss", marker='s')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training & Validation Loss")
plt.legend()
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
os.makedirs(os.path.join(figs_path,str(epoch)),exist_ok=True)
plt.savefig(os.path.join(figs_path,str(epoch),"metrics.png"))
plt.show()
def train_epoch(self, epoch, looper):
self.model.train()
running_loss = 0.0
all_preds = []
all_targets = []
current_loss=0
total_batches = len(self.trainloader)
for batch_idx, data in looper:
image = data['image'].to(self.configs["device"], non_blocking=True)
target = data['labels'].to(self.configs["device"], non_blocking=True)
with torch.autocast(device_type=self.configs["device"].type, dtype=torch.float16):
img,preds,mask = self.model(image)
loss = self.criterion(img,preds,mask)
loss_back = loss / self.configs["accumulation"]
running_loss += loss.item()
if torch.isfinite(loss):
#loss_back.backward()
self.scaler.scale(loss_back).backward()
else:
self.optimizer.zero_grad(set_to_none=True)
continue
if (batch_idx + 1) % self.configs["accumulation"] == 0 or batch_idx == total_batches - 1:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.scaler.step(self.optimizer)
self.scaler.update()
#self.optimizer.step()
self.schedular.step()
self.optimizer.zero_grad(set_to_none=True)
# === LIVE METRICS (every batch) ===
current_loss = running_loss / (batch_idx + 1)
if (batch_idx + 1) % 10 == 0:
current_lr = self.optimizer.param_groups[0]['lr']
looper.set_postfix({
"lr": f"{current_lr:.2e}","batch":f"{batch_idx}/{total_batches}",
"epoch": f"{epoch}/{self.configs['num_epochs']}",
"loss": f"{current_loss:.3f}",
})
return current_loss
def validate(self, epoch, looper):
self.model.eval()
val_loss = 0.0
all_preds = []
all_targets = []
lenloader=len(self.valloader)
current_loss=0
with torch.no_grad():
for batch_idx, data in looper:
image = data["image"].to(self.configs["device"], non_blocking=True)
target = data["labels"].to(self.configs["device"], non_blocking=True)
with torch.autocast(device_type=self.configs["device"].type, dtype=torch.float16):
img,preds,mask = self.model(image)
loss = self.criterion(img,preds,mask)
val_loss += loss.item()
# === LIVE METRICS ===
current_loss = val_loss / (batch_idx + 1)
if (batch_idx + 1) % 10 == 0 :
looper.set_postfix({
"epoch": f"{epoch}/{self.configs['num_epochs']}",
"batch":f"{batch_idx}/{lenloader}",
"loss": f"{current_loss:.3f}",
})
return current_loss
def train(self):
for epoch in range(self.current_epoch,self.configs["num_epochs"]):
trainlooper=tqdm(enumerate(self.trainloader),desc="training: ", leave=False,file=self.traintee)
vallooper=tqdm(enumerate(self.valloader),desc="validating: ",leave=False,file=self.valtee)
self.model.train()
self.optimizer.zero_grad(set_to_none=True)
running_loss=self.train_epoch(epoch,trainlooper)
torch.cuda.synchronize()
torch.cuda.empty_cache()
val_loss=self.validate(epoch,vallooper)
torch.cuda.synchronize()
torch.cuda.empty_cache()
gc.collect()
if (self.history["val_loss"] and (val_loss<min(self.history["val_loss"]))) :
checkpoint={"model":self.model.state_dict(),"optimizer":self.optimizer.state_dict(),"schedular":self.schedular.state_dict(),"schedular1":self.schedular1.state_dict(),"schedular2":self.schedular2.state_dict(),"scaler":self.scaler.state_dict(),"epoch":epoch}
torch.save(checkpoint, self.configs["resume"])
print(f"train loss {running_loss} val loss {val_loss}")
self.history["train_loss"].append(float(running_loss))
self.history["val_loss"].append(float(val_loss))
if epoch%10==0:
historyfile=os.path.join(self.configs["logdir"],"history.json")
if os.path.exists(historyfile):
with open(historyfile,"r") as f:
history=json.load(f)
history["train_loss"]+=self.history["train_loss"]
history["val_loss"]+=self.history["val_loss"]
with open(historyfile,"w") as f:
json.dump(self.history,f)
f.close()
MAETrainer.plot_training_metrics(self.history,epoch+1,self.configs["logdir"])
self.current_epoch=epoch