Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -88,7 +88,7 @@ def load_checkpoint_for_inference(filepath, model_class):
|
|
| 88 |
|
| 89 |
|
| 90 |
@spaces.GPU # Make this function run on GPU
|
| 91 |
-
def generate_images(selected_class_name, num_samples):
|
| 92 |
print(f"Generating {num_samples} samples for class: {selected_class_name}")
|
| 93 |
|
| 94 |
# Map class name to class ID
|
|
@@ -105,14 +105,14 @@ def generate_images(selected_class_name, num_samples):
|
|
| 105 |
y = torch.full((num_samples,), label, dtype=torch.long).to(device)
|
| 106 |
|
| 107 |
# Sampling loop
|
| 108 |
-
for i, t in tqdm(enumerate(noise_scheduler.timesteps)):
|
| 109 |
|
| 110 |
# Get model pred
|
| 111 |
with torch.no_grad():
|
| 112 |
residual = model(x, t, y) # Note that we pass in our label
|
| 113 |
|
| 114 |
# Update sample with step
|
| 115 |
-
x = noise_scheduler.step(residual, t, x).prev_sample
|
| 116 |
|
| 117 |
generated_pil_images = []
|
| 118 |
for j in range(num_samples):
|
|
|
|
| 88 |
|
| 89 |
|
| 90 |
@spaces.GPU # Make this function run on GPU
|
| 91 |
+
def generate_images(selected_class_name, num_samples, progress=gr.Progress()):
|
| 92 |
print(f"Generating {num_samples} samples for class: {selected_class_name}")
|
| 93 |
|
| 94 |
# Map class name to class ID
|
|
|
|
| 105 |
y = torch.full((num_samples,), label, dtype=torch.long).to(device)
|
| 106 |
|
| 107 |
# Sampling loop
|
| 108 |
+
for i, t in progress.tqdm(enumerate(noise_scheduler.timesteps), total=len(noise_scheduler.timesteps), desc=f"Generating {selected_class_name} images"):
|
| 109 |
|
| 110 |
# Get model pred
|
| 111 |
with torch.no_grad():
|
| 112 |
residual = model(x, t, y) # Note that we pass in our label
|
| 113 |
|
| 114 |
# Update sample with step
|
| 115 |
+
x = noise_scheduler.step(residual, t, x).prev_sample
|
| 116 |
|
| 117 |
generated_pil_images = []
|
| 118 |
for j in range(num_samples):
|