wi-lab commited on
Commit
e00eaca
·
1 Parent(s): f439c65

Update embed_lwm.py

Browse files
Files changed (1) hide show
  1. embed_lwm.py +135 -119
embed_lwm.py CHANGED
@@ -1,69 +1,95 @@
1
  # embed_lwm.py
2
  import os
3
  import sys
4
- from typing import List, Optional, Tuple
5
 
6
  import torch
 
7
 
8
- def _log(msg: str):
9
- print(msg, flush=True)
10
 
11
- def _maybe_add_lwm_repo_to_path():
 
 
 
 
 
 
12
  """
13
- Ensure the HF-cloned LWM repo is importable.
14
- You can override the location with env var LWM_REPO_DIR.
 
 
15
  """
16
- candidates = [
17
- os.getenv("LWM_REPO_DIR", ""), # user override
18
- "./LWM-v1.1", # local default
19
- "/home/user/app/LWM-v1.1", # HF Space default path
20
- ]
21
- for c in candidates:
22
- if c and os.path.isdir(c) and c not in sys.path:
23
- sys.path.insert(0, c)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def get_lwm_encoder():
26
  """
27
- Try to load the encoder from pretrained_model.py in the HF repo.
28
- Returns a torch.nn.Module or None if it can’t be loaded.
29
  """
30
- _maybe_add_lwm_repo_to_path()
31
- try:
32
- # HF repo exports a builder called `lwm`
33
- from pretrained_model import lwm # type: ignore
34
- except Exception as e:
35
- _log(f"[WARN] Failed to import pretrained_model.lwm: {e}")
36
- return None
37
-
38
  try:
 
 
 
 
 
 
 
 
 
39
  model = lwm()
40
- except Exception as e:
41
- _log(f"[WARN] pretrained_model.lwm() failed to build model: {e}")
42
- return None
43
 
44
- # Load weights if present
45
- weights = None
46
- for cand in ("models/model.pth", "./LWM-v1.1/models/model.pth"):
47
- if os.path.isfile(cand):
48
- weights = cand
49
- break
50
 
51
- if weights:
52
- try:
53
- sd = torch.load(weights, map_location="cpu")
54
- try:
55
- model.load_state_dict(sd)
56
- except Exception:
57
- # sometimes saved as {"model": state_dict}
58
- if isinstance(sd, dict) and "model" in sd:
59
- model.load_state_dict(sd["model"])
60
- else:
61
- raise
62
- except Exception as e:
63
- _log(f"[WARN] Could not load weights from {weights}: {e}")
64
-
65
- model.eval()
66
- return model
67
 
68
 
69
  @torch.no_grad()
