File size: 20,051 Bytes
0a0f923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
# xray_generator/models/diffusion.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
from tqdm.auto import tqdm

logger = logging.getLogger(__name__)

def extract_into_tensor(a, t, shape):
    """Extract specific timestep values and broadcast to target shape."""
    if not isinstance(a, torch.Tensor):
        a = torch.tensor(a, dtype=torch.float32)
    a = a.to(t.device)
    
    b, *_ = t.shape
    out = a.gather(-1, t)
    while len(out.shape) < len(shape):
        out = out[..., None]
    
    return out.expand(shape)

def get_named_beta_schedule(schedule_type, num_diffusion_steps):
    """

    Get a pre-defined beta schedule for the given name.

    

    Available schedules:

    - linear: linear schedule from Ho et al

    - cosine: cosine schedule from Improved DDPM

    """
    if schedule_type == "linear":
        # Linear schedule from Ho et al.
        scale = 1000 / num_diffusion_steps
        beta_start = scale * 0.0001
        beta_end = scale * 0.02
        return torch.linspace(beta_start, beta_end, num_diffusion_steps, dtype=torch.float32)
    
    elif schedule_type == "cosine":
        # Cosine schedule from Improved DDPM
        steps = num_diffusion_steps + 1
        x = torch.linspace(0, num_diffusion_steps, steps, dtype=torch.float32)
        alphas_cumprod = torch.cos(((x / num_diffusion_steps) + 0.008) / 1.008 * math.pi / 2) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        return torch.clip(betas, 0.0001, 0.9999)
    
    elif schedule_type == "scaled_linear":
        # Scaled linear schedule
        beta_start = 0.0001
        beta_end = 0.02
        return torch.linspace(beta_start**0.5, beta_end**0.5, num_diffusion_steps, dtype=torch.float32) ** 2
    
    else:
        raise ValueError(f"Unknown beta schedule: {schedule_type}")

