Spaces:
Sleeping
Sleeping
Commit
·
6c0e8ac
1
Parent(s):
5050e58
Fix Gradio slider update and CUDA device issues for all attributes
Browse files- app.py +1 -1
- editings/ganspace.py +6 -5
app.py
CHANGED
|
@@ -206,7 +206,7 @@ def build_ui() -> gr.Blocks:
|
|
| 206 |
lo, hi = recommended_range(name)
|
| 207 |
# Keep current value within new bounds
|
| 208 |
new_val = max(lo, min(hi, strength.value if hasattr(strength, "value") else 0))
|
| 209 |
-
return gr.Slider
|
| 210 |
|
| 211 |
attr.change(_on_attr_change, inputs=attr, outputs=strength)
|
| 212 |
|
|
|
|
| 206 |
lo, hi = recommended_range(name)
|
| 207 |
# Keep current value within new bounds
|
| 208 |
new_val = max(lo, min(hi, strength.value if hasattr(strength, "value") else 0))
|
| 209 |
+
return gr.Slider(minimum=lo, maximum=hi, value=new_val)
|
| 210 |
|
| 211 |
attr.change(_on_attr_change, inputs=attr, outputs=strength)
|
| 212 |
|
editings/ganspace.py
CHANGED
|
@@ -6,16 +6,17 @@ def edit(latents, pca, edit_directions):
|
|
| 6 |
for latent in latents:
|
| 7 |
for pca_idx, start, end, strength in edit_directions:
|
| 8 |
delta = get_delta(pca, latent, pca_idx, strength)
|
| 9 |
-
delta_padded = torch.zeros(latent.shape).to(
|
| 10 |
delta_padded[start:end] += delta.repeat(end - start, 1)
|
| 11 |
edit_latents.append(latent + delta_padded)
|
| 12 |
return torch.stack(edit_latents)
|
| 13 |
|
| 14 |
|
| 15 |
def get_delta(pca, latent, idx, strength):
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
| 19 |
w_coord = (
|
| 20 |
torch.sum(w_centered[0].reshape(-1) * lat_comp[idx].reshape(-1)) / lat_std[idx]
|
| 21 |
)
|
|
@@ -26,7 +27,7 @@ def get_delta(pca, latent, idx, strength):
|
|
| 26 |
def edit_latent(latent, pca, edit_direction):
|
| 27 |
pca_idx, start, end, strength = edit_direction
|
| 28 |
delta = get_delta(pca, latent, pca_idx, strength)
|
| 29 |
-
delta_padded = torch.zeros(latent.shape).to(
|
| 30 |
delta_padded[start:end] += delta.repeat(end - start, 1)
|
| 31 |
edit_latent = latent + delta_padded
|
| 32 |
return edit_latent
|
|
|
|
| 6 |
for latent in latents:
|
| 7 |
for pca_idx, start, end, strength in edit_directions:
|
| 8 |
delta = get_delta(pca, latent, pca_idx, strength)
|
| 9 |
+
delta_padded = torch.zeros(latent.shape).to(latent.device)
|
| 10 |
delta_padded[start:end] += delta.repeat(end - start, 1)
|
| 11 |
edit_latents.append(latent + delta_padded)
|
| 12 |
return torch.stack(edit_latents)
|
| 13 |
|
| 14 |
|
| 15 |
def get_delta(pca, latent, idx, strength):
|
| 16 |
+
device = latent.device
|
| 17 |
+
w_centered = latent - pca["mean"].to(device)
|
| 18 |
+
lat_comp = pca["comp"].to(device)
|
| 19 |
+
lat_std = pca["std"].to(device)
|
| 20 |
w_coord = (
|
| 21 |
torch.sum(w_centered[0].reshape(-1) * lat_comp[idx].reshape(-1)) / lat_std[idx]
|
| 22 |
)
|
|
|
|
| 27 |
def edit_latent(latent, pca, edit_direction):
|
| 28 |
pca_idx, start, end, strength = edit_direction
|
| 29 |
delta = get_delta(pca, latent, pca_idx, strength)
|
| 30 |
+
delta_padded = torch.zeros(latent.shape).to(latent.device)
|
| 31 |
delta_padded[start:end] += delta.repeat(end - start, 1)
|
| 32 |
edit_latent = latent + delta_padded
|
| 33 |
return edit_latent
|