Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,752 Bytes
042ee08 c252b52 d98797f c252b52 d98797f c252b52 0e3b20e 647b15b c252b52 647b15b c252b52 647b15b c252b52 0e3b20e c252b52 b27a722 c45fd6e db7b6f0 c45fd6e b27a722 c252b52 0e3b20e c252b52 0e3b20e c252b52 c45fd6e c252b52 0e3b20e c252b52 c45fd6e c252b52 0e3b20e c252b52 0e3b20e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import spaces # for using hugging face ZeroGPU
import gradio as gr
import torch
from torch import nn
import numpy as np
from PIL import Image
from tqdm.auto import tqdm
from diffusers import DDPMScheduler, UNet2DModel # Hugging Face diffusers library
device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')
# Create a scheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
class ClassConditionedUnet(nn.Module):
def __init__(self, num_classes=10, class_emb_size=4):
super().__init__()
# The embedding layer will map the class label to a vector of size class_emb_size
self.class_emb = nn.Embedding(num_classes, class_emb_size)
# Self.model is an unconditional UNet with extra input channels to accept the conditioning information (the class embedding)
self.model = UNet2DModel(
sample_size=32, # the target image resolution
in_channels=3 + class_emb_size, # (R, G, B) Plus additional input channels for class cond.
out_channels=3, # the number of output channels
layers_per_block=2, # how many ResNet layers to use per UNet block
# block_out_channels=(32, 64, 64),
block_out_channels=(128, 256, 256, 512), # trying a larger network
down_block_types=(
"DownBlock2D", # a regular ResNet downsampling block
"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
"AttnDownBlock2D",
"AttnDownBlock2D",
),
up_block_types=(
"AttnUpBlock2D",
"AttnUpBlock2D",
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
"UpBlock2D", # a regular ResNet upsampling block
),
)
# Our forward method now takes the class labels as an additional argument
def forward(self, x, t, class_labels):
# Shape of x:
bs, ch, w, h = x.shape
# class conditioning in right shape to add as additional input channels
class_cond = self.class_emb(class_labels) # Map to embedding dimension
class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
# x is shape (bs, 1, 28, 28) and class_cond is now (bs, 4, 28, 28)
# Net input is now x and class cond concatenated together along dimension 1
net_input = torch.cat((x, class_cond), 1) # (bs, 5, 28, 28)
# Feed this to the UNet alongside the timestep and return the prediction
return self.model(net_input, t).sample # (bs, 1, 28, 28)
def load_checkpoint_for_inference(filepath, model_class):
"""
Initializes the model architecture and loads only the trained weights for inference.
"""
# Instantiate the model with the correct architecture/arguments
# (You need the Model Class definition handy)
model = model_class()
# Load the checkpoint file with map_location to handle device mismatch
checkpoint = torch.load(filepath, map_location=torch.device('cpu'))
# Load the state dictionary into the model instance
model.load_state_dict(checkpoint)
# Set the model to evaluation mode for inference
model.eval()
# Move the model to the device (GPU/CPU)
model.to(device)
print(f"Checkpoint loaded.")
return model
@spaces.GPU # Make this function run on GPU
def generate_images(selected_class_name, num_samples, progress=gr.Progress()):
print(f"Generating {num_samples} samples for class: {selected_class_name}")
# Map class name to class ID
try:
label = cifar10_classes.index(selected_class_name)
except ValueError:
print(f"Error: Class '{selected_class_name}' not found.")
return [] # Return empty list if class not found
# Prepare random x to start from, plus a tensor for the desired label y
# num images, num channels, img width, img height
x = torch.randn(num_samples, 3, 32, 32).to(device)
# The label needs to be a tensor of shape (num_samples,) or broadcastable to it
y = torch.full((num_samples,), label, dtype=torch.long).to(device)
# Sampling loop
for i, t in progress.tqdm(enumerate(noise_scheduler.timesteps), total=len(noise_scheduler.timesteps), desc=f"Generating {selected_class_name} images"):
# Get model pred
with torch.no_grad():
residual = model(x, t, y) # Note that we pass in our label
# Update sample with step
x = noise_scheduler.step(residual, t, x).prev_sample
generated_pil_images = []
for j in range(num_samples):
# Convert tensor to PIL Image
img_tensor = x[j].detach().cpu().clip(-1, 1) * 0.5 + 0.5 # Denormalize and move to [0, 1]
img_tensor = img_tensor.permute(1, 2, 0) # C, H, W -> H, W, C
img_array = (img_tensor.numpy() * 255).astype(np.uint8)
pil_img = Image.fromarray(img_array)
generated_pil_images.append(pil_img)
return generated_pil_images
@spaces.GPU # Make this function run on GPU
def generate_all_classes(progress=gr.Progress()):
# Prepare random x to start from, plus some desired labels y
x = torch.randn(10, 3, 32, 32).to(device)
y = torch.tensor([i for i in range(10)]).to(device)
# Sampling loop
for i, t in progress.tqdm(enumerate(noise_scheduler.timesteps), total=len(noise_scheduler.timesteps), desc="Generating all 10 classes"):
# Get model pred
with torch.no_grad():
residual = model(x, t, y) # Again, note that we pass in our labels y
# Update sample with step
x = noise_scheduler.step(residual, t, x).prev_sample
generated_pil_images = []
for j in range(10):
# Convert tensor to PIL Image
img_tensor = x[j].detach().cpu().clip(-1, 1) * 0.5 + 0.5 # Denormalize and move to [0, 1]
img_tensor = img_tensor.permute(1, 2, 0) # C, H, W -> H, W, C
img_array = (img_tensor.numpy() * 255).astype(np.uint8)
pil_img = Image.fromarray(img_array)
generated_pil_images.append(pil_img)
return generated_pil_images
# CIFAR-10 class names
cifar10_classes = [
"plane", "car", "bird", "cat", "deer",
"dog", "frog", "horse", "ship", "truck"
]
# Initialize a dummy model (replace with your actual model loading)
model = load_checkpoint_for_inference(filepath="CIFAR10_unet_200_epochs_inference.pth", model_class=ClassConditionedUnet)
# Create the Gradio interface
custom_css = """
#gallery {
display: grid !important;
grid-auto-flow: column !important;
grid-auto-columns: min-content !important;
overflow-x: auto !important;
gap: 10px !important;
}
"""
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("# CIFAR-10 Diffusion Model")
gr.Markdown("Select a class and click 'Generate' to create image samples.")
with gr.Row():
class_selector = gr.Radio(
cifar10_classes, label="Select CIFAR-10 Class", value=cifar10_classes[0]
)
num_samples_slider = gr.Slider(
minimum=1, maximum=10, value=4, step=1, label="Number of Images to Generate"
)
with gr.Row():
generate_btn = gr.Button("Generate Selected Class")
generate_all_btn = gr.Button("Generate All Classes")
with gr.Row():
output_gallery = gr.Gallery(label="Generated Images", show_label=True, elem_id="gallery")
generate_btn.click(
fn=generate_images,
inputs=[class_selector, num_samples_slider],
outputs=output_gallery
)
generate_all_btn.click(
fn=generate_all_classes,
inputs=[],
outputs=output_gallery
)
# Run the Gradio app
demo.launch() |