Spaces:
Sleeping
Sleeping
File size: 1,185 Bytes
95b1715 6c0e8ac 95b1715 6c0e8ac 95b1715 6c0e8ac 95b1715 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import torch
def edit(latents, pca, edit_directions):
edit_latents = []
for latent in latents:
for pca_idx, start, end, strength in edit_directions:
delta = get_delta(pca, latent, pca_idx, strength)
delta_padded = torch.zeros(latent.shape).to(latent.device)
delta_padded[start:end] += delta.repeat(end - start, 1)
edit_latents.append(latent + delta_padded)
return torch.stack(edit_latents)
def get_delta(pca, latent, idx, strength):
device = latent.device
w_centered = latent - pca["mean"].to(device)
lat_comp = pca["comp"].to(device)
lat_std = pca["std"].to(device)
w_coord = (
torch.sum(w_centered[0].reshape(-1) * lat_comp[idx].reshape(-1)) / lat_std[idx]
)
delta = (strength - w_coord) * lat_comp[idx] * lat_std[idx]
return delta
def edit_latent(latent, pca, edit_direction):
pca_idx, start, end, strength = edit_direction
delta = get_delta(pca, latent, pca_idx, strength)
delta_padded = torch.zeros(latent.shape).to(latent.device)
delta_padded[start:end] += delta.repeat(end - start, 1)
edit_latent = latent + delta_padded
return edit_latent
|