File size: 7,752 Bytes
042ee08
c252b52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d98797f
 
c252b52
 
 
 
 
 
 
d98797f
c252b52
 
 
 
 
 
 
 
0e3b20e
647b15b
c252b52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
647b15b
c252b52
 
 
 
 
 
647b15b
c252b52
 
 
 
 
 
 
 
 
 
 
 
 
0e3b20e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c252b52
b27a722
 
 
 
 
 
 
c45fd6e
db7b6f0
c45fd6e
b27a722
c252b52
 
 
0e3b20e
 
 
 
 
c252b52
 
 
0e3b20e
c252b52
 
 
 
 
 
 
c45fd6e
 
 
c252b52
 
0e3b20e
 
c252b52
 
 
 
 
 
c45fd6e
c252b52
 
 
0e3b20e
 
 
 
 
c252b52
 
 
0e3b20e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import spaces # for using hugging face ZeroGPU
import gradio as gr
import torch
from torch import nn
import numpy as np
from PIL import Image
from tqdm.auto import tqdm
from diffusers import DDPMScheduler, UNet2DModel # Hugging Face diffusers library

device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

# Create a scheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')

class ClassConditionedUnet(nn.Module):
  def __init__(self, num_classes=10, class_emb_size=4):
    super().__init__()

    # The embedding layer will map the class label to a vector of size class_emb_size
    self.class_emb = nn.Embedding(num_classes, class_emb_size)

    # Self.model is an unconditional UNet with extra input channels to accept the conditioning information (the class embedding)
    self.model = UNet2DModel(
        sample_size=32,           # the target image resolution
        in_channels=3 + class_emb_size, # (R, G, B) Plus additional input channels for class cond.
        out_channels=3,           # the number of output channels
        layers_per_block=2,       # how many ResNet layers to use per UNet block
        # block_out_channels=(32, 64, 64),
        block_out_channels=(128, 256, 256, 512), # trying a larger network
        down_block_types=(
            "DownBlock2D",        # a regular ResNet downsampling block
            "AttnDownBlock2D",    # a ResNet downsampling block with spatial self-attention
            "AttnDownBlock2D",
            "AttnDownBlock2D",
            ),
        up_block_types=(
            "AttnUpBlock2D",
            "AttnUpBlock2D",
            "AttnUpBlock2D",      # a ResNet upsampling block with spatial self-attention
            "UpBlock2D",          # a regular ResNet upsampling block
            ),
    )

  # Our forward method now takes the class labels as an additional argument
  def forward(self, x, t, class_labels):
    # Shape of x:
    bs, ch, w, h = x.shape

    # class conditioning in right shape to add as additional input channels
    class_cond = self.class_emb(class_labels) # Map to embedding dimension
    class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
    # x is shape (bs, 1, 28, 28) and class_cond is now (bs, 4, 28, 28)

    # Net input is now x and class cond concatenated together along dimension 1
    net_input = torch.cat((x, class_cond), 1) # (bs, 5, 28, 28)

    # Feed this to the UNet alongside the timestep and return the prediction
    return self.model(net_input, t).sample # (bs, 1, 28, 28)



def load_checkpoint_for_inference(filepath, model_class):
    """
    Initializes the model architecture and loads only the trained weights for inference.
    """

    # Instantiate the model with the correct architecture/arguments
    # (You need the Model Class definition handy)
    model = model_class()

    # Load the checkpoint file with map_location to handle device mismatch
    checkpoint = torch.load(filepath, map_location=torch.device('cpu'))

    # Load the state dictionary into the model instance
    model.load_state_dict(checkpoint)

    # Set the model to evaluation mode for inference
    model.eval()

    # Move the model to the device (GPU/CPU)
    model.to(device)

    print(f"Checkpoint loaded.")

    return model