@@ -72,101 +98,91 @@ def build_lwm_embeddings(
72
  datasets: List[Tuple[torch.Tensor, Optional[torch.Tensor], str]],
73
  n_per_dataset: int,
74
  label_aware: bool
75
- ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
76
  """
77
- Build per-dataset embeddings using the LWM encoder.
78
-
79
  Strategy:
80
- 1) If `utils.tokenizer` exists in the repo, try tokenizing each channel sample
81
- and pass the tokenized tensor to the model.
82
- 2) If that fails, try feeding a flattened real-valued vector to the model.
83
- 3) If the forward still fails, fall back to using the flattened vector as the “embedding”.
84
 
85
  Returns:
86
  embs: [D, n, d]
87
- labels_per_ds (optional)
88
  """
89
- _maybe_add_lwm_repo_to_path()
90
-
91
- # Try to import tokenizer if present; fall back to identity
92
- def _identity(x): return x
93
  try:
94
  from utils import tokenizer as lwm_tokenizer # type: ignore
95
  except Exception:
96
- lwm_tokenizer = _identity # type: ignore
97
 
98
- all_feats = []
99
- labels_per_ds = [] if label_aware else None
100
 
101
- try:
102
- params = list(model.parameters())
103
- device = next(p.device for p in params) if params else torch.device("cpu")
104
- except Exception:
105
- device = torch.device("cpu")
106
 
107
- model = model.to(device)
108
- model.eval()
 
 
 
109
 
110
- for chs, labels, _name in datasets:
111
- n = min(int(n_per_dataset), int(chs.shape[0]))
112
- idx = torch.randperm(chs.shape[0])[:n]
113
- sub = chs[idx]
114
- feats_this = []
115
 
116
- for x in sub:
117
- # Ensure 2D (e.g., [N_ant, SC]) if possible
118
- x_proc = x
119
- if x_proc.ndim > 2:
120
- x_proc = x_proc.squeeze(0)
 
 
 
 
 
 
121
 
122
- # First, try tokenizer-based forward
123
- did_forward = False
124
  try:
125
- tok = lwm_tokenizer(x_proc) # repo-specific; often returns a tensor
126
- tok = tok.to(device)
127
- y = model(tok)
128
- y = torch.as_tensor(y).reshape(1, -1).detach().cpu()
129
- feats_this.append(y)
130
- did_forward = True
 
 
131
  except Exception:
132
- # If tokenizer-based call fails, try flat-vector forward
133
  pass
134
 
135
- if not did_forward:
136
- try:
137
- # Flatten to real vector
138
- vec = x_proc.reshape(-1)
139
- if torch.is_complex(vec):
140
- vec = torch.cat([vec.real, vec.imag], dim=0)
141
- vec = vec.to(torch.float32).unsqueeze(0).to(device) # [1, d]
142
- y2 = model(vec)
143
- y2 = torch.as_tensor(y2).reshape(1, -1).detach().cpu()
144
- feats_this.append(y2)
145
- did_forward = True
146
- except Exception:
147
- # Last resort: use the flattened vector as the embedding
148
- vec = x_proc.reshape(-1)
149
- if torch.is_complex(vec):
150
- vec = torch.cat([vec.real, vec.imag], dim=0)
151
- vec = vec.to(torch.float32).unsqueeze(0).cpu()
152
- feats_this.append(vec)
153
 
154
- embs_this = torch.cat(feats_this, dim=0) # [n, d]
155
- all_feats.append(embs_this)
156
 
157
- if label_aware and labels is not None and labels.numel() > 0:
158
- labels_per_ds.append(labels[idx].clone())
 
 
 
159
 
160
- # Pad to common dimension
161
- max_d = max(t.shape[1] for t in all_feats)
162
  padded = []
163
- for t in all_feats:
164
  if t.shape[1] < max_d:
165
  pad = torch.zeros((t.shape[0], max_d - t.shape[1]), dtype=t.dtype)
166
  t = torch.cat([t, pad], dim=1)
167
  padded.append(t)
168
 
169
- embs = torch.stack(padded, dim=0) # [D, n, d]
170
- if label_aware:
171
- return embs, labels_per_ds if labels_per_ds is not None else []
172
- return embs, None
 
1
  # embed_lwm.py
2
  import os
3
  import sys
4
+ from typing import List, Tuple, Optional
5
 
6
  import torch
7
+ from huggingface_hub import snapshot_download
8
 
9
+ _LWM_MODEL = None
10
+ _LWM_DIR = None
11
 
12
+
13
+ def _add_repo_to_path(path: str):
14
+ if path and os.path.isdir(path) and path not in sys.path:
15
+ sys.path.insert(0, path)
16
+
17
+
18
+ def _load_state_dict_flex(model: torch.nn.Module, state):
19
  """
20
+ Load a variety of saved formats into `model`:
21
+ - plain state_dict
22
+ - {"model": state_dict}
23
+ - with or without "module." prefixes
24
  """
25
+ def _try(sd, strict=False):
26
+ try:
27
+ model.load_state_dict(sd, strict=strict)
28
+ return True
29
+ except Exception:
30
+ return False
31
+
32
+ # direct state dict?
33
+ if isinstance(state, dict) and all(isinstance(k, str) for k in state.keys()) and any(
34
+ torch.is_tensor(v) for v in state.values()
35
+ ):
36
+ sd = state
37
+ elif isinstance(state, dict) and "model" in state and isinstance(state["model"], dict):
38
+ sd = state["model"]
39
+ else:
40
+ raise ValueError("Unrecognized checkpoint format.")
41
+
42
+ # Try as-is
43
+ if _try(sd, strict=False):
44
+ return
45
+
46
+ # Try to add "module." prefix
47
+ if not any(k.startswith("module.") for k in sd.keys()):
48
+ sd_mod = {f"module.{k}": v for k, v in sd.items()}
49
+ if _try(sd_mod, strict=False):
50
+ return
51
+
52
+ # Try to strip "module." prefix
53
+ sd_strip = {k.replace("module.", "", 1): v for k, v in sd.items()}
54
+ if _try(sd_strip, strict=False):
55
+ return
56
+
57
+ # last resort strict=False on original again
58
+ model.load_state_dict(sd, strict=False)
59
+
60
 
61
  def get_lwm_encoder():
62
  """
63
+ Download & load wi-lab/lwm-v1.1 and create the encoder from lwm_model.py.
64
+ Returns a torch.nn.Module or None on failure.
65
  """
66
+ global _LWM_MODEL, _LWM_DIR
67
+ if _LWM_MODEL is not None:
68
+ return _LWM_MODEL
 
 
 
 
 
69
  try:
70
+ _LWM_DIR = snapshot_download(
71
+ repo_id="wi-lab/lwm-v1.1",
72
+ local_dir="./LWM-v1.1",
73
+ local_dir_use_symlinks=False,
74
+ )
75
+ _add_repo_to_path(_LWM_DIR)
76
+
77
+ # Import builder from the HF repo (it's named lwm_model.py)
78
+ from lwm_model import lwm # type: ignore
79
  model = lwm()
 
 
 
80
 
81
+ # Load checkpoint from models/model.pth
82
+ ckpt_path = os.path.join(_LWM_DIR, "models", "model.pth")
83
+ if os.path.isfile(ckpt_path):
84
+ state = torch.load(ckpt_path, map_location="cpu")
85
+ _load_state_dict_flex(model, state)
 
86
 
87
+ model.eval()
88
+ _LWM_MODEL = model
89
+ return _LWM_MODEL
90
+ except Exception as e:
91
+ print(f"[WARN] Failed to load LWM encoder: {e}", flush=True)
92
+ return None
 
 
 
 
 
 
 
 
 
 
93
 
94
 
95
  @torch.no_grad()
 
98
  datasets: List[Tuple[torch.Tensor, Optional[torch.Tensor], str]],
99
  n_per_dataset: int,
100
  label_aware: bool
101
+ ):
102
  """
103
+ Build embeddings with the LWM encoder.
 
104
  Strategy:
105
+ 1) Try repo's tokenizer if available (utils.tokenizer), feed to model.
106
+ 2) Else try feeding flattened real vectors to the model.
107
+ 3) If forward fails, fall back to using flattened vectors as embeddings.
 
108
 
109
  Returns:
110
  embs: [D, n, d]
111
+ labels_per_ds: Optional[List[Tensor]]
112
  """
