AICME-runtime / sim_priors_pk /utils /experiments_potsprocessing.py
cesarali's picture
manual runtime bundle push from load_and_push.ipynb
5686f5b verified
from __future__ import annotations
import math
from typing import Any, Collection, Dict, List, Tuple, Optional
import pandas as pd # if you have it in this module already
# 1. Standardization dictionary
_DRUG_NAME_MAP = {
"1-hydroxymidazolam": "1-hydroxy-midazolam",
"4-hydroxytolbutamide": "4-hydroxy-tolbutamide",
"5-hydroxyomeprazole": "5-hydroxy-omeprazole",
"caffeine (137X)": "caffeine",
"dextromethorphan": "dextromethorphan",
"dextrorphan": "dextrorphan",
"digoxin": "digoxin",
"hydroxy repaglinide": "hydroxy-repaglinide",
"memantine": "memantine",
"midazolam": "midazolam",
"omeprazole": "omeprazole",
"omeprazole sulfone": "omeprazole sulfone",
"paracetamol": "paracetamol",
"paracetamol glucuronide": "paracetamol glucuronide",
"paraxanthine (17X)": "paraxanthine",
"repaglinide": "repaglinide",
"rosuvastatin": "rosuvastatin",
"tolbutamide": "tolbutamide",
"Indometacin": "indometacin",
"Theophylline": "theophylline",
}
reference_results = {
"drug": [
"caffeine",
"dextromethorphan",
"digoxin",
"memantine",
"midazolam",
"omeprazole",
"paracetamol",
"repaglinide",
"rosuvastatin",
"tolbutamide",
"1-hydroxy-midazolam",
"4-hydroxy-tolbutamide",
"5-hydroxy-omeprazole",
"dextrorphan",
"hydroxy-repaglinide",
"omeprazole sulfone",
"paracetamol glucuronide",
"paraxanthine",
],
"NLME": [
0.356,
0.796,
0.315,
0.411,
0.674,
1.470,
0.319,
0.632,
0.470,
0.766,
math.nan,
math.nan,
math.nan,
math.nan,
math.nan,
math.nan,
math.nan,
math.nan,
],
"NODE-PK": [
0.914,
0.668,
1.403,
0.549,
0.456,
1.940,
1.094,
0.879,
0.471,
0.683,
0.741,
0.871,
2.014,
0.723,
0.340,
1.992,
0.509,
1.648,
],
"T-PK": [
0.575,
0.630,
0.717,
0.799,
0.735,
1.864,
0.825,
0.846,
0.748,
0.816,
0.678,
0.898,
1.683,
1.001,
0.532,
1.620,
0.423,
0.646,
],
"SNODE-PK": [
0.780,
1.702,
0.501,
0.580,
0.874,
1.267,
1.115,
1.514,
0.624,
0.949,
1.395,
0.524,
1.811,
0.904,
0.059,
1.529,
0.823,
0.653,
],
"ST-PK": [
0.984,
1.412,
0.421,
0.869,
0.817,
1.078,
1.050,
1.246,
0.604,
0.998,
1.216,
0.742,
1.600,
0.860,
0.336,
1.294,
1.057,
0.858,
],
"AICME-RNN": [
0.646,
0.640,
0.569,
0.534,
0.548,
1.395,
0.691,
0.562,
0.578,
0.854,
0.935,
0.274,
1.575,
0.614,
0.095,
1.438,
0.365,
0.409,
],
"AICMET": [
0.477,
0.437,
0.457,
0.362,
0.366,
1.139,
0.406,
0.583,
0.396,
0.691,
0.729,
0.265,
1.615,
0.374,
0.113,
1.366,
0.295,
0.266,
],
}
reference_data_nme = {
"drug": [
"caffeine (137X)",
"dextromethorphan",
"digoxin",
"memantine",
"midazolam",
"omeprazole",
"paracetamol",
"repaglinide",
"rosuvastatin",
"tolbutamide",
"indometacin",
"theophylline",
],
"log-rmse": [0.356, 0.796, 0.315, 0.411, 0.674, 1.47, 0.319, 0.632, 0.470, 0.766, 0.604, 0.754],
"log-r2": [0.820, 0.556, 0.482, 0.740, 0.344, -0.75, 0.905, 0.561, 0.557, 0.506, 100.0, 100.0],
}
reference_df = pd.DataFrame(reference_results)
def normalize_drug_name(raw: str) -> str:
"""
Normalize drug names from comet logs to match reference table names.
Falls back to the raw name if no mapping exists.
"""
return _DRUG_NAME_MAP.get(raw, raw)
def _extract_drug_from_metric_name(
metric_name_full: str,
metric_name: str,
top_level: str | None = None,
) -> str | None:
"""
Extract drug name from metricName.
Handles patterns like:
"Empirical/Synthetic/paracetamol glucuronide/r2"
"Synthetic/Synthetic/substance_16/rmse"
(and ignores things like "Empirical/epoch_399/r2")
top_level:
If given, require metricName to start with this first segment, e.g. "Empirical".
"""
parts = metric_name_full.split("/")
if not parts:
return None
# Require that the last segment matches the metric_name we're interested in
if parts[-1] != metric_name:
return None
# Optional filter on the very first segment: "Empirical", "Synthetic", etc.
if top_level is not None and parts[0] != top_level:
return None
# Drop the metric name at the end
core = parts[:-1]
# Old-style names might have a trailing "epoch_399" segment; drop it if present
if core and core[-1].startswith("epoch_"):
core = core[:-1]
# We expect at least [prefix, drug] -> length >= 2
if len(core) < 2:
return None
raw_drug = core[-1]
if not raw_drug:
return None
# Don't treat these prefixes as drugs
if raw_drug.lower() in {"empirical", "synthetic", "train", "val", "test"}:
return None
return normalize_drug_name(raw_drug)
def metrics_list_to_pandas(
metrics_list: List[Dict[str, Any]],
model_name: str,
metric_name: str,
epoch: int | str,
top_level: str | None = None,
) -> pd.DataFrame:
"""
Convert comet_ml metrics to a per-drug DataFrame for a given metric and epoch.
metrics_list entries look like:
{
"metricName": "Empirical/Synthetic/paracetamol glucuronide/r2",
"metricValue": "-0.09778215289115906",
"timestamp": 1764093835814,
"step": 2,
"epoch": 0,
...
}
top_level:
Optional filter on the first path segment in metricName, e.g. "Empirical" or "Synthetic".
"""
# -----------------------
# 1) Resolve target epoch
# -----------------------
target_epoch: int | None
if isinstance(epoch, str):
if epoch == "last":
epochs: List[int] = []
for m in metrics_list:
e = m.get("epoch")
try:
if e is not None:
epochs.append(int(e))
except (TypeError, ValueError):
continue
target_epoch = max(epochs) if epochs else None
else:
# Unknown epoch label → nothing to do
return pd.DataFrame(columns=["drug", model_name])
else:
target_epoch = int(epoch)
if target_epoch is None:
return pd.DataFrame(columns=["drug", model_name])
# -----------------------
# 2) Collect rows
# -----------------------
rows: list[tuple[str, float, int]] = []
for m in metrics_list:
name = m.get("metricName") or ""
drug = _extract_drug_from_metric_name(
metric_name_full=name,
metric_name=metric_name,
top_level=top_level,
)
if not drug:
continue
# Filter by epoch field (new comet format)
e_raw = m.get("epoch")
try:
e_val = int(e_raw)
except (TypeError, ValueError):
continue
if e_val != target_epoch:
continue
# Metric value
try:
value = float(m.get("metricValue"))
except (TypeError, ValueError):
continue
ts = int(m.get("timestamp", 0))
rows.append((drug, value, ts))
if not rows:
return pd.DataFrame(columns=["drug", model_name])
# -----------------------
# 3) Keep latest per drug
# -----------------------
latest: dict[str, tuple[float, int]] = {}
for drug, value, ts in rows:
cur = latest.get(drug)
if cur is None or ts > cur[1]:
latest[drug] = (value, ts)
data = [{"drug": d, model_name: vts[0]} for d, vts in latest.items()]
df = pd.DataFrame(data).sort_values("drug").reset_index(drop=True)
return df
def empirical_batches_to_pandas(
all_empirical_batches: Dict[str, List["AICMECompartmentsDataBatch"]],
model: Any,
model_name: str,
metric_name: str,
repo_filter: Optional[Collection[str]] = None,
) -> pd.DataFrame:
"""
Aggregate per-drug metrics computed from all_empirical_batches into a
DataFrame with columns ["drug", model_name], analogous to metrics_list_to_pandas.
Parameters
----------
all_empirical_batches : Dict[str, List[AICMECompartmentsDataBatch]]
Mapping repo_id -> list of batches.
model : Any
Model instance exposing `_compute_metrics_from_batch_list(batch_list, repo_id)`.
model_name : str
Name of the model; becomes the metric column name in the DataFrame.
metric_name : str
Which metric to extract ("rmse", "log_rmse", "r2", "log_r2", ...).
repo_filter : Optional[Collection[str]]
If given, only these repo_ids are processed.
Returns
-------
pd.DataFrame
Columns: ["drug", model_name], sorted by drug.
"""
rows: list[tuple[str, float, str]] = []
for repo_id, batch_list in all_empirical_batches.items():
if repo_filter is not None and repo_id not in repo_filter:
continue
# metrics: dict[raw_drug -> dict[metric_name -> value, ...]]
metrics, _prediction_cache = model._compute_metrics_from_batch_list(batch_list, repo_id)
for raw_drug, metric_dict in metrics.items():
if metric_dict is None:
continue
if metric_name not in metric_dict:
continue
value = metric_dict[metric_name]
if value is None:
continue
try:
v = float(value)
except (TypeError, ValueError):
continue
drug = normalize_drug_name(raw_drug)
rows.append((drug, v, repo_id))
if not rows:
return pd.DataFrame(columns=["drug", model_name])
# If a drug appears multiple times (e.g. in several repos), keep the last one.
latest_by_drug: Dict[str, float] = {}
for drug, value, _repo_id in rows:
latest_by_drug[drug] = value
data = [{"drug": d, model_name: v} for d, v in latest_by_drug.items()]
df = pd.DataFrame(data).sort_values("drug").reset_index(drop=True)
return df
def reference_dict_to_pandas(
reference_data: Dict[str, list],
model_name: str,
metric_name: str,
) -> pd.DataFrame:
"""
Convert a reference dictionary with drug-level metrics into a pandas DataFrame.
The dictionary must have at least the keys:
- "drug": list[str]
- <metric_name>: list[float]
Applies normalization of drug names to ensure consistency.
Parameters
----------
reference_data : dict
Dictionary with keys "drug" and metric names (e.g., "log-rmse", "log-r2").
model_name : str
Name for the output value column (like "NodePK" or "GP").
metric_name : str
Which metric to extract (must be in the dict).
Returns
-------
pd.DataFrame
Two-column DataFrame with:
- "drug": standardized drug names
- model_name: metric values
Sorted by drug name.
"""
if metric_name not in reference_data:
raise ValueError(
f"Metric '{metric_name}' not in reference_data keys {list(reference_data.keys())}"
)
drugs = [normalize_drug_name(d) for d in reference_data["drug"]]
values = reference_data[metric_name]
df = pd.DataFrame({"drug": drugs, model_name: values})
return df.sort_values("drug").reset_index(drop=True)
def available_epochs_and_metrics(metrics_list: List[Dict[str, Any]]) -> Dict[str, list[str]]:
"""
Summarize which epochs, metrics and top-level prefixes are available in
a comet_ml metrics list.
This handles both:
- New-style: epoch is in the 'epoch' field and metricName is something like
"Empirical/Synthetic/paracetamol glucuronide/r2"
- Old-style: epoch encoded in metricName, e.g.
"Empirical/epoch_399/r2"
Returns
-------
Dict[str, list[str]]
{
"epochs_available": list of unique epoch identifiers (strings),
"metrics_available": list of unique metric names (last path segment),
"top_levels_available": list of unique top-level prefixes (first path segment)
}
"""
epochs: set[str] = set()
metrics: set[str] = set()
top_levels: set[str] = set()
for m in metrics_list:
name = m.get("metricName") or ""
if not name:
continue
parts = name.split("/")
if not parts:
continue
# top-level, e.g. "Empirical" or "Synthetic"
top_levels.add(parts[0])
# metric name is always the last segment, e.g. "rmse", "r2"
metric = parts[-1]
metrics.add(metric)
# --- New-style: epoch field present ---
e_field = m.get("epoch", None)
if e_field is not None:
try:
epochs.add(str(int(e_field)))
except (TypeError, ValueError):
pass
else:
# --- Fallback: old-style epoch encoded in the parent segment ---
if len(parts) >= 2:
parent = parts[-2]
if parent.startswith("epoch_"):
epochs.add(parent.replace("epoch_", ""))
return {
"epochs_available": sorted(epochs, key=lambda x: (x != "last", x)),
"metrics_available": sorted(metrics),
"top_levels_available": sorted(top_levels),
}
def count_model_wins(
df: pd.DataFrame,
model_a: str,
model_b: str,
*,
smaller_is_better: bool = True,
) -> Tuple[int, int, int]:
"""
Compare two models column-by-column in a merged DataFrame and count wins.
Parameters
----------
df : pd.DataFrame
Must contain the two columns `model_a` and `model_b` with numeric values.
model_a : str
Name of the first model column in df.
model_b : str
Name of the second model column in df.
smaller_is_better : bool, default=True
If True, smaller values are considered better (e.g. RMSE).
If False, larger values are considered better (e.g. R^2).
Returns
-------
wins_a : int
Number of rows where model_a outperforms model_b.
wins_b : int
Number of rows where model_b outperforms model_a.
ties : int
Number of rows where they are equal (after dropping NaNs).
"""
# Select valid rows only
valid = df[[model_a, model_b]].dropna()
if smaller_is_better:
wins_a = (valid[model_a] < valid[model_b]).sum()
wins_b = (valid[model_b] < valid[model_a]).sum()
else:
wins_a = (valid[model_a] > valid[model_b]).sum()
wins_b = (valid[model_b] > valid[model_a]).sum()
ties = (valid[model_a] == valid[model_b]).sum()
return wins_a, wins_b, ties