pyamy's picture
Upload 31 files
0a0f923 verified
# xray_generator/models/diffusion.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
from tqdm.auto import tqdm
logger = logging.getLogger(__name__)
def extract_into_tensor(a, t, shape):
"""Extract specific timestep values and broadcast to target shape."""
if not isinstance(a, torch.Tensor):
a = torch.tensor(a, dtype=torch.float32)
a = a.to(t.device)
b, *_ = t.shape
out = a.gather(-1, t)
while len(out.shape) < len(shape):
out = out[..., None]
return out.expand(shape)
def get_named_beta_schedule(schedule_type, num_diffusion_steps):
"""
Get a pre-defined beta schedule for the given name.
Available schedules:
- linear: linear schedule from Ho et al
- cosine: cosine schedule from Improved DDPM
"""
if schedule_type == "linear":
# Linear schedule from Ho et al.
scale = 1000 / num_diffusion_steps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, num_diffusion_steps, dtype=torch.float32)
elif schedule_type == "cosine":
# Cosine schedule from Improved DDPM
steps = num_diffusion_steps + 1
x = torch.linspace(0, num_diffusion_steps, steps, dtype=torch.float32)
alphas_cumprod = torch.cos(((x / num_diffusion_steps) + 0.008) / 1.008 * math.pi / 2) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0.0001, 0.9999)
elif schedule_type == "scaled_linear":
# Scaled linear schedule
beta_start = 0.0001
beta_end = 0.02
return torch.linspace(beta_start**0.5, beta_end**0.5, num_diffusion_steps, dtype=torch.float32) ** 2
else:
raise ValueError(f"Unknown beta schedule: {schedule_type}")
class DiffusionModel:
"""
Diffusion model for medical image generation.
Combines VAE, UNet, and text encoder with diffusion process.
"""
def __init__(
self,
vae,
unet,
text_encoder,
scheduler_type="ddpm",
num_train_timesteps=1000,
beta_schedule="linear",
prediction_type="epsilon",
guidance_scale=7.5,
device=None
):
"""Initialize diffusion model."""
self.vae = vae
self.unet = unet
self.text_encoder = text_encoder
self.scheduler_type = scheduler_type
self.num_train_timesteps = num_train_timesteps
self.beta_schedule = beta_schedule
self.prediction_type = prediction_type
self.guidance_scale = guidance_scale
self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize diffusion parameters
self._initialize_diffusion_parameters()
logger.info(f"Initialized diffusion model with {scheduler_type} scheduler, {beta_schedule} beta schedule")
def _initialize_diffusion_parameters(self):
"""Initialize diffusion parameters."""
# Get beta schedule
self.betas = get_named_beta_schedule(
self.beta_schedule, self.num_train_timesteps
).to(self.device)
# Calculate alphas
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.alphas_cumprod_prev = torch.cat([torch.ones(1, device=self.device), self.alphas_cumprod[:-1]])
# Calculate diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod)
# Calculate posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
self.posterior_log_variance_clipped = torch.log(
torch.cat([self.posterior_variance[1:2], self.posterior_variance[1:]])
)
self.posterior_mean_coef1 = self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1.0 - self.alphas_cumprod)
def q_sample(self, x_start, t, noise=None):
"""Forward diffusion: q(x_t | x_0)."""
if noise is None:
noise = torch.randn_like(x_start)
sqrt_alphas_cumprod_t = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumprod_t = extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
def predict_start_from_noise(self, x_t, t, noise):
"""Predict x_0 from noise."""
sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1)
sqrt_recip_alphas_cumprod_t = extract_into_tensor(sqrt_recip_alphas_cumprod, t, x_t.shape)
sqrt_recipm1_alphas_cumprod_t = extract_into_tensor(sqrt_recipm1_alphas_cumprod, t, x_t.shape)
return sqrt_recip_alphas_cumprod_t * x_t - sqrt_recipm1_alphas_cumprod_t * noise
def q_posterior_mean_variance(self, x_start, x_t, t):
"""Compute posterior mean and variance: q(x_{t-1} | x_t, x_0)."""
posterior_mean_coef1_t = extract_into_tensor(self.posterior_mean_coef1, t, x_start.shape)
posterior_mean_coef2_t = extract_into_tensor(self.posterior_mean_coef2, t, x_start.shape)
posterior_mean = posterior_mean_coef1_t * x_start + posterior_mean_coef2_t * x_t
posterior_variance_t = extract_into_tensor(self.posterior_variance, t, x_start.shape)
posterior_log_variance_t = extract_into_tensor(self.posterior_log_variance_clipped, t, x_start.shape)
return posterior_mean, posterior_variance_t, posterior_log_variance_t
def p_mean_variance(self, x_t, t, context):
"""Predict mean and variance for the denoising process."""
# Predict noise using UNet
noise_pred = self.unet(x_t, t, context)
# Predict x_0
x_0 = self.predict_start_from_noise(x_t, t, noise_pred)
# Clip prediction
x_0 = torch.clamp(x_0, -1.0, 1.0)
# Get posterior parameters
mean, var, log_var = self.q_posterior_mean_variance(x_0, x_t, t)
return mean, var, log_var
def p_sample(self, x_t, t, context):
"""Sample from p(x_{t-1} | x_t)."""
# Get mean and variance
mean, _, log_var = self.p_mean_variance(x_t, t, context)
# Sample
noise = torch.randn_like(x_t)
mask = (t > 0).float().reshape(-1, *([1] * (len(x_t.shape) - 1)))
return mean + mask * torch.exp(0.5 * log_var) * noise
def ddim_sample(self, x_t, t, prev_t, context, eta=0.0):
"""DDIM sampling step."""
# Get alphas
alpha_t = self.alphas_cumprod[t]
alpha_prev = self.alphas_cumprod[prev_t]
# Predict noise
noise_pred = self.unet(x_t, t, context)
# Predict x_0
x_0_pred = self.predict_start_from_noise(x_t, t, noise_pred)
# Clip prediction
x_0_pred = torch.clamp(x_0_pred, -1.0, 1.0)
# DDIM formula
variance = eta * torch.sqrt((1 - alpha_prev) / (1 - alpha_t) * (1 - alpha_t / alpha_prev))
# Mean component
mean = torch.sqrt(alpha_prev) * x_0_pred + torch.sqrt(1 - alpha_prev - variance**2) * noise_pred
# Add noise if eta > 0
noise = torch.randn_like(x_t)
x_prev = mean
if eta > 0:
x_prev = x_prev + variance * noise
return x_prev
def training_step(self, batch, train_unet_only=True):
"""Training step for diffusion model."""
# Extract data
images = batch['image'].to(self.device)
input_ids = batch['input_ids'].to(self.device) if 'input_ids' in batch else None
attention_mask = batch['attention_mask'].to(self.device) if 'attention_mask' in batch else None
if input_ids is None or attention_mask is None:
raise ValueError("Batch must contain tokenized text")
# Metrics dictionary
metrics = {}
try:
# Encode images to latent space
with torch.set_grad_enabled(not train_unet_only):
# Get latent distribution
mu, logvar = self.vae.encode(images)
# Use latent mean for stability in early training
latents = mu
# Scale latents
latents = latents * 0.18215
# Compute VAE loss if not training UNet only
if not train_unet_only:
recon, mu, logvar = self.vae(images)
# Reconstruction loss
recon_loss = F.mse_loss(recon, images)
# KL divergence
kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
# Total VAE loss
vae_loss_val = recon_loss + 1e-4 * kl_loss
metrics['vae_loss'] = vae_loss_val.item()
metrics['recon_loss'] = recon_loss.item()
metrics['kl_loss'] = kl_loss.item()
# Encode text
with torch.set_grad_enabled(not train_unet_only):
context = self.text_encoder(input_ids, attention_mask)
# Sample timestep
batch_size = images.shape[0]
t = torch.randint(0, self.num_train_timesteps, (batch_size,), device=self.device).long()
# Generate noise
noise = torch.randn_like(latents)
# Add noise to latents (forward diffusion)
noisy_latents = self.q_sample(latents, t, noise=noise)
# Sometimes train with empty context (10% of the time)
import random
if random.random() < 0.1:
context = torch.zeros_like(context)
# Predict noise
noise_pred = self.unet(noisy_latents, t, context)
# Compute loss based on prediction type
if self.prediction_type == "epsilon":
# Predict noise (ε)
diffusion_loss = F.mse_loss(noise_pred, noise)
elif self.prediction_type == "v_prediction":
# Predict velocity (v)
velocity = self.sqrt_alphas_cumprod[t] * noise - self.sqrt_one_minus_alphas_cumprod[t] * latents
diffusion_loss = F.mse_loss(noise_pred, velocity)
else:
raise ValueError(f"Unknown prediction type: {self.prediction_type}")
metrics['diffusion_loss'] = diffusion_loss.item()
# Total loss
if train_unet_only:
total_loss = diffusion_loss
else:
total_loss = diffusion_loss + vae_loss_val
metrics['total_loss'] = total_loss.item()
return total_loss, metrics
except Exception as e:
logger.error(f"Error in training step: {e}")
import traceback
logger.error(traceback.format_exc())
# Return dummy values to avoid breaking training loop
dummy_loss = torch.tensor(0.0, device=self.device, requires_grad=True)
return dummy_loss, {'total_loss': 0.0, 'diffusion_loss': 0.0}
def validation_step(self, batch):
"""Validation step for diffusion model."""
with torch.no_grad():
# Extract data
images = batch['image'].to(self.device)
input_ids = batch['input_ids'].to(self.device) if 'input_ids' in batch else None
attention_mask = batch['attention_mask'].to(self.device) if 'attention_mask' in batch else None
if input_ids is None or attention_mask is None:
raise ValueError("Batch must contain tokenized text")
try:
# Encode images to latent space
mu, logvar = self.vae.encode(images)
latents = mu # Use mean for validation
# Scale latents
latents = latents * 0.18215
# Compute VAE loss
recon, mu, logvar = self.vae(images)
# Reconstruction loss
recon_loss = F.mse_loss(recon, images)
# KL divergence
kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
# Total VAE loss
vae_loss_val = recon_loss + 1e-4 * kl_loss
# Encode text
context = self.text_encoder(input_ids, attention_mask)
# Sample timestep
batch_size = images.shape[0]
t = torch.randint(0, self.num_train_timesteps, (batch_size,), device=self.device).long()
# Generate noise
noise = torch.randn_like(latents)
# Add noise to latents
noisy_latents = self.q_sample(latents, t, noise=noise)
# Predict noise
noise_pred = self.unet(noisy_latents, t, context)
# Compute diffusion loss
if self.prediction_type == "epsilon":
diffusion_loss = F.mse_loss(noise_pred, noise)
elif self.prediction_type == "v_prediction":
velocity = self.sqrt_alphas_cumprod[t] * noise - self.sqrt_one_minus_alphas_cumprod[t] * latents
diffusion_loss = F.mse_loss(noise_pred, velocity)
# Total loss
total_loss = diffusion_loss + vae_loss_val
# Return metrics
return {
'val_loss': total_loss.item(),
'val_diffusion_loss': diffusion_loss.item(),
'val_vae_loss': vae_loss_val.item(),
'val_recon_loss': recon_loss.item(),
'val_kl_loss': kl_loss.item()
}
except Exception as e:
logger.error(f"Error in validation step: {e}")
# Return dummy metrics
return {
'val_loss': 0.0,
'val_diffusion_loss': 0.0,
'val_vae_loss': 0.0
}
@torch.no_grad()
def sample(
self,
text,
height=256,
width=256,
num_inference_steps=50,
guidance_scale=None,
eta=0.0,
tokenizer=None,
latents=None,
return_all_latents=False
):
"""Sample from diffusion model given text prompt."""
# Default guidance scale
if guidance_scale is None:
guidance_scale = self.guidance_scale
# Ensure text is a list
if isinstance(text, str):
text = [text]
batch_size = len(text)
# Check if tokenizer is provided
if tokenizer is None:
raise ValueError("Tokenizer must be provided for sampling")
# Encode text
tokens = tokenizer(
text,
padding="max_length",
max_length=256, # Replace with your max token length
truncation=True,
return_tensors="pt"
).to(self.device)
context = self.text_encoder(tokens.input_ids, tokens.attention_mask)
# Calculate latent size
latent_height = height // 8 # VAE downsampling factor
latent_width = width // 8
# Generate random latents if not provided
if latents is None:
latents = torch.randn(
(batch_size, self.vae.latent_channels, latent_height, latent_width),
device=self.device
)
latents = latents * 0.18215 # Scale factor
# Store all latents if requested
if return_all_latents:
all_latents = [latents.clone()]
# Prepare scheduler timesteps
if self.scheduler_type == "ddim":
# DDIM timesteps
timesteps = torch.linspace(
self.num_train_timesteps - 1,
0,
num_inference_steps,
dtype=torch.long,
device=self.device
)
else:
# DDPM timesteps
step_indices = list(range(0, self.num_train_timesteps, self.num_train_timesteps // num_inference_steps))
timesteps = torch.tensor(sorted(step_indices, reverse=True), dtype=torch.long, device=self.device)
# Text embeddings for classifier-free guidance
uncond_context = torch.zeros_like(context)
# Sampling loop
for i, t in enumerate(tqdm(timesteps, desc="Generating image")):
# Expand for classifier-free guidance
latent_model_input = torch.cat([latents] * 2)
t_input = torch.cat([t.unsqueeze(0)] * 2 * batch_size)
# Get text conditioning
text_embeddings = torch.cat([uncond_context, context])
# Predict noise
noise_pred = self.unet(latent_model_input, t_input, text_embeddings)
# Perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# Sampling step
if self.scheduler_type == "ddim":
# DDIM step
prev_t = timesteps[i + 1] if i < len(timesteps) - 1 else torch.tensor([0], device=self.device)
latents = self.ddim_sample(latents, t.repeat(batch_size), prev_t.repeat(batch_size), context, eta)
else:
# DDPM step
latents = self.p_sample(latents, t.repeat(batch_size), context)
# Store latent if requested
if return_all_latents:
all_latents.append(latents.clone())
# Scale latents
latents = 1 / 0.18215 * latents
# Decode latents
images = self.vae.decode(latents)
# Normalize to [0, 1]
images = (images + 1) / 2
images = torch.clamp(images, 0, 1)
result = {
'images': images,
'latents': latents
}
if return_all_latents:
result['all_latents'] = all_latents
return result