|
|
|
|
|
import os
|
|
|
import time
|
|
|
import logging
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from torch.optim import AdamW
|
|
|
import random
|
|
|
import math
|
|
|
from tqdm.auto import tqdm
|
|
|
from pathlib import Path
|
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
import numpy as np
|
|
|
from torch.utils.data import Subset
|
|
|
|
|
|
from .models.vae import MedicalVAE
|
|
|
from .models.unet import DiffusionUNet
|
|
|
from .models.text_encoder import MedicalTextEncoder
|
|
|
from .models.diffusion import DiffusionModel
|
|
|
from .utils.processing import set_seed, get_device, log_gpu_memory, create_transforms
|
|
|
from .utils.dataset import ChestXrayDataset
|
|
|
from transformers import AutoTokenizer
|
|
|
from torch.utils.data import random_split
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EarlyStopping:
|
|
|
"""Early stopping implementation."""
|
|
|
def __init__(self, patience=7, verbose=True, delta=0, path='checkpoint.pt'):
|
|
|
"""Initialize early stopping."""
|
|
|
self.patience = patience
|
|
|
self.verbose = verbose
|
|
|
self.counter = 0
|
|
|
self.best_score = None
|
|
|
self.early_stop = False
|
|
|
self.val_loss_min = float('inf')
|
|
|
self.delta = delta
|
|
|
self.path = path
|
|
|
|
|
|
def __call__(self, val_loss, model=None):
|
|
|
"""Call early stopping logic."""
|
|
|
score = -val_loss
|
|
|
|
|
|
if self.best_score is None:
|
|
|
self.best_score = score
|
|
|
self.save_checkpoint(val_loss, model)
|
|
|
elif score < self.best_score + self.delta:
|
|
|
self.counter += 1
|
|
|
if self.verbose:
|
|
|
logger.info(f'EarlyStopping counter: {self.counter} out of {self.patience}')
|
|
|
if self.counter >= self.patience:
|
|
|
self.early_stop = True
|
|
|
return True
|
|
|
else:
|
|
|
self.best_score = score
|
|
|
self.save_checkpoint(val_loss, model)
|
|
|
self.counter = 0
|
|
|
|
|
|
return False
|
|
|
|
|
|
def save_checkpoint(self, val_loss, model):
|
|
|
"""Save model checkpoint."""
|
|
|
if self.verbose:
|
|
|
logger.info(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
|
|
|
if model is not None:
|
|
|
torch.save(model.state_dict(), self.path)
|
|
|
self.val_loss_min = val_loss
|
|
|
|
|
|
def create_lr_scheduler(optimizer, num_warmup_steps, num_training_steps, min_lr_ratio=0.1):
|
|
|
"""Create learning rate scheduler with warmup and cosine decay."""
|
|
|
def lr_lambda(current_step):
|
|
|
|
|
|
if current_step < num_warmup_steps:
|
|
|
return float(current_step) / float(max(1, num_warmup_steps))
|
|
|
|
|
|
|
|
|
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
|
|
return max(min_lr_ratio, 0.5 * (1.0 + math.cos(math.pi * progress)))
|
|
|
|
|
|
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
|
|
|
|
|
def save_checkpoint(model, optimizer, scheduler, epoch, global_step, best_metrics, checkpoint_dir, is_best=False):
|
|
|
"""Save checkpoint every checkpoint_freq epochs plus best model"""
|
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
if isinstance(model, dict):
|
|
|
|
|
|
checkpoint = {
|
|
|
'epoch': epoch,
|
|
|
'model_state_dict': model['vae'].state_dict(),
|
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
|
'best_metrics': best_metrics,
|
|
|
'global_step': global_step
|
|
|
}
|
|
|
else:
|
|
|
|
|
|
checkpoint = {
|
|
|
'epoch': epoch,
|
|
|
'vae_state_dict': model.vae.state_dict(),
|
|
|
'unet_state_dict': model.unet.state_dict(),
|
|
|
'text_encoder_state_dict': model.text_encoder.state_dict(),
|
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
|
'best_metrics': best_metrics,
|
|
|
'global_step': global_step,
|
|
|
'config': {
|
|
|
'latent_channels': model.vae.latent_channels,
|
|
|
'model_channels': model.unet.model_channels,
|
|
|
'scheduler_type': model.scheduler_type,
|
|
|
'beta_schedule': model.beta_schedule,
|
|
|
'prediction_type': model.prediction_type,
|
|
|
'guidance_scale': model.guidance_scale,
|
|
|
'num_train_timesteps': model.num_train_timesteps
|
|
|
}
|
|
|
}
|
|
|
|
|
|
if scheduler is not None:
|
|
|
checkpoint['scheduler_state_dict'] = scheduler.state_dict()
|
|
|
|
|
|
|
|
|
if not is_best:
|
|
|
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt")
|
|
|
else:
|
|
|
checkpoint_path = os.path.join(checkpoint_dir, "best_model.pt")
|
|
|
|
|
|
|
|
|
torch.save(checkpoint, checkpoint_path)
|
|
|
logger.info(f"Checkpoint saved to {checkpoint_path}")
|
|
|
|
|
|
|
|
|
if not is_best:
|
|
|
cleanup_old_checkpoints(checkpoint_dir, keep_last_n=5)
|
|
|
|
|
|
def cleanup_old_checkpoints(checkpoint_dir, keep_last_n):
|
|
|
"""Remove old checkpoints, keeping only the most recent n checkpoints"""
|
|
|
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_epoch_")]
|
|
|
|
|
|
if len(checkpoints) <= keep_last_n:
|
|
|
return
|
|
|
|
|
|
|
|
|
checkpoints.sort(key=lambda x: int(x.split("_epoch_")[1].split(".")[0]))
|
|
|
|
|
|
|
|
|
for old_ckpt in checkpoints[:-keep_last_n]:
|
|
|
old_path = os.path.join(checkpoint_dir, old_ckpt)
|
|
|
try:
|
|
|
os.remove(old_path)
|
|
|
logger.info(f"Removed old checkpoint: {old_path}")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to remove old checkpoint {old_path}: {e}")
|
|
|
|
|
|
def load_checkpoint(model, optimizer, scheduler, path):
|
|
|
"""Load checkpoint and resume training"""
|
|
|
if not os.path.exists(path):
|
|
|
logger.info(f"No checkpoint found at {path}")
|
|
|
return 0, 0, {'val_loss': float('inf')}
|
|
|
|
|
|
logger.info(f"Loading checkpoint from {path}")
|
|
|
checkpoint = torch.load(path, map_location='cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
|
|
|
if isinstance(model, dict):
|
|
|
|
|
|
model['vae'].load_state_dict(checkpoint['model_state_dict'])
|
|
|
else:
|
|
|
|
|
|
model.vae.load_state_dict(checkpoint['vae_state_dict'])
|
|
|
model.unet.load_state_dict(checkpoint['unet_state_dict'])
|
|
|
model.text_encoder.load_state_dict(checkpoint['text_encoder_state_dict'])
|
|
|
|
|
|
|
|
|
if optimizer is not None and 'optimizer_state_dict' in checkpoint:
|
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
|
|
|
|
if scheduler is not None and 'scheduler_state_dict' in checkpoint:
|
|
|
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
|
|
|
|
|
|
|
epoch = checkpoint.get('epoch', 0)
|
|
|
global_step = checkpoint.get('global_step', 0)
|
|
|
best_metrics = checkpoint.get('best_metrics', {'val_loss': float('inf')})
|
|
|
|
|
|
logger.info(f"Loaded checkpoint from epoch {epoch}")
|
|
|
|
|
|
return epoch, global_step, best_metrics
|
|
|
|
|
|
def visualize_epoch_results(epoch, model, tokenizer, val_loader, output_dir):
|
|
|
"""Generate and save visualization samples after each epoch."""
|
|
|
|
|
|
samples_dir = os.path.join(output_dir, "visualizations", f"epoch_{epoch+1}")
|
|
|
os.makedirs(samples_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
val_batch = next(iter(val_loader))
|
|
|
|
|
|
|
|
|
batch_size = min(2, len(val_batch['image']))
|
|
|
indices = random.sample(range(len(val_batch['image'])), batch_size)
|
|
|
|
|
|
for i, idx in enumerate(indices):
|
|
|
|
|
|
img = val_batch['image'][idx].unsqueeze(0)
|
|
|
if isinstance(model, dict):
|
|
|
device = next(model['vae'].parameters()).device
|
|
|
img = img.to(device)
|
|
|
vae = model['vae']
|
|
|
else:
|
|
|
img = img.to(model.device)
|
|
|
vae = model.vae
|
|
|
|
|
|
report = val_batch['report'][idx]
|
|
|
|
|
|
|
|
|
img_np = img.squeeze(0).cpu().numpy().transpose(1, 2, 0)
|
|
|
img_np = (img_np * 0.5 + 0.5) * 255
|
|
|
if img_np.shape[-1] == 1:
|
|
|
img_np = img_np.squeeze(-1)
|
|
|
img_path = os.path.join(samples_dir, f"real_{i+1}.png")
|
|
|
from PIL import Image
|
|
|
Image.fromarray(img_np.astype(np.uint8)).save(img_path)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
recon, _, _ = vae(img)
|
|
|
|
|
|
|
|
|
recon_np = recon.squeeze(0).cpu().numpy().transpose(1, 2, 0)
|
|
|
recon_np = (recon_np * 0.5 + 0.5) * 255
|
|
|
if recon_np.shape[-1] == 1:
|
|
|
recon_np = recon_np.squeeze(-1)
|
|
|
recon_path = os.path.join(samples_dir, f"recon_{i+1}.png")
|
|
|
Image.fromarray(recon_np.astype(np.uint8)).save(recon_path)
|
|
|
|
|
|
|
|
|
report_path = os.path.join(samples_dir, f"report_{i+1}.txt")
|
|
|
with open(report_path, "w") as f:
|
|
|
f.write(report)
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error generating real samples: {e}")
|
|
|
|
|
|
|
|
|
if not isinstance(model, dict) and tokenizer is not None:
|
|
|
try:
|
|
|
|
|
|
sample_prompts = [
|
|
|
"Normal chest X-ray with clear lungs and no abnormalities.",
|
|
|
"Right lower lobe pneumonia with focal consolidation."
|
|
|
]
|
|
|
|
|
|
|
|
|
model.vae.eval()
|
|
|
model.text_encoder.eval()
|
|
|
model.unet.eval()
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for i, prompt in enumerate(sample_prompts):
|
|
|
results = model.sample(
|
|
|
prompt,
|
|
|
height=256,
|
|
|
width=256,
|
|
|
num_inference_steps=30,
|
|
|
tokenizer=tokenizer
|
|
|
)
|
|
|
|
|
|
|
|
|
img = results['images'][0]
|
|
|
img_np = img.cpu().numpy().transpose(1, 2, 0)
|
|
|
img_np = img_np * 255
|
|
|
if img_np.shape[-1] == 1:
|
|
|
img_np = img_np.squeeze(-1)
|
|
|
img_path = os.path.join(samples_dir, f"gen_{i+1}.png")
|
|
|
from PIL import Image
|
|
|
Image.fromarray(img_np.astype(np.uint8)).save(img_path)
|
|
|
|
|
|
|
|
|
prompt_path = os.path.join(samples_dir, f"prompt_{i+1}.txt")
|
|
|
with open(prompt_path, "w") as f:
|
|
|
f.write(prompt)
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error generating samples from prompts: {e}")
|
|
|
|
|
|
logger.info(f"Saved visualization for epoch {epoch+1} to {samples_dir}")
|
|
|
|
|
|
def create_quick_test_dataset(dataset, percentage=0.01):
|
|
|
"""Create a small subset of a dataset for quick testing."""
|
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
class SmallDatasetWrapper(Dataset):
|
|
|
def __init__(self, dataset, percentage=0.01):
|
|
|
self.dataset = dataset
|
|
|
indices = random.sample(range(len(dataset)), int(len(dataset) * percentage))
|
|
|
logger.info(f"Using {len(indices)} samples out of {len(dataset)} ({percentage*100:.1f}%)")
|
|
|
self.indices = indices
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
return self.dataset[self.indices[idx]]
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.indices)
|
|
|
|
|
|
return SmallDatasetWrapper(dataset, percentage)
|
|
|
|
|
|
def train(
|
|
|
config: Dict,
|
|
|
dataset_path: str,
|
|
|
reports_csv: str,
|
|
|
projections_csv: str,
|
|
|
output_dir: str = "./outputs",
|
|
|
resume_from: Optional[str] = None,
|
|
|
train_vae_only: bool = False,
|
|
|
seed: int = 42,
|
|
|
quick_test: bool = False
|
|
|
):
|
|
|
"""
|
|
|
Train the chest X-ray diffusion model.
|
|
|
|
|
|
Args:
|
|
|
config: Configuration dictionary with model and training parameters
|
|
|
dataset_path: Path to the X-ray image directory
|
|
|
reports_csv: Path to the reports CSV file
|
|
|
projections_csv: Path to the projections CSV file
|
|
|
output_dir: Path to save outputs
|
|
|
resume_from: Path to resume training from checkpoint
|
|
|
train_vae_only: Whether to train only the VAE component
|
|
|
seed: Random seed for reproducibility
|
|
|
quick_test: Whether to run a quick test with reduced settings
|
|
|
"""
|
|
|
|
|
|
if quick_test:
|
|
|
logger.warning("⚠️ RUNNING IN TEST MODE - QUICK TEST WITH 1% OF DATA AND REDUCED SETTINGS ⚠️")
|
|
|
|
|
|
quick_config = config.copy()
|
|
|
quick_config["batch_size"] = min(config.get("batch_size", 4), 2)
|
|
|
quick_config["epochs"] = min(config.get("epochs", 100), 2)
|
|
|
quick_config["num_workers"] = 0
|
|
|
config = quick_config
|
|
|
|
|
|
|
|
|
batch_size = config.get('batch_size', 4)
|
|
|
num_workers = config.get('num_workers', 0)
|
|
|
epochs = config.get('epochs', 100)
|
|
|
learning_rate = config.get('learning_rate', 1e-4)
|
|
|
latent_channels = config.get('latent_channels', 8)
|
|
|
model_channels = config.get('model_channels', 48)
|
|
|
image_size = config.get('image_size', 256)
|
|
|
gradient_accumulation_steps = config.get('gradient_accumulation_steps', 4)
|
|
|
use_amp = config.get('use_amp', True)
|
|
|
checkpoint_freq = config.get('checkpoint_freq', 5)
|
|
|
tokenizer_name = config.get('tokenizer_name', "dmis-lab/biobert-base-cased-v1.1")
|
|
|
|
|
|
|
|
|
set_seed(seed)
|
|
|
device = get_device()
|
|
|
|
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
if train_vae_only:
|
|
|
checkpoint_dir = os.path.join(output_dir, "checkpoints", "vae")
|
|
|
else:
|
|
|
checkpoint_dir = os.path.join(output_dir, "checkpoints", "diffusion")
|
|
|
|
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
transforms = create_transforms(image_size)
|
|
|
logger.info(f"Creating dataset from {dataset_path}")
|
|
|
|
|
|
|
|
|
dataset = ChestXrayDataset(
|
|
|
reports_csv=reports_csv,
|
|
|
projections_csv=projections_csv,
|
|
|
image_folder=dataset_path,
|
|
|
transform=None,
|
|
|
target_size=(image_size, image_size),
|
|
|
filter_frontal=True,
|
|
|
tokenizer_name=tokenizer_name,
|
|
|
max_length=256,
|
|
|
use_clahe=True
|
|
|
)
|
|
|
|
|
|
|
|
|
if quick_test:
|
|
|
dataset = create_quick_test_dataset(dataset, percentage=0.01)
|
|
|
|
|
|
|
|
|
dataset_size = len(dataset)
|
|
|
val_size = int(0.1 * dataset_size)
|
|
|
test_size = int(0.1 * dataset_size)
|
|
|
train_size = dataset_size - val_size - test_size
|
|
|
|
|
|
|
|
|
generator = torch.Generator().manual_seed(seed)
|
|
|
train_dataset, val_dataset, test_dataset = random_split(
|
|
|
dataset, [train_size, val_size, test_size], generator=generator
|
|
|
)
|
|
|
|
|
|
|
|
|
train_transform, val_transform = transforms
|
|
|
|
|
|
|
|
|
def set_dataset_transform(dataset, transform):
|
|
|
"""Set transform for a specific dataset split."""
|
|
|
dataset.transform = transform
|
|
|
|
|
|
|
|
|
original_getitem = dataset.__getitem__
|
|
|
|
|
|
def new_getitem(idx):
|
|
|
item = original_getitem(idx)
|
|
|
if dataset.transform and 'image' in item and item['image'] is not None:
|
|
|
item['image'] = dataset.transform(item['image'])
|
|
|
return item
|
|
|
|
|
|
dataset.__getitem__ = new_getitem
|
|
|
|
|
|
set_dataset_transform(train_dataset, train_transform)
|
|
|
set_dataset_transform(val_dataset, val_transform)
|
|
|
set_dataset_transform(test_dataset, val_transform)
|
|
|
|
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
from .utils.processing import custom_collate_fn
|
|
|
|
|
|
train_loader = DataLoader(
|
|
|
train_dataset,
|
|
|
batch_size=batch_size,
|
|
|
shuffle=True,
|
|
|
num_workers=num_workers,
|
|
|
pin_memory=True,
|
|
|
drop_last=True,
|
|
|
worker_init_fn=lambda worker_id: np.random.seed(seed + worker_id),
|
|
|
collate_fn=custom_collate_fn
|
|
|
)
|
|
|
|
|
|
val_loader = DataLoader(
|
|
|
val_dataset,
|
|
|
batch_size=batch_size,
|
|
|
shuffle=False,
|
|
|
num_workers=num_workers,
|
|
|
pin_memory=True,
|
|
|
drop_last=False,
|
|
|
collate_fn=custom_collate_fn
|
|
|
)
|
|
|
|
|
|
test_loader = DataLoader(
|
|
|
test_dataset,
|
|
|
batch_size=batch_size,
|
|
|
shuffle=False,
|
|
|
num_workers=num_workers,
|
|
|
pin_memory=True,
|
|
|
drop_last=False,
|
|
|
collate_fn=custom_collate_fn
|
|
|
)
|
|
|
|
|
|
|
|
|
logger.info("Initializing models")
|
|
|
|
|
|
|
|
|
vae = MedicalVAE(
|
|
|
in_channels=1,
|
|
|
out_channels=1,
|
|
|
latent_channels=latent_channels,
|
|
|
hidden_dims=[model_channels, model_channels*2, model_channels*4, model_channels*8]
|
|
|
).to(device)
|
|
|
|
|
|
|
|
|
if train_vae_only:
|
|
|
optimizer = AdamW(vae.parameters(), lr=learning_rate, weight_decay=1e-6)
|
|
|
|
|
|
|
|
|
start_epoch = 0
|
|
|
global_step = 0
|
|
|
best_metrics = {'val_loss': float('inf')}
|
|
|
|
|
|
|
|
|
if resume_from and os.path.exists(resume_from):
|
|
|
start_epoch, global_step, best_metrics = load_checkpoint(
|
|
|
{'vae': vae}, optimizer, None, resume_from
|
|
|
)
|
|
|
logger.info(f"Resumed VAE training from epoch {start_epoch}")
|
|
|
|
|
|
|
|
|
total_steps = len(train_loader) * epochs // gradient_accumulation_steps
|
|
|
warmup_steps = int(0.1 * total_steps)
|
|
|
scheduler = create_lr_scheduler(optimizer, warmup_steps, total_steps)
|
|
|
|
|
|
|
|
|
vae_trainer = VAETrainer(
|
|
|
model=vae,
|
|
|
train_loader=train_loader,
|
|
|
val_loader=val_loader,
|
|
|
optimizer=optimizer,
|
|
|
scheduler=scheduler,
|
|
|
device=device,
|
|
|
config=config
|
|
|
)
|
|
|
|
|
|
best_model = vae_trainer.train(
|
|
|
num_epochs=epochs,
|
|
|
checkpoint_dir=checkpoint_dir,
|
|
|
start_epoch=start_epoch,
|
|
|
global_step=global_step,
|
|
|
best_metrics=best_metrics
|
|
|
)
|
|
|
|
|
|
logger.info("VAE training complete")
|
|
|
return best_model
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
text_encoder = MedicalTextEncoder(
|
|
|
model_name=tokenizer_name,
|
|
|
projection_dim=768,
|
|
|
freeze_base=True
|
|
|
).to(device)
|
|
|
|
|
|
|
|
|
unet = DiffusionUNet(
|
|
|
in_channels=latent_channels,
|
|
|
model_channels=model_channels,
|
|
|
out_channels=latent_channels,
|
|
|
num_res_blocks=2,
|
|
|
attention_resolutions=(8, 16, 32),
|
|
|
dropout=0.1,
|
|
|
channel_mult=(1, 2, 4, 8),
|
|
|
context_dim=768
|
|
|
).to(device)
|
|
|
|
|
|
|
|
|
diffusion_model = DiffusionModel(
|
|
|
vae=vae,
|
|
|
unet=unet,
|
|
|
text_encoder=text_encoder,
|
|
|
scheduler_type=config.get('scheduler_type', "ddim"),
|
|
|
num_train_timesteps=config.get('num_train_timesteps', 1000),
|
|
|
beta_schedule=config.get('beta_schedule', "linear"),
|
|
|
prediction_type=config.get('prediction_type', "epsilon"),
|
|
|
guidance_scale=config.get('guidance_scale', 7.5),
|
|
|
device=device
|
|
|
)
|
|
|
|
|
|
|
|
|
train_unet_only = config.get('train_unet_only', True)
|
|
|
|
|
|
if train_unet_only:
|
|
|
optimizer = AdamW(unet.parameters(), lr=learning_rate, weight_decay=1e-6)
|
|
|
else:
|
|
|
parameters = list(unet.parameters())
|
|
|
parameters.extend(vae.parameters())
|
|
|
parameters.extend(text_encoder.parameters())
|
|
|
optimizer = AdamW(parameters, lr=learning_rate, weight_decay=1e-6)
|
|
|
|
|
|
|
|
|
start_epoch = 0
|
|
|
global_step = 0
|
|
|
best_metrics = {'val_loss': float('inf')}
|
|
|
|
|
|
|
|
|
if resume_from and os.path.exists(resume_from):
|
|
|
start_epoch, global_step, best_metrics = load_checkpoint(
|
|
|
diffusion_model, optimizer, None, resume_from
|
|
|
)
|
|
|
logger.info(f"Resumed diffusion training from epoch {start_epoch}")
|
|
|
|
|
|
|
|
|
try:
|
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
|
|
logger.info(f"Loaded tokenizer: {tokenizer_name}")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error loading tokenizer: {e}")
|
|
|
logger.warning("Will not generate samples during training")
|
|
|
tokenizer = None
|
|
|
|
|
|
|
|
|
total_steps = len(train_loader) * epochs
|
|
|
warmup_steps = int(0.1 * total_steps)
|
|
|
scheduler = create_lr_scheduler(optimizer, warmup_steps, total_steps)
|
|
|
|
|
|
|
|
|
diffusion_trainer = DiffusionTrainer(
|
|
|
model=diffusion_model,
|
|
|
train_loader=train_loader,
|
|
|
val_loader=val_loader,
|
|
|
optimizer=optimizer,
|
|
|
scheduler=scheduler,
|
|
|
tokenizer=tokenizer,
|
|
|
device=device,
|
|
|
config=config
|
|
|
)
|
|
|
|
|
|
trained_model = diffusion_trainer.train(
|
|
|
num_epochs=epochs,
|
|
|
checkpoint_dir=checkpoint_dir,
|
|
|
train_unet_only=train_unet_only,
|
|
|
start_epoch=start_epoch,
|
|
|
global_step=global_step,
|
|
|
best_metrics=best_metrics
|
|
|
)
|
|
|
|
|
|
logger.info("Diffusion model training complete")
|
|
|
return trained_model
|
|
|
|
|
|
class VAETrainer:
|
|
|
"""Trainer for VAE model."""
|
|
|
def __init__(
|
|
|
self,
|
|
|
model,
|
|
|
train_loader,
|
|
|
val_loader,
|
|
|
optimizer,
|
|
|
scheduler=None,
|
|
|
device=None,
|
|
|
config=None
|
|
|
):
|
|
|
self.model = model
|
|
|
self.train_loader = train_loader
|
|
|
self.val_loader = val_loader
|
|
|
self.optimizer = optimizer
|
|
|
self.scheduler = scheduler
|
|
|
self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
self.config = config if config is not None else {}
|
|
|
|
|
|
|
|
|
self.use_amp = self.config.get('use_amp', True)
|
|
|
self.gradient_accumulation_steps = self.config.get('gradient_accumulation_steps', 4)
|
|
|
self.checkpoint_freq = self.config.get('checkpoint_freq', 5)
|
|
|
|
|
|
|
|
|
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp and torch.cuda.is_available() else None
|
|
|
|
|
|
def vae_loss_fn(self, recon_x, x, mu, logvar, kld_weight=1e-4):
|
|
|
"""VAE loss function."""
|
|
|
|
|
|
recon_loss = F.mse_loss(recon_x, x, reduction='mean')
|
|
|
|
|
|
|
|
|
kld_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
|
|
|
|
|
|
|
|
|
loss = recon_loss + kld_weight * kld_loss
|
|
|
|
|
|
return loss, recon_loss, kld_loss
|
|
|
|
|
|
def train(
|
|
|
self,
|
|
|
num_epochs,
|
|
|
checkpoint_dir,
|
|
|
start_epoch=0,
|
|
|
global_step=0,
|
|
|
best_metrics=None
|
|
|
):
|
|
|
"""Train the VAE model."""
|
|
|
logger.info("Starting VAE training")
|
|
|
|
|
|
|
|
|
best_loss = best_metrics.get('val_loss', float('inf')) if best_metrics else float('inf')
|
|
|
best_model_state = None
|
|
|
|
|
|
|
|
|
early_stopping_path = os.path.join(checkpoint_dir, "best_vae.pt")
|
|
|
early_stopping = EarlyStopping(
|
|
|
patience=5,
|
|
|
verbose=True,
|
|
|
path=early_stopping_path
|
|
|
)
|
|
|
|
|
|
|
|
|
for epoch in range(start_epoch, num_epochs):
|
|
|
logger.info(f"Starting VAE epoch {epoch+1}/{num_epochs}")
|
|
|
|
|
|
|
|
|
self.model.train()
|
|
|
train_loss = 0.0
|
|
|
train_recon_loss = 0.0
|
|
|
train_kld_loss = 0.0
|
|
|
|
|
|
|
|
|
self.optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{num_epochs} (VAE Training)")
|
|
|
for batch_idx, batch in enumerate(progress_bar):
|
|
|
try:
|
|
|
|
|
|
images = batch['image'].to(self.device)
|
|
|
|
|
|
|
|
|
if images.shape[0] < 2:
|
|
|
logger.warning(f"Skipping batch with only {images.shape[0]} samples")
|
|
|
continue
|
|
|
|
|
|
|
|
|
if self.use_amp and torch.cuda.is_available():
|
|
|
with torch.cuda.amp.autocast():
|
|
|
recon, mu, logvar = self.model(images)
|
|
|
loss, recon_loss, kld_loss = self.vae_loss_fn(recon, images, mu, logvar)
|
|
|
|
|
|
loss = loss / self.gradient_accumulation_steps
|
|
|
|
|
|
|
|
|
self.scaler.scale(loss).backward()
|
|
|
|
|
|
|
|
|
if (batch_idx + 1) % self.gradient_accumulation_steps == 0 or batch_idx + 1 == len(self.train_loader):
|
|
|
self.scaler.unscale_(self.optimizer)
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
|
|
self.scaler.step(self.optimizer)
|
|
|
self.scaler.update()
|
|
|
self.optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
if self.scheduler:
|
|
|
self.scheduler.step()
|
|
|
global_step += 1
|
|
|
else:
|
|
|
recon, mu, logvar = self.model(images)
|
|
|
loss, recon_loss, kld_loss = self.vae_loss_fn(recon, images, mu, logvar)
|
|
|
|
|
|
loss = loss / self.gradient_accumulation_steps
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
|
if (batch_idx + 1) % self.gradient_accumulation_steps == 0 or batch_idx + 1 == len(self.train_loader):
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
|
|
self.optimizer.step()
|
|
|
self.optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
if self.scheduler:
|
|
|
self.scheduler.step()
|
|
|
global_step += 1
|
|
|
|
|
|
|
|
|
train_loss += loss.item() * self.gradient_accumulation_steps
|
|
|
train_recon_loss += recon_loss.item()
|
|
|
train_kld_loss += kld_loss.item()
|
|
|
|
|
|
|
|
|
progress_bar.set_postfix({
|
|
|
'loss': f"{loss.item() * self.gradient_accumulation_steps:.4f}",
|
|
|
'recon': f"{recon_loss.item():.4f}",
|
|
|
'kld': f"{kld_loss.item():.4f}"
|
|
|
})
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in VAE training batch {batch_idx}: {e}")
|
|
|
import traceback
|
|
|
logger.error(traceback.format_exc())
|
|
|
continue
|
|
|
|
|
|
|
|
|
train_loss /= max(1, len(self.train_loader))
|
|
|
train_recon_loss /= max(1, len(self.train_loader))
|
|
|
train_kld_loss /= max(1, len(self.train_loader))
|
|
|
|
|
|
|
|
|
self.model.eval()
|
|
|
val_loss = 0.0
|
|
|
val_recon_loss = 0.0
|
|
|
val_kld_loss = 0.0
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
val_progress = tqdm(self.val_loader, desc=f"Epoch {epoch+1}/{num_epochs} (VAE Validation)")
|
|
|
for batch_idx, batch in enumerate(val_progress):
|
|
|
try:
|
|
|
|
|
|
images = batch['image'].to(self.device)
|
|
|
|
|
|
|
|
|
if images.shape[0] < 2:
|
|
|
continue
|
|
|
|
|
|
|
|
|
recon, mu, logvar = self.model(images)
|
|
|
loss, recon_loss, kld_loss = self.vae_loss_fn(recon, images, mu, logvar)
|
|
|
|
|
|
|
|
|
val_loss += loss.item()
|
|
|
val_recon_loss += recon_loss.item()
|
|
|
val_kld_loss += kld_loss.item()
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in VAE validation: {e}")
|
|
|
continue
|
|
|
|
|
|
|
|
|
val_loss /= max(1, len(self.val_loader))
|
|
|
val_recon_loss /= max(1, len(self.val_loader))
|
|
|
val_kld_loss /= max(1, len(self.val_loader))
|
|
|
|
|
|
|
|
|
logger.info(f"VAE Epoch {epoch+1}/{num_epochs} | "
|
|
|
f"Train Loss: {train_loss:.4f} (Recon: {train_recon_loss:.4f}, KLD: {train_kld_loss:.4f}) | "
|
|
|
f"Val Loss: {val_loss:.4f} (Recon: {val_recon_loss:.4f}, KLD: {val_kld_loss:.4f})")
|
|
|
|
|
|
|
|
|
if val_loss < best_loss:
|
|
|
best_loss = val_loss
|
|
|
best_model_state = self.model.state_dict().copy()
|
|
|
|
|
|
|
|
|
save_checkpoint(
|
|
|
{'vae': self.model},
|
|
|
self.optimizer,
|
|
|
self.scheduler,
|
|
|
epoch+1,
|
|
|
global_step,
|
|
|
{'val_loss': val_loss},
|
|
|
checkpoint_dir,
|
|
|
is_best=True
|
|
|
)
|
|
|
|
|
|
|
|
|
if (epoch + 1) % self.checkpoint_freq == 0:
|
|
|
save_checkpoint(
|
|
|
{'vae': self.model},
|
|
|
self.optimizer,
|
|
|
self.scheduler,
|
|
|
epoch+1,
|
|
|
global_step,
|
|
|
{'val_loss': val_loss},
|
|
|
checkpoint_dir,
|
|
|
is_best=False
|
|
|
)
|
|
|
|
|
|
|
|
|
if early_stopping(val_loss, self.model):
|
|
|
logger.info(f"Early stopping triggered at epoch {epoch+1}")
|
|
|
break
|
|
|
|
|
|
|
|
|
if Path(checkpoint_dir).exists():
|
|
|
from PIL import Image
|
|
|
visualize_epoch_results(
|
|
|
epoch,
|
|
|
{"vae": self.model},
|
|
|
None,
|
|
|
self.val_loader,
|
|
|
checkpoint_dir
|
|
|
)
|
|
|
|
|
|
|
|
|
if best_model_state is not None:
|
|
|
logger.info(f"VAE training complete. Best validation loss: {best_loss:.4f}")
|
|
|
return best_model_state
|
|
|
else:
|
|
|
logger.warning("VAE training complete, but no best model state was saved.")
|
|
|
return self.model.state_dict()
|
|
|
|
|
|
class DiffusionTrainer:
|
|
|
"""Trainer for diffusion model."""
|
|
|
def __init__(
|
|
|
self,
|
|
|
model,
|
|
|
train_loader,
|
|
|
val_loader,
|
|
|
optimizer,
|
|
|
scheduler=None,
|
|
|
tokenizer=None,
|
|
|
device=None,
|
|
|
config=None
|
|
|
):
|
|
|
self.model = model
|
|
|
self.train_loader = train_loader
|
|
|
self.val_loader = val_loader
|
|
|
self.optimizer = optimizer
|
|
|
self.scheduler = scheduler
|
|
|
self.tokenizer = tokenizer
|
|
|
self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
self.config = config if config is not None else {}
|
|
|
|
|
|
|
|
|
self.use_amp = self.config.get('use_amp', True)
|
|
|
self.checkpoint_freq = self.config.get('checkpoint_freq', 5)
|
|
|
|
|
|
|
|
|
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp and torch.cuda.is_available() else None
|
|
|
|
|
|
def train(
|
|
|
self,
|
|
|
num_epochs,
|
|
|
checkpoint_dir,
|
|
|
train_unet_only=True,
|
|
|
start_epoch=0,
|
|
|
global_step=0,
|
|
|
best_metrics=None
|
|
|
):
|
|
|
"""Train the diffusion model."""
|
|
|
logger.info("Starting diffusion model training")
|
|
|
logger.info(f"Training {'UNet only' if train_unet_only else 'all components'}")
|
|
|
|
|
|
|
|
|
logger.info("Testing diffusion dataloader by extracting first batch...")
|
|
|
|
|
|
|
|
|
try:
|
|
|
first_batch = next(iter(self.train_loader))
|
|
|
logger.info(f"First batch loaded successfully")
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
loss, metrics = self.model.training_step(first_batch, train_unet_only)
|
|
|
logger.info(f"Forward pass successful. Loss: {loss.item()}")
|
|
|
|
|
|
|
|
|
del first_batch
|
|
|
torch.cuda.empty_cache()
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in diffusion dataloader test: {e}")
|
|
|
import traceback
|
|
|
logger.error(traceback.format_exc())
|
|
|
raise RuntimeError("Failed to test diffusion dataloader - check configuration")
|
|
|
|
|
|
|
|
|
early_stopping_path = os.path.join(checkpoint_dir, "best_diffusion.pt")
|
|
|
early_stopping = EarlyStopping(
|
|
|
patience=8,
|
|
|
verbose=True,
|
|
|
path=early_stopping_path
|
|
|
)
|
|
|
|
|
|
|
|
|
best_loss = best_metrics.get('val_loss', float('inf')) if best_metrics else float('inf')
|
|
|
|
|
|
|
|
|
for epoch in range(start_epoch, num_epochs):
|
|
|
logger.info(f"Starting diffusion epoch {epoch+1}/{num_epochs}")
|
|
|
|
|
|
|
|
|
if train_unet_only:
|
|
|
self.model.vae.eval()
|
|
|
self.model.text_encoder.eval()
|
|
|
self.model.unet.train()
|
|
|
else:
|
|
|
self.model.vae.train()
|
|
|
self.model.text_encoder.train()
|
|
|
self.model.unet.train()
|
|
|
|
|
|
train_loss = 0.0
|
|
|
train_diffusion_loss = 0.0
|
|
|
train_vae_loss = 0.0
|
|
|
|
|
|
|
|
|
processed_batches = 0
|
|
|
|
|
|
|
|
|
progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{num_epochs} (Training)")
|
|
|
for batch_idx, batch in enumerate(progress_bar):
|
|
|
try:
|
|
|
|
|
|
self.optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
if self.use_amp and torch.cuda.is_available():
|
|
|
with torch.cuda.amp.autocast():
|
|
|
loss, metrics = self.model.training_step(batch, train_unet_only)
|
|
|
|
|
|
|
|
|
self.scaler.scale(loss).backward()
|
|
|
|
|
|
|
|
|
if train_unet_only:
|
|
|
self.scaler.unscale_(self.optimizer)
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.unet.parameters(), max_norm=1.0)
|
|
|
else:
|
|
|
self.scaler.unscale_(self.optimizer)
|
|
|
torch.nn.utils.clip_grad_norm_(
|
|
|
list(self.model.vae.parameters()) +
|
|
|
list(self.model.text_encoder.parameters()) +
|
|
|
list(self.model.unet.parameters()),
|
|
|
max_norm=1.0
|
|
|
)
|
|
|
|
|
|
self.scaler.step(self.optimizer)
|
|
|
self.scaler.update()
|
|
|
else:
|
|
|
loss, metrics = self.model.training_step(batch, train_unet_only)
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
|
if train_unet_only:
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.unet.parameters(), max_norm=1.0)
|
|
|
else:
|
|
|
torch.nn.utils.clip_grad_norm_(
|
|
|
list(self.model.vae.parameters()) +
|
|
|
list(self.model.text_encoder.parameters()) +
|
|
|
list(self.model.unet.parameters()),
|
|
|
max_norm=1.0
|
|
|
)
|
|
|
|
|
|
self.optimizer.step()
|
|
|
|
|
|
|
|
|
if self.scheduler:
|
|
|
self.scheduler.step()
|
|
|
|
|
|
|
|
|
global_step += 1
|
|
|
|
|
|
|
|
|
train_loss += metrics['total_loss']
|
|
|
train_diffusion_loss += metrics['diffusion_loss']
|
|
|
if 'vae_loss' in metrics:
|
|
|
train_vae_loss += metrics['vae_loss']
|
|
|
|
|
|
|
|
|
processed_batches += 1
|
|
|
|
|
|
|
|
|
progress_bar.set_postfix({
|
|
|
'loss': f"{metrics['total_loss']:.4f}",
|
|
|
'diff': f"{metrics['diffusion_loss']:.4f}",
|
|
|
'lr': f"{self.scheduler.get_last_lr()[0]:.6f}" if self.scheduler else "N/A"
|
|
|
})
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in diffusion training batch {batch_idx}: {e}")
|
|
|
import traceback
|
|
|
logger.error(traceback.format_exc())
|
|
|
continue
|
|
|
|
|
|
|
|
|
train_loss /= max(1, len(self.train_loader))
|
|
|
train_diffusion_loss /= max(1, len(self.train_loader))
|
|
|
if not train_unet_only:
|
|
|
train_vae_loss /= max(1, len(self.train_loader))
|
|
|
|
|
|
|
|
|
self.model.vae.eval()
|
|
|
self.model.text_encoder.eval()
|
|
|
self.model.unet.eval()
|
|
|
|
|
|
val_loss = 0.0
|
|
|
val_diffusion_loss = 0.0
|
|
|
val_vae_loss = 0.0
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
val_progress = tqdm(self.val_loader, desc=f"Epoch {epoch+1}/{num_epochs} (Validation)")
|
|
|
for batch_idx, batch in enumerate(val_progress):
|
|
|
try:
|
|
|
|
|
|
metrics = self.model.validation_step(batch)
|
|
|
|
|
|
|
|
|
val_loss += metrics['val_loss']
|
|
|
val_diffusion_loss += metrics['val_diffusion_loss']
|
|
|
val_vae_loss += metrics['val_vae_loss']
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in diffusion validation batch {batch_idx}: {e}")
|
|
|
continue
|
|
|
|
|
|
|
|
|
val_loss /= max(1, len(self.val_loader))
|
|
|
val_diffusion_loss /= max(1, len(self.val_loader))
|
|
|
val_vae_loss /= max(1, len(self.val_loader))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if Path(checkpoint_dir).exists() and self.tokenizer:
|
|
|
from PIL import Image
|
|
|
visualize_epoch_results(
|
|
|
epoch,
|
|
|
self.model,
|
|
|
self.tokenizer,
|
|
|
self.val_loader,
|
|
|
checkpoint_dir
|
|
|
)
|
|
|
|
|
|
|
|
|
vae_loss_str = f", VAE: {train_vae_loss:.4f}/{val_vae_loss:.4f}" if not train_unet_only else ""
|
|
|
logger.info(f"Epoch {epoch+1}/{num_epochs} | "
|
|
|
f"Train/Val Loss: {train_loss:.4f}/{val_loss:.4f} | "
|
|
|
f"Diff: {train_diffusion_loss:.4f}/{val_diffusion_loss:.4f}"
|
|
|
f"{vae_loss_str}")
|
|
|
|
|
|
|
|
|
|
|
|
if (epoch + 1) % self.checkpoint_freq == 0 or epoch == num_epochs - 1:
|
|
|
metrics = {
|
|
|
'train_loss': train_loss,
|
|
|
'train_diffusion_loss': train_diffusion_loss,
|
|
|
'val_loss': val_loss,
|
|
|
'val_diffusion_loss': val_diffusion_loss
|
|
|
}
|
|
|
|
|
|
save_checkpoint(
|
|
|
self.model,
|
|
|
self.optimizer,
|
|
|
self.scheduler,
|
|
|
epoch + 1,
|
|
|
global_step,
|
|
|
metrics,
|
|
|
checkpoint_dir,
|
|
|
is_best=False
|
|
|
)
|
|
|
|
|
|
|
|
|
if val_loss < best_loss:
|
|
|
best_loss = val_loss
|
|
|
|
|
|
metrics = {
|
|
|
'train_loss': train_loss,
|
|
|
'train_diffusion_loss': train_diffusion_loss,
|
|
|
'val_loss': val_loss,
|
|
|
'val_diffusion_loss': val_diffusion_loss
|
|
|
}
|
|
|
|
|
|
save_checkpoint(
|
|
|
self.model,
|
|
|
self.optimizer,
|
|
|
self.scheduler,
|
|
|
epoch + 1,
|
|
|
global_step,
|
|
|
metrics,
|
|
|
checkpoint_dir,
|
|
|
is_best=True
|
|
|
)
|
|
|
logger.info(f"New best model saved with val_loss={val_loss:.4f}")
|
|
|
|
|
|
|
|
|
if self.tokenizer is not None and ((epoch + 1) % 10 == 0 or epoch == num_epochs - 1):
|
|
|
try:
|
|
|
|
|
|
sample_prompts = [
|
|
|
"Normal chest X-ray with clear lungs and no abnormalities.",
|
|
|
"Right lower lobe pneumonia with focal consolidation."
|
|
|
]
|
|
|
|
|
|
|
|
|
logger.info("Generating sample images...")
|
|
|
|
|
|
self.model.vae.eval()
|
|
|
self.model.text_encoder.eval()
|
|
|
self.model.unet.eval()
|
|
|
samples_dir = os.path.join(checkpoint_dir, "samples")
|
|
|
os.makedirs(samples_dir, exist_ok=True)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for i, prompt in enumerate(sample_prompts):
|
|
|
results = self.model.sample(
|
|
|
prompt,
|
|
|
height=256,
|
|
|
width=256,
|
|
|
num_inference_steps=30,
|
|
|
tokenizer=self.tokenizer
|
|
|
)
|
|
|
|
|
|
|
|
|
img = results['images'][0]
|
|
|
img_np = img.cpu().numpy().transpose(1, 2, 0)
|
|
|
img_np = (img_np * 255).astype(np.uint8)
|
|
|
if img_np.shape[-1] == 1:
|
|
|
img_np = img_np.squeeze(-1)
|
|
|
|
|
|
from PIL import Image
|
|
|
img_path = os.path.join(samples_dir, f"sample_epoch{epoch+1}_{i}.png")
|
|
|
Image.fromarray(img_np).save(img_path)
|
|
|
|
|
|
logger.info(f"Saved sample images to {samples_dir}")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error generating samples: {e}")
|
|
|
|
|
|
|
|
|
if early_stopping(val_loss):
|
|
|
logger.info(f"Early stopping triggered at epoch {epoch+1}")
|
|
|
break
|
|
|
|
|
|
|
|
|
best_path = os.path.join(checkpoint_dir, "best_model.pt")
|
|
|
if os.path.exists(best_path):
|
|
|
_, _, _ = load_checkpoint(self.model, None, None, best_path)
|
|
|
logger.info("Loaded best model from saved checkpoint")
|
|
|
|
|
|
logger.info("Diffusion model training complete")
|
|
|
|
|
|
return self.model |