Spaces:
Sleeping
Sleeping
Commit
·
96c63c3
1
Parent(s):
ad8ad02
Fix DeltaEditor device handling: add device parameter, error handling for missing files, replace .cuda() with .to(device)
Browse files- editings/deltaedit/editor.py +18 -7
- editings/latent_editor.py +1 -1
editings/deltaedit/editor.py
CHANGED
|
@@ -24,14 +24,25 @@ def improved_ds(ds, select):
|
|
| 24 |
|
| 25 |
|
| 26 |
class DeltaEditor:
|
| 27 |
-
def __init__(self):
|
| 28 |
-
device =
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
np.set_printoptions(suppress=True)
|
| 31 |
|
| 32 |
self.net = DeltaMapper()
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
self.net = self.net.to(device).eval()
|
| 36 |
|
| 37 |
self.clip_model, self.preprocess = clip.load("ViT-B/32", device=device)
|
|
@@ -43,11 +54,11 @@ class DeltaEditor:
|
|
| 43 |
classnames = [target, neutral]
|
| 44 |
dt = map_tool.GetDt(classnames, self.clip_model)
|
| 45 |
select = GetBoundary(self.fs3, dt, trash)
|
| 46 |
-
dt = torch.Tensor(dt).
|
| 47 |
dt = dt / dt.norm(dim=-1, keepdim=True).float().clamp(min=1e-5)
|
| 48 |
|
| 49 |
img_gen_for_clip = self.avg_pool(orig_image)
|
| 50 |
-
c_latents = self.clip_model.encode_image(img_gen_for_clip.
|
| 51 |
c_latents = c_latents / c_latents.norm(dim=-1, keepdim=True).float()
|
| 52 |
|
| 53 |
delta_c = torch.cat((c_latents, dt.unsqueeze(0)), dim=1)
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
class DeltaEditor:
|
| 27 |
+
def __init__(self, device="cpu"):
|
| 28 |
+
self.device = device
|
| 29 |
+
try:
|
| 30 |
+
self.fs3 = np.load("pretrained_models/fs3.npy")
|
| 31 |
+
except FileNotFoundError:
|
| 32 |
+
# If fs3.npy is not available, create a dummy array
|
| 33 |
+
# This is a fallback for when the file is not downloaded yet
|
| 34 |
+
self.fs3 = np.zeros((512, 512)) # Dummy fs3 array
|
| 35 |
+
print("Warning: fs3.npy not found, using dummy array")
|
| 36 |
+
|
| 37 |
np.set_printoptions(suppress=True)
|
| 38 |
|
| 39 |
self.net = DeltaMapper()
|
| 40 |
+
try:
|
| 41 |
+
net_ckpt = torch.load("pretrained_models/delta_mapper.pt", map_location=device)
|
| 42 |
+
self.net.load_state_dict(net_ckpt)
|
| 43 |
+
except FileNotFoundError:
|
| 44 |
+
print("Warning: delta_mapper.pt not found, using uninitialized network")
|
| 45 |
+
|
| 46 |
self.net = self.net.to(device).eval()
|
| 47 |
|
| 48 |
self.clip_model, self.preprocess = clip.load("ViT-B/32", device=device)
|
|
|
|
| 54 |
classnames = [target, neutral]
|
| 55 |
dt = map_tool.GetDt(classnames, self.clip_model)
|
| 56 |
select = GetBoundary(self.fs3, dt, trash)
|
| 57 |
+
dt = torch.Tensor(dt).to(self.device)
|
| 58 |
dt = dt / dt.norm(dim=-1, keepdim=True).float().clamp(min=1e-5)
|
| 59 |
|
| 60 |
img_gen_for_clip = self.avg_pool(orig_image)
|
| 61 |
+
c_latents = self.clip_model.encode_image(img_gen_for_clip.to(self.device))
|
| 62 |
c_latents = c_latents / c_latents.norm(dim=-1, keepdim=True).float()
|
| 63 |
|
| 64 |
delta_c = torch.cat((c_latents, dt.unsqueeze(0)), dim=1)
|
editings/latent_editor.py
CHANGED
|
@@ -120,7 +120,7 @@ class LatentEditor:
|
|
| 120 |
"fs_makeup": "editings/bound/Heavy_Makeup_boundary.npy"
|
| 121 |
}
|
| 122 |
|
| 123 |
-
self.deltaedit_editor = DeltaEditor()
|
| 124 |
|
| 125 |
elif self.domain == "car":
|
| 126 |
|
|
|
|
| 120 |
"fs_makeup": "editings/bound/Heavy_Makeup_boundary.npy"
|
| 121 |
}
|
| 122 |
|
| 123 |
+
self.deltaedit_editor = DeltaEditor(device=self.device)
|
| 124 |
|
| 125 |
elif self.domain == "car":
|
| 126 |
|