File size: 1,185 Bytes
95b1715
 
 
 
 
 
 
 
6c0e8ac
95b1715
 
 
 
 
 
6c0e8ac
 
 
 
95b1715
 
 
 
 
 
 
 
 
 
6c0e8ac
95b1715
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch


def edit(latents, pca, edit_directions):
    edit_latents = []
    for latent in latents:
        for pca_idx, start, end, strength in edit_directions:
            delta = get_delta(pca, latent, pca_idx, strength)
            delta_padded = torch.zeros(latent.shape).to(latent.device)
            delta_padded[start:end] += delta.repeat(end - start, 1)
            edit_latents.append(latent + delta_padded)
    return torch.stack(edit_latents)


def get_delta(pca, latent, idx, strength):
    device = latent.device
    w_centered = latent - pca["mean"].to(device)
    lat_comp = pca["comp"].to(device)
    lat_std = pca["std"].to(device)
    w_coord = (
        torch.sum(w_centered[0].reshape(-1) * lat_comp[idx].reshape(-1)) / lat_std[idx]
    )
    delta = (strength - w_coord) * lat_comp[idx] * lat_std[idx]
    return delta


def edit_latent(latent, pca, edit_direction):
    pca_idx, start, end, strength = edit_direction
    delta = get_delta(pca, latent, pca_idx, strength)
    delta_padded = torch.zeros(latent.shape).to(latent.device)
    delta_padded[start:end] += delta.repeat(end - start, 1)
    edit_latent = latent + delta_padded
    return edit_latent