LogicGoInfotechSpaces commited on
Commit
6c0e8ac
·
1 Parent(s): 5050e58

Fix Gradio slider update and CUDA device issues for all attributes

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. editings/ganspace.py +6 -5
app.py CHANGED
@@ -206,7 +206,7 @@ def build_ui() -> gr.Blocks:
206
  lo, hi = recommended_range(name)
207
  # Keep current value within new bounds
208
  new_val = max(lo, min(hi, strength.value if hasattr(strength, "value") else 0))
209
- return gr.Slider.update(minimum=lo, maximum=hi, value=new_val)
210
 
211
  attr.change(_on_attr_change, inputs=attr, outputs=strength)
212
 
 
206
  lo, hi = recommended_range(name)
207
  # Keep current value within new bounds
208
  new_val = max(lo, min(hi, strength.value if hasattr(strength, "value") else 0))
209
+ return gr.Slider(minimum=lo, maximum=hi, value=new_val)
210
 
211
  attr.change(_on_attr_change, inputs=attr, outputs=strength)
212
 
editings/ganspace.py CHANGED
@@ -6,16 +6,17 @@ def edit(latents, pca, edit_directions):
6
  for latent in latents:
7
  for pca_idx, start, end, strength in edit_directions:
8
  delta = get_delta(pca, latent, pca_idx, strength)
9
- delta_padded = torch.zeros(latent.shape).to("cuda")
10
  delta_padded[start:end] += delta.repeat(end - start, 1)
11
  edit_latents.append(latent + delta_padded)
12
  return torch.stack(edit_latents)
13
 
14
 
15
  def get_delta(pca, latent, idx, strength):
16
- w_centered = latent - pca["mean"].to("cuda")
17
- lat_comp = pca["comp"].to("cuda")
18
- lat_std = pca["std"].to("cuda")
 
19
  w_coord = (
20
  torch.sum(w_centered[0].reshape(-1) * lat_comp[idx].reshape(-1)) / lat_std[idx]
21
  )
@@ -26,7 +27,7 @@ def get_delta(pca, latent, idx, strength):
26
  def edit_latent(latent, pca, edit_direction):
27
  pca_idx, start, end, strength = edit_direction
28
  delta = get_delta(pca, latent, pca_idx, strength)
29
- delta_padded = torch.zeros(latent.shape).to("cuda")
30
  delta_padded[start:end] += delta.repeat(end - start, 1)
31
  edit_latent = latent + delta_padded
32
  return edit_latent
 
6
  for latent in latents:
7
  for pca_idx, start, end, strength in edit_directions:
8
  delta = get_delta(pca, latent, pca_idx, strength)
9
+ delta_padded = torch.zeros(latent.shape).to(latent.device)
10
  delta_padded[start:end] += delta.repeat(end - start, 1)
11
  edit_latents.append(latent + delta_padded)
12
  return torch.stack(edit_latents)
13
 
14
 
15
  def get_delta(pca, latent, idx, strength):
16
+ device = latent.device
17
+ w_centered = latent - pca["mean"].to(device)
18
+ lat_comp = pca["comp"].to(device)
19
+ lat_std = pca["std"].to(device)
20
  w_coord = (
21
  torch.sum(w_centered[0].reshape(-1) * lat_comp[idx].reshape(-1)) / lat_std[idx]
22
  )
 
27
  def edit_latent(latent, pca, edit_direction):
28
  pca_idx, start, end, strength = edit_direction
29
  delta = get_delta(pca, latent, pca_idx, strength)
30
+ delta_padded = torch.zeros(latent.shape).to(latent.device)
31
  delta_padded[start:end] += delta.repeat(end - start, 1)
32
  edit_latent = latent + delta_padded
33
  return edit_latent