Spaces:
Sleeping
Sleeping
File size: 2,477 Bytes
95b1715 96c63c3 95b1715 96c63c3 95b1715 96c63c3 95b1715 96c63c3 95b1715 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import torch
import clip
import copy
import numpy as np
import torch.nn.functional as F
from editings.deltaedit import map_tool
from editings.deltaedit.delta_mapper import DeltaMapper
STYLE_DIM = [512] * 10 + [256, 256, 128, 128, 64, 64, 32]
def GetBoundary(fs3, dt, threshold):
tmp = np.dot(fs3, dt)
select = np.abs(tmp) < threshold
return select
def improved_ds(ds, select):
ds_imp = copy.copy(ds)
ds_imp[select] = 0
ds_imp = ds_imp.unsqueeze(0)
return ds_imp
class DeltaEditor:
def __init__(self, device="cpu"):
self.device = device
try:
self.fs3 = np.load("pretrained_models/fs3.npy")
except FileNotFoundError:
# If fs3.npy is not available, create a dummy array
# This is a fallback for when the file is not downloaded yet
self.fs3 = np.zeros((512, 512)) # Dummy fs3 array
print("Warning: fs3.npy not found, using dummy array")
np.set_printoptions(suppress=True)
self.net = DeltaMapper()
try:
net_ckpt = torch.load("pretrained_models/delta_mapper.pt", map_location=device)
self.net.load_state_dict(net_ckpt)
except FileNotFoundError:
print("Warning: delta_mapper.pt not found, using uninitialized network")
self.net = self.net.to(device).eval()
self.clip_model, self.preprocess = clip.load("ViT-B/32", device=device)
self.avg_pool = torch.nn.AdaptiveAvgPool2d((224, 224))
self.upsample = torch.nn.Upsample(scale_factor=7)
def get_delta_s(self, neutral, target, trash, orig_image, start_s):
with torch.no_grad():
classnames = [target, neutral]
dt = map_tool.GetDt(classnames, self.clip_model)
select = GetBoundary(self.fs3, dt, trash)
dt = torch.Tensor(dt).to(self.device)
dt = dt / dt.norm(dim=-1, keepdim=True).float().clamp(min=1e-5)
img_gen_for_clip = self.avg_pool(orig_image)
c_latents = self.clip_model.encode_image(img_gen_for_clip.to(self.device))
c_latents = c_latents / c_latents.norm(dim=-1, keepdim=True).float()
delta_c = torch.cat((c_latents, dt.unsqueeze(0)), dim=1)
fake_delta_s = self.net(torch.cat(start_s, dim=-1), delta_c)
improved_fake_delta_s = improved_ds(fake_delta_s[0], select)
return torch.split(improved_fake_delta_s, STYLE_DIM, dim=-1)
|