Namhyun Kim commited on
Commit
2a6ccf4
·
1 Parent(s): 0275ff2

Harden demo data loading (token, LFS, schema)

Browse files
Files changed (1) hide show
  1. app.py +85 -2
app.py CHANGED
@@ -22,7 +22,19 @@ APP_DIR = Path(__file__).resolve().parent
22
  DEMO_DATA_PATH = APP_DIR / "demo_data.pt"
23
  MOE_DATA_PATH = APP_DIR / "demo_data_moe.pt"
24
  HUB_REPO_ID = "wi-lab/lwm-spectro"
25
- HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HF_HUB_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # Fixed ordering for the 14 joint SNR/Doppler labels
28
  JOINT_LABELS = [
@@ -72,6 +84,62 @@ def _safe_load_tensor(path: Path):
72
  return torch.load(path, weights_only=False)
73
 
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def _create_dummy_dataset(base_path: Path, moe_path: Path) -> None:
76
  """Create a tiny synthetic dataset so the Space can start even if hub download fails."""
77
  print(f"[WARN] Creating synthetic demo dataset at {base_path}")
@@ -109,7 +177,7 @@ def _create_dummy_dataset(base_path: Path, moe_path: Path) -> None:
109
 
110
  def _ensure_local_file(local_path: Path, hub_filename: str) -> Optional[Path]:
111
  """Ensure a file exists locally; try Hub download if missing."""
112
- if local_path.exists():
113
  return local_path
114
  try:
115
  cached = hf_hub_download(
@@ -145,7 +213,20 @@ def load_data(mapping: Dict[str, object]):
145
  pair_to_id = mapping["pair_to_id"]
146
 
147
  records = []
 
148
  for i, sample in enumerate(data):
 
 
 
 
 
 
 
 
 
 
 
 
149
  embedding = sample["embedding"]
150
  if isinstance(embedding, torch.Tensor):
151
  base_embedding = embedding.detach().cpu().numpy()
@@ -212,6 +293,8 @@ def load_data(mapping: Dict[str, object]):
212
  )
213
 
214
  df = pd.DataFrame(records)
 
 
215
  print(f"[INFO] Loaded {len(df)} samples (MoE embeddings: {has_moe})")
216
  return df, has_moe
217
 
 
22
  DEMO_DATA_PATH = APP_DIR / "demo_data.pt"
23
  MOE_DATA_PATH = APP_DIR / "demo_data_moe.pt"
24
  HUB_REPO_ID = "wi-lab/lwm-spectro"
25
+
26
+
27
+ def _get_hf_token() -> str | None:
28
+ # Spaces / HF Hub tooling uses a few common names.
29
+ return (
30
+ os.getenv("HF_TOKEN")
31
+ or os.getenv("HF_HUB_TOKEN")
32
+ or os.getenv("HUGGINGFACEHUB_API_TOKEN")
33
+ or os.getenv("HF_API_TOKEN")
34
+ )
35
+
36
+
37
+ HF_TOKEN = _get_hf_token()
38
 
39
  # Fixed ordering for the 14 joint SNR/Doppler labels
40
  JOINT_LABELS = [
 
84
  return torch.load(path, weights_only=False)
85
 
86
 
87
+ def _is_git_lfs_pointer(path: Path) -> bool:
88
+ try:
89
+ with path.open("rb") as handle:
90
+ head = handle.read(256)
91
+ return b"git-lfs.github.com/spec" in head
92
+ except OSError:
93
+ return False
94
+
95
+
96
+ def _normalize_tech_label(value: object) -> object:
97
+ if value is None:
98
+ return value
99
+ text = str(value).strip()
100
+ if not text:
101
+ return value
102
+ normalized = text.lower().replace(" ", "").replace("-", "")
103
+ if normalized in {"wifi", "wi-fi", "wi_fi"}:
104
+ return "WiFi"
105
+ if normalized == "lte":
106
+ return "LTE"
107
+ if normalized in {"5g", "nr", "5gnr", "sub6", "sub6ghz", "5gsub6", "5gsub6ghz"}:
108
+ return "5G"
109
+ return text
110
+
111
+
112
+ def _normalize_mobility_label(value: object) -> object:
113
+ if value is None:
114
+ return value
115
+ text = str(value).strip()
116
+ if not text:
117
+ return value
118
+ normalized = text.lower().replace(" ", "").replace("-", "")
119
+ if normalized in {"ped", "pedestrian", "walking"}:
120
+ return "pedestrian"
121
+ if normalized in {"veh", "vehicular", "vehicle", "driving", "car"}:
122
+ return "vehicular"
123
+ return text
124
+
125
+
126
+ def _normalize_sample(sample: Dict[str, object]) -> Dict[str, object]:
127
+ out = dict(sample)
128
+ # Schema aliases (some artifacts use longer names).
129
+ if "tech" not in out and "technology" in out:
130
+ out["tech"] = out.get("technology")
131
+ if "mod" not in out and "modulation" in out:
132
+ out["mod"] = out.get("modulation")
133
+ if "mob" not in out and "mobility" in out:
134
+ out["mob"] = out.get("mobility")
135
+ if "snr" not in out and "snr_label" in out:
136
+ out["snr"] = out.get("snr_label")
137
+
138
+ out["tech"] = _normalize_tech_label(out.get("tech"))
139
+ out["mob"] = _normalize_mobility_label(out.get("mob"))
140
+ return out
141
+
142
+
143
  def _create_dummy_dataset(base_path: Path, moe_path: Path) -> None:
144
  """Create a tiny synthetic dataset so the Space can start even if hub download fails."""
145
  print(f"[WARN] Creating synthetic demo dataset at {base_path}")
 
177
 
178
  def _ensure_local_file(local_path: Path, hub_filename: str) -> Optional[Path]:
179
  """Ensure a file exists locally; try Hub download if missing."""
180
+ if local_path.exists() and not _is_git_lfs_pointer(local_path):
181
  return local_path
182
  try:
183
  cached = hf_hub_download(
 
213
  pair_to_id = mapping["pair_to_id"]
214
 
215
  records = []
216
+ skipped = 0
217
  for i, sample in enumerate(data):
218
+ if not isinstance(sample, dict):
219
+ skipped += 1
220
+ continue
221
+ sample = _normalize_sample(sample)
222
+
223
+ if not sample.get("tech") or not sample.get("snr") or not sample.get("mob") or not sample.get("mod"):
224
+ skipped += 1
225
+ continue
226
+ if "embedding" not in sample or "data" not in sample:
227
+ skipped += 1
228
+ continue
229
+
230
  embedding = sample["embedding"]
231
  if isinstance(embedding, torch.Tensor):
232
  base_embedding = embedding.detach().cpu().numpy()
 
293
  )
294
 
295
  df = pd.DataFrame(records)
296
+ if skipped:
297
+ print(f"[WARN] Skipped {skipped} malformed samples while loading demo data")
298
  print(f"[INFO] Loaded {len(df)} samples (MoE embeddings: {has_moe})")
299
  return df, has_moe
300