File size: 14,048 Bytes
b47954d
c8deba6
b47954d
 
 
 
539c642
 
 
 
73bc3db
 
 
539c642
 
 
 
 
 
 
 
 
 
7bc226a
539c642
 
b47954d
 
c8ad6f1
f431fcf
 
 
 
c8ad6f1
f431fcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8ad6f1
 
7f62814
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca70a4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f62814
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b47954d
7bc226a
 
 
 
c8ad6f1
539c642
b47954d
 
 
 
 
 
 
 
 
 
 
 
 
 
539c642
 
c8ad6f1
 
b47954d
539c642
 
7bc226a
c939196
539c642
 
 
c939196
 
039fd7e
 
 
 
 
 
 
 
 
 
 
 
 
c939196
 
 
 
 
 
 
 
039fd7e
 
c939196
039fd7e
b47954d
c939196
 
7bc226a
 
539c642
 
 
 
 
 
 
 
7bc226a
539c642
 
 
53af87a
539c642
 
 
 
7bc226a
539c642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53af87a
539c642
 
 
 
 
53af87a
b47954d
 
 
 
 
 
539c642
 
c8ad6f1
 
539c642
 
 
 
 
 
 
 
 
 
 
 
 
 
7bc226a
539c642
 
 
c8ad6f1
 
 
539c642
b47954d
539c642
b47954d
 
 
 
539c642
b47954d
539c642
 
b47954d
 
 
118eb7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b47954d
7f62814
 
 
 
 
 
 
 
118eb7a
7f62814
 
 
 
 
 
539c642
 
 
c8deba6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539c642
 
b47954d
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
"""Data loading utilities for pre-computed PETIMOT predictions."""
import os, json, glob, zipfile, io, pickle
import numpy as np
import pandas as pd
from pathlib import Path
from functools import lru_cache
import logging

logger = logging.getLogger(__name__)

# ── Root path (importable by pages) ──
PETIMOT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# ── Cache the zip namelist for fast lookups ──
_zip_namelist_cache = {}

def _get_zip_namelist(zip_path: str) -> list[str]:
    """Cache the zip namelist to avoid reopening the zip for every call."""
    if zip_path not in _zip_namelist_cache:
        try:
            with zipfile.ZipFile(zip_path, 'r') as zf:
                _zip_namelist_cache[zip_path] = zf.namelist()
        except Exception as e:
            logger.warning(f"Failed to read zip {zip_path}: {e}")
            _zip_namelist_cache[zip_path] = []
    return _zip_namelist_cache[zip_path]


def get_predictions_zip(root: str) -> str | None:
    """Find valid predictions.zip (not an LFS pointer) in the root directory.
    
    If not found locally or is an LFS pointer, try to download from HF.
    """
    zip_path = os.path.join(root, "predictions.zip")
    
    # Check if it exists and is a real file (not LFS pointer ~134 bytes)
    if os.path.exists(zip_path) and os.path.getsize(zip_path) > 10000:
        return zip_path
    
    # Try auto-downloading from HuggingFace
    try:
        from app.utils.download import ensure_predictions_zip
        result = ensure_predictions_zip(root)
        if result and os.path.exists(result) and os.path.getsize(result) > 10000:
            logger.info(f"Auto-downloaded predictions.zip: {os.path.getsize(result)} bytes")
            return result
    except Exception as e:
        logger.warning(f"Auto-download failed: {e}")
    
    return None


_gt_extracted_flag: dict = {}

def ensure_ground_truth(root: str) -> str:
    """Extract ground_truth.zip to root/ground_truth/ on first call (idempotent).
    
    Returns the path to the ground_truth directory.
    Works both locally and on HuggingFace Space.
    """
    gt_dir = os.path.join(root, "ground_truth")
    if gt_dir in _gt_extracted_flag:
        return gt_dir
    
    # Already extracted?
    if os.path.isdir(gt_dir) and len(os.listdir(gt_dir)) > 100:
        _gt_extracted_flag[gt_dir] = True
        return gt_dir
    
    # Try extracting from ground_truth.zip
    zip_path = os.path.join(root, "ground_truth.zip")
    
    # If not local, try downloading from HF Dataset (for HF Space deployment)
    if not (os.path.exists(zip_path) and os.path.getsize(zip_path) > 10000):
        try:
            from huggingface_hub import hf_hub_download
            logger.info("Downloading ground_truth.zip from HF Dataset Valmbd/petimot-ground-truth ...")
            zip_path = hf_hub_download(
                repo_id="Valmbd/petimot-ground-truth",
                filename="ground_truth.zip",
                repo_type="dataset",
                local_dir=root,
            )
            logger.info(f"Downloaded ground_truth.zip: {os.path.getsize(zip_path)//1e6:.0f} MB")
        except Exception as e:
            logger.warning(f"Could not download ground_truth from dataset: {e}")
            _gt_extracted_flag[gt_dir] = True
            return gt_dir

        logger.info(f"Extracting ground_truth.zip ({os.path.getsize(zip_path)//1e6:.0f} MB)...")
        os.makedirs(gt_dir, exist_ok=True)
        try:
            with zipfile.ZipFile(zip_path, 'r') as zf:
                # Extract everything, strip top-level 'ground_truth/' prefix
                for member in zf.infolist():
                    name = member.filename
                    # Strip leading 'ground_truth/' if present
                    stripped = name[len('ground_truth/'):] if name.startswith('ground_truth/') else name
                    if not stripped or stripped.endswith('/'):
                        continue
                    dest = os.path.join(gt_dir, stripped)
                    os.makedirs(os.path.dirname(dest), exist_ok=True)
                    with zf.open(member) as src, open(dest, 'wb') as dst:
                        dst.write(src.read())
            logger.info(f"Ground truth extracted: {len(os.listdir(gt_dir))} files")
            _gt_extracted_flag[gt_dir] = True
        except Exception as e:
            logger.warning(f"Failed to extract ground_truth.zip: {e}")
    else:
        _gt_extracted_flag[gt_dir] = True  # mark as tried
    
    return gt_dir


