Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import logging | |
| import uuid | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from typing import Optional | |
| import pandas as pd | |
| logger = logging.getLogger(__name__) | |
| _SCHEMA = { | |
| "id": pd.StringDtype(), | |
| "raw_value": pd.StringDtype(), | |
| "entity_type": pd.StringDtype(), | |
| "canonical_id": pd.StringDtype(), | |
| "source_config": pd.StringDtype(), | |
| "source_field": pd.StringDtype(), | |
| "status": pd.StringDtype(), | |
| "strategy": pd.StringDtype(), | |
| "confidence": "float64", | |
| "notes": pd.StringDtype(), | |
| "created_at": pd.StringDtype(), | |
| "updated_at": pd.StringDtype(), | |
| } | |
| def _empty_df() -> pd.DataFrame: | |
| return pd.DataFrame({col: pd.Series(dtype=dtype) for col, dtype in _SCHEMA.items()}) | |
| class AliasStore: | |
| """Wraps the aliases table. Loaded into memory; writes are in-memory only.""" | |
| def __init__(self, df: pd.DataFrame, read_only: bool = False) -> None: | |
| self._df = df.copy() | |
| self.read_only = read_only | |
| # Per-entity_type caches — built lazily on first access | |
| # Normalized lookup key: (entity_type, source_config or None) | |
| self._normalized_cache: dict[tuple[str, Optional[str]], dict[str, str]] = {} | |
| self._candidates_cache: dict[tuple[str, Optional[str]], list[tuple[str, str]]] = {} | |
| self._lookup_index: dict[tuple[str, str, Optional[str]], str] | None = None | |
| def _ensure_lookup_index(self) -> None: | |
| """Build a dict index for O(1) exact lookups.""" | |
| if self._lookup_index is not None: | |
| return | |
| self._lookup_index = {} | |
| df = self._df[self._df["status"] != "rejected"] | |
| for _, row in df.iterrows(): | |
| # Convert pandas NaN to None so dict.get((..., ..., None)) | |
| # finds rows whose source_config is null. Without this, | |
| # exact_match silently misses every global alias (since | |
| # NaN != None in dict-key comparison) and the resolver | |
| # falls through to normalized_match, which can't | |
| # disambiguate when two canonicals share a normalized form. | |
| sc = row.get("source_config") | |
| if pd.isna(sc): | |
| sc = None | |
| key = (row["raw_value"], row["entity_type"], sc) | |
| self._lookup_index[key] = row["canonical_id"] | |
| def _invalidate_caches(self) -> None: | |
| self._normalized_cache.clear() | |
| self._candidates_cache.clear() | |
| self._lookup_index = None | |
| # ------------------------------------------------------------------ | |
| # Constructors | |
| # ------------------------------------------------------------------ | |
| def from_parquet(cls, path: str | Path, read_only: bool = False) -> "AliasStore": | |
| p = Path(path) / "aliases.parquet" | |
| if not p.exists(): | |
| # Missing dir / missing file is the legitimate "fresh store" | |
| # case (used by tests and first-time seed runs), so log at INFO | |
| # instead of WARNING — but still surface it. | |
| logger.info( | |
| "AliasStore.from_parquet: %s not found; falling back to empty store", | |
| p, | |
| ) | |
| return cls(_empty_df(), read_only=read_only) | |
| try: | |
| df = pd.read_parquet(p) | |
| except (OSError, ValueError) as exc: | |
| # OSError covers permission / IO errors; ValueError is what | |
| # pyarrow raises for corrupt parquet (ArrowInvalid is a | |
| # ValueError subclass). Keep the empty-store fallback so a | |
| # corrupt local cache doesn't hard-crash callers, but log so | |
| # the failure isn't silent. | |
| logger.warning( | |
| "AliasStore.from_parquet: failed to read %s (%s: %s); " | |
| "falling back to empty store", | |
| p, | |
| type(exc).__name__, | |
| exc, | |
| ) | |
| df = _empty_df() | |
| return cls(df, read_only=read_only) | |
| def from_hf(cls, repo_id: str, read_only: bool = False) -> "AliasStore": | |
| from huggingface_hub import hf_hub_download | |
| from huggingface_hub.errors import ( | |
| EntryNotFoundError, | |
| HfHubHTTPError, | |
| RepositoryNotFoundError, | |
| ) | |
| try: | |
| local = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="aliases/part-0.parquet", | |
| repo_type="dataset", | |
| ) | |
| df = pd.read_parquet(local) | |
| except ( | |
| RepositoryNotFoundError, | |
| EntryNotFoundError, | |
| HfHubHTTPError, | |
| FileNotFoundError, | |
| OSError, | |
| ValueError, | |
| ) as exc: | |
| # Specific catches: | |
| # - RepositoryNotFoundError: repo missing or auth failure | |
| # (HF returns 401 disguised as 404 when token is invalid). | |
| # - EntryNotFoundError: repo exists but aliases/part-0.parquet | |
| # hasn't been seeded yet. | |
| # - HfHubHTTPError: catch-all for other HTTP failures | |
| # (network errors, 5xx, rate limits). | |
| # - FileNotFoundError / OSError: filesystem-level errors | |
| # reading the downloaded file. | |
| # - ValueError: pyarrow.lib.ArrowInvalid (parquet corruption) | |
| # subclasses ValueError. | |
| # We keep the fallback-to-empty recovery (callers expect the | |
| # store to construct), but emit a warning so the failure is | |
| # visible — silent fallback was masking auth and corruption | |
| # issues during deploys. | |
| logger.warning( | |
| "AliasStore.from_hf: failed to load aliases from %r (%s: %s); " | |
| "falling back to empty store", | |
| repo_id, | |
| type(exc).__name__, | |
| exc, | |
| ) | |
| df = _empty_df() | |
| return cls(df, read_only=read_only) | |
| # ------------------------------------------------------------------ | |
| # Lookup | |
| # ------------------------------------------------------------------ | |
| def lookup( | |
| self, | |
| raw_value: str, | |
| entity_type: str, | |
| source_config: Optional[str], | |
| ) -> Optional[str]: | |
| """Return canonical_id for first non-rejected match. Config-scoped before global.""" | |
| self._ensure_lookup_index() | |
| # Config-scoped | |
| if source_config: | |
| result = self._lookup_index.get((raw_value, entity_type, source_config)) | |
| if result is not None: | |
| return result | |
| # Global | |
| return self._lookup_index.get((raw_value, entity_type, None)) | |
| # ------------------------------------------------------------------ | |
| # Writes (in-memory only; caller is responsible for persistence) | |
| # ------------------------------------------------------------------ | |
| def add_alias( | |
| self, | |
| raw_value: str, | |
| entity_type: str, | |
| canonical_id: str, | |
| source_config: Optional[str], | |
| source_field: Optional[str], | |
| status: str, | |
| strategy: str, | |
| confidence: float, | |
| ) -> None: | |
| if self.read_only: | |
| raise RuntimeError("AliasStore is read-only") | |
| now = datetime.now(timezone.utc).isoformat() | |
| row = { | |
| "id": str(uuid.uuid4()), | |
| "raw_value": raw_value, | |
| "entity_type": entity_type, | |
| "canonical_id": canonical_id, | |
| "source_config": source_config, | |
| "source_field": source_field, | |
| "status": status, | |
| "strategy": strategy, | |
| "confidence": confidence, | |
| "notes": None, | |
| "created_at": now, | |
| "updated_at": now, | |
| } | |
| self._df = pd.concat([self._df, pd.DataFrame([row])], ignore_index=True) | |
| self._invalidate_caches() | |
| def update_alias( | |
| self, | |
| raw_value: str, | |
| entity_type: str, | |
| source_config: Optional[str], | |
| canonical_id: str, | |
| status: str, | |
| strategy: str, | |
| confidence: float, | |
| ) -> None: | |
| """Upsert: update existing alias row or add new one.""" | |
| if self.read_only: | |
| raise RuntimeError("AliasStore is read-only") | |
| df = self._df | |
| mask = (df["raw_value"] == raw_value) & (df["entity_type"] == entity_type) | |
| if source_config: | |
| mask = mask & (df["source_config"] == source_config) | |
| else: | |
| mask = mask & df["source_config"].isna() | |
| if mask.any(): | |
| now = datetime.now(timezone.utc).isoformat() | |
| self._df.loc[mask, "canonical_id"] = canonical_id | |
| self._df.loc[mask, "status"] = status | |
| self._df.loc[mask, "strategy"] = strategy | |
| self._df.loc[mask, "confidence"] = confidence | |
| self._df.loc[mask, "updated_at"] = now | |
| self._invalidate_caches() | |
| else: | |
| self.add_alias(raw_value, entity_type, canonical_id, source_config, None, status, strategy, confidence) | |
| # ------------------------------------------------------------------ | |
| # Export | |
| # ------------------------------------------------------------------ | |
| def to_dataframe(self) -> pd.DataFrame: | |
| return self._df.copy() | |
| def get_normalized_lookup( | |
| self, entity_type: str, source_config: Optional[str] = None | |
| ) -> dict[str, str]: | |
| """Return {normalized_raw_value: canonical_id} for use by strategies. | |
| When ``source_config`` is given, the returned map merges config-scoped | |
| aliases on top of global (source_config IS NULL) aliases, so scoped | |
| matches win over global for the same normalized form. When | |
| ``source_config`` is None, only global aliases are included — scoped | |
| aliases do NOT leak into unrelated lookups. | |
| """ | |
| key = (entity_type, source_config) | |
| if key in self._normalized_cache: | |
| return self._normalized_cache[key] | |
| from eval_entity_resolver.normalization import normalize | |
| base = self._df[(self._df["entity_type"] == entity_type) & (self._df["status"] != "rejected")] | |
| # Start from global aliases. | |
| global_df = base[base["source_config"].isna()] | |
| result: dict[str, str] = {} | |
| for _, row in global_df.iterrows(): | |
| result[normalize(row["raw_value"])] = row["canonical_id"] | |
| # Overlay scoped aliases for the requested source_config. | |
| if source_config: | |
| scoped_df = base[base["source_config"] == source_config] | |
| for _, row in scoped_df.iterrows(): | |
| result[normalize(row["raw_value"])] = row["canonical_id"] | |
| self._normalized_cache[key] = result | |
| return result | |
| def get_all_for_type( | |
| self, entity_type: str, source_config: Optional[str] = None | |
| ) -> list[tuple[str, str]]: | |
| """Return [(raw_value, canonical_id)] for non-rejected aliases of ``entity_type``. | |
| Filtering matches ``get_normalized_lookup`` — when ``source_config`` is | |
| given, includes global + that config's scoped aliases; otherwise global | |
| only. Cached per (entity_type, source_config). | |
| """ | |
| key = (entity_type, source_config) | |
| if key in self._candidates_cache: | |
| return self._candidates_cache[key] | |
| base = self._df[(self._df["entity_type"] == entity_type) & (self._df["status"] != "rejected")] | |
| if source_config: | |
| mask = base["source_config"].isna() | (base["source_config"] == source_config) | |
| df = base[mask] | |
| else: | |
| df = base[base["source_config"].isna()] | |
| result = list(zip(df["raw_value"].tolist(), df["canonical_id"].tolist())) | |
| self._candidates_cache[key] = result | |
| return result | |