Spaces:
Running
Running
Namhyun Kim
commited on
Commit
·
2a6ccf4
1
Parent(s):
0275ff2
Harden demo data loading (token, LFS, schema)
Browse files
app.py
CHANGED
|
@@ -22,7 +22,19 @@ APP_DIR = Path(__file__).resolve().parent
|
|
| 22 |
DEMO_DATA_PATH = APP_DIR / "demo_data.pt"
|
| 23 |
MOE_DATA_PATH = APP_DIR / "demo_data_moe.pt"
|
| 24 |
HUB_REPO_ID = "wi-lab/lwm-spectro"
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
# Fixed ordering for the 14 joint SNR/Doppler labels
|
| 28 |
JOINT_LABELS = [
|
|
@@ -72,6 +84,62 @@ def _safe_load_tensor(path: Path):
|
|
| 72 |
return torch.load(path, weights_only=False)
|
| 73 |
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
def _create_dummy_dataset(base_path: Path, moe_path: Path) -> None:
|
| 76 |
"""Create a tiny synthetic dataset so the Space can start even if hub download fails."""
|
| 77 |
print(f"[WARN] Creating synthetic demo dataset at {base_path}")
|
|
@@ -109,7 +177,7 @@ def _create_dummy_dataset(base_path: Path, moe_path: Path) -> None:
|
|
| 109 |
|
| 110 |
def _ensure_local_file(local_path: Path, hub_filename: str) -> Optional[Path]:
|
| 111 |
"""Ensure a file exists locally; try Hub download if missing."""
|
| 112 |
-
if local_path.exists():
|
| 113 |
return local_path
|
| 114 |
try:
|
| 115 |
cached = hf_hub_download(
|
|
@@ -145,7 +213,20 @@ def load_data(mapping: Dict[str, object]):
|
|
| 145 |
pair_to_id = mapping["pair_to_id"]
|
| 146 |
|
| 147 |
records = []
|
|
|
|
| 148 |
for i, sample in enumerate(data):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
embedding = sample["embedding"]
|
| 150 |
if isinstance(embedding, torch.Tensor):
|
| 151 |
base_embedding = embedding.detach().cpu().numpy()
|
|
@@ -212,6 +293,8 @@ def load_data(mapping: Dict[str, object]):
|
|
| 212 |
)
|
| 213 |
|
| 214 |
df = pd.DataFrame(records)
|
|
|
|
|
|
|
| 215 |
print(f"[INFO] Loaded {len(df)} samples (MoE embeddings: {has_moe})")
|
| 216 |
return df, has_moe
|
| 217 |
|
|
|
|
| 22 |
DEMO_DATA_PATH = APP_DIR / "demo_data.pt"
|
| 23 |
MOE_DATA_PATH = APP_DIR / "demo_data_moe.pt"
|
| 24 |
HUB_REPO_ID = "wi-lab/lwm-spectro"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _get_hf_token() -> str | None:
|
| 28 |
+
# Spaces / HF Hub tooling uses a few common names.
|
| 29 |
+
return (
|
| 30 |
+
os.getenv("HF_TOKEN")
|
| 31 |
+
or os.getenv("HF_HUB_TOKEN")
|
| 32 |
+
or os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 33 |
+
or os.getenv("HF_API_TOKEN")
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
HF_TOKEN = _get_hf_token()
|
| 38 |
|
| 39 |
# Fixed ordering for the 14 joint SNR/Doppler labels
|
| 40 |
JOINT_LABELS = [
|
|
|
|
| 84 |
return torch.load(path, weights_only=False)
|
| 85 |
|
| 86 |
|
| 87 |
+
def _is_git_lfs_pointer(path: Path) -> bool:
|
| 88 |
+
try:
|
| 89 |
+
with path.open("rb") as handle:
|
| 90 |
+
head = handle.read(256)
|
| 91 |
+
return b"git-lfs.github.com/spec" in head
|
| 92 |
+
except OSError:
|
| 93 |
+
return False
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _normalize_tech_label(value: object) -> object:
|
| 97 |
+
if value is None:
|
| 98 |
+
return value
|
| 99 |
+
text = str(value).strip()
|
| 100 |
+
if not text:
|
| 101 |
+
return value
|
| 102 |
+
normalized = text.lower().replace(" ", "").replace("-", "")
|
| 103 |
+
if normalized in {"wifi", "wi-fi", "wi_fi"}:
|
| 104 |
+
return "WiFi"
|
| 105 |
+
if normalized == "lte":
|
| 106 |
+
return "LTE"
|
| 107 |
+
if normalized in {"5g", "nr", "5gnr", "sub6", "sub6ghz", "5gsub6", "5gsub6ghz"}:
|
| 108 |
+
return "5G"
|
| 109 |
+
return text
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _normalize_mobility_label(value: object) -> object:
|
| 113 |
+
if value is None:
|
| 114 |
+
return value
|
| 115 |
+
text = str(value).strip()
|
| 116 |
+
if not text:
|
| 117 |
+
return value
|
| 118 |
+
normalized = text.lower().replace(" ", "").replace("-", "")
|
| 119 |
+
if normalized in {"ped", "pedestrian", "walking"}:
|
| 120 |
+
return "pedestrian"
|
| 121 |
+
if normalized in {"veh", "vehicular", "vehicle", "driving", "car"}:
|
| 122 |
+
return "vehicular"
|
| 123 |
+
return text
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _normalize_sample(sample: Dict[str, object]) -> Dict[str, object]:
|
| 127 |
+
out = dict(sample)
|
| 128 |
+
# Schema aliases (some artifacts use longer names).
|
| 129 |
+
if "tech" not in out and "technology" in out:
|
| 130 |
+
out["tech"] = out.get("technology")
|
| 131 |
+
if "mod" not in out and "modulation" in out:
|
| 132 |
+
out["mod"] = out.get("modulation")
|
| 133 |
+
if "mob" not in out and "mobility" in out:
|
| 134 |
+
out["mob"] = out.get("mobility")
|
| 135 |
+
if "snr" not in out and "snr_label" in out:
|
| 136 |
+
out["snr"] = out.get("snr_label")
|
| 137 |
+
|
| 138 |
+
out["tech"] = _normalize_tech_label(out.get("tech"))
|
| 139 |
+
out["mob"] = _normalize_mobility_label(out.get("mob"))
|
| 140 |
+
return out
|
| 141 |
+
|
| 142 |
+
|
| 143 |
def _create_dummy_dataset(base_path: Path, moe_path: Path) -> None:
|
| 144 |
"""Create a tiny synthetic dataset so the Space can start even if hub download fails."""
|
| 145 |
print(f"[WARN] Creating synthetic demo dataset at {base_path}")
|
|
|
|
| 177 |
|
| 178 |
def _ensure_local_file(local_path: Path, hub_filename: str) -> Optional[Path]:
|
| 179 |
"""Ensure a file exists locally; try Hub download if missing."""
|
| 180 |
+
if local_path.exists() and not _is_git_lfs_pointer(local_path):
|
| 181 |
return local_path
|
| 182 |
try:
|
| 183 |
cached = hf_hub_download(
|
|
|
|
| 213 |
pair_to_id = mapping["pair_to_id"]
|
| 214 |
|
| 215 |
records = []
|
| 216 |
+
skipped = 0
|
| 217 |
for i, sample in enumerate(data):
|
| 218 |
+
if not isinstance(sample, dict):
|
| 219 |
+
skipped += 1
|
| 220 |
+
continue
|
| 221 |
+
sample = _normalize_sample(sample)
|
| 222 |
+
|
| 223 |
+
if not sample.get("tech") or not sample.get("snr") or not sample.get("mob") or not sample.get("mod"):
|
| 224 |
+
skipped += 1
|
| 225 |
+
continue
|
| 226 |
+
if "embedding" not in sample or "data" not in sample:
|
| 227 |
+
skipped += 1
|
| 228 |
+
continue
|
| 229 |
+
|
| 230 |
embedding = sample["embedding"]
|
| 231 |
if isinstance(embedding, torch.Tensor):
|
| 232 |
base_embedding = embedding.detach().cpu().numpy()
|
|
|
|
| 293 |
)
|
| 294 |
|
| 295 |
df = pd.DataFrame(records)
|
| 296 |
+
if skipped:
|
| 297 |
+
print(f"[WARN] Skipped {skipped} malformed samples while loading demo data")
|
| 298 |
print(f"[INFO] Loaded {len(df)} samples (MoE embeddings: {has_moe})")
|
| 299 |
return df, has_moe
|
| 300 |
|