def find_predictions_dir(root: str) -> str | None:
    """Find the predictions directory (most recent model) or zip.
    
    Returns root if predictions.zip exists, or the latest predictions subdir.
    """
    if get_predictions_zip(root):
        return root
    pred_root = os.path.join(root, "predictions")
    if not os.path.isdir(pred_root):
        return None
    subdirs = [os.path.join(pred_root, d) for d in os.listdir(pred_root)
               if os.path.isdir(os.path.join(pred_root, d))]
    if not subdirs:
        return None
    return max(subdirs, key=os.path.getmtime)


@lru_cache(maxsize=1)
def load_prediction_index(pred_dir: str) -> pd.DataFrame:
    """Build index of all predicted proteins with metadata."""
    rows = []
    
    # ── Try reading from predictions.zip ──
    zip_path = get_predictions_zip(pred_dir)
    if zip_path:
        try:
            with zipfile.ZipFile(zip_path, 'r') as zf:
                # Look for index.json inside the zip
                idx_file = next((f for f in zf.namelist() if f.endswith("index.json")), None)
                index_dict = {}
                if idx_file:
                    with zf.open(idx_file) as f:
                        index_dict = json.load(f)

                if index_dict:
                    # Load external disp_profiles if available
                    _prof_path = os.path.join(
                        os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
                        "data", "disp_profiles.json"
                    )
                    _profiles: dict = {}
                    try:
                        if os.path.exists(_prof_path):
                            with open(_prof_path) as _pf:
                                _profiles = json.load(_pf)
                    except Exception:
                        pass

                    # Use pre-built index
                    for k, v in index_dict.items():
                        rows.append({
                            "name": k,
                            "seq_len": v.get("seq_len", 0),
                            "n_modes": v.get("n_modes", 0),
                            "mean_disp_m0": v.get("mean_disp", 0.0),
                            "max_disp_m0": v.get("max_disp", 0.0),
                            "top_residue": v.get("top_residue", -1),
                            "disp_profile": _profiles.get(k, v.get("disp_profile", [])),
                        })

                else:
                    # No index.json or empty β€” scan zip for _mode_0.txt files
                    logger.info("index.json missing or empty β€” scanning zip for mode files...")
                    mode0_files = [f for f in zf.namelist() if f.endswith("_mode_0.txt")]
                    for mf in mode0_files:
                        base = os.path.basename(mf).replace("_mode_0.txt", "")
                        try:
                            with zf.open(mf) as f:
                                vecs = np.loadtxt(f)
                            mag = np.linalg.norm(vecs, axis=1)
                            rows.append({
                                "name": base,
                                "seq_len": len(vecs),
                                "n_modes": 4,  # assume default
                                "mean_disp_m0": float(mag.mean()),
                                "max_disp_m0": float(mag.max()),
                                "top_residue": int(np.argmax(mag)) + 1,
                                "disp_profile": mag[::max(1, len(mag)//20)].tolist(),
                            })
                        except Exception:
                            continue
        except Exception as e:
            logger.warning(f"Failed to load predictions from zip: {e}")

        if rows:
            return pd.DataFrame(rows).sort_values("name").reset_index(drop=True)

    # ── Fallback to loose files on disk ──
    if os.path.isdir(pred_dir):
        mode_files = glob.glob(os.path.join(pred_dir, "*_mode_0.txt"))
        for mf in mode_files:
            base = os.path.basename(mf).replace("_mode_0.txt", "")
            try:
                vecs = np.loadtxt(mf)
                n_res = len(vecs)
                mag = np.linalg.norm(vecs, axis=1)
                n_modes = sum(1 for k in range(10)
                              if os.path.exists(os.path.join(pred_dir, f"{base}_mode_{k}.txt")))
                rows.append({
                    "name": base,
                    "seq_len": n_res,
                    "n_modes": n_modes,
                    "mean_disp_m0": float(mag.mean()),
                    "max_disp_m0": float(mag.max()),
                    "top_residue": int(np.argmax(mag)) + 1,
                    "disp_profile": mag[::max(1, len(mag)//20)].tolist(),
                })
            except Exception:
                continue

    if not rows:
        return pd.DataFrame(columns=["name", "seq_len", "n_modes", "mean_disp_m0", "max_disp_m0", "top_residue", "disp_profile"])
    return pd.DataFrame(rows).sort_values("name").reset_index(drop=True)


def load_modes(pred_dir: str, name: str) -> dict[int, np.ndarray]:
    """Load all mode files for a protein."""
    modes = {}

    # ── Try from zip ──
    zip_path = get_predictions_zip(pred_dir)
    if zip_path:
        namelist = _get_zip_namelist(zip_path)
        try:
            with zipfile.ZipFile(zip_path, 'r') as zf:
                for k in range(10):
                    found = False
                    for pfx in [f"extracted_{name}", name]:
                        suffix = f"{pfx}_mode_{k}.txt"
                        matched = next((f for f in namelist if f.endswith(f"/{suffix}") or f == suffix), None)
                        if matched:
                            with zf.open(matched) as f:
                                modes[k] = np.loadtxt(f)
                            found = True
                            break
                    if not found and k > 0:
                        break  # No more modes
        except Exception as e:
            logger.warning(f"Failed to load modes from zip for {name}: {e}")

        if modes:
            return modes

    # ── Fallback for loose files ──
    for k in range(10):
        found = False
        for pfx in [f"extracted_{name}", name]:
            mf = os.path.join(pred_dir, f"{pfx}_mode_{k}.txt")
            if os.path.exists(mf):
                modes[k] = np.loadtxt(mf)
                found = True
                break
        if not found and k > 0:
            break
    return modes


def load_embeddings(pred_dir: str, name: str) -> np.ndarray | None:
    """Load node embeddings if they exist."""
    zip_path = get_predictions_zip(pred_dir)
    if zip_path:
        namelist = _get_zip_namelist(zip_path)
        try:
            with zipfile.ZipFile(zip_path, 'r') as zf:
                for pfx in [f"extracted_{name}", name]:
                    suffix = f"{pfx}_embeddings.npy"
                    matched = next((f for f in namelist if f.endswith(f"/{suffix}") or f == suffix), None)
                    if matched:
                        with zf.open(matched) as f:
                            from io import BytesIO
                            return np.load(BytesIO(f.read()))
        except Exception as e:
            logger.warning(f"Failed to load embeddings from zip for {name}: {e}")

    # Fallback to loose files
    for pfx in [f"extracted_{name}", name]:
        emb_path = os.path.join(pred_dir, f"{pfx}_embeddings.npy")
        if os.path.exists(emb_path):
            try:
                return np.load(emb_path)
            except Exception:
                pass
    return None


def load_ground_truth(gt_dir: str, name: str) -> dict | None:
    """Load ground truth data for a protein.
    
    Automatically extracts ground_truth.zip if the directory doesn't exist yet.
    """
    # Auto-extract ground_truth.zip if needed (idempotent)
    root = os.path.dirname(gt_dir)
    resolved_gt_dir = ensure_ground_truth(root)
    if not os.path.isdir(resolved_gt_dir):
        return None
    # Search in directory and one level of subdirectories
    for search_dir in [resolved_gt_dir] + [
        os.path.join(resolved_gt_dir, d)
        for d in os.listdir(resolved_gt_dir)
        if os.path.isdir(os.path.join(resolved_gt_dir, d))
    ]:
        path = os.path.join(search_dir, f"{name}.pt")
        if os.path.exists(path):
            try:
                # Load .pt file without torch β€” use pickle directly
                with open(path, "rb") as f:
                    data = pickle.load(f)
                # Convert any torch tensors to numpy if torch is available
                result = {}
                for k, v in data.items():
                    try:
                        import torch as _torch
                        if isinstance(v, _torch.Tensor):
                            result[k] = v.numpy()
                        else:
                            result[k] = v
                    except Exception:
                        result[k] = v
                return result
            except Exception as e:
                logger.warning(f"Failed to load {path}: {e}")
                return None
    return None


def load_pdb_text(pdb_path: str) -> str | None:
    """Load PDB file as text."""
    if not os.path.exists(pdb_path):
        return None
    with open(pdb_path) as f:
        return f.read()