Petimot / app /utils /data_loader.py
Valmbd's picture
Remove torch (~800MB) from build — use pickle for .pt files — build time: 10min→2min
c8deba6 verified
"""Data loading utilities for pre-computed PETIMOT predictions."""
import os, json, glob, zipfile, io, pickle
import numpy as np
import pandas as pd
from pathlib import Path
from functools import lru_cache
import logging
logger = logging.getLogger(__name__)
# ── Root path (importable by pages) ──
PETIMOT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# ── Cache the zip namelist for fast lookups ──
_zip_namelist_cache = {}
def _get_zip_namelist(zip_path: str) -> list[str]:
"""Cache the zip namelist to avoid reopening the zip for every call."""
if zip_path not in _zip_namelist_cache:
try:
with zipfile.ZipFile(zip_path, 'r') as zf:
_zip_namelist_cache[zip_path] = zf.namelist()
except Exception as e:
logger.warning(f"Failed to read zip {zip_path}: {e}")
_zip_namelist_cache[zip_path] = []
return _zip_namelist_cache[zip_path]
def get_predictions_zip(root: str) -> str | None:
"""Find valid predictions.zip (not an LFS pointer) in the root directory.
If not found locally or is an LFS pointer, try to download from HF.
"""
zip_path = os.path.join(root, "predictions.zip")
# Check if it exists and is a real file (not LFS pointer ~134 bytes)
if os.path.exists(zip_path) and os.path.getsize(zip_path) > 10000:
return zip_path
# Try auto-downloading from HuggingFace
try:
from app.utils.download import ensure_predictions_zip
result = ensure_predictions_zip(root)
if result and os.path.exists(result) and os.path.getsize(result) > 10000:
logger.info(f"Auto-downloaded predictions.zip: {os.path.getsize(result)} bytes")
return result
except Exception as e:
logger.warning(f"Auto-download failed: {e}")
return None
_gt_extracted_flag: dict = {}
def ensure_ground_truth(root: str) -> str:
"""Extract ground_truth.zip to root/ground_truth/ on first call (idempotent).
Returns the path to the ground_truth directory.
Works both locally and on HuggingFace Space.
"""
gt_dir = os.path.join(root, "ground_truth")
if gt_dir in _gt_extracted_flag:
return gt_dir
# Already extracted?
if os.path.isdir(gt_dir) and len(os.listdir(gt_dir)) > 100:
_gt_extracted_flag[gt_dir] = True
return gt_dir
# Try extracting from ground_truth.zip
zip_path = os.path.join(root, "ground_truth.zip")
# If not local, try downloading from HF Dataset (for HF Space deployment)
if not (os.path.exists(zip_path) and os.path.getsize(zip_path) > 10000):
try:
from huggingface_hub import hf_hub_download
logger.info("Downloading ground_truth.zip from HF Dataset Valmbd/petimot-ground-truth ...")
zip_path = hf_hub_download(
repo_id="Valmbd/petimot-ground-truth",
filename="ground_truth.zip",
repo_type="dataset",
local_dir=root,
)
logger.info(f"Downloaded ground_truth.zip: {os.path.getsize(zip_path)//1e6:.0f} MB")
except Exception as e:
logger.warning(f"Could not download ground_truth from dataset: {e}")
_gt_extracted_flag[gt_dir] = True
return gt_dir
logger.info(f"Extracting ground_truth.zip ({os.path.getsize(zip_path)//1e6:.0f} MB)...")
os.makedirs(gt_dir, exist_ok=True)
try:
with zipfile.ZipFile(zip_path, 'r') as zf:
# Extract everything, strip top-level 'ground_truth/' prefix
for member in zf.infolist():
name = member.filename
# Strip leading 'ground_truth/' if present
stripped = name[len('ground_truth/'):] if name.startswith('ground_truth/') else name
if not stripped or stripped.endswith('/'):
continue
dest = os.path.join(gt_dir, stripped)
os.makedirs(os.path.dirname(dest), exist_ok=True)
with zf.open(member) as src, open(dest, 'wb') as dst:
dst.write(src.read())
logger.info(f"Ground truth extracted: {len(os.listdir(gt_dir))} files")
_gt_extracted_flag[gt_dir] = True
except Exception as e:
logger.warning(f"Failed to extract ground_truth.zip: {e}")
else:
_gt_extracted_flag[gt_dir] = True # mark as tried
return gt_dir
def find_predictions_dir(root: str) -> str | None:
"""Find the predictions directory (most recent model) or zip.
Returns root if predictions.zip exists, or the latest predictions subdir.
"""
if get_predictions_zip(root):
return root
pred_root = os.path.join(root, "predictions")
if not os.path.isdir(pred_root):
return None
subdirs = [os.path.join(pred_root, d) for d in os.listdir(pred_root)
if os.path.isdir(os.path.join(pred_root, d))]
if not subdirs:
return None
return max(subdirs, key=os.path.getmtime)
@lru_cache(maxsize=1)
def load_prediction_index(pred_dir: str) -> pd.DataFrame:
"""Build index of all predicted proteins with metadata."""
rows = []
# ── Try reading from predictions.zip ──
zip_path = get_predictions_zip(pred_dir)
if zip_path:
try:
with zipfile.ZipFile(zip_path, 'r') as zf:
# Look for index.json inside the zip
idx_file = next((f for f in zf.namelist() if f.endswith("index.json")), None)
index_dict = {}
if idx_file:
with zf.open(idx_file) as f:
index_dict = json.load(f)
if index_dict:
# Load external disp_profiles if available
_prof_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"data", "disp_profiles.json"
)
_profiles: dict = {}
try:
if os.path.exists(_prof_path):
with open(_prof_path) as _pf:
_profiles = json.load(_pf)
except Exception:
pass
# Use pre-built index
for k, v in index_dict.items():
rows.append({
"name": k,
"seq_len": v.get("seq_len", 0),
"n_modes": v.get("n_modes", 0),
"mean_disp_m0": v.get("mean_disp", 0.0),
"max_disp_m0": v.get("max_disp", 0.0),
"top_residue": v.get("top_residue", -1),
"disp_profile": _profiles.get(k, v.get("disp_profile", [])),
})
else:
# No index.json or empty β€” scan zip for _mode_0.txt files
logger.info("index.json missing or empty β€” scanning zip for mode files...")
mode0_files = [f for f in zf.namelist() if f.endswith("_mode_0.txt")]
for mf in mode0_files:
base = os.path.basename(mf).replace("_mode_0.txt", "")
try:
with zf.open(mf) as f:
vecs = np.loadtxt(f)
mag = np.linalg.norm(vecs, axis=1)
rows.append({
"name": base,
"seq_len": len(vecs),
"n_modes": 4, # assume default
"mean_disp_m0": float(mag.mean()),
"max_disp_m0": float(mag.max()),
"top_residue": int(np.argmax(mag)) + 1,
"disp_profile": mag[::max(1, len(mag)//20)].tolist(),
})
except Exception:
continue
except Exception as e:
logger.warning(f"Failed to load predictions from zip: {e}")
if rows:
return pd.DataFrame(rows).sort_values("name").reset_index(drop=True)
# ── Fallback to loose files on disk ──
if os.path.isdir(pred_dir):
mode_files = glob.glob(os.path.join(pred_dir, "*_mode_0.txt"))
for mf in mode_files:
base = os.path.basename(mf).replace("_mode_0.txt", "")
try:
vecs = np.loadtxt(mf)
n_res = len(vecs)
mag = np.linalg.norm(vecs, axis=1)
n_modes = sum(1 for k in range(10)
if os.path.exists(os.path.join(pred_dir, f"{base}_mode_{k}.txt")))
rows.append({
"name": base,
"seq_len": n_res,
"n_modes": n_modes,
"mean_disp_m0": float(mag.mean()),
"max_disp_m0": float(mag.max()),
"top_residue": int(np.argmax(mag)) + 1,
"disp_profile": mag[::max(1, len(mag)//20)].tolist(),
})
except Exception:
continue
if not rows:
return pd.DataFrame(columns=["name", "seq_len", "n_modes", "mean_disp_m0", "max_disp_m0", "top_residue", "disp_profile"])
return pd.DataFrame(rows).sort_values("name").reset_index(drop=True)
def load_modes(pred_dir: str, name: str) -> dict[int, np.ndarray]:
"""Load all mode files for a protein."""
modes = {}
# ── Try from zip ──
zip_path = get_predictions_zip(pred_dir)
if zip_path:
namelist = _get_zip_namelist(zip_path)
try:
with zipfile.ZipFile(zip_path, 'r') as zf:
for k in range(10):
found = False
for pfx in [f"extracted_{name}", name]:
suffix = f"{pfx}_mode_{k}.txt"
matched = next((f for f in namelist if f.endswith(f"/{suffix}") or f == suffix), None)
if matched:
with zf.open(matched) as f:
modes[k] = np.loadtxt(f)
found = True
break
if not found and k > 0:
break # No more modes
except Exception as e:
logger.warning(f"Failed to load modes from zip for {name}: {e}")
if modes:
return modes
# ── Fallback for loose files ──
for k in range(10):
found = False
for pfx in [f"extracted_{name}", name]:
mf = os.path.join(pred_dir, f"{pfx}_mode_{k}.txt")
if os.path.exists(mf):
modes[k] = np.loadtxt(mf)
found = True
break
if not found and k > 0:
break
return modes
def load_embeddings(pred_dir: str, name: str) -> np.ndarray | None:
"""Load node embeddings if they exist."""
zip_path = get_predictions_zip(pred_dir)
if zip_path:
namelist = _get_zip_namelist(zip_path)
try:
with zipfile.ZipFile(zip_path, 'r') as zf:
for pfx in [f"extracted_{name}", name]:
suffix = f"{pfx}_embeddings.npy"
matched = next((f for f in namelist if f.endswith(f"/{suffix}") or f == suffix), None)
if matched:
with zf.open(matched) as f:
from io import BytesIO
return np.load(BytesIO(f.read()))
except Exception as e:
logger.warning(f"Failed to load embeddings from zip for {name}: {e}")
# Fallback to loose files
for pfx in [f"extracted_{name}", name]:
emb_path = os.path.join(pred_dir, f"{pfx}_embeddings.npy")
if os.path.exists(emb_path):
try:
return np.load(emb_path)
except Exception:
pass
return None
def load_ground_truth(gt_dir: str, name: str) -> dict | None:
"""Load ground truth data for a protein.
Automatically extracts ground_truth.zip if the directory doesn't exist yet.
"""
# Auto-extract ground_truth.zip if needed (idempotent)
root = os.path.dirname(gt_dir)
resolved_gt_dir = ensure_ground_truth(root)
if not os.path.isdir(resolved_gt_dir):
return None
# Search in directory and one level of subdirectories
for search_dir in [resolved_gt_dir] + [
os.path.join(resolved_gt_dir, d)
for d in os.listdir(resolved_gt_dir)
if os.path.isdir(os.path.join(resolved_gt_dir, d))
]:
path = os.path.join(search_dir, f"{name}.pt")
if os.path.exists(path):
try:
# Load .pt file without torch β€” use pickle directly
with open(path, "rb") as f:
data = pickle.load(f)
# Convert any torch tensors to numpy if torch is available
result = {}
for k, v in data.items():
try:
import torch as _torch
if isinstance(v, _torch.Tensor):
result[k] = v.numpy()
else:
result[k] = v
except Exception:
result[k] = v
return result
except Exception as e:
logger.warning(f"Failed to load {path}: {e}")
return None
return None
def load_pdb_text(pdb_path: str) -> str | None:
"""Load PDB file as text."""
if not os.path.exists(pdb_path):
return None
with open(pdb_path) as f:
return f.read()