cameron-d commited on
Commit
b27a722
·
verified ·
1 Parent(s): b93505c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -9
app.py CHANGED
@@ -12,6 +12,7 @@ print(f'Using device: {device}')
12
  # Create a scheduler
13
  noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
14
 
 
15
  class ClassConditionedUnet(nn.Module):
16
  def __init__(self, num_classes=10, class_emb_size=4):
17
  super().__init__()
@@ -58,12 +59,6 @@ class ClassConditionedUnet(nn.Module):
58
  return self.model(net_input, t).sample # (bs, 1, 28, 28)
59
 
60
 
61
- # CIFAR-10 class names
62
- cifar10_classes = [
63
- "plane", "car", "bird", "cat", "deer",
64
- "dog", "frog", "horse", "ship", "truck"
65
- ]
66
-
67
 
68
  def load_checkpoint_for_inference(filepath, model_class):
69
  """
@@ -95,9 +90,6 @@ def load_checkpoint_for_inference(filepath, model_class):
95
  # loaded_model = load_checkpoint_for_inference("model_path", ClassConditionedUnet)
96
 
97
 
98
- # Initialize a dummy model (replace with your actual model loading)
99
- model = load_checkpoint_for_inference(filepath="/content/drive/MyDrive/Colab Notebooks/HF_Diffusion_Course/model_v02/CIFAR10_unet_v_02_100_epochs_inference.pth", model_class=ClassConditionedUnet)
100
-
101
 
102
  def generate_images(selected_class_name, num_samples=4):
103
  print(f"Generating {num_samples} samples for class: {selected_class_name}")
@@ -138,6 +130,18 @@ def generate_images(selected_class_name, num_samples=4):
138
 
139
 
140
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  # Create the Gradio interface
142
  custom_css = """
143
  #gallery {
 
12
  # Create a scheduler
13
  noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
14
 
15
+
16
  class ClassConditionedUnet(nn.Module):
17
  def __init__(self, num_classes=10, class_emb_size=4):
18
  super().__init__()
 
59
  return self.model(net_input, t).sample # (bs, 1, 28, 28)
60
 
61
 
 
 
 
 
 
 
62
 
63
  def load_checkpoint_for_inference(filepath, model_class):
64
  """
 
90
  # loaded_model = load_checkpoint_for_inference("model_path", ClassConditionedUnet)
91
 
92
 
 
 
 
93
 
94
  def generate_images(selected_class_name, num_samples=4):
95
  print(f"Generating {num_samples} samples for class: {selected_class_name}")
 
130
 
131
 
132
 
133
+
134
+
135
+ # Initialize a dummy model (replace with your actual model loading)
136
+ model = load_checkpoint_for_inference(filepath="CIFAR10_unet_v_02_100_epochs_inference.pth", model_class=ClassConditionedUnet)
137
+
138
+ # CIFAR-10 class names
139
+ cifar10_classes = [
140
+ "plane", "car", "bird", "cat", "deer",
141
+ "dog", "frog", "horse", "ship", "truck"
142
+ ]
143
+
144
+
145
  # Create the Gradio interface
146
  custom_css = """
147
  #gallery {