Spaces:
Running on Zero
Running on Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,7 +12,6 @@ print(f'Using device: {device}')
|
|
| 12 |
# Create a scheduler
|
| 13 |
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
|
| 14 |
|
| 15 |
-
|
| 16 |
class ClassConditionedUnet(nn.Module):
|
| 17 |
def __init__(self, num_classes=10, class_emb_size=4):
|
| 18 |
super().__init__()
|
|
@@ -90,7 +89,7 @@ def load_checkpoint_for_inference(filepath, model_class):
|
|
| 90 |
|
| 91 |
|
| 92 |
|
| 93 |
-
def generate_images(selected_class_name, num_samples
|
| 94 |
print(f"Generating {num_samples} samples for class: {selected_class_name}")
|
| 95 |
|
| 96 |
# Map class name to class ID
|
|
@@ -130,16 +129,15 @@ def generate_images(selected_class_name, num_samples=4):
|
|
| 130 |
|
| 131 |
|
| 132 |
|
| 133 |
-
|
| 134 |
-
# Initialize a dummy model (replace with your actual model loading)
|
| 135 |
-
model = load_checkpoint_for_inference(filepath="CIFAR10_unet_v_02_100_epochs_inference.pth", model_class=ClassConditionedUnet)
|
| 136 |
-
|
| 137 |
# CIFAR-10 class names
|
| 138 |
cifar10_classes = [
|
| 139 |
"plane", "car", "bird", "cat", "deer",
|
| 140 |
"dog", "frog", "horse", "ship", "truck"
|
| 141 |
]
|
| 142 |
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
# Create the Gradio interface
|
| 145 |
custom_css = """
|
|
@@ -173,16 +171,19 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 173 |
class_selector = gr.Radio(
|
| 174 |
cifar10_classes, label="Select CIFAR-10 Class", value=cifar10_classes[0]
|
| 175 |
)
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
with gr.Row():
|
| 178 |
-
generate_btn = gr.Button("Generate
|
| 179 |
|
| 180 |
with gr.Row():
|
| 181 |
output_gallery = gr.Gallery(label="Generated Images", show_label=True, elem_id="gallery")
|
| 182 |
|
| 183 |
generate_btn.click(
|
| 184 |
fn=generate_images,
|
| 185 |
-
inputs=class_selector,
|
| 186 |
outputs=output_gallery
|
| 187 |
)
|
| 188 |
|
|
|
|
| 12 |
# Create a scheduler
|
| 13 |
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
|
| 14 |
|
|
|
|
| 15 |
class ClassConditionedUnet(nn.Module):
|
| 16 |
def __init__(self, num_classes=10, class_emb_size=4):
|
| 17 |
super().__init__()
|
|
|
|
| 89 |
|
| 90 |
|
| 91 |
|
| 92 |
+
def generate_images(selected_class_name, num_samples):
|
| 93 |
print(f"Generating {num_samples} samples for class: {selected_class_name}")
|
| 94 |
|
| 95 |
# Map class name to class ID
|
|
|
|
| 129 |
|
| 130 |
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
# CIFAR-10 class names
|
| 133 |
cifar10_classes = [
|
| 134 |
"plane", "car", "bird", "cat", "deer",
|
| 135 |
"dog", "frog", "horse", "ship", "truck"
|
| 136 |
]
|
| 137 |
|
| 138 |
+
# Initialize a dummy model (replace with your actual model loading)
|
| 139 |
+
model = load_checkpoint_for_inference(filepath="CIFAR10_unet_v_02_100_epochs_inference.pth", model_class=ClassConditionedUnet)
|
| 140 |
+
|
| 141 |
|
| 142 |
# Create the Gradio interface
|
| 143 |
custom_css = """
|
|
|
|
| 171 |
class_selector = gr.Radio(
|
| 172 |
cifar10_classes, label="Select CIFAR-10 Class", value=cifar10_classes[0]
|
| 173 |
)
|
| 174 |
+
num_samples_slider = gr.Slider(
|
| 175 |
+
minimum=1, maximum=10, value=4, step=1, label="Number of Images to Generate"
|
| 176 |
+
)
|
| 177 |
|
| 178 |
with gr.Row():
|
| 179 |
+
generate_btn = gr.Button("Generate")
|
| 180 |
|
| 181 |
with gr.Row():
|
| 182 |
output_gallery = gr.Gallery(label="Generated Images", show_label=True, elem_id="gallery")
|
| 183 |
|
| 184 |
generate_btn.click(
|
| 185 |
fn=generate_images,
|
| 186 |
+
inputs=[class_selector, num_samples_slider],
|
| 187 |
outputs=output_gallery
|
| 188 |
)
|
| 189 |
|