LogicGoInfotechSpaces commited on
Commit
96c63c3
·
1 Parent(s): ad8ad02

Fix DeltaEditor device handling: add device parameter, error handling for missing files, replace .cuda() with .to(device)

Browse files
editings/deltaedit/editor.py CHANGED
@@ -24,14 +24,25 @@ def improved_ds(ds, select):
24
 
25
 
26
  class DeltaEditor:
27
- def __init__(self):
28
- device = "cuda"
29
- self.fs3 = np.load("pretrained_models/fs3.npy")
 
 
 
 
 
 
 
30
  np.set_printoptions(suppress=True)
31
 
32
  self.net = DeltaMapper()
33
- net_ckpt = torch.load("pretrained_models/delta_mapper.pt")
34
- self.net.load_state_dict(net_ckpt)
 
 
 
 
35
  self.net = self.net.to(device).eval()
36
 
37
  self.clip_model, self.preprocess = clip.load("ViT-B/32", device=device)
@@ -43,11 +54,11 @@ class DeltaEditor:
43
  classnames = [target, neutral]
44
  dt = map_tool.GetDt(classnames, self.clip_model)
45
  select = GetBoundary(self.fs3, dt, trash)
46
- dt = torch.Tensor(dt).cuda()
47
  dt = dt / dt.norm(dim=-1, keepdim=True).float().clamp(min=1e-5)
48
 
49
  img_gen_for_clip = self.avg_pool(orig_image)
50
- c_latents = self.clip_model.encode_image(img_gen_for_clip.cuda())
51
  c_latents = c_latents / c_latents.norm(dim=-1, keepdim=True).float()
52
 
53
  delta_c = torch.cat((c_latents, dt.unsqueeze(0)), dim=1)
 
24
 
25
 
26
  class DeltaEditor:
27
+ def __init__(self, device="cpu"):
28
+ self.device = device
29
+ try:
30
+ self.fs3 = np.load("pretrained_models/fs3.npy")
31
+ except FileNotFoundError:
32
+ # If fs3.npy is not available, create a dummy array
33
+ # This is a fallback for when the file is not downloaded yet
34
+ self.fs3 = np.zeros((512, 512)) # Dummy fs3 array
35
+ print("Warning: fs3.npy not found, using dummy array")
36
+
37
  np.set_printoptions(suppress=True)
38
 
39
  self.net = DeltaMapper()
40
+ try:
41
+ net_ckpt = torch.load("pretrained_models/delta_mapper.pt", map_location=device)
42
+ self.net.load_state_dict(net_ckpt)
43
+ except FileNotFoundError:
44
+ print("Warning: delta_mapper.pt not found, using uninitialized network")
45
+
46
  self.net = self.net.to(device).eval()
47
 
48
  self.clip_model, self.preprocess = clip.load("ViT-B/32", device=device)
 
54
  classnames = [target, neutral]
55
  dt = map_tool.GetDt(classnames, self.clip_model)
56
  select = GetBoundary(self.fs3, dt, trash)
57
+ dt = torch.Tensor(dt).to(self.device)
58
  dt = dt / dt.norm(dim=-1, keepdim=True).float().clamp(min=1e-5)
59
 
60
  img_gen_for_clip = self.avg_pool(orig_image)
61
+ c_latents = self.clip_model.encode_image(img_gen_for_clip.to(self.device))
62
  c_latents = c_latents / c_latents.norm(dim=-1, keepdim=True).float()
63
 
64
  delta_c = torch.cat((c_latents, dt.unsqueeze(0)), dim=1)
editings/latent_editor.py CHANGED
@@ -120,7 +120,7 @@ class LatentEditor:
120
  "fs_makeup": "editings/bound/Heavy_Makeup_boundary.npy"
121
  }
122
 
123
- self.deltaedit_editor = DeltaEditor()
124
 
125
  elif self.domain == "car":
126
 
 
120
  "fs_makeup": "editings/bound/Heavy_Makeup_boundary.npy"
121
  }
122
 
123
+ self.deltaedit_editor = DeltaEditor(device=self.device)
124
 
125
  elif self.domain == "car":
126