SupraDashboard / src /data /loader.py
Tianyi-Billy-Ma
Deploy: simplify codebase (dead-code removal, behavior-preserving)
1727091
Raw
History Blame Contribute Delete
10.5 kB
"""Load CB[7] host-guest features by joining four private HF datasets.
The dashboard does NOT recompute docking/xtb features live. It reads the GEOM
guest table as the left-side base and joins ligand, pose, and cavity feature
tables on ``inchikey``. Missing feature datasets degrade to empty/NaN columns so
the UI can still list guests; a missing GEOM base is fatal.
HF_TOKEN read token with access to the private datasets
HF_DS_GEOM dataset repo id (default SupraBench/SupraDB-GEOM)
HF_DS_LIGAND dataset repo id (default SupraBench/SupraDB-LigandScore)
HF_DS_POSE dataset repo id (default SupraBench/SupraDB-PoseFeat)
HF_DS_CAVITY dataset repo id (default SupraBench/SupraDB-CavityScore)
Each dataset also supports LOCAL_GEOM / LOCAL_LIGAND / LOCAL_POSE /
LOCAL_CAVITY CSV overrides for offline development.
"""
from __future__ import annotations
import functools
import logging
import os
# pandas is imported lazily inside load_features() so that importing data_loader
# (and by extension app/prompts) works without pandas — e.g. in the smoke-test
# environment or any context that only needs FEATURES / build_prompt logic.
_LOG = logging.getLogger(__name__)
_DATASETS = {
"geom": {
"repo_env": "HF_DS_GEOM",
"file_env": "HF_DS_GEOM_FILE",
"local_env": "LOCAL_GEOM",
"repo": "SupraBench/SupraDB-GEOM",
"file": "guests.csv",
"columns": ["inchikey", "name", "smiles", "logka",
"scaffold_family", "known_scaffold", "tmax_known", "is_novel"],
"required": True,
},
"ligand": {
"repo_env": "HF_DS_LIGAND",
"file_env": "HF_DS_LIGAND_FILE",
"local_env": "LOCAL_LIGAND",
"repo": "SupraBench/SupraDB-LigandScore",
"file": "features.csv",
"columns": [
"inchikey",
"S_charge",
"S_hydrophobic",
"S_rigidity",
"S_desolvation",
"S_packing",
"S_shape",
"S_conformer_diversity",
"S_boltzmann_concentration",
"S_bad",
],
"required": False,
},
"pose": {
"repo_env": "HF_DS_POSE",
"file_env": "HF_DS_POSE_FILE",
"local_env": "LOCAL_POSE",
"repo": "SupraBench/SupraDB-PoseFeat",
"file": "features.csv",
"columns": [
"inchikey",
"DockingScore",
"Pose_Energy",
"Distance_to_Cavity_Center",
"Distance_to_Portal",
"Insertion_Depth",
"Packing_Coefficient",
"Occupancy",
"Hydrophobic_Occupancy",
"Shape_Complementarity",
"Steric_Clash",
"Guest_CB7_Min_Distance",
"Pose_RMSD_to_Template",
"Portal_Compatibility",
"Positive_Center_to_Portal_Distance",
"Positive_Center_Orientation",
"Charge_Accessibility",
"Portal_Facing_Accessibility",
"HBond_Count",
"HBond_Geometry",
"Carbonyl_Oxygen_Contact_Count",
"Hydrophobic_Contact",
"Polar_Contact_Penalty",
"Bad_Group_Portal_Exposure",
"Desolvation_Penalty",
"boltzmann_weight",
"delta_e",
],
"required": False,
},
"cavity": {
"repo_env": "HF_DS_CAVITY",
"file_env": "HF_DS_CAVITY_FILE",
"local_env": "LOCAL_CAVITY",
"repo": "SupraBench/SupraDB-CavityScore",
"file": "features.csv",
"columns": [
"inchikey",
"S_occupancy",
"S_portal",
"S_accessibility",
"S_orientation",
],
"required": False,
},
}
_LOAD_STATUS = {name: False for name in _DATASETS}
# 22 surfaced features, label -> column, in prompt order (mirrors gen_label_studio.py)
FEATURES = [
("Binding energy ΔE_bind", "DockingScore"),
("Packing coefficient", "Packing_Coefficient"),
("Cavity occupancy", "Occupancy"),
("Hydrophobic occupancy", "Hydrophobic_Occupancy"),
("Shape complementarity", "Shape_Complementarity"),
("Insertion depth", "Insertion_Depth"),
("Steric clashes", "Steric_Clash"),
("Positive-center-to-portal distance", "Positive_Center_to_Portal_Distance"),
("H-bond count", "HBond_Count"),
("Carbonyl-oxygen contacts", "Carbonyl_Oxygen_Contact_Count"),
("Charge", "S_charge"),
("Hydrophobicity", "S_hydrophobic"),
("Rigidity", "S_rigidity"),
("Desolvation ease", "S_desolvation"),
("Packing quality", "S_packing"),
("Hydrophobic cavity filling", "S_occupancy"),
("Shape compactness", "S_shape"),
("Preorganization", "S_conformer_diversity"),
("Portal engagement", "S_portal"),
("Positive-center exposure", "S_accessibility"),
("Positive-center orientation score", "S_orientation"),
("Unfavorable-feature penalty", "S_bad"),
]
@functools.lru_cache(maxsize=1)
def load_features():
"""Return the per-guest feature table (pandas DataFrame), indexed by inchikey.
GEOM is the left-side base and provides ``inchikey``, ``guest_name``,
``smiles``, and optional ``logka``. Ligand, pose, and cavity datasets are
optional feature groups. Pose rows are defensively collapsed to the row with
the largest ``boltzmann_weight`` per ``inchikey``. Cached for the process
lifetime.
"""
import pandas as pd # lazy: keeps module importable without pandas installed
for name in _LOAD_STATUS:
_LOAD_STATUS[name] = False
def _empty(columns: list[str]):
return pd.DataFrame(columns=columns)
def _read_dataset(name: str):
cfg = _DATASETS[name]
local = os.environ.get(cfg["local_env"])
try:
if local:
if not os.path.exists(local):
raise FileNotFoundError(local)
df = pd.read_csv(local)
else:
from huggingface_hub import hf_hub_download
repo = os.environ.get(cfg["repo_env"], cfg["repo"])
fname = os.environ.get(cfg["file_env"], cfg["file"])
token = os.environ.get("HF_TOKEN")
path = hf_hub_download(
repo_id=repo,
filename=fname,
repo_type="dataset",
token=token,
)
df = pd.read_csv(path)
except Exception as exc:
_LOAD_STATUS[name] = False
if cfg["required"]:
raise RuntimeError(f"failed to load required GEOM dataset: {exc}") from exc
_LOG.warning("failed to load optional %s dataset; continuing with NaN columns: %s", name, exc)
return _empty(cfg["columns"])
if "inchikey" not in df.columns:
_LOAD_STATUS[name] = False
if cfg["required"]:
raise ValueError("required GEOM dataset is missing 'inchikey'")
_LOG.warning("optional %s dataset is missing 'inchikey'; continuing with NaN columns", name)
return _empty(cfg["columns"])
_LOAD_STATUS[name] = True
return df
geom = _read_dataset("geom")
if "guest_name" not in geom.columns and "name" in geom.columns:
geom = geom.rename(columns={"name": "guest_name"})
if "guest_name" not in geom.columns:
geom["guest_name"] = geom["inchikey"]
for col in ("smiles", "logka"):
if col not in geom.columns:
geom[col] = pd.NA
geom = geom.drop_duplicates(subset="inchikey", keep="first").reset_index(drop=True)
# Novelty annotation (computed offline by engineering/annotate_geom_novelty.py
# and shipped in guests.csv). Missing on older dataset snapshots -> NaN columns
# so the board degrades to a blank Novelty cell instead of crashing.
_novelty = ["scaffold_family", "known_scaffold", "tmax_known", "is_novel"]
for col in _novelty:
if col not in geom.columns:
geom[col] = pd.NA
merged = geom[["inchikey", "guest_name", "smiles", "logka", *_novelty]].copy()
for name in ("ligand", "pose", "cavity"):
cfg = _DATASETS[name]
df = _read_dataset(name)
if name == "pose" and "boltzmann_weight" in df.columns:
df = (
df.sort_values("boltzmann_weight", ascending=False)
.drop_duplicates(subset="inchikey", keep="first")
.reset_index(drop=True)
)
else:
df = df.drop_duplicates(subset="inchikey", keep="first").reset_index(drop=True)
for col in cfg["columns"]:
if col not in df.columns:
df[col] = pd.NA
merged = merged.merge(df[cfg["columns"]], on="inchikey", how="left")
return merged.set_index("inchikey", drop=False)
def load_status() -> dict[str, bool]:
"""Report whether each of the four source datasets loaded successfully."""
if load_features.cache_info().currsize == 0:
load_features()
return dict(_LOAD_STATUS)
def guest_choices() -> list[str]:
"""Dropdown labels: 'guest_name' values (falls back to inchikey)."""
df = load_features()
col = "guest_name" if "guest_name" in df.columns else "inchikey"
return sorted(df[col].dropna().astype(str).unique().tolist())
def get_record(guest_name: str) -> dict:
"""Look up one guest's row by guest_name (or inchikey) -> dict."""
df = load_features()
key = "guest_name" if "guest_name" in df.columns else "inchikey"
hit = df[df[key].astype(str) == str(guest_name)]
if hit.empty and "inchikey" in df.columns:
hit = df[df["inchikey"].astype(str) == str(guest_name)]
if hit.empty:
raise KeyError(f"no feature row for {guest_name!r}")
return hit.iloc[0].to_dict()
# The host is fixed (CB[7]) for this benchmark. PubChem CID 6096207 resolves the
# canonical 2D/3D depiction; the connectivity SMILES is the RDKit fallback.
_HOST = {
"inchikey": "ZDOBFUIMGBWEAB-UHFFFAOYSA-N",
"smiles": (
"C1N2C3C4N(C2=O)CN5C6C7N(C5=O)CN8C9C2N(C8=O)CN5C8C%10N(C5=O)CN5C%11C%12"
"N(C5=O)CN5C%13C%14N(C5=O)CN5C%15C(N1C5=O)N1CN3C(=O)N4CN6C(=O)N7CN9C(=O)"
"N2CN8C(=O)N%10CN%11C(=O)N%12CN%13C(=O)N%14CN%15C1=O"
),
"guest_name": "Cucurbit[7]uril (CB[7])",
}
def host_record() -> dict:
"""Return the fixed host (CB[7]) identifiers for structure rendering."""
return dict(_HOST)