cameron-d commited on
Commit
647b15b
·
verified ·
1 Parent(s): 0e3b20e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -88,7 +88,7 @@ def load_checkpoint_for_inference(filepath, model_class):
88
 
89
 
90
  @spaces.GPU # Make this function run on GPU
91
- def generate_images(selected_class_name, num_samples):
92
  print(f"Generating {num_samples} samples for class: {selected_class_name}")
93
 
94
  # Map class name to class ID
@@ -105,14 +105,14 @@ def generate_images(selected_class_name, num_samples):
105
  y = torch.full((num_samples,), label, dtype=torch.long).to(device)
106
 
107
  # Sampling loop
108
- for i, t in tqdm(enumerate(noise_scheduler.timesteps)):
109
 
110
  # Get model pred
111
  with torch.no_grad():
112
  residual = model(x, t, y) # Note that we pass in our label
113
 
114
  # Update sample with step
115
- x = noise_scheduler.step(residual, t, x).prev_sample # Correctly update x
116
 
117
  generated_pil_images = []
118
  for j in range(num_samples):
 
88
 
89
 
90
  @spaces.GPU # Make this function run on GPU
91
+ def generate_images(selected_class_name, num_samples, progress=gr.Progress()):
92
  print(f"Generating {num_samples} samples for class: {selected_class_name}")
93
 
94
  # Map class name to class ID
 
105
  y = torch.full((num_samples,), label, dtype=torch.long).to(device)
106
 
107
  # Sampling loop
108
+ for i, t in progress.tqdm(enumerate(noise_scheduler.timesteps), total=len(noise_scheduler.timesteps), desc=f"Generating {selected_class_name} images"):
109
 
110
  # Get model pred
111
  with torch.no_grad():
112
  residual = model(x, t, y) # Note that we pass in our label
113
 
114
  # Update sample with step
115
+ x = noise_scheduler.step(residual, t, x).prev_sample
116
 
117
  generated_pil_images = []
118
  for j in range(num_samples):