cameron-d commited on
Commit
c252b52
·
verified ·
1 Parent(s): e8a5f5f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -0
app.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torch import nn
4
+ import numpy as np
5
+ from PIL import Image
6
+ from tqdm.auto import tqdm
7
+ from diffusers import DDPMScheduler, UNet2DModel # Hugging Face diffusers library
8
+
9
+ device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
10
+ print(f'Using device: {device}')
11
+
12
+ # Create a scheduler
13
+ noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
14
+
15
+ class ClassConditionedUnet(nn.Module):
16
+ def __init__(self, num_classes=10, class_emb_size=4):
17
+ super().__init__()
18
+
19
+ # The embedding layer will map the class label to a vector of size class_emb_size
20
+ self.class_emb = nn.Embedding(num_classes, class_emb_size)
21
+
22
+ # Self.model is an unconditional UNet with extra input channels to accept the conditioning information (the class embedding)
23
+ self.model = UNet2DModel(
24
+ sample_size=32, # the target image resolution
25
+ in_channels=3 + class_emb_size, # (R, G, B) Plus additional input channels for class cond.
26
+ out_channels=3, # the number of output channels
27
+ layers_per_block=2, # how many ResNet layers to use per UNet block
28
+ # block_out_channels=(32, 64, 64),
29
+ block_out_channels=(128, 256, 256, 512), # trying a larger network
30
+ down_block_types=(
31
+ "DownBlock2D", # a regular ResNet downsampling block
32
+ "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
33
+ "AttnDownBlock2D",
34
+ "AttnDownBlock2D",
35
+ ),
36
+ up_block_types=(
37
+ "AttnUpBlock2D",
38
+ "AttnUpBlock2D",
39
+ "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
40
+ "UpBlock2D", # a regular ResNet upsampling block
41
+ ),
42
+ )
43
+
44
+ # Our forward method now takes the class labels as an additional argument
45
+ def forward(self, x, t, class_labels):
46
+ # Shape of x:
47
+ bs, ch, w, h = x.shape
48
+
49
+ # class conditioning in right shape to add as additional input channels
50
+ class_cond = self.class_emb(class_labels) # Map to embedding dimension
51
+ class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
52
+ # x is shape (bs, 1, 28, 28) and class_cond is now (bs, 4, 28, 28)
53
+
54
+ # Net input is now x and class cond concatenated together along dimension 1
55
+ net_input = torch.cat((x, class_cond), 1) # (bs, 5, 28, 28)
56
+
57
+ # Feed this to the UNet alongside the timestep and return the prediction
58
+ return self.model(net_input, t).sample # (bs, 1, 28, 28)
59
+
60
+
61
+ # CIFAR-10 class names
62
+ cifar10_classes = [
63
+ "plane", "car", "bird", "cat", "deer",
64
+ "dog", "frog", "horse", "ship", "truck"
65
+ ]
66
+
67
+
68
+ def load_checkpoint_for_inference(filepath, model_class):
69
+ """
70
+ Initializes the model architecture and loads only the trained weights for inference.
71
+ """
72
+
73
+ # Instantiate the model with the correct architecture/arguments
74
+ # (You need the Model Class definition handy)
75
+ model = model_class()
76
+
77
+ # Load the checkpoint file
78
+ checkpoint = torch.load(filepath)
79
+
80
+ # Load the state dictionary into the model instance
81
+ model.load_state_dict(checkpoint)
82
+
83
+ # Set the model to evaluation mode for inference
84
+ model.eval()
85
+
86
+ # Optional: Move the model to the appropriate device (GPU/CPU)
87
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88
+ model.to(device)
89
+
90
+ print(f"Checkpoint loaded.")
91
+
92
+ return model
93
+
94
+ # Example Usage:
95
+ # loaded_model = load_checkpoint_for_inference("model_path", ClassConditionedUnet)
96
+
97
+
98
+ # Initialize a dummy model (replace with your actual model loading)
99
+ model = load_checkpoint_for_inference(filepath="/content/drive/MyDrive/Colab Notebooks/HF_Diffusion_Course/model_v02/CIFAR10_unet_v_02_100_epochs_inference.pth", model_class=ClassConditionedUnet)
100
+
101
+
102
+ def generate_images(selected_class_name, num_samples=4):
103
+ print(f"Generating {num_samples} samples for class: {selected_class_name}")
104
+
105
+ # Map class name to class ID
106
+ try:
107
+ label = cifar10_classes.index(selected_class_name)
108
+ except ValueError:
109
+ print(f"Error: Class '{selected_class_name}' not found.")
110
+ return [] # Return empty list if class not found
111
+
112
+ # Prepare random x to start from, plus a tensor for the desired label y
113
+ # num images, num channels, img width, img height
114
+ x = torch.randn(num_samples, 3, 32, 32).to(device)
115
+ # The label needs to be a tensor of shape (num_samples,) or broadcastable to it
116
+ y = torch.full((num_samples,), label, dtype=torch.long).to(device)
117
+
118
+ # Sampling loop
119
+ for i, t in tqdm(enumerate(noise_scheduler.timesteps)):
120
+
121
+ # Get model pred
122
+ with torch.no_grad():
123
+ residual = model(x, t, y) # Note that we pass in our label
124
+
125
+ # Update sample with step
126
+ x = noise_scheduler.step(residual, t, x).prev_sample # Correctly update x
127
+
128
+ generated_pil_images = []
129
+ for j in range(num_samples):
130
+ # Convert tensor to PIL Image
131
+ img_tensor = x[j].detach().cpu().clip(-1, 1) * 0.5 + 0.5 # Denormalize and move to [0, 1]
132
+ img_tensor = img_tensor.permute(1, 2, 0) # C, H, W -> H, W, C
133
+ img_array = (img_tensor.numpy() * 255).astype(np.uint8)
134
+ pil_img = Image.fromarray(img_array)
135
+ generated_pil_images.append(pil_img)
136
+
137
+ return generated_pil_images
138
+
139
+
140
+
141
+ # Create the Gradio interface
142
+ custom_css = """
143
+ #gallery {
144
+ display: flex; /* Use flexbox for layout */
145
+ flex-wrap: nowrap; /* Prevent wrapping to multiple rows */
146
+ overflow-x: auto; /* Enable horizontal scrolling if content overflows */
147
+ align-items: flex-start; /* Align items to the start of the cross axis (top) */
148
+ }
149
+ #gallery .thumbnail-item { /* Targeting the individual image containers within the gallery */
150
+ flex-shrink: 0; /* Prevent items from shrinking */
151
+ width: 120px; /* Give each item a fixed width, slightly larger than the image */
152
+ height: auto; /* Allow height to adjust */
153
+ margin: 5px; /* Add some spacing between images */
154
+ display: flex; /* Make the item itself a flex container to center the image */
155
+ justify-content: center; /* Center image horizontally */
156
+ align-items: center; /* Center image vertically */
157
+ }
158
+ #gallery img {
159
+ max-width: 100px !important;
160
+ max-height: 100px !important;
161
+ object-fit: contain; /* Ensure the entire image is visible within its bounds */
162
+ }
163
+ """
164
+
165
+ with gr.Blocks(css=custom_css) as demo:
166
+ gr.Markdown("# CIFAR-10 Diffusion Model")
167
+ gr.Markdown("Select a class and click 'Generate' to create image samples.")
168
+
169
+ with gr.Row():
170
+ class_selector = gr.Radio(
171
+ cifar10_classes, label="Select CIFAR-10 Class", value=cifar10_classes[0]
172
+ )
173
+
174
+ with gr.Row():
175
+ generate_btn = gr.Button("Generate Samples")
176
+
177
+ with gr.Row():
178
+ output_gallery = gr.Gallery(label="Generated Images", show_label=True, elem_id="gallery")
179
+
180
+ generate_btn.click(
181
+ fn=generate_images,
182
+ inputs=class_selector,
183
+ outputs=output_gallery
184
+ )
185
+
186
+
187
+
188
+ # Run the Gradio app
189
+ demo.launch()