wi-lab commited on
Commit
f439c65
·
1 Parent(s): 7589a7e

Update embed_lwm.py

Browse files
Files changed (1) hide show
  1. embed_lwm.py +108 -115
embed_lwm.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import sys
3
  from typing import List, Optional, Tuple
@@ -7,110 +8,63 @@ import torch
7
  def _log(msg: str):
8
  print(msg, flush=True)
9
 
10
- def _candidate_repo_dirs():
11
- return [
12
- os.getenv("LWM_REPO_DIR", "").strip(),
13
- "./LWM-v1.1",
14
- "/home/user/app/LWM-v1.1",
15
- ]
16
-
17
- def _ensure_repo_on_path() -> Optional[str]:
18
- for d in _candidate_repo_dirs():
19
- if d and os.path.isdir(d):
20
- if d not in sys.path:
21
- sys.path.insert(0, d)
22
- return d
23
- return None
24
-
25
- def _ensure_pretrained_model_shim(repo_dir: str) -> None:
26
  """
27
- Some LWM examples import: `from pretrained_model import lwm`
28
- If the repo doesn't ship `pretrained_model.py`, but has `lwm_model.py` with class `LWM`,
29
- we create a tiny shim so imports succeed.
30
  """
31
- shim_path = os.path.join(repo_dir, "pretrained_model.py")
32
- lwm_path = os.path.join(repo_dir, "lwm_model.py")
33
- if os.path.isfile(shim_path):
34
- return
35
- if not os.path.isfile(lwm_path):
36
- return # nothing we can do
37
-
38
- # Create a simple factory around LWM
39
- shim_code = """# Auto-generated shim to satisfy `from pretrained_model import lwm`
40
- import torch
41
- try:
42
- from lwm_model import LWM
43
- except Exception as e:
44
- raise ImportError(f"Shim could not import LWM from lwm_model.py: {e}")
45
-
46
- def lwm():
47
- # Build a default LWM encoder (adjust constructor args if your repo requires them)
48
- return LWM()
49
- """
50
- try:
51
- with open(shim_path, "w", encoding="utf-8") as f:
52
- f.write(shim_code)
53
- _log(f"[INFO] Created shim: {shim_path}")
54
- except Exception as e:
55
- _log(f"[WARN] Could not create pretrained_model shim: {e}")
56
-
57
- def _maybe_load_weights(model, repo_dir: str):
58
- # Try common weight locations
59
  candidates = [
60
- os.path.join(repo_dir, "models", "model.pth"),
61
- os.path.join(repo_dir, "model.pth"),
 
62
  ]
63
- for w in candidates:
64
- if os.path.isfile(w):
65
- try:
66
- sd = torch.load(w, map_location="cpu")
67
- # Sometimes saved as {'model': state_dict}
68
- if isinstance(sd, dict) and "state_dict" in sd:
69
- sd = sd["state_dict"]
70
- elif isinstance(sd, dict) and "model" in sd:
71
- sd = sd["model"]
72
- model.load_state_dict(sd, strict=False)
73
- _log(f"[INFO] Loaded LWM weights from {w}")
74
- return
75
- except Exception as e:
76
- _log(f"[WARN] Failed to load weights from {w}: {e}")
77
- _log("[WARN] No weights file found; using randomly-initialized LWM.")
78
 
79
  def get_lwm_encoder():
80
  """
81
- Try to build an LWM encoder using the cloned repo.
82
- Returns a torch.nn.Module or None.
83
  """
84
- repo_dir = _ensure_repo_on_path()
85
- if not repo_dir:
86
- _log("[WARN] LWM repo not found; set LWM_REPO_DIR or clone to ./LWM-v1.1")
 
 
 
87
  return None
88
 
89
- # If the repo's modules expect `pretrained_model`, make sure it exists
90
- _ensure_pretrained_model_shim(repo_dir)
91
-
92
- # Try the most common entry point used in examples
93
  try:
94
- # Import order: prefer pretrained_model.lwm() if available
95
- import pretrained_model # type: ignore
96
- if hasattr(pretrained_model, "lwm"):
97
- model = pretrained_model.lwm()
98
- else:
99
- # Fallback: try lwm_model directly
100
- import lwm_model # type: ignore
101
- if hasattr(lwm_model, "LWM"):
102
- model = lwm_model.LWM()
103
- elif hasattr(lwm_model, "build_model"):
104
- model = lwm_model.build_model()
105
- else:
106
- raise ImportError("No LWM builder found in lwm_model or pretrained_model")
107
- _maybe_load_weights(model, repo_dir)
108
- model.eval()
109
- return model
110
  except Exception as e:
