cameron-d commited on
Commit
c45fd6e
·
verified ·
1 Parent(s): d98797f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -12,7 +12,6 @@ print(f'Using device: {device}')
12
  # Create a scheduler
13
  noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
14
 
15
-
16
  class ClassConditionedUnet(nn.Module):
17
  def __init__(self, num_classes=10, class_emb_size=4):
18
  super().__init__()
@@ -90,7 +89,7 @@ def load_checkpoint_for_inference(filepath, model_class):
90
 
91
 
92
 
93
- def generate_images(selected_class_name, num_samples=4):
94
  print(f"Generating {num_samples} samples for class: {selected_class_name}")
95
 
96
  # Map class name to class ID
@@ -130,16 +129,15 @@ def generate_images(selected_class_name, num_samples=4):
130
 
131
 
132
 
133
-
134
- # Initialize a dummy model (replace with your actual model loading)
135
- model = load_checkpoint_for_inference(filepath="CIFAR10_unet_v_02_100_epochs_inference.pth", model_class=ClassConditionedUnet)
136
-
137
  # CIFAR-10 class names
138
  cifar10_classes = [
139
  "plane", "car", "bird", "cat", "deer",
140
  "dog", "frog", "horse", "ship", "truck"
141
  ]
142
 
 
 
 
143
 
144
  # Create the Gradio interface
145
  custom_css = """
@@ -173,16 +171,19 @@ with gr.Blocks(css=custom_css) as demo:
173
  class_selector = gr.Radio(
174
  cifar10_classes, label="Select CIFAR-10 Class", value=cifar10_classes[0]
175
  )
 
 
 
176
 
177
  with gr.Row():
178
- generate_btn = gr.Button("Generate Samples")
179
 
180
  with gr.Row():
181
  output_gallery = gr.Gallery(label="Generated Images", show_label=True, elem_id="gallery")
182
 
183
  generate_btn.click(
184
  fn=generate_images,
185
- inputs=class_selector,
186
  outputs=output_gallery
187
  )
188
 
 
12
  # Create a scheduler
13
  noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
14
 
 
15
  class ClassConditionedUnet(nn.Module):
16
  def __init__(self, num_classes=10, class_emb_size=4):
17
  super().__init__()
 
89
 
90
 
91
 
92
+ def generate_images(selected_class_name, num_samples):
93
  print(f"Generating {num_samples} samples for class: {selected_class_name}")
94
 
95
  # Map class name to class ID
 
129
 
130
 
131
 
 
 
 
 
132
  # CIFAR-10 class names
133
  cifar10_classes = [
134
  "plane", "car", "bird", "cat", "deer",
135
  "dog", "frog", "horse", "ship", "truck"
136
  ]
137
 
138
+ # Initialize a dummy model (replace with your actual model loading)
139
+ model = load_checkpoint_for_inference(filepath="CIFAR10_unet_v_02_100_epochs_inference.pth", model_class=ClassConditionedUnet)
140
+
141
 
142
  # Create the Gradio interface
143
  custom_css = """
 
171
  class_selector = gr.Radio(
172
  cifar10_classes, label="Select CIFAR-10 Class", value=cifar10_classes[0]
173
  )
174
+ num_samples_slider = gr.Slider(
175
+ minimum=1, maximum=10, value=4, step=1, label="Number of Images to Generate"
176
+ )
177
 
178
  with gr.Row():
179
+ generate_btn = gr.Button("Generate")
180
 
181
  with gr.Row():
182
  output_gallery = gr.Gallery(label="Generated Images", show_label=True, elem_id="gallery")
183
 
184
  generate_btn.click(
185
  fn=generate_images,
186
+ inputs=[class_selector, num_samples_slider],
187
  outputs=output_gallery
188
  )
189