File size: 2,477 Bytes
95b1715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96c63c3
 
 
 
 
 
 
 
 
 
95b1715
 
 
96c63c3
 
 
 
 
 
95b1715
 
 
 
 
 
 
 
 
 
 
96c63c3
95b1715
 
 
96c63c3
95b1715
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import clip
import copy
import numpy as np
import torch.nn.functional as F

from editings.deltaedit import map_tool
from editings.deltaedit.delta_mapper import DeltaMapper


STYLE_DIM = [512] * 10 + [256, 256, 128, 128, 64, 64, 32]


def GetBoundary(fs3, dt, threshold):
    tmp = np.dot(fs3, dt)
    select = np.abs(tmp) < threshold
    return select

def improved_ds(ds, select):
    ds_imp = copy.copy(ds)
    ds_imp[select] = 0
    ds_imp = ds_imp.unsqueeze(0)
    return ds_imp


class DeltaEditor:
    def __init__(self, device="cpu"):
        self.device = device
        try:
            self.fs3 = np.load("pretrained_models/fs3.npy")
        except FileNotFoundError:
            # If fs3.npy is not available, create a dummy array
            # This is a fallback for when the file is not downloaded yet
            self.fs3 = np.zeros((512, 512))  # Dummy fs3 array
            print("Warning: fs3.npy not found, using dummy array")
        
        np.set_printoptions(suppress=True)

        self.net = DeltaMapper()
        try:
            net_ckpt = torch.load("pretrained_models/delta_mapper.pt", map_location=device)
            self.net.load_state_dict(net_ckpt)
        except FileNotFoundError:
            print("Warning: delta_mapper.pt not found, using uninitialized network")
        
        self.net = self.net.to(device).eval()

        self.clip_model, self.preprocess = clip.load("ViT-B/32", device=device)
        self.avg_pool = torch.nn.AdaptiveAvgPool2d((224, 224))
        self.upsample = torch.nn.Upsample(scale_factor=7)

    def get_delta_s(self, neutral, target, trash, orig_image, start_s):
        with torch.no_grad():
            classnames = [target, neutral]
            dt = map_tool.GetDt(classnames, self.clip_model)
            select = GetBoundary(self.fs3, dt, trash)
            dt = torch.Tensor(dt).to(self.device)
            dt = dt / dt.norm(dim=-1, keepdim=True).float().clamp(min=1e-5)

            img_gen_for_clip = self.avg_pool(orig_image)
            c_latents = self.clip_model.encode_image(img_gen_for_clip.to(self.device))
            c_latents = c_latents / c_latents.norm(dim=-1, keepdim=True).float()

            delta_c = torch.cat((c_latents, dt.unsqueeze(0)), dim=1)
            fake_delta_s = self.net(torch.cat(start_s, dim=-1), delta_c)
            improved_fake_delta_s = improved_ds(fake_delta_s[0], select)
            return torch.split(improved_fake_delta_s, STYLE_DIM, dim=-1)