111
- _log(f"[WARN] Failed to load LWM encoder: {e}")
112
  return None
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  @torch.no_grad()
116
  def build_lwm_embeddings(
@@ -120,51 +74,90 @@ def build_lwm_embeddings(
120
  label_aware: bool
121
  ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
122
  """
123
- Generic embedding builder:
124
- - Flattens each complex channel (concat real/imag),
125
- - Forwards through the model if it accepts a flat vector,
126
- - Pads to a common embedding dim.
127
- If forward fails, falls back to the raw flattened vector.
 
 
 
 
 
 
128
  """
 
 
 
 
 
 
 
 
 
129
  all_feats = []
130
  labels_per_ds = [] if label_aware else None
131
 
132
  try:
133
- device = next(model.parameters()).device
134
- except StopIteration:
 
135
  device = torch.device("cpu")
136
- model = model.to(device).eval()
137
 
138
- for chs, y, _name in datasets:
 
 
 
139
  n = min(int(n_per_dataset), int(chs.shape[0]))
140
  idx = torch.randperm(chs.shape[0])[:n]
141
  sub = chs[idx]
142
  feats_this = []
143
 
144
  for x in sub:
145
- if x.ndim > 2:
146
- x = x.squeeze(0)
147
- vec = x.reshape(-1)
148
- if torch.is_complex(vec):
149
- vec = torch.cat([vec.real, vec.imag], dim=0)
150
- vec = vec.to(torch.float32).unsqueeze(0).to(device) # [1, d]
151
 
 
 
152
  try:
153
- out = model(vec) # adapt here if your model expects another shape
154
- out = out.reshape(1, -1).detach().cpu()
 
 
 
 
155
  except Exception:
156
- # If the model forward signature mismatches, use the raw vector
157
- out = vec.detach().cpu()
158
-
159
- feats_this.append(out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  embs_this = torch.cat(feats_this, dim=0) # [n, d’]
162
  all_feats.append(embs_this)
163
 
164
- if label_aware and y is not None and y.numel() > 0:
165
- labels_per_ds.append(y[idx].clone())
166
 
167
- # Pad to common dim
168
  max_d = max(t.shape[1] for t in all_feats)
169
  padded = []
170
  for t in all_feats:
@@ -173,7 +166,7 @@ def build_lwm_embeddings(
173
  t = torch.cat([t, pad], dim=1)
174
  padded.append(t)
175
 
176
- embs = torch.stack(padded, dim=0) # [D, n, d]
177
  if label_aware:
178
  return embs, labels_per_ds if labels_per_ds is not None else []
179
  return embs, None
 
1
+ # embed_lwm.py
2
  import os
3
  import sys
4
  from typing import List, Optional, Tuple
 
8
  def _log(msg: str):
9
  print(msg, flush=True)
10
 
11
+ def _maybe_add_lwm_repo_to_path():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  """
13
+ Ensure the HF-cloned LWM repo is importable.
14
+ You can override the location with env var LWM_REPO_DIR.
 
15
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  candidates = [
17
+ os.getenv("LWM_REPO_DIR", ""), # user override
18
+ "./LWM-v1.1", # local default
19
+ "/home/user/app/LWM-v1.1", # HF Space default path
20
  ]
21
+ for c in candidates:
22
+ if c and os.path.isdir(c) and c not in sys.path:
23
+ sys.path.insert(0, c)
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def get_lwm_encoder():
26
  """
27
+ Try to load the encoder from pretrained_model.py in the HF repo.
28
+ Returns a torch.nn.Module or None if it can’t be loaded.
29
  """
30
+ _maybe_add_lwm_repo_to_path()
31
+ try:
32
+ # HF repo exports a builder called `lwm`
33
+ from pretrained_model import lwm # type: ignore
34
+ except Exception as e:
35
+ _log(f"[WARN] Failed to import pretrained_model.lwm: {e}")
36
  return None
37
 
 
 
 
 
38
  try:
39
+ model = lwm()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  except Exception as e:
41
+ _log(f"[WARN] pretrained_model.lwm() failed to build model: {e}")
42
  return None
43
 
44
+ # Load weights if present
45
+ weights = None
46
+ for cand in ("models/model.pth", "./LWM-v1.1/models/model.pth"):
47
+ if os.path.isfile(cand):
48
+ weights = cand
49
+ break
50
+
51
+ if weights:
52
+ try:
53
+ sd = torch.load(weights, map_location="cpu")
54
+ try:
55
+ model.load_state_dict(sd)
56
+ except Exception:
57
+ # sometimes saved as {"model": state_dict}
58
+ if isinstance(sd, dict) and "model" in sd:
59
+ model.load_state_dict(sd["model"])
60
+ else:
61
+ raise
62
+ except Exception as e:
63
+ _log(f"[WARN] Could not load weights from {weights}: {e}")
64
+
65
+ model.eval()
66
+ return model
67
+
68
 
69
  @torch.no_grad()
70
  def build_lwm_embeddings(
 
74
  label_aware: bool
75
  ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
76
  """
77
+ Build per-dataset embeddings using the LWM encoder.
78
+
79
+ Strategy:
80
+ 1) If `utils.tokenizer` exists in the repo, try tokenizing each channel sample
81
+ and pass the tokenized tensor to the model.
82
+ 2) If that fails, try feeding a flattened real-valued vector to the model.
83
+ 3) If the forward still fails, fall back to using the flattened vector as the “embedding”.
84
+
85
+ Returns:
86
+ embs: [D, n, d]
87
+ labels_per_ds (optional)
88
  """
89
+ _maybe_add_lwm_repo_to_path()
90
+
91
+ # Try to import tokenizer if present; fall back to identity
92
+ def _identity(x): return x
93
+ try:
94
+ from utils import tokenizer as lwm_tokenizer # type: ignore
95
+ except Exception:
96
+ lwm_tokenizer = _identity # type: ignore
97
+
98
  all_feats = []
99
  labels_per_ds = [] if label_aware else None
100
 
101
  try:
102
+ params = list(model.parameters())
103
+ device = next(p.device for p in params) if params else torch.device("cpu")
104
+ except Exception:
105
  device = torch.device("cpu")
 
106
 
107
+ model = model.to(device)
108
+ model.eval()
109
+
110
+ for chs, labels, _name in datasets:
111
  n = min(int(n_per_dataset), int(chs.shape[0]))
112
  idx = torch.randperm(chs.shape[0])[:n]
113
  sub = chs[idx]
114
  feats_this = []
115
 
116
  for x in sub:
117
+ # Ensure 2D (e.g., [N_ant, SC]) if possible
118
+ x_proc = x
119
+ if x_proc.ndim > 2:
120
+ x_proc = x_proc.squeeze(0)
 
 
121
 
122
+ # First, try tokenizer-based forward
123
+ did_forward = False
124
  try:
125
+ tok = lwm_tokenizer(x_proc) # repo-specific; often returns a tensor
126
+ tok = tok.to(device)
127
+ y = model(tok)
128
+ y = torch.as_tensor(y).reshape(1, -1).detach().cpu()
129
+ feats_this.append(y)
130
+ did_forward = True
131
  except Exception:
132
+ # If tokenizer-based call fails, try flat-vector forward
133
+ pass
134
+
135
+ if not did_forward:
136
+ try:
137
+ # Flatten to real vector
138
+ vec = x_proc.reshape(-1)
139
+ if torch.is_complex(vec):
140
+ vec = torch.cat([vec.real, vec.imag], dim=0)
141
+ vec = vec.to(torch.float32).unsqueeze(0).to(device) # [1, d]
142
+ y2 = model(vec)
143
+ y2 = torch.as_tensor(y2).reshape(1, -1).detach().cpu()
144
+ feats_this.append(y2)
145
+ did_forward = True
146
+ except Exception:
147
+ # Last resort: use the flattened vector as the embedding
148
+ vec = x_proc.reshape(-1)
149
+ if torch.is_complex(vec):
150
+ vec = torch.cat([vec.real, vec.imag], dim=0)
151
+ vec = vec.to(torch.float32).unsqueeze(0).cpu()
152
+ feats_this.append(vec)
153
 
154
  embs_this = torch.cat(feats_this, dim=0) # [n, d’]
155
  all_feats.append(embs_this)
156
 
157
+ if label_aware and labels is not None and labels.numel() > 0:
158
+ labels_per_ds.append(labels[idx].clone())
159
 
160
+ # Pad to common dimension
161
  max_d = max(t.shape[1] for t in all_feats)
162
  padded = []
163
  for t in all_feats:
 
166
  t = torch.cat([t, pad], dim=1)
167
  padded.append(t)
168
 
169
+ embs = torch.stack(padded, dim=0) # [D, n, d]
170
  if label_aware:
171
  return embs, labels_per_ds if labels_per_ds is not None else []
172
  return embs, None