| from __future__ import annotations |
|
|
| import math |
| from typing import Any, Collection, Dict, List, Tuple, Optional |
|
|
| import pandas as pd |
|
|
| |
| _DRUG_NAME_MAP = { |
| "1-hydroxymidazolam": "1-hydroxy-midazolam", |
| "4-hydroxytolbutamide": "4-hydroxy-tolbutamide", |
| "5-hydroxyomeprazole": "5-hydroxy-omeprazole", |
| "caffeine (137X)": "caffeine", |
| "dextromethorphan": "dextromethorphan", |
| "dextrorphan": "dextrorphan", |
| "digoxin": "digoxin", |
| "hydroxy repaglinide": "hydroxy-repaglinide", |
| "memantine": "memantine", |
| "midazolam": "midazolam", |
| "omeprazole": "omeprazole", |
| "omeprazole sulfone": "omeprazole sulfone", |
| "paracetamol": "paracetamol", |
| "paracetamol glucuronide": "paracetamol glucuronide", |
| "paraxanthine (17X)": "paraxanthine", |
| "repaglinide": "repaglinide", |
| "rosuvastatin": "rosuvastatin", |
| "tolbutamide": "tolbutamide", |
| "Indometacin": "indometacin", |
| "Theophylline": "theophylline", |
| } |
|
|
|
|
| reference_results = { |
| "drug": [ |
| "caffeine", |
| "dextromethorphan", |
| "digoxin", |
| "memantine", |
| "midazolam", |
| "omeprazole", |
| "paracetamol", |
| "repaglinide", |
| "rosuvastatin", |
| "tolbutamide", |
| "1-hydroxy-midazolam", |
| "4-hydroxy-tolbutamide", |
| "5-hydroxy-omeprazole", |
| "dextrorphan", |
| "hydroxy-repaglinide", |
| "omeprazole sulfone", |
| "paracetamol glucuronide", |
| "paraxanthine", |
| ], |
| "NLME": [ |
| 0.356, |
| 0.796, |
| 0.315, |
| 0.411, |
| 0.674, |
| 1.470, |
| 0.319, |
| 0.632, |
| 0.470, |
| 0.766, |
| math.nan, |
| math.nan, |
| math.nan, |
| math.nan, |
| math.nan, |
| math.nan, |
| math.nan, |
| math.nan, |
| ], |
| "NODE-PK": [ |
| 0.914, |
| 0.668, |
| 1.403, |
| 0.549, |
| 0.456, |
| 1.940, |
| 1.094, |
| 0.879, |
| 0.471, |
| 0.683, |
| 0.741, |
| 0.871, |
| 2.014, |
| 0.723, |
| 0.340, |
| 1.992, |
| 0.509, |
| 1.648, |
| ], |
| "T-PK": [ |
| 0.575, |
| 0.630, |
| 0.717, |
| 0.799, |
| 0.735, |
| 1.864, |
| 0.825, |
| 0.846, |
| 0.748, |
| 0.816, |
| 0.678, |
| 0.898, |
| 1.683, |
| 1.001, |
| 0.532, |
| 1.620, |
| 0.423, |
| 0.646, |
| ], |
| "SNODE-PK": [ |
| 0.780, |
| 1.702, |
| 0.501, |
| 0.580, |
| 0.874, |
| 1.267, |
| 1.115, |
| 1.514, |
| 0.624, |
| 0.949, |
| 1.395, |
| 0.524, |
| 1.811, |
| 0.904, |
| 0.059, |
| 1.529, |
| 0.823, |
| 0.653, |
| ], |
| "ST-PK": [ |
| 0.984, |
| 1.412, |
| 0.421, |
| 0.869, |
| 0.817, |
| 1.078, |
| 1.050, |
| 1.246, |
| 0.604, |
| 0.998, |
| 1.216, |
| 0.742, |
| 1.600, |
| 0.860, |
| 0.336, |
| 1.294, |
| 1.057, |
| 0.858, |
| ], |
| "AICME-RNN": [ |
| 0.646, |
| 0.640, |
| 0.569, |
| 0.534, |
| 0.548, |
| 1.395, |
| 0.691, |
| 0.562, |
| 0.578, |
| 0.854, |
| 0.935, |
| 0.274, |
| 1.575, |
| 0.614, |
| 0.095, |
| 1.438, |
| 0.365, |
| 0.409, |
| ], |
| "AICMET": [ |
| 0.477, |
| 0.437, |
| 0.457, |
| 0.362, |
| 0.366, |
| 1.139, |
| 0.406, |
| 0.583, |
| 0.396, |
| 0.691, |
| 0.729, |
| 0.265, |
| 1.615, |
| 0.374, |
| 0.113, |
| 1.366, |
| 0.295, |
| 0.266, |
| ], |
| } |
|
|
| reference_data_nme = { |
| "drug": [ |
| "caffeine (137X)", |
| "dextromethorphan", |
| "digoxin", |
| "memantine", |
| "midazolam", |
| "omeprazole", |
| "paracetamol", |
| "repaglinide", |
| "rosuvastatin", |
| "tolbutamide", |
| "indometacin", |
| "theophylline", |
| ], |
| "log-rmse": [0.356, 0.796, 0.315, 0.411, 0.674, 1.47, 0.319, 0.632, 0.470, 0.766, 0.604, 0.754], |
| "log-r2": [0.820, 0.556, 0.482, 0.740, 0.344, -0.75, 0.905, 0.561, 0.557, 0.506, 100.0, 100.0], |
| } |
|
|
| reference_df = pd.DataFrame(reference_results) |
|
|
|
|
| def normalize_drug_name(raw: str) -> str: |
| """ |
| Normalize drug names from comet logs to match reference table names. |
| Falls back to the raw name if no mapping exists. |
| """ |
| return _DRUG_NAME_MAP.get(raw, raw) |
|
|
|
|
| def _extract_drug_from_metric_name( |
| metric_name_full: str, |
| metric_name: str, |
| top_level: str | None = None, |
| ) -> str | None: |
| """ |
| Extract drug name from metricName. |
| |
| Handles patterns like: |
| "Empirical/Synthetic/paracetamol glucuronide/r2" |
| "Synthetic/Synthetic/substance_16/rmse" |
| (and ignores things like "Empirical/epoch_399/r2") |
| |
| top_level: |
| If given, require metricName to start with this first segment, e.g. "Empirical". |
| """ |
| parts = metric_name_full.split("/") |
| if not parts: |
| return None |
|
|
| |
| if parts[-1] != metric_name: |
| return None |
|
|
| |
| if top_level is not None and parts[0] != top_level: |
| return None |
|
|
| |
| core = parts[:-1] |
|
|
| |
| if core and core[-1].startswith("epoch_"): |
| core = core[:-1] |
|
|
| |
| if len(core) < 2: |
| return None |
|
|
| raw_drug = core[-1] |
| if not raw_drug: |
| return None |
|
|
| |
| if raw_drug.lower() in {"empirical", "synthetic", "train", "val", "test"}: |
| return None |
|
|
| return normalize_drug_name(raw_drug) |
|
|
|
|
| def metrics_list_to_pandas( |
| metrics_list: List[Dict[str, Any]], |
| model_name: str, |
| metric_name: str, |
| epoch: int | str, |
| top_level: str | None = None, |
| ) -> pd.DataFrame: |
| """ |
| Convert comet_ml metrics to a per-drug DataFrame for a given metric and epoch. |
| |
| metrics_list entries look like: |
| { |
| "metricName": "Empirical/Synthetic/paracetamol glucuronide/r2", |
| "metricValue": "-0.09778215289115906", |
| "timestamp": 1764093835814, |
| "step": 2, |
| "epoch": 0, |
| ... |
| } |
| |
| top_level: |
| Optional filter on the first path segment in metricName, e.g. "Empirical" or "Synthetic". |
| """ |
|
|
| |
| |
| |
| target_epoch: int | None |
| if isinstance(epoch, str): |
| if epoch == "last": |
| epochs: List[int] = [] |
| for m in metrics_list: |
| e = m.get("epoch") |
| try: |
| if e is not None: |
| epochs.append(int(e)) |
| except (TypeError, ValueError): |
| continue |
| target_epoch = max(epochs) if epochs else None |
| else: |
| |
| return pd.DataFrame(columns=["drug", model_name]) |
| else: |
| target_epoch = int(epoch) |
|
|
| if target_epoch is None: |
| return pd.DataFrame(columns=["drug", model_name]) |
|
|
| |
| |
| |
| rows: list[tuple[str, float, int]] = [] |
|
|
| for m in metrics_list: |
| name = m.get("metricName") or "" |
|
|
| drug = _extract_drug_from_metric_name( |
| metric_name_full=name, |
| metric_name=metric_name, |
| top_level=top_level, |
| ) |
| if not drug: |
| continue |
|
|
| |
| e_raw = m.get("epoch") |
| try: |
| e_val = int(e_raw) |
| except (TypeError, ValueError): |
| continue |
|
|
| if e_val != target_epoch: |
| continue |
|
|
| |
| try: |
| value = float(m.get("metricValue")) |
| except (TypeError, ValueError): |
| continue |
|
|
| ts = int(m.get("timestamp", 0)) |
| rows.append((drug, value, ts)) |
|
|
| if not rows: |
| return pd.DataFrame(columns=["drug", model_name]) |
|
|
| |
| |
| |
| latest: dict[str, tuple[float, int]] = {} |
| for drug, value, ts in rows: |
| cur = latest.get(drug) |
| if cur is None or ts > cur[1]: |
| latest[drug] = (value, ts) |
|
|
| data = [{"drug": d, model_name: vts[0]} for d, vts in latest.items()] |
| df = pd.DataFrame(data).sort_values("drug").reset_index(drop=True) |
| return df |
|
|
|
|
| def empirical_batches_to_pandas( |
| all_empirical_batches: Dict[str, List["AICMECompartmentsDataBatch"]], |
| model: Any, |
| model_name: str, |
| metric_name: str, |
| repo_filter: Optional[Collection[str]] = None, |
| ) -> pd.DataFrame: |
| """ |
| Aggregate per-drug metrics computed from all_empirical_batches into a |
| DataFrame with columns ["drug", model_name], analogous to metrics_list_to_pandas. |
| |
| Parameters |
| ---------- |
| all_empirical_batches : Dict[str, List[AICMECompartmentsDataBatch]] |
| Mapping repo_id -> list of batches. |
| model : Any |
| Model instance exposing `_compute_metrics_from_batch_list(batch_list, repo_id)`. |
| model_name : str |
| Name of the model; becomes the metric column name in the DataFrame. |
| metric_name : str |
| Which metric to extract ("rmse", "log_rmse", "r2", "log_r2", ...). |
| repo_filter : Optional[Collection[str]] |
| If given, only these repo_ids are processed. |
| |
| Returns |
| ------- |
| pd.DataFrame |
| Columns: ["drug", model_name], sorted by drug. |
| """ |
| rows: list[tuple[str, float, str]] = [] |
|
|
| for repo_id, batch_list in all_empirical_batches.items(): |
| if repo_filter is not None and repo_id not in repo_filter: |
| continue |
|
|
| |
| metrics, _prediction_cache = model._compute_metrics_from_batch_list(batch_list, repo_id) |
|
|
| for raw_drug, metric_dict in metrics.items(): |
| if metric_dict is None: |
| continue |
|
|
| if metric_name not in metric_dict: |
| continue |
|
|
| value = metric_dict[metric_name] |
| if value is None: |
| continue |
|
|
| try: |
| v = float(value) |
| except (TypeError, ValueError): |
| continue |
|
|
| drug = normalize_drug_name(raw_drug) |
| rows.append((drug, v, repo_id)) |
|
|
| if not rows: |
| return pd.DataFrame(columns=["drug", model_name]) |
|
|
| |
| latest_by_drug: Dict[str, float] = {} |
| for drug, value, _repo_id in rows: |
| latest_by_drug[drug] = value |
|
|
| data = [{"drug": d, model_name: v} for d, v in latest_by_drug.items()] |
| df = pd.DataFrame(data).sort_values("drug").reset_index(drop=True) |
| return df |
|
|
|
|
| def reference_dict_to_pandas( |
| reference_data: Dict[str, list], |
| model_name: str, |
| metric_name: str, |
| ) -> pd.DataFrame: |
| """ |
| Convert a reference dictionary with drug-level metrics into a pandas DataFrame. |
| |
| The dictionary must have at least the keys: |
| - "drug": list[str] |
| - <metric_name>: list[float] |
| |
| Applies normalization of drug names to ensure consistency. |
| |
| Parameters |
| ---------- |
| reference_data : dict |
| Dictionary with keys "drug" and metric names (e.g., "log-rmse", "log-r2"). |
| model_name : str |
| Name for the output value column (like "NodePK" or "GP"). |
| metric_name : str |
| Which metric to extract (must be in the dict). |
| |
| Returns |
| ------- |
| pd.DataFrame |
| Two-column DataFrame with: |
| - "drug": standardized drug names |
| - model_name: metric values |
| Sorted by drug name. |
| """ |
| if metric_name not in reference_data: |
| raise ValueError( |
| f"Metric '{metric_name}' not in reference_data keys {list(reference_data.keys())}" |
| ) |
|
|
| drugs = [normalize_drug_name(d) for d in reference_data["drug"]] |
| values = reference_data[metric_name] |
| df = pd.DataFrame({"drug": drugs, model_name: values}) |
| return df.sort_values("drug").reset_index(drop=True) |
|
|
|
|
| def available_epochs_and_metrics(metrics_list: List[Dict[str, Any]]) -> Dict[str, list[str]]: |
| """ |
| Summarize which epochs, metrics and top-level prefixes are available in |
| a comet_ml metrics list. |
| |
| This handles both: |
| - New-style: epoch is in the 'epoch' field and metricName is something like |
| "Empirical/Synthetic/paracetamol glucuronide/r2" |
| - Old-style: epoch encoded in metricName, e.g. |
| "Empirical/epoch_399/r2" |
| |
| Returns |
| ------- |
| Dict[str, list[str]] |
| { |
| "epochs_available": list of unique epoch identifiers (strings), |
| "metrics_available": list of unique metric names (last path segment), |
| "top_levels_available": list of unique top-level prefixes (first path segment) |
| } |
| """ |
| epochs: set[str] = set() |
| metrics: set[str] = set() |
| top_levels: set[str] = set() |
|
|
| for m in metrics_list: |
| name = m.get("metricName") or "" |
| if not name: |
| continue |
|
|
| parts = name.split("/") |
| if not parts: |
| continue |
|
|
| |
| top_levels.add(parts[0]) |
|
|
| |
| metric = parts[-1] |
| metrics.add(metric) |
|
|
| |
| e_field = m.get("epoch", None) |
| if e_field is not None: |
| try: |
| epochs.add(str(int(e_field))) |
| except (TypeError, ValueError): |
| pass |
| else: |
| |
| if len(parts) >= 2: |
| parent = parts[-2] |
| if parent.startswith("epoch_"): |
| epochs.add(parent.replace("epoch_", "")) |
|
|
| return { |
| "epochs_available": sorted(epochs, key=lambda x: (x != "last", x)), |
| "metrics_available": sorted(metrics), |
| "top_levels_available": sorted(top_levels), |
| } |
|
|
|
|
| def count_model_wins( |
| df: pd.DataFrame, |
| model_a: str, |
| model_b: str, |
| *, |
| smaller_is_better: bool = True, |
| ) -> Tuple[int, int, int]: |
| """ |
| Compare two models column-by-column in a merged DataFrame and count wins. |
| |
| Parameters |
| ---------- |
| df : pd.DataFrame |
| Must contain the two columns `model_a` and `model_b` with numeric values. |
| model_a : str |
| Name of the first model column in df. |
| model_b : str |
| Name of the second model column in df. |
| smaller_is_better : bool, default=True |
| If True, smaller values are considered better (e.g. RMSE). |
| If False, larger values are considered better (e.g. R^2). |
| |
| Returns |
| ------- |
| wins_a : int |
| Number of rows where model_a outperforms model_b. |
| wins_b : int |
| Number of rows where model_b outperforms model_a. |
| ties : int |
| Number of rows where they are equal (after dropping NaNs). |
| """ |
| |
| valid = df[[model_a, model_b]].dropna() |
|
|
| if smaller_is_better: |
| wins_a = (valid[model_a] < valid[model_b]).sum() |
| wins_b = (valid[model_b] < valid[model_a]).sum() |
| else: |
| wins_a = (valid[model_a] > valid[model_b]).sum() |
| wins_b = (valid[model_b] > valid[model_a]).sum() |
|
|
| ties = (valid[model_a] == valid[model_b]).sum() |
|
|
| return wins_a, wins_b, ties |
|
|