cameron-d's picture
Update app.py
db7b6f0 verified
import spaces # for using hugging face ZeroGPU
import gradio as gr
import torch
from torch import nn
import numpy as np
from PIL import Image
from tqdm.auto import tqdm
from diffusers import DDPMScheduler, UNet2DModel # Hugging Face diffusers library
device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')
# Create a scheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
class ClassConditionedUnet(nn.Module):
def __init__(self, num_classes=10, class_emb_size=4):
super().__init__()
# The embedding layer will map the class label to a vector of size class_emb_size
self.class_emb = nn.Embedding(num_classes, class_emb_size)
# Self.model is an unconditional UNet with extra input channels to accept the conditioning information (the class embedding)
self.model = UNet2DModel(
sample_size=32, # the target image resolution
in_channels=3 + class_emb_size, # (R, G, B) Plus additional input channels for class cond.
out_channels=3, # the number of output channels
layers_per_block=2, # how many ResNet layers to use per UNet block
# block_out_channels=(32, 64, 64),
block_out_channels=(128, 256, 256, 512), # trying a larger network
down_block_types=(
"DownBlock2D", # a regular ResNet downsampling block
"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
"AttnDownBlock2D",
"AttnDownBlock2D",
),
up_block_types=(
"AttnUpBlock2D",
"AttnUpBlock2D",
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
"UpBlock2D", # a regular ResNet upsampling block
),
)
# Our forward method now takes the class labels as an additional argument
def forward(self, x, t, class_labels):
# Shape of x:
bs, ch, w, h = x.shape
# class conditioning in right shape to add as additional input channels
class_cond = self.class_emb(class_labels) # Map to embedding dimension
class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
# x is shape (bs, 1, 28, 28) and class_cond is now (bs, 4, 28, 28)
# Net input is now x and class cond concatenated together along dimension 1
net_input = torch.cat((x, class_cond), 1) # (bs, 5, 28, 28)
# Feed this to the UNet alongside the timestep and return the prediction
return self.model(net_input, t).sample # (bs, 1, 28, 28)
def load_checkpoint_for_inference(filepath, model_class):
"""
Initializes the model architecture and loads only the trained weights for inference.
"""
# Instantiate the model with the correct architecture/arguments
# (You need the Model Class definition handy)
model = model_class()
# Load the checkpoint file with map_location to handle device mismatch
checkpoint = torch.load(filepath, map_location=torch.device('cpu'))
# Load the state dictionary into the model instance
model.load_state_dict(checkpoint)
# Set the model to evaluation mode for inference
model.eval()
# Move the model to the device (GPU/CPU)
model.to(device)
print(f"Checkpoint loaded.")
return model
@spaces.GPU # Make this function run on GPU
def generate_images(selected_class_name, num_samples, progress=gr.Progress()):
print(f"Generating {num_samples} samples for class: {selected_class_name}")
# Map class name to class ID
try:
label = cifar10_classes.index(selected_class_name)
except ValueError:
print(f"Error: Class '{selected_class_name}' not found.")
return [] # Return empty list if class not found
# Prepare random x to start from, plus a tensor for the desired label y
# num images, num channels, img width, img height
x = torch.randn(num_samples, 3, 32, 32).to(device)
# The label needs to be a tensor of shape (num_samples,) or broadcastable to it
y = torch.full((num_samples,), label, dtype=torch.long).to(device)
# Sampling loop
for i, t in progress.tqdm(enumerate(noise_scheduler.timesteps), total=len(noise_scheduler.timesteps), desc=f"Generating {selected_class_name} images"):
# Get model pred
with torch.no_grad():
residual = model(x, t, y) # Note that we pass in our label
# Update sample with step
x = noise_scheduler.step(residual, t, x).prev_sample
generated_pil_images = []
for j in range(num_samples):
# Convert tensor to PIL Image
img_tensor = x[j].detach().cpu().clip(-1, 1) * 0.5 + 0.5 # Denormalize and move to [0, 1]
img_tensor = img_tensor.permute(1, 2, 0) # C, H, W -> H, W, C
img_array = (img_tensor.numpy() * 255).astype(np.uint8)
pil_img = Image.fromarray(img_array)
generated_pil_images.append(pil_img)
return generated_pil_images
@spaces.GPU # Make this function run on GPU
def generate_all_classes(progress=gr.Progress()):
# Prepare random x to start from, plus some desired labels y
x = torch.randn(10, 3, 32, 32).to(device)
y = torch.tensor([i for i in range(10)]).to(device)
# Sampling loop
for i, t in progress.tqdm(enumerate(noise_scheduler.timesteps), total=len(noise_scheduler.timesteps), desc="Generating all 10 classes"):
# Get model pred
with torch.no_grad():
residual = model(x, t, y) # Again, note that we pass in our labels y
# Update sample with step
x = noise_scheduler.step(residual, t, x).prev_sample
generated_pil_images = []
for j in range(10):
# Convert tensor to PIL Image
img_tensor = x[j].detach().cpu().clip(-1, 1) * 0.5 + 0.5 # Denormalize and move to [0, 1]
img_tensor = img_tensor.permute(1, 2, 0) # C, H, W -> H, W, C
img_array = (img_tensor.numpy() * 255).astype(np.uint8)
pil_img = Image.fromarray(img_array)
generated_pil_images.append(pil_img)
return generated_pil_images
# CIFAR-10 class names
cifar10_classes = [
"plane", "car", "bird", "cat", "deer",
"dog", "frog", "horse", "ship", "truck"
]
# Initialize a dummy model (replace with your actual model loading)
model = load_checkpoint_for_inference(filepath="CIFAR10_unet_200_epochs_inference.pth", model_class=ClassConditionedUnet)
# Create the Gradio interface
custom_css = """
#gallery {
display: grid !important;
grid-auto-flow: column !important;
grid-auto-columns: min-content !important;
overflow-x: auto !important;
gap: 10px !important;
}
"""
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("# CIFAR-10 Diffusion Model")
gr.Markdown("Select a class and click 'Generate' to create image samples.")
with gr.Row():
class_selector = gr.Radio(
cifar10_classes, label="Select CIFAR-10 Class", value=cifar10_classes[0]
)
num_samples_slider = gr.Slider(
minimum=1, maximum=10, value=4, step=1, label="Number of Images to Generate"
)
with gr.Row():
generate_btn = gr.Button("Generate Selected Class")
generate_all_btn = gr.Button("Generate All Classes")
with gr.Row():
output_gallery = gr.Gallery(label="Generated Images", show_label=True, elem_id="gallery")
generate_btn.click(
fn=generate_images,
inputs=[class_selector, num_samples_slider],
outputs=output_gallery
)
generate_all_btn.click(
fn=generate_all_classes,
inputs=[],
outputs=output_gallery
)
# Run the Gradio app
demo.launch()