from data.dataset import CheXpertDataset from loss.mae_loss import mae_loss from models.mae import * from torch.utils.data import DataLoader import json import os import io import sys class TeeFile: """ File-like object that writes to multiple streams (e.g., stdout and a file) Automatically handles string paths by opening them as files. Usage: # This now works with both file objects and paths tee = TeeFile(sys.stdout, "/path/to/log.txt") print("Hello", file=tee) # Writes to both stdout and the file """ def __init__(self, *file_objects_or_paths): """ Args: *file_objects_or_paths: Mix of file objects (like sys.stdout) or string paths to log files """ self.files = [] self.opened_files = [] # Track files we opened so we can close them later for item in file_objects_or_paths: if isinstance(item, str): # It's a path string - open it as a file f = open(item, 'a', buffering=1) # Append mode, line buffered self.files.append(f) self.opened_files.append(f) else: # It's already a file-like object (e.g., sys.stdout) self.files.append(item) def write(self, data): """Write data to all streams""" for f in self.files: try: f.write(data) f.flush() except Exception as e: # Handle closed file gracefully print(f"Warning: Could not write to {f}: {e}", file=sys.stderr) def flush(self): """Flush all streams""" for f in self.files: try: f.flush() except: pass def isatty(self): """Check if any stream is a terminal (for tqdm compatibility)""" return any(getattr(f, "isatty", lambda: False)() for f in self.files) def fileno(self): """Get file descriptor from any real file-like stream""" for f in self.files: if hasattr(f, "fileno"): try: return f.fileno() except Exception: pass raise io.UnsupportedOperation("No fileno available") def close(self): """Close any files we opened""" for f in self.opened_files: try: f.close() except: pass self.opened_files.clear() def __del__(self): """Cleanup on deletion""" self.close() def __enter__(self): """Context manager support""" return self def __exit__(self, exc_type, exc_val, exc_tb): """Context manager cleanup""" self.close() return False class MAETrainer: def __init__(self,configs={}): self.configs=configs os.makedirs(configs["logdir"],exist_ok=True) log_path_train = os.path.join(configs["logdir"], "training_log.txt") log_path_val = os.path.join(configs["logdir"], "val_log.txt") log_path_test = os.path.join(configs["logdir"], "test_log.txt") #self.log_file = open(log_path, 'w', buffering=1) self.traintee = TeeFile(sys.stdout, log_path_train) self.valtee = TeeFile(sys.stdout, log_path_val) self.testtee = TeeFile(sys.stdout, log_path_test) for dir in self.configs["dirsToMake"]: os.makedirs(dir,exist_ok=True) self.model=MaskedAutoEncoder( c=configs["channels"], mask_ratio=configs["mask_ratio"], dropout=configs["dropout"], img_size=configs["img_size"], encoder_dim=configs["encoder_dim"], mlp_dim=configs["mlp_dim"], decoder_dim=configs["decoder_dim"], encoder_depth=configs["encoder_depth"], encoder_head=configs["encoder_head"], decoder_depth=configs["decoder_depth"], decoder_head=configs["decoder_head"], patch_size=configs["patch_size"] ).to(configs["device"]) self.criterion=mae_loss self.optimizer=torch.optim.AdamW(self.model.parameters(),configs["lr"], weight_decay=configs["weight_decay"]) self.schedular1=torch.optim.lr_scheduler.LinearLR(self.optimizer,start_factor=0.1,end_factor=1.0,total_iters=configs["warmup"]) self.schedular2=torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer,T_max=configs["num_epochs"]-configs["warmup"]) self.schedular=torch.optim.lr_scheduler.SequentialLR (self.optimizer,schedulers=[self.schedular1,self.schedular2],milestones=[configs["warmup"]]) self.scaler=torch.amp.GradScaler() self.train_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["train_csv"],root_dir=configs["datadir"],augment=True,use_frontal_only=True) self.val_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["val_csv"],root_dir=configs["datadir"],augment=False,use_frontal_only=True ) self.class_Weights=self.train_dataset.get_class_weights().to(self.configs["device"]) self.sample_Weights=self.train_dataset.get_sample_weights() self.sampler=torch.utils.data.WeightedRandomSampler(self.sample_Weights,num_samples=len(self.sample_Weights)) self.trainloader=DataLoader(self.train_dataset,batch_size=configs["batch_size"],sampler=self.sampler,num_workers=8,pin_memory=True,persistent_workers=True) self.valloader=DataLoader(self.val_dataset,batch_size=configs["batch_size"],shuffle=False,num_workers=8,pin_memory=True,persistent_workers=True) self.history={"train_loss":[],"val_loss":[]} self.current_epoch=0 if os.path.exists(self.configs["resume"]): loadedpickle=torch.load(self.configs["resume"],map_location=self.configs["device"]) self.model.load_state_dict(loadedpickle["model"],strict=False) self.optimizer.load_state_dict(loadedpickle["optimizer"]) self.schedular.load_state_dict(loadedpickle["schedular"]) self.schedular1.load_state_dict(loadedpickle["schedular1"]) self.schedular2.load_state_dict(loadedpickle["schedular2"]) self.scaler.load_state_dict(loadedpickle["scaler"]) self.current_epoch=loadedpickle["epoch"]+1 self.test_dataset = None self.testloader = None if configs.get("test_csv"): self.test_dataset = CheXpertDataset( zip_path=configs["zip_path"], csv_path=configs["test_csv"], root_dir=configs["datadir"], augment=False, use_frontal_only=True ) self.testloader = DataLoader( self.test_dataset, batch_size=configs["batch_size"], shuffle=False, num_workers=8, pin_memory=True, persistent_workers=True ) print(f"Test loader ready – {len(self.test_dataset)} images") torch.backends.cudnn.benchmark = True torch.backends.cudnn.enabled = True # FIX: Set memory allocator settings os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' # FIX: Enable gradient checkpointing if model supports it if hasattr(self.model, 'enable_gradient_checkpointing'): self.model.enable_gradient_checkpointing() @staticmethod def plot_training_metrics(metrics, epoch,figs_path): import matplotlib.pyplot as plt """ Plot loss and AUC curves from training metrics. Args: metrics (dict): Dictionary containing lists for each metric key: { "train_loss": [...], "val_loss": [...] } epoch (int): Current epoch number (used for title or axis scaling) """ epochs = list(range(1, epoch + 1)) #Compute the common length across all series keys = ["train_loss","val_loss"] lengths = [len(metrics[k]) for k in keys if k in metrics] if not lengths: return n = min(lengths) # Slice everything to the same length m = {k: metrics[k][:n] for k in keys if k in metrics} epochs = list(range(1, n + 1)) plt.figure(figsize=(14, 6)) # ---- Loss subplot ---- plt.subplot(1, 2, 1) plt.plot(epochs, metrics["train_loss"], label="Train Loss", marker='o') plt.plot(epochs, metrics["val_loss"], label="Val Loss", marker='s') plt.xlabel("Epoch") plt.ylabel("Loss") plt.title("Training & Validation Loss") plt.legend() plt.grid(True, linestyle='--', alpha=0.6) plt.tight_layout() os.makedirs(os.path.join(figs_path,str(epoch)),exist_ok=True) plt.savefig(os.path.join(figs_path,str(epoch),"metrics.png")) plt.show() def train_epoch(self, epoch, looper): self.model.train() running_loss = 0.0 all_preds = [] all_targets = [] current_loss=0 total_batches = len(self.trainloader) for batch_idx, data in looper: image = data['image'].to(self.configs["device"], non_blocking=True) target = data['labels'].to(self.configs["device"], non_blocking=True) with torch.autocast(device_type=self.configs["device"].type, dtype=torch.float16): img,preds,mask = self.model(image) loss = self.criterion(img,preds,mask) loss_back = loss / self.configs["accumulation"] running_loss += loss.item() if torch.isfinite(loss): #loss_back.backward() self.scaler.scale(loss_back).backward() else: self.optimizer.zero_grad(set_to_none=True) continue if (batch_idx + 1) % self.configs["accumulation"] == 0 or batch_idx == total_batches - 1: self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) self.scaler.step(self.optimizer) self.scaler.update() #self.optimizer.step() self.schedular.step() self.optimizer.zero_grad(set_to_none=True) # === LIVE METRICS (every batch) === current_loss = running_loss / (batch_idx + 1) if (batch_idx + 1) % 10 == 0: current_lr = self.optimizer.param_groups[0]['lr'] looper.set_postfix({ "lr": f"{current_lr:.2e}","batch":f"{batch_idx}/{total_batches}", "epoch": f"{epoch}/{self.configs['num_epochs']}", "loss": f"{current_loss:.3f}", }) return current_loss def validate(self, epoch, looper): self.model.eval() val_loss = 0.0 all_preds = [] all_targets = [] lenloader=len(self.valloader) current_loss=0 with torch.no_grad(): for batch_idx, data in looper: image = data["image"].to(self.configs["device"], non_blocking=True) target = data["labels"].to(self.configs["device"], non_blocking=True) with torch.autocast(device_type=self.configs["device"].type, dtype=torch.float16): img,preds,mask = self.model(image) loss = self.criterion(img,preds,mask) val_loss += loss.item() # === LIVE METRICS === current_loss = val_loss / (batch_idx + 1) if (batch_idx + 1) % 10 == 0 : looper.set_postfix({ "epoch": f"{epoch}/{self.configs['num_epochs']}", "batch":f"{batch_idx}/{lenloader}", "loss": f"{current_loss:.3f}", }) return current_loss def train(self): for epoch in range(self.current_epoch,self.configs["num_epochs"]): trainlooper=tqdm(enumerate(self.trainloader),desc="training: ", leave=False,file=self.traintee) vallooper=tqdm(enumerate(self.valloader),desc="validating: ",leave=False,file=self.valtee) self.model.train() self.optimizer.zero_grad(set_to_none=True) running_loss=self.train_epoch(epoch,trainlooper) torch.cuda.synchronize() torch.cuda.empty_cache() val_loss=self.validate(epoch,vallooper) torch.cuda.synchronize() torch.cuda.empty_cache() gc.collect() if (self.history["val_loss"] and (val_loss