Spaces:
Sleeping
Sleeping
Fix DeltaEditor device handling: add device parameter, error handling for missing files, replace .cuda() with .to(device)
96c63c3
| import torch | |
| import clip | |
| import copy | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from editings.deltaedit import map_tool | |
| from editings.deltaedit.delta_mapper import DeltaMapper | |
| STYLE_DIM = [512] * 10 + [256, 256, 128, 128, 64, 64, 32] | |
| def GetBoundary(fs3, dt, threshold): | |
| tmp = np.dot(fs3, dt) | |
| select = np.abs(tmp) < threshold | |
| return select | |
| def improved_ds(ds, select): | |
| ds_imp = copy.copy(ds) | |
| ds_imp[select] = 0 | |
| ds_imp = ds_imp.unsqueeze(0) | |
| return ds_imp | |
| class DeltaEditor: | |
| def __init__(self, device="cpu"): | |
| self.device = device | |
| try: | |
| self.fs3 = np.load("pretrained_models/fs3.npy") | |
| except FileNotFoundError: | |
| # If fs3.npy is not available, create a dummy array | |
| # This is a fallback for when the file is not downloaded yet | |
| self.fs3 = np.zeros((512, 512)) # Dummy fs3 array | |
| print("Warning: fs3.npy not found, using dummy array") | |
| np.set_printoptions(suppress=True) | |
| self.net = DeltaMapper() | |
| try: | |
| net_ckpt = torch.load("pretrained_models/delta_mapper.pt", map_location=device) | |
| self.net.load_state_dict(net_ckpt) | |
| except FileNotFoundError: | |
| print("Warning: delta_mapper.pt not found, using uninitialized network") | |
| self.net = self.net.to(device).eval() | |
| self.clip_model, self.preprocess = clip.load("ViT-B/32", device=device) | |
| self.avg_pool = torch.nn.AdaptiveAvgPool2d((224, 224)) | |
| self.upsample = torch.nn.Upsample(scale_factor=7) | |
| def get_delta_s(self, neutral, target, trash, orig_image, start_s): | |
| with torch.no_grad(): | |
| classnames = [target, neutral] | |
| dt = map_tool.GetDt(classnames, self.clip_model) | |
| select = GetBoundary(self.fs3, dt, trash) | |
| dt = torch.Tensor(dt).to(self.device) | |
| dt = dt / dt.norm(dim=-1, keepdim=True).float().clamp(min=1e-5) | |
| img_gen_for_clip = self.avg_pool(orig_image) | |
| c_latents = self.clip_model.encode_image(img_gen_for_clip.to(self.device)) | |
| c_latents = c_latents / c_latents.norm(dim=-1, keepdim=True).float() | |
| delta_c = torch.cat((c_latents, dt.unsqueeze(0)), dim=1) | |
| fake_delta_s = self.net(torch.cat(start_s, dim=-1), delta_c) | |
| improved_fake_delta_s = improved_ds(fake_delta_s[0], select) | |
| return torch.split(improved_fake_delta_s, STYLE_DIM, dim=-1) | |