Spaces:
Running on Zero
Running on Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,6 +12,7 @@ print(f'Using device: {device}')
|
|
| 12 |
# Create a scheduler
|
| 13 |
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
|
| 14 |
|
|
|
|
| 15 |
class ClassConditionedUnet(nn.Module):
|
| 16 |
def __init__(self, num_classes=10, class_emb_size=4):
|
| 17 |
super().__init__()
|
|
@@ -58,12 +59,6 @@ class ClassConditionedUnet(nn.Module):
|
|
| 58 |
return self.model(net_input, t).sample # (bs, 1, 28, 28)
|
| 59 |
|
| 60 |
|
| 61 |
-
# CIFAR-10 class names
|
| 62 |
-
cifar10_classes = [
|
| 63 |
-
"plane", "car", "bird", "cat", "deer",
|
| 64 |
-
"dog", "frog", "horse", "ship", "truck"
|
| 65 |
-
]
|
| 66 |
-
|
| 67 |
|
| 68 |
def load_checkpoint_for_inference(filepath, model_class):
|
| 69 |
"""
|
|
@@ -95,9 +90,6 @@ def load_checkpoint_for_inference(filepath, model_class):
|
|
| 95 |
# loaded_model = load_checkpoint_for_inference("model_path", ClassConditionedUnet)
|
| 96 |
|
| 97 |
|
| 98 |
-
# Initialize a dummy model (replace with your actual model loading)
|
| 99 |
-
model = load_checkpoint_for_inference(filepath="/content/drive/MyDrive/Colab Notebooks/HF_Diffusion_Course/model_v02/CIFAR10_unet_v_02_100_epochs_inference.pth", model_class=ClassConditionedUnet)
|
| 100 |
-
|
| 101 |
|
| 102 |
def generate_images(selected_class_name, num_samples=4):
|
| 103 |
print(f"Generating {num_samples} samples for class: {selected_class_name}")
|
|
@@ -138,6 +130,18 @@ def generate_images(selected_class_name, num_samples=4):
|
|
| 138 |
|
| 139 |
|
| 140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
# Create the Gradio interface
|
| 142 |
custom_css = """
|
| 143 |
#gallery {
|
|
|
|
| 12 |
# Create a scheduler
|
| 13 |
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
|
| 14 |
|
| 15 |
+
|
| 16 |
class ClassConditionedUnet(nn.Module):
|
| 17 |
def __init__(self, num_classes=10, class_emb_size=4):
|
| 18 |
super().__init__()
|
|
|
|
| 59 |
return self.model(net_input, t).sample # (bs, 1, 28, 28)
|
| 60 |
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
def load_checkpoint_for_inference(filepath, model_class):
|
| 64 |
"""
|
|
|
|
| 90 |
# loaded_model = load_checkpoint_for_inference("model_path", ClassConditionedUnet)
|
| 91 |
|
| 92 |
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
def generate_images(selected_class_name, num_samples=4):
|
| 95 |
print(f"Generating {num_samples} samples for class: {selected_class_name}")
|
|
|
|
| 130 |
|
| 131 |
|
| 132 |
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# Initialize a dummy model (replace with your actual model loading)
|
| 136 |
+
model = load_checkpoint_for_inference(filepath="CIFAR10_unet_v_02_100_epochs_inference.pth", model_class=ClassConditionedUnet)
|
| 137 |
+
|
| 138 |
+
# CIFAR-10 class names
|
| 139 |
+
cifar10_classes = [
|
| 140 |
+
"plane", "car", "bird", "cat", "deer",
|
| 141 |
+
"dog", "frog", "horse", "ship", "truck"
|
| 142 |
+
]
|
| 143 |
+
|
| 144 |
+
|
| 145 |
# Create the Gradio interface
|
| 146 |
custom_css = """
|
| 147 |
#gallery {
|