| # Inference script for the trained diffusion model | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import matplotlib.pyplot as plt | |
| from tqdm import tqdm | |
| import math | |
| # [Copy all the model architecture classes here - TimeEmbedding, ResidualBlock, etc.] | |
| def load_model(checkpoint_path, device='cuda'): | |
| """Load the trained diffusion model""" | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| # Initialize model with saved config | |
| model = SimpleUNet(**checkpoint['model_config']) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.to(device) | |
| model.eval() | |
| # Initialize scheduler | |
| scheduler = DDPMScheduler(**checkpoint['diffusion_config'], device=device) | |
| return model, scheduler, checkpoint['model_info'] | |
| # Usage example: | |
| # model, scheduler, info = load_model('complete_diffusion_model.pth') | |
| # generated_images = generate_images(model, scheduler, num_images=4) | |