LogicGoInfotechSpaces's picture
Fix DeltaEditor device handling: add device parameter, error handling for missing files, replace .cuda() with .to(device)
96c63c3
import torch
import clip
import copy
import numpy as np
import torch.nn.functional as F
from editings.deltaedit import map_tool
from editings.deltaedit.delta_mapper import DeltaMapper
STYLE_DIM = [512] * 10 + [256, 256, 128, 128, 64, 64, 32]
def GetBoundary(fs3, dt, threshold):
tmp = np.dot(fs3, dt)
select = np.abs(tmp) < threshold
return select
def improved_ds(ds, select):
ds_imp = copy.copy(ds)
ds_imp[select] = 0
ds_imp = ds_imp.unsqueeze(0)
return ds_imp
class DeltaEditor:
def __init__(self, device="cpu"):
self.device = device
try:
self.fs3 = np.load("pretrained_models/fs3.npy")
except FileNotFoundError:
# If fs3.npy is not available, create a dummy array
# This is a fallback for when the file is not downloaded yet
self.fs3 = np.zeros((512, 512)) # Dummy fs3 array
print("Warning: fs3.npy not found, using dummy array")
np.set_printoptions(suppress=True)
self.net = DeltaMapper()
try:
net_ckpt = torch.load("pretrained_models/delta_mapper.pt", map_location=device)
self.net.load_state_dict(net_ckpt)
except FileNotFoundError:
print("Warning: delta_mapper.pt not found, using uninitialized network")
self.net = self.net.to(device).eval()
self.clip_model, self.preprocess = clip.load("ViT-B/32", device=device)
self.avg_pool = torch.nn.AdaptiveAvgPool2d((224, 224))
self.upsample = torch.nn.Upsample(scale_factor=7)
def get_delta_s(self, neutral, target, trash, orig_image, start_s):
with torch.no_grad():
classnames = [target, neutral]
dt = map_tool.GetDt(classnames, self.clip_model)
select = GetBoundary(self.fs3, dt, trash)
dt = torch.Tensor(dt).to(self.device)
dt = dt / dt.norm(dim=-1, keepdim=True).float().clamp(min=1e-5)
img_gen_for_clip = self.avg_pool(orig_image)
c_latents = self.clip_model.encode_image(img_gen_for_clip.to(self.device))
c_latents = c_latents / c_latents.norm(dim=-1, keepdim=True).float()
delta_c = torch.cat((c_latents, dt.unsqueeze(0)), dim=1)
fake_delta_s = self.net(torch.cat(start_s, dim=-1), delta_c)
improved_fake_delta_s = improved_ds(fake_delta_s[0], select)
return torch.split(improved_fake_delta_s, STYLE_DIM, dim=-1)