|
|
from data.dataset import CheXpertDataset
|
|
|
from loss.mae_loss import mae_loss
|
|
|
from models.mae import *
|
|
|
from torch.utils.data import DataLoader
|
|
|
import json
|
|
|
import os
|
|
|
import io
|
|
|
import sys
|
|
|
|
|
|
class TeeFile:
|
|
|
"""
|
|
|
File-like object that writes to multiple streams (e.g., stdout and a file)
|
|
|
Automatically handles string paths by opening them as files.
|
|
|
|
|
|
Usage:
|
|
|
# This now works with both file objects and paths
|
|
|
tee = TeeFile(sys.stdout, "/path/to/log.txt")
|
|
|
print("Hello", file=tee) # Writes to both stdout and the file
|
|
|
"""
|
|
|
def __init__(self, *file_objects_or_paths):
|
|
|
"""
|
|
|
Args:
|
|
|
*file_objects_or_paths: Mix of file objects (like sys.stdout)
|
|
|
or string paths to log files
|
|
|
"""
|
|
|
self.files = []
|
|
|
self.opened_files = []
|
|
|
|
|
|
for item in file_objects_or_paths:
|
|
|
if isinstance(item, str):
|
|
|
|
|
|
f = open(item, 'a', buffering=1)
|
|
|
self.files.append(f)
|
|
|
self.opened_files.append(f)
|
|
|
else:
|
|
|
|
|
|
self.files.append(item)
|
|
|
|
|
|
def write(self, data):
|
|
|
"""Write data to all streams"""
|
|
|
for f in self.files:
|
|
|
try:
|
|
|
f.write(data)
|
|
|
f.flush()
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"Warning: Could not write to {f}: {e}", file=sys.stderr)
|
|
|
|
|
|
def flush(self):
|
|
|
"""Flush all streams"""
|
|
|
for f in self.files:
|
|
|
try:
|
|
|
f.flush()
|
|
|
except:
|
|
|
pass
|
|
|
|
|
|
def isatty(self):
|
|
|
"""Check if any stream is a terminal (for tqdm compatibility)"""
|
|
|
return any(getattr(f, "isatty", lambda: False)() for f in self.files)
|
|
|
|
|
|
def fileno(self):
|
|
|
"""Get file descriptor from any real file-like stream"""
|
|
|
for f in self.files:
|
|
|
if hasattr(f, "fileno"):
|
|
|
try:
|
|
|
return f.fileno()
|
|
|
except Exception:
|
|
|
pass
|
|
|
raise io.UnsupportedOperation("No fileno available")
|
|
|
|
|
|
def close(self):
|
|
|
"""Close any files we opened"""
|
|
|
for f in self.opened_files:
|
|
|
try:
|
|
|
f.close()
|
|
|
except:
|
|
|
pass
|
|
|
self.opened_files.clear()
|
|
|
|
|
|
def __del__(self):
|
|
|
"""Cleanup on deletion"""
|
|
|
self.close()
|
|
|
|
|
|
def __enter__(self):
|
|
|
"""Context manager support"""
|
|
|
return self
|
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
|
"""Context manager cleanup"""
|
|
|
self.close()
|
|
|
return False
|
|
|
|
|
|
class MAETrainer:
|
|
|
def __init__(self,configs={}):
|
|
|
|
|
|
self.configs=configs
|
|
|
os.makedirs(configs["logdir"],exist_ok=True)
|
|
|
log_path_train = os.path.join(configs["logdir"], "training_log.txt")
|
|
|
log_path_val = os.path.join(configs["logdir"], "val_log.txt")
|
|
|
log_path_test = os.path.join(configs["logdir"], "test_log.txt")
|
|
|
|
|
|
self.traintee = TeeFile(sys.stdout, log_path_train)
|
|
|
self.valtee = TeeFile(sys.stdout, log_path_val)
|
|
|
self.testtee = TeeFile(sys.stdout, log_path_test)
|
|
|
|
|
|
for dir in self.configs["dirsToMake"]: os.makedirs(dir,exist_ok=True)
|
|
|
|
|
|
self.model=MaskedAutoEncoder(
|
|
|
c=configs["channels"],
|
|
|
mask_ratio=configs["mask_ratio"],
|
|
|
dropout=configs["dropout"],
|
|
|
img_size=configs["img_size"],
|
|
|
encoder_dim=configs["encoder_dim"],
|
|
|
mlp_dim=configs["mlp_dim"],
|
|
|
decoder_dim=configs["decoder_dim"],
|
|
|
encoder_depth=configs["encoder_depth"],
|
|
|
encoder_head=configs["encoder_head"],
|
|
|
decoder_depth=configs["decoder_depth"],
|
|
|
decoder_head=configs["decoder_head"],
|
|
|
patch_size=configs["patch_size"]
|
|
|
).to(configs["device"])
|
|
|
|
|
|
self.criterion=mae_loss
|
|
|
|
|
|
self.optimizer=torch.optim.AdamW(self.model.parameters(),configs["lr"], weight_decay=configs["weight_decay"])
|
|
|
self.schedular1=torch.optim.lr_scheduler.LinearLR(self.optimizer,start_factor=0.1,end_factor=1.0,total_iters=configs["warmup"])
|
|
|
self.schedular2=torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer,T_max=configs["num_epochs"]-configs["warmup"])
|
|
|
self.schedular=torch.optim.lr_scheduler.SequentialLR (self.optimizer,schedulers=[self.schedular1,self.schedular2],milestones=[configs["warmup"]])
|
|
|
self.scaler=torch.amp.GradScaler()
|
|
|
|
|
|
self.train_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["train_csv"],root_dir=configs["datadir"],augment=True,use_frontal_only=True)
|
|
|
self.val_dataset= CheXpertDataset(zip_path=configs["zip_path"],csv_path=configs["val_csv"],root_dir=configs["datadir"],augment=False,use_frontal_only=True )
|
|
|
self.class_Weights=self.train_dataset.get_class_weights().to(self.configs["device"])
|
|
|
self.sample_Weights=self.train_dataset.get_sample_weights()
|
|
|
self.sampler=torch.utils.data.WeightedRandomSampler(self.sample_Weights,num_samples=len(self.sample_Weights))
|
|
|
self.trainloader=DataLoader(self.train_dataset,batch_size=configs["batch_size"],sampler=self.sampler,num_workers=8,pin_memory=True,persistent_workers=True)
|
|
|
self.valloader=DataLoader(self.val_dataset,batch_size=configs["batch_size"],shuffle=False,num_workers=8,pin_memory=True,persistent_workers=True)
|
|
|
self.history={"train_loss":[],"val_loss":[]}
|
|
|
|
|
|
self.current_epoch=0
|
|
|
|
|
|
if os.path.exists(self.configs["resume"]):
|
|
|
loadedpickle=torch.load(self.configs["resume"],map_location=self.configs["device"])
|
|
|
self.model.load_state_dict(loadedpickle["model"],strict=False)
|
|
|
self.optimizer.load_state_dict(loadedpickle["optimizer"])
|
|
|
self.schedular.load_state_dict(loadedpickle["schedular"])
|
|
|
self.schedular1.load_state_dict(loadedpickle["schedular1"])
|
|
|
self.schedular2.load_state_dict(loadedpickle["schedular2"])
|
|
|
self.scaler.load_state_dict(loadedpickle["scaler"])
|
|
|
self.current_epoch=loadedpickle["epoch"]+1
|
|
|
|
|
|
|
|
|
|
|
|
self.test_dataset = None
|
|
|
self.testloader = None
|
|
|
if configs.get("test_csv"):
|
|
|
self.test_dataset = CheXpertDataset(
|
|
|
zip_path=configs["zip_path"],
|
|
|
csv_path=configs["test_csv"],
|
|
|
root_dir=configs["datadir"],
|
|
|
augment=False,
|
|
|
use_frontal_only=True
|
|
|
)
|
|
|
self.testloader = DataLoader(
|
|
|
self.test_dataset,
|
|
|
batch_size=configs["batch_size"],
|
|
|
shuffle=False,
|
|
|
num_workers=8,
|
|
|
pin_memory=True,
|
|
|
persistent_workers=True
|
|
|
)
|
|
|
print(f"Test loader ready – {len(self.test_dataset)} images")
|
|
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
torch.backends.cudnn.enabled = True
|
|
|
|
|
|
|
|
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
|
|
|
|
|
|
|
|
if hasattr(self.model, 'enable_gradient_checkpointing'):
|
|
|
self.model.enable_gradient_checkpointing()
|
|
|
@staticmethod
|
|
|
def plot_training_metrics(metrics, epoch,figs_path):
|
|
|
import matplotlib.pyplot as plt
|
|
|
"""
|
|
|
Plot loss and AUC curves from training metrics.
|
|
|
|
|
|
Args:
|
|
|
metrics (dict): Dictionary containing lists for each metric key:
|
|
|
{
|
|
|
"train_loss": [...],
|
|
|
"val_loss": [...]
|
|
|
}
|
|
|
epoch (int): Current epoch number (used for title or axis scaling)
|
|
|
"""
|
|
|
epochs = list(range(1, epoch + 1))
|
|
|
|
|
|
|
|
|
keys = ["train_loss","val_loss"]
|
|
|
lengths = [len(metrics[k]) for k in keys if k in metrics]
|
|
|
if not lengths:
|
|
|
return
|
|
|
n = min(lengths)
|
|
|
|
|
|
|
|
|
m = {k: metrics[k][:n] for k in keys if k in metrics}
|
|
|
epochs = list(range(1, n + 1))
|
|
|
|
|
|
plt.figure(figsize=(14, 6))
|
|
|
|
|
|
|
|
|
|
|
|
plt.subplot(1, 2, 1)
|
|
|
plt.plot(epochs, metrics["train_loss"], label="Train Loss", marker='o')
|
|
|
plt.plot(epochs, metrics["val_loss"], label="Val Loss", marker='s')
|
|
|
plt.xlabel("Epoch")
|
|
|
plt.ylabel("Loss")
|
|
|
plt.title("Training & Validation Loss")
|
|
|
plt.legend()
|
|
|
plt.grid(True, linestyle='--', alpha=0.6)
|
|
|
|
|
|
|
|
|
plt.tight_layout()
|
|
|
os.makedirs(os.path.join(figs_path,str(epoch)),exist_ok=True)
|
|
|
plt.savefig(os.path.join(figs_path,str(epoch),"metrics.png"))
|
|
|
plt.show()
|
|
|
|
|
|
def train_epoch(self, epoch, looper):
|
|
|
self.model.train()
|
|
|
running_loss = 0.0
|
|
|
all_preds = []
|
|
|
all_targets = []
|
|
|
current_loss=0
|
|
|
total_batches = len(self.trainloader)
|
|
|
|
|
|
for batch_idx, data in looper:
|
|
|
image = data['image'].to(self.configs["device"], non_blocking=True)
|
|
|
target = data['labels'].to(self.configs["device"], non_blocking=True)
|
|
|
|
|
|
with torch.autocast(device_type=self.configs["device"].type, dtype=torch.float16):
|
|
|
img,preds,mask = self.model(image)
|
|
|
loss = self.criterion(img,preds,mask)
|
|
|
|
|
|
loss_back = loss / self.configs["accumulation"]
|
|
|
running_loss += loss.item()
|
|
|
|
|
|
if torch.isfinite(loss):
|
|
|
|
|
|
self.scaler.scale(loss_back).backward()
|
|
|
else:
|
|
|
self.optimizer.zero_grad(set_to_none=True)
|
|
|
continue
|
|
|
|
|
|
if (batch_idx + 1) % self.configs["accumulation"] == 0 or batch_idx == total_batches - 1:
|
|
|
self.scaler.unscale_(self.optimizer)
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
|
self.scaler.step(self.optimizer)
|
|
|
self.scaler.update()
|
|
|
|
|
|
self.schedular.step()
|
|
|
self.optimizer.zero_grad(set_to_none=True)
|
|
|
|
|
|
|
|
|
|
|
|
current_loss = running_loss / (batch_idx + 1)
|
|
|
if (batch_idx + 1) % 10 == 0:
|
|
|
current_lr = self.optimizer.param_groups[0]['lr']
|
|
|
looper.set_postfix({
|
|
|
"lr": f"{current_lr:.2e}","batch":f"{batch_idx}/{total_batches}",
|
|
|
"epoch": f"{epoch}/{self.configs['num_epochs']}",
|
|
|
"loss": f"{current_loss:.3f}",
|
|
|
})
|
|
|
|
|
|
return current_loss
|
|
|
def validate(self, epoch, looper):
|
|
|
self.model.eval()
|
|
|
val_loss = 0.0
|
|
|
all_preds = []
|
|
|
all_targets = []
|
|
|
lenloader=len(self.valloader)
|
|
|
current_loss=0
|
|
|
with torch.no_grad():
|
|
|
for batch_idx, data in looper:
|
|
|
image = data["image"].to(self.configs["device"], non_blocking=True)
|
|
|
target = data["labels"].to(self.configs["device"], non_blocking=True)
|
|
|
|
|
|
with torch.autocast(device_type=self.configs["device"].type, dtype=torch.float16):
|
|
|
img,preds,mask = self.model(image)
|
|
|
loss = self.criterion(img,preds,mask)
|
|
|
|
|
|
val_loss += loss.item()
|
|
|
|
|
|
|
|
|
current_loss = val_loss / (batch_idx + 1)
|
|
|
if (batch_idx + 1) % 10 == 0 :
|
|
|
|
|
|
looper.set_postfix({
|
|
|
"epoch": f"{epoch}/{self.configs['num_epochs']}",
|
|
|
"batch":f"{batch_idx}/{lenloader}",
|
|
|
"loss": f"{current_loss:.3f}",
|
|
|
})
|
|
|
|
|
|
return current_loss
|
|
|
def train(self):
|
|
|
|
|
|
for epoch in range(self.current_epoch,self.configs["num_epochs"]):
|
|
|
trainlooper=tqdm(enumerate(self.trainloader),desc="training: ", leave=False,file=self.traintee)
|
|
|
vallooper=tqdm(enumerate(self.valloader),desc="validating: ",leave=False,file=self.valtee)
|
|
|
|
|
|
|
|
|
self.model.train()
|
|
|
self.optimizer.zero_grad(set_to_none=True)
|
|
|
|
|
|
running_loss=self.train_epoch(epoch,trainlooper)
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
val_loss=self.validate(epoch,vallooper)
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
gc.collect()
|
|
|
|
|
|
if (self.history["val_loss"] and (val_loss<min(self.history["val_loss"]))) :
|
|
|
checkpoint={"model":self.model.state_dict(),"optimizer":self.optimizer.state_dict(),"schedular":self.schedular.state_dict(),"schedular1":self.schedular1.state_dict(),"schedular2":self.schedular2.state_dict(),"scaler":self.scaler.state_dict(),"epoch":epoch}
|
|
|
torch.save(checkpoint, self.configs["resume"])
|
|
|
|
|
|
print(f"train loss {running_loss} val loss {val_loss}")
|
|
|
|
|
|
self.history["train_loss"].append(float(running_loss))
|
|
|
self.history["val_loss"].append(float(val_loss))
|
|
|
|
|
|
if epoch%10==0:
|
|
|
historyfile=os.path.join(self.configs["logdir"],"history.json")
|
|
|
if os.path.exists(historyfile):
|
|
|
with open(historyfile,"r") as f:
|
|
|
history=json.load(f)
|
|
|
history["train_loss"]+=self.history["train_loss"]
|
|
|
history["val_loss"]+=self.history["val_loss"]
|
|
|
with open(historyfile,"w") as f:
|
|
|
json.dump(self.history,f)
|
|
|
f.close()
|
|
|
MAETrainer.plot_training_metrics(self.history,epoch+1,self.configs["logdir"])
|
|
|
|
|
|
self.current_epoch=epoch |