Spaces:
Running
on
Zero
Running
on
Zero
| import spaces # for using hugging face ZeroGPU | |
| import gradio as gr | |
| import torch | |
| from torch import nn | |
| import numpy as np | |
| from PIL import Image | |
| from tqdm.auto import tqdm | |
| from diffusers import DDPMScheduler, UNet2DModel # Hugging Face diffusers library | |
| device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu' | |
| print(f'Using device: {device}') | |
| # Create a scheduler | |
| noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2') | |
| class ClassConditionedUnet(nn.Module): | |
| def __init__(self, num_classes=10, class_emb_size=4): | |
| super().__init__() | |
| # The embedding layer will map the class label to a vector of size class_emb_size | |
| self.class_emb = nn.Embedding(num_classes, class_emb_size) | |
| # Self.model is an unconditional UNet with extra input channels to accept the conditioning information (the class embedding) | |
| self.model = UNet2DModel( | |
| sample_size=32, # the target image resolution | |
| in_channels=3 + class_emb_size, # (R, G, B) Plus additional input channels for class cond. | |
| out_channels=3, # the number of output channels | |
| layers_per_block=2, # how many ResNet layers to use per UNet block | |
| # block_out_channels=(32, 64, 64), | |
| block_out_channels=(128, 256, 256, 512), # trying a larger network | |
| down_block_types=( | |
| "DownBlock2D", # a regular ResNet downsampling block | |
| "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention | |
| "AttnDownBlock2D", | |
| "AttnDownBlock2D", | |
| ), | |
| up_block_types=( | |
| "AttnUpBlock2D", | |
| "AttnUpBlock2D", | |
| "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention | |
| "UpBlock2D", # a regular ResNet upsampling block | |
| ), | |
| ) | |
| # Our forward method now takes the class labels as an additional argument | |
| def forward(self, x, t, class_labels): | |
| # Shape of x: | |
| bs, ch, w, h = x.shape | |
| # class conditioning in right shape to add as additional input channels | |
| class_cond = self.class_emb(class_labels) # Map to embedding dimension | |
| class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h) | |
| # x is shape (bs, 1, 28, 28) and class_cond is now (bs, 4, 28, 28) | |
| # Net input is now x and class cond concatenated together along dimension 1 | |
| net_input = torch.cat((x, class_cond), 1) # (bs, 5, 28, 28) | |
| # Feed this to the UNet alongside the timestep and return the prediction | |
| return self.model(net_input, t).sample # (bs, 1, 28, 28) | |
| def load_checkpoint_for_inference(filepath, model_class): | |
| """ | |
| Initializes the model architecture and loads only the trained weights for inference. | |
| """ | |
| # Instantiate the model with the correct architecture/arguments | |
| # (You need the Model Class definition handy) | |
| model = model_class() | |
| # Load the checkpoint file with map_location to handle device mismatch | |
| checkpoint = torch.load(filepath, map_location=torch.device('cpu')) | |
| # Load the state dictionary into the model instance | |
| model.load_state_dict(checkpoint) | |
| # Set the model to evaluation mode for inference | |
| model.eval() | |
| # Move the model to the device (GPU/CPU) | |
| model.to(device) | |
| print(f"Checkpoint loaded.") | |
| return model | |
| # Make this function run on GPU | |
| def generate_images(selected_class_name, num_samples, progress=gr.Progress()): | |
| print(f"Generating {num_samples} samples for class: {selected_class_name}") | |
| # Map class name to class ID | |
| try: | |
| label = cifar10_classes.index(selected_class_name) | |
| except ValueError: | |
| print(f"Error: Class '{selected_class_name}' not found.") | |
| return [] # Return empty list if class not found | |
| # Prepare random x to start from, plus a tensor for the desired label y | |
| # num images, num channels, img width, img height | |
| x = torch.randn(num_samples, 3, 32, 32).to(device) | |
| # The label needs to be a tensor of shape (num_samples,) or broadcastable to it | |
| y = torch.full((num_samples,), label, dtype=torch.long).to(device) | |
| # Sampling loop | |
| for i, t in progress.tqdm(enumerate(noise_scheduler.timesteps), total=len(noise_scheduler.timesteps), desc=f"Generating {selected_class_name} images"): | |
| # Get model pred | |
| with torch.no_grad(): | |
| residual = model(x, t, y) # Note that we pass in our label | |
| # Update sample with step | |
| x = noise_scheduler.step(residual, t, x).prev_sample | |
| generated_pil_images = [] | |
| for j in range(num_samples): | |
| # Convert tensor to PIL Image | |
| img_tensor = x[j].detach().cpu().clip(-1, 1) * 0.5 + 0.5 # Denormalize and move to [0, 1] | |
| img_tensor = img_tensor.permute(1, 2, 0) # C, H, W -> H, W, C | |
| img_array = (img_tensor.numpy() * 255).astype(np.uint8) | |
| pil_img = Image.fromarray(img_array) | |
| generated_pil_images.append(pil_img) | |
| return generated_pil_images | |
| # Make this function run on GPU | |
| def generate_all_classes(progress=gr.Progress()): | |
| # Prepare random x to start from, plus some desired labels y | |
| x = torch.randn(10, 3, 32, 32).to(device) | |
| y = torch.tensor([i for i in range(10)]).to(device) | |
| # Sampling loop | |
| for i, t in progress.tqdm(enumerate(noise_scheduler.timesteps), total=len(noise_scheduler.timesteps), desc="Generating all 10 classes"): | |
| # Get model pred | |
| with torch.no_grad(): | |
| residual = model(x, t, y) # Again, note that we pass in our labels y | |
| # Update sample with step | |
| x = noise_scheduler.step(residual, t, x).prev_sample | |
| generated_pil_images = [] | |
| for j in range(10): | |
| # Convert tensor to PIL Image | |
| img_tensor = x[j].detach().cpu().clip(-1, 1) * 0.5 + 0.5 # Denormalize and move to [0, 1] | |
| img_tensor = img_tensor.permute(1, 2, 0) # C, H, W -> H, W, C | |
| img_array = (img_tensor.numpy() * 255).astype(np.uint8) | |
| pil_img = Image.fromarray(img_array) | |
| generated_pil_images.append(pil_img) | |
| return generated_pil_images | |
| # CIFAR-10 class names | |
| cifar10_classes = [ | |
| "plane", "car", "bird", "cat", "deer", | |
| "dog", "frog", "horse", "ship", "truck" | |
| ] | |
| # Initialize a dummy model (replace with your actual model loading) | |
| model = load_checkpoint_for_inference(filepath="CIFAR10_unet_200_epochs_inference.pth", model_class=ClassConditionedUnet) | |
| # Create the Gradio interface | |
| custom_css = """ | |
| #gallery { | |
| display: grid !important; | |
| grid-auto-flow: column !important; | |
| grid-auto-columns: min-content !important; | |
| overflow-x: auto !important; | |
| gap: 10px !important; | |
| } | |
| """ | |
| with gr.Blocks(css=custom_css) as demo: | |
| gr.Markdown("# CIFAR-10 Diffusion Model") | |
| gr.Markdown("Select a class and click 'Generate' to create image samples.") | |
| with gr.Row(): | |
| class_selector = gr.Radio( | |
| cifar10_classes, label="Select CIFAR-10 Class", value=cifar10_classes[0] | |
| ) | |
| num_samples_slider = gr.Slider( | |
| minimum=1, maximum=10, value=4, step=1, label="Number of Images to Generate" | |
| ) | |
| with gr.Row(): | |
| generate_btn = gr.Button("Generate Selected Class") | |
| generate_all_btn = gr.Button("Generate All Classes") | |
| with gr.Row(): | |
| output_gallery = gr.Gallery(label="Generated Images", show_label=True, elem_id="gallery") | |
| generate_btn.click( | |
| fn=generate_images, | |
| inputs=[class_selector, num_samples_slider], | |
| outputs=output_gallery | |
| ) | |
| generate_all_btn.click( | |
| fn=generate_all_classes, | |
| inputs=[], | |
| outputs=output_gallery | |
| ) | |
| # Run the Gradio app | |
| demo.launch() |