113
+ # Try optional tokenizer
 
 
 
114
  try:
115
  from utils import tokenizer as lwm_tokenizer # type: ignore
116
  except Exception:
117
+ lwm_tokenizer = None
118
 
119
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
120
+ model = model.to(device).eval()
121
 
122
+ all_embs = []
123
+ labels_per_ds = [] if label_aware else None
 
 
 
124
 
125
+ for ch, y, _name in datasets:
126
+ N = int(ch.shape[0])
127
+ n = min(int(n_per_dataset), N)
128
+ idx = torch.randperm(N)[:n]
129
+ Xi = ch[idx]
130
 
131
+ feats = []
132
+ for x in Xi:
133
+ x2 = x
134
+ if x2.ndim > 2:
135
+ x2 = x2.squeeze(0)
136
 
137
+ # 1) tokenizer path
138
+ if lwm_tokenizer is not None:
139
+ try:
140
+ tok = lwm_tokenizer(x2)
141
+ tok = tok.to(device)
142
+ out = model(tok)
143
+ out = torch.as_tensor(out).reshape(1, -1).detach().cpu()
144
+ feats.append(out)
145
+ continue
146
+ except Exception:
147
+ pass
148
 
149
+ # 2) flattened forward path
 
150
  try:
151
+ vec = x2.reshape(-1)
152
+ if torch.is_complex(vec):
153
+ vec = torch.cat([vec.real, vec.imag], dim=0)
154
+ vec = vec.to(torch.float32).unsqueeze(0).to(device)
155
+ out = model(vec)
156
+ out = torch.as_tensor(out).reshape(1, -1).detach().cpu()
157
+ feats.append(out)
158
+ continue
159
  except Exception:
 
160
  pass
161
 
162
+ # 3) fallback: use flattened vector directly
163
+ vec = x2.reshape(-1)
164
+ if torch.is_complex(vec):
165
+ vec = torch.cat([vec.real, vec.imag], dim=0)
166
+ vec = vec.to(torch.float32).unsqueeze(0).cpu()
167
+ feats.append(vec)
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
+ Zi = torch.cat(feats, dim=0) # [n, d]
170
+ all_embs.append(Zi)
171
 
172
+ if label_aware:
173
+ if y is not None and len(y) >= n:
174
+ labels_per_ds.append(y[idx].clone())
175
+ else:
176
+ labels_per_ds.append(torch.empty((0,), dtype=torch.long))
177
 
178
+ # Pad to common dim
179
+ max_d = max(t.shape[1] for t in all_embs)
180
  padded = []
181
+ for t in all_embs:
182
  if t.shape[1] < max_d:
183
  pad = torch.zeros((t.shape[0], max_d - t.shape[1]), dtype=t.dtype)
184
  t = torch.cat([t, pad], dim=1)
185
  padded.append(t)
186
 
187
+ embs = torch.stack(padded, dim=0) # [D, n, d]
188
+ return embs, labels_per_ds