@spaces.GPU # Make this function run on GPU
def generate_images(selected_class_name, num_samples, progress=gr.Progress()):
    print(f"Generating {num_samples} samples for class: {selected_class_name}")

    # Map class name to class ID
    try:
        label = cifar10_classes.index(selected_class_name)
    except ValueError:
        print(f"Error: Class '{selected_class_name}' not found.")
        return [] # Return empty list if class not found

    # Prepare random x to start from, plus a tensor for the desired label y
    #      num images, num channels, img width, img height
    x = torch.randn(num_samples, 3, 32, 32).to(device)
    # The label needs to be a tensor of shape (num_samples,) or broadcastable to it
    y = torch.full((num_samples,), label, dtype=torch.long).to(device)

    # Sampling loop
    for i, t in progress.tqdm(enumerate(noise_scheduler.timesteps), total=len(noise_scheduler.timesteps), desc=f"Generating {selected_class_name} images"):

        # Get model pred
        with torch.no_grad():
            residual = model(x, t, y)  # Note that we pass in our label

        # Update sample with step
        x = noise_scheduler.step(residual, t, x).prev_sample

    generated_pil_images = []
    for j in range(num_samples):
        # Convert tensor to PIL Image
        img_tensor = x[j].detach().cpu().clip(-1, 1) * 0.5 + 0.5 # Denormalize and move to [0, 1]
        img_tensor = img_tensor.permute(1, 2, 0) # C, H, W -> H, W, C
        img_array = (img_tensor.numpy() * 255).astype(np.uint8)
        pil_img = Image.fromarray(img_array)
        generated_pil_images.append(pil_img)

    return generated_pil_images


@spaces.GPU # Make this function run on GPU
def generate_all_classes(progress=gr.Progress()):
    # Prepare random x to start from, plus some desired labels y
    x = torch.randn(10, 3, 32, 32).to(device)
    y = torch.tensor([i for i in range(10)]).to(device)

    # Sampling loop
    for i, t in progress.tqdm(enumerate(noise_scheduler.timesteps), total=len(noise_scheduler.timesteps), desc="Generating all 10 classes"):

        # Get model pred
        with torch.no_grad():
            residual = model(x, t, y)  # Again, note that we pass in our labels y

        # Update sample with step
        x = noise_scheduler.step(residual, t, x).prev_sample

    generated_pil_images = []
    for j in range(10):
        # Convert tensor to PIL Image
        img_tensor = x[j].detach().cpu().clip(-1, 1) * 0.5 + 0.5 # Denormalize and move to [0, 1]
        img_tensor = img_tensor.permute(1, 2, 0) # C, H, W -> H, W, C
        img_array = (img_tensor.numpy() * 255).astype(np.uint8)
        pil_img = Image.fromarray(img_array)
        generated_pil_images.append(pil_img)

    return generated_pil_images


# CIFAR-10 class names
cifar10_classes = [
    "plane", "car", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
]

# Initialize a dummy model (replace with your actual model loading)
model = load_checkpoint_for_inference(filepath="CIFAR10_unet_200_epochs_inference.pth", model_class=ClassConditionedUnet)


# Create the Gradio interface
custom_css = """
#gallery {
    display: grid !important;
    grid-auto-flow: column !important;
    grid-auto-columns: min-content !important;
    overflow-x: auto !important;
    gap: 10px !important;
}
"""

with gr.Blocks(css=custom_css) as demo:
    gr.Markdown("# CIFAR-10 Diffusion Model")
    gr.Markdown("Select a class and click 'Generate' to create image samples.")

    with gr.Row():
        class_selector = gr.Radio(
            cifar10_classes, label="Select CIFAR-10 Class", value=cifar10_classes[0]
        )
        num_samples_slider = gr.Slider(
            minimum=1, maximum=10, value=4, step=1, label="Number of Images to Generate"
        )

    with gr.Row():
        generate_btn = gr.Button("Generate Selected Class")
        generate_all_btn = gr.Button("Generate All Classes")

    with gr.Row():
        output_gallery = gr.Gallery(label="Generated Images", show_label=True, elem_id="gallery")

    generate_btn.click(
        fn=generate_images,
        inputs=[class_selector, num_samples_slider],
        outputs=output_gallery
    )

    generate_all_btn.click(
        fn=generate_all_classes,
        inputs=[],
        outputs=output_gallery
    )


# Run the Gradio app
demo.launch()