class DiffusionModel:
    """

    Diffusion model for medical image generation.

    Combines VAE, UNet, and text encoder with diffusion process.

    """
    def __init__(

        self,

        vae,

        unet,

        text_encoder,

        scheduler_type="ddpm",

        num_train_timesteps=1000,

        beta_schedule="linear",

        prediction_type="epsilon",

        guidance_scale=7.5,

        device=None

    ):
        """Initialize diffusion model."""
        self.vae = vae
        self.unet = unet
        self.text_encoder = text_encoder
        self.scheduler_type = scheduler_type
        self.num_train_timesteps = num_train_timesteps
        self.beta_schedule = beta_schedule
        self.prediction_type = prediction_type
        self.guidance_scale = guidance_scale
        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Initialize diffusion parameters
        self._initialize_diffusion_parameters()

        logger.info(f"Initialized diffusion model with {scheduler_type} scheduler, {beta_schedule} beta schedule")

    def _initialize_diffusion_parameters(self):
        """Initialize diffusion parameters."""
        # Get beta schedule
        self.betas = get_named_beta_schedule(
            self.beta_schedule, self.num_train_timesteps
        ).to(self.device)
        
        # Calculate alphas
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = torch.cat([torch.ones(1, device=self.device), self.alphas_cumprod[:-1]])
        
        # Calculate diffusion q(x_t | x_{t-1}) and others
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod)
        
        # Calculate posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        self.posterior_log_variance_clipped = torch.log(
            torch.cat([self.posterior_variance[1:2], self.posterior_variance[1:]])
        )
        self.posterior_mean_coef1 = self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1.0 - self.alphas_cumprod)
    
    def q_sample(self, x_start, t, noise=None):
        """Forward diffusion: q(x_t | x_0)."""
        if noise is None:
            noise = torch.randn_like(x_start)
        
        sqrt_alphas_cumprod_t = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
        
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
    
    def predict_start_from_noise(self, x_t, t, noise):
        """Predict x_0 from noise."""
        sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
        sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1)
        
        sqrt_recip_alphas_cumprod_t = extract_into_tensor(sqrt_recip_alphas_cumprod, t, x_t.shape)
        sqrt_recipm1_alphas_cumprod_t = extract_into_tensor(sqrt_recipm1_alphas_cumprod, t, x_t.shape)
        
        return sqrt_recip_alphas_cumprod_t * x_t - sqrt_recipm1_alphas_cumprod_t * noise
    
    def q_posterior_mean_variance(self, x_start, x_t, t):
        """Compute posterior mean and variance: q(x_{t-1} | x_t, x_0)."""
        posterior_mean_coef1_t = extract_into_tensor(self.posterior_mean_coef1, t, x_start.shape)
        posterior_mean_coef2_t = extract_into_tensor(self.posterior_mean_coef2, t, x_start.shape)
        
        posterior_mean = posterior_mean_coef1_t * x_start + posterior_mean_coef2_t * x_t
        posterior_variance_t = extract_into_tensor(self.posterior_variance, t, x_start.shape)
        posterior_log_variance_t = extract_into_tensor(self.posterior_log_variance_clipped, t, x_start.shape)
        
        return posterior_mean, posterior_variance_t, posterior_log_variance_t
    
    def p_mean_variance(self, x_t, t, context):
        """Predict mean and variance for the denoising process."""
        # Predict noise using UNet
        noise_pred = self.unet(x_t, t, context)
        
        # Predict x_0
        x_0 = self.predict_start_from_noise(x_t, t, noise_pred)
        
        # Clip prediction
        x_0 = torch.clamp(x_0, -1.0, 1.0)
        
        # Get posterior parameters
        mean, var, log_var = self.q_posterior_mean_variance(x_0, x_t, t)
        
        return mean, var, log_var
    
    def p_sample(self, x_t, t, context):
        """Sample from p(x_{t-1} | x_t)."""
        # Get mean and variance
        mean, _, log_var = self.p_mean_variance(x_t, t, context)
        
        # Sample
        noise = torch.randn_like(x_t)
        mask = (t > 0).float().reshape(-1, *([1] * (len(x_t.shape) - 1)))
        
        return mean + mask * torch.exp(0.5 * log_var) * noise
    
    def ddim_sample(self, x_t, t, prev_t, context, eta=0.0):
        """DDIM sampling step."""
        # Get alphas
        alpha_t = self.alphas_cumprod[t]
        alpha_prev = self.alphas_cumprod[prev_t]
        
        # Predict noise
        noise_pred = self.unet(x_t, t, context)
        
        # Predict x_0
        x_0_pred = self.predict_start_from_noise(x_t, t, noise_pred)
        
        # Clip prediction
        x_0_pred = torch.clamp(x_0_pred, -1.0, 1.0)
        
        # DDIM formula
        variance = eta * torch.sqrt((1 - alpha_prev) / (1 - alpha_t) * (1 - alpha_t / alpha_prev))
        
        # Mean component
        mean = torch.sqrt(alpha_prev) * x_0_pred + torch.sqrt(1 - alpha_prev - variance**2) * noise_pred
        
        # Add noise if eta > 0
        noise = torch.randn_like(x_t)
        x_prev = mean
        
        if eta > 0:
            x_prev = x_prev + variance * noise
            
        return x_prev
    
    def training_step(self, batch, train_unet_only=True):
        """Training step for diffusion model."""
        # Extract data
        images = batch['image'].to(self.device)
        input_ids = batch['input_ids'].to(self.device) if 'input_ids' in batch else None
        attention_mask = batch['attention_mask'].to(self.device) if 'attention_mask' in batch else None
        
        if input_ids is None or attention_mask is None:
            raise ValueError("Batch must contain tokenized text")
        
        # Metrics dictionary
        metrics = {}
        
        try:
            # Encode images to latent space
            with torch.set_grad_enabled(not train_unet_only):
                # Get latent distribution
                mu, logvar = self.vae.encode(images)
                
                # Use latent mean for stability in early training
                latents = mu
                
                # Scale latents 
                latents = latents * 0.18215
                
                # Compute VAE loss if not training UNet only
                if not train_unet_only:
                    recon, mu, logvar = self.vae(images)
                    
                    # Reconstruction loss
                    recon_loss = F.mse_loss(recon, images)
                    
                    # KL divergence
                    kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
                    
                    # Total VAE loss
                    vae_loss_val = recon_loss + 1e-4 * kl_loss
                    
                    metrics['vae_loss'] = vae_loss_val.item()
                    metrics['recon_loss'] = recon_loss.item()
                    metrics['kl_loss'] = kl_loss.item()
                    
            # Encode text
            with torch.set_grad_enabled(not train_unet_only):
                context = self.text_encoder(input_ids, attention_mask)
                
            # Sample timestep
            batch_size = images.shape[0]
            t = torch.randint(0, self.num_train_timesteps, (batch_size,), device=self.device).long()
            
            # Generate noise
            noise = torch.randn_like(latents)
            
            # Add noise to latents (forward diffusion)
            noisy_latents = self.q_sample(latents, t, noise=noise)
            
            # Sometimes train with empty context (10% of the time)
            import random
            if random.random() < 0.1:
                context = torch.zeros_like(context)
                
            # Predict noise
            noise_pred = self.unet(noisy_latents, t, context)
            
            # Compute loss based on prediction type
            if self.prediction_type == "epsilon":
                # Predict noise (ε)
                diffusion_loss = F.mse_loss(noise_pred, noise)
                
            elif self.prediction_type == "v_prediction":
                # Predict velocity (v)
                velocity = self.sqrt_alphas_cumprod[t] * noise - self.sqrt_one_minus_alphas_cumprod[t] * latents
                diffusion_loss = F.mse_loss(noise_pred, velocity)
                
            else:
                raise ValueError(f"Unknown prediction type: {self.prediction_type}")
                
            metrics['diffusion_loss'] = diffusion_loss.item()
            
            # Total loss
            if train_unet_only:
                total_loss = diffusion_loss
            else:
                total_loss = diffusion_loss + vae_loss_val
                
            metrics['total_loss'] = total_loss.item()
            
            return total_loss, metrics
            
        except Exception as e:
            logger.error(f"Error in training step: {e}")
            import traceback
            logger.error(traceback.format_exc())
            
            # Return dummy values to avoid breaking training loop
            dummy_loss = torch.tensor(0.0, device=self.device, requires_grad=True)
            return dummy_loss, {'total_loss': 0.0, 'diffusion_loss': 0.0}
    
    def validation_step(self, batch):
        """Validation step for diffusion model."""
        with torch.no_grad():
            # Extract data
            images = batch['image'].to(self.device)
            input_ids = batch['input_ids'].to(self.device) if 'input_ids' in batch else None
            attention_mask = batch['attention_mask'].to(self.device) if 'attention_mask' in batch else None
            
            if input_ids is None or attention_mask is None:
                raise ValueError("Batch must contain tokenized text")
            
            try:
                # Encode images to latent space
                mu, logvar = self.vae.encode(images)
                latents = mu  # Use mean for validation
                
                # Scale latents
                latents = latents * 0.18215
                
                # Compute VAE loss
                recon, mu, logvar = self.vae(images)
                
                # Reconstruction loss
                recon_loss = F.mse_loss(recon, images)
                
                # KL divergence
                kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
                
                # Total VAE loss
                vae_loss_val = recon_loss + 1e-4 * kl_loss
                
                # Encode text
                context = self.text_encoder(input_ids, attention_mask)
                
                # Sample timestep
                batch_size = images.shape[0]
                t = torch.randint(0, self.num_train_timesteps, (batch_size,), device=self.device).long()
                
                # Generate noise
                noise = torch.randn_like(latents)
                
                # Add noise to latents
                noisy_latents = self.q_sample(latents, t, noise=noise)
                
                # Predict noise
                noise_pred = self.unet(noisy_latents, t, context)
                
                # Compute diffusion loss
                if self.prediction_type == "epsilon":
                    diffusion_loss = F.mse_loss(noise_pred, noise)
                elif self.prediction_type == "v_prediction":
                    velocity = self.sqrt_alphas_cumprod[t] * noise - self.sqrt_one_minus_alphas_cumprod[t] * latents
                    diffusion_loss = F.mse_loss(noise_pred, velocity)
                
                # Total loss
                total_loss = diffusion_loss + vae_loss_val
                
                # Return metrics
                return {
                    'val_loss': total_loss.item(),
                    'val_diffusion_loss': diffusion_loss.item(),
                    'val_vae_loss': vae_loss_val.item(),
                    'val_recon_loss': recon_loss.item(),
                    'val_kl_loss': kl_loss.item()
                }
                
            except Exception as e:
                logger.error(f"Error in validation step: {e}")
                
                # Return dummy metrics
                return {
                    'val_loss': 0.0,
                    'val_diffusion_loss': 0.0,
                    'val_vae_loss': 0.0
                }
    
    @torch.no_grad()
    def sample(

        self,

        text,

        height=256,

        width=256,

        num_inference_steps=50,

        guidance_scale=None,

        eta=0.0,

        tokenizer=None,

        latents=None,

        return_all_latents=False

    ):
        """Sample from diffusion model given text prompt."""
        # Default guidance scale
        if guidance_scale is None:
            guidance_scale = self.guidance_scale
            
        # Ensure text is a list
        if isinstance(text, str):
            text = [text]
        
        batch_size = len(text)
        
        # Check if tokenizer is provided
        if tokenizer is None:
            raise ValueError("Tokenizer must be provided for sampling")
        
        # Encode text
        tokens = tokenizer(
            text,
            padding="max_length",
            max_length=256,  # Replace with your max token length
            truncation=True,
            return_tensors="pt"
        ).to(self.device)
        
        context = self.text_encoder(tokens.input_ids, tokens.attention_mask)
        
        # Calculate latent size
        latent_height = height // 8  # VAE downsampling factor
        latent_width = width // 8
        
        # Generate random latents if not provided
        if latents is None:
            latents = torch.randn(
                (batch_size, self.vae.latent_channels, latent_height, latent_width),
                device=self.device
            )
            latents = latents * 0.18215  # Scale factor
        
        # Store all latents if requested
        if return_all_latents:
            all_latents = [latents.clone()]
        
        # Prepare scheduler timesteps
        if self.scheduler_type == "ddim":
            # DDIM timesteps
            timesteps = torch.linspace(
                self.num_train_timesteps - 1,
                0,
                num_inference_steps,
                dtype=torch.long,
                device=self.device
            )
        else:
            # DDPM timesteps
            step_indices = list(range(0, self.num_train_timesteps, self.num_train_timesteps // num_inference_steps))
            timesteps = torch.tensor(sorted(step_indices, reverse=True), dtype=torch.long, device=self.device)
        
        # Text embeddings for classifier-free guidance
        uncond_context = torch.zeros_like(context)
        
        # Sampling loop
        for i, t in enumerate(tqdm(timesteps, desc="Generating image")):
            # Expand for classifier-free guidance
            latent_model_input = torch.cat([latents] * 2)
            t_input = torch.cat([t.unsqueeze(0)] * 2 * batch_size)
            
            # Get text conditioning
            text_embeddings = torch.cat([uncond_context, context])
            
            # Predict noise
            noise_pred = self.unet(latent_model_input, t_input, text_embeddings)
            
            # Perform guidance
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
            
            # Sampling step
            if self.scheduler_type == "ddim":
                # DDIM step
                prev_t = timesteps[i + 1] if i < len(timesteps) - 1 else torch.tensor([0], device=self.device)
                latents = self.ddim_sample(latents, t.repeat(batch_size), prev_t.repeat(batch_size), context, eta)
            else:
                # DDPM step
                latents = self.p_sample(latents, t.repeat(batch_size), context)
            
            # Store latent if requested
            if return_all_latents:
                all_latents.append(latents.clone())
        
        # Scale latents
        latents = 1 / 0.18215 * latents
        
        # Decode latents
        images = self.vae.decode(latents)
        
        # Normalize to [0, 1]
        images = (images + 1) / 2
        images = torch.clamp(images, 0, 1)
        
        result = {
            'images': images,
            'latents': latents
        }
        
        if return_all_latents:
            result['all_latents'] = all_latents
            
        return result