from __future__ import annotations import math from typing import Any, Collection, Dict, List, Tuple, Optional import pandas as pd # if you have it in this module already # 1. Standardization dictionary _DRUG_NAME_MAP = { "1-hydroxymidazolam": "1-hydroxy-midazolam", "4-hydroxytolbutamide": "4-hydroxy-tolbutamide", "5-hydroxyomeprazole": "5-hydroxy-omeprazole", "caffeine (137X)": "caffeine", "dextromethorphan": "dextromethorphan", "dextrorphan": "dextrorphan", "digoxin": "digoxin", "hydroxy repaglinide": "hydroxy-repaglinide", "memantine": "memantine", "midazolam": "midazolam", "omeprazole": "omeprazole", "omeprazole sulfone": "omeprazole sulfone", "paracetamol": "paracetamol", "paracetamol glucuronide": "paracetamol glucuronide", "paraxanthine (17X)": "paraxanthine", "repaglinide": "repaglinide", "rosuvastatin": "rosuvastatin", "tolbutamide": "tolbutamide", "Indometacin": "indometacin", "Theophylline": "theophylline", } reference_results = { "drug": [ "caffeine", "dextromethorphan", "digoxin", "memantine", "midazolam", "omeprazole", "paracetamol", "repaglinide", "rosuvastatin", "tolbutamide", "1-hydroxy-midazolam", "4-hydroxy-tolbutamide", "5-hydroxy-omeprazole", "dextrorphan", "hydroxy-repaglinide", "omeprazole sulfone", "paracetamol glucuronide", "paraxanthine", ], "NLME": [ 0.356, 0.796, 0.315, 0.411, 0.674, 1.470, 0.319, 0.632, 0.470, 0.766, math.nan, math.nan, math.nan, math.nan, math.nan, math.nan, math.nan, math.nan, ], "NODE-PK": [ 0.914, 0.668, 1.403, 0.549, 0.456, 1.940, 1.094, 0.879, 0.471, 0.683, 0.741, 0.871, 2.014, 0.723, 0.340, 1.992, 0.509, 1.648, ], "T-PK": [ 0.575, 0.630, 0.717, 0.799, 0.735, 1.864, 0.825, 0.846, 0.748, 0.816, 0.678, 0.898, 1.683, 1.001, 0.532, 1.620, 0.423, 0.646, ], "SNODE-PK": [ 0.780, 1.702, 0.501, 0.580, 0.874, 1.267, 1.115, 1.514, 0.624, 0.949, 1.395, 0.524, 1.811, 0.904, 0.059, 1.529, 0.823, 0.653, ], "ST-PK": [ 0.984, 1.412, 0.421, 0.869, 0.817, 1.078, 1.050, 1.246, 0.604, 0.998, 1.216, 0.742, 1.600, 0.860, 0.336, 1.294, 1.057, 0.858, ], "AICME-RNN": [ 0.646, 0.640, 0.569, 0.534, 0.548, 1.395, 0.691, 0.562, 0.578, 0.854, 0.935, 0.274, 1.575, 0.614, 0.095, 1.438, 0.365, 0.409, ], "AICMET": [ 0.477, 0.437, 0.457, 0.362, 0.366, 1.139, 0.406, 0.583, 0.396, 0.691, 0.729, 0.265, 1.615, 0.374, 0.113, 1.366, 0.295, 0.266, ], } reference_data_nme = { "drug": [ "caffeine (137X)", "dextromethorphan", "digoxin", "memantine", "midazolam", "omeprazole", "paracetamol", "repaglinide", "rosuvastatin", "tolbutamide", "indometacin", "theophylline", ], "log-rmse": [0.356, 0.796, 0.315, 0.411, 0.674, 1.47, 0.319, 0.632, 0.470, 0.766, 0.604, 0.754], "log-r2": [0.820, 0.556, 0.482, 0.740, 0.344, -0.75, 0.905, 0.561, 0.557, 0.506, 100.0, 100.0], } reference_df = pd.DataFrame(reference_results) def normalize_drug_name(raw: str) -> str: """ Normalize drug names from comet logs to match reference table names. Falls back to the raw name if no mapping exists. """ return _DRUG_NAME_MAP.get(raw, raw) def _extract_drug_from_metric_name( metric_name_full: str, metric_name: str, top_level: str | None = None, ) -> str | None: """ Extract drug name from metricName. Handles patterns like: "Empirical/Synthetic/paracetamol glucuronide/r2" "Synthetic/Synthetic/substance_16/rmse" (and ignores things like "Empirical/epoch_399/r2") top_level: If given, require metricName to start with this first segment, e.g. "Empirical". """ parts = metric_name_full.split("/") if not parts: return None # Require that the last segment matches the metric_name we're interested in if parts[-1] != metric_name: return None # Optional filter on the very first segment: "Empirical", "Synthetic", etc. if top_level is not None and parts[0] != top_level: return None # Drop the metric name at the end core = parts[:-1] # Old-style names might have a trailing "epoch_399" segment; drop it if present if core and core[-1].startswith("epoch_"): core = core[:-1] # We expect at least [prefix, drug] -> length >= 2 if len(core) < 2: return None raw_drug = core[-1] if not raw_drug: return None # Don't treat these prefixes as drugs if raw_drug.lower() in {"empirical", "synthetic", "train", "val", "test"}: return None return normalize_drug_name(raw_drug) def metrics_list_to_pandas( metrics_list: List[Dict[str, Any]], model_name: str, metric_name: str, epoch: int | str, top_level: str | None = None, ) -> pd.DataFrame: """ Convert comet_ml metrics to a per-drug DataFrame for a given metric and epoch. metrics_list entries look like: { "metricName": "Empirical/Synthetic/paracetamol glucuronide/r2", "metricValue": "-0.09778215289115906", "timestamp": 1764093835814, "step": 2, "epoch": 0, ... } top_level: Optional filter on the first path segment in metricName, e.g. "Empirical" or "Synthetic". """ # ----------------------- # 1) Resolve target epoch # ----------------------- target_epoch: int | None if isinstance(epoch, str): if epoch == "last": epochs: List[int] = [] for m in metrics_list: e = m.get("epoch") try: if e is not None: epochs.append(int(e)) except (TypeError, ValueError): continue target_epoch = max(epochs) if epochs else None else: # Unknown epoch label → nothing to do return pd.DataFrame(columns=["drug", model_name]) else: target_epoch = int(epoch) if target_epoch is None: return pd.DataFrame(columns=["drug", model_name]) # ----------------------- # 2) Collect rows # ----------------------- rows: list[tuple[str, float, int]] = [] for m in metrics_list: name = m.get("metricName") or "" drug = _extract_drug_from_metric_name( metric_name_full=name, metric_name=metric_name, top_level=top_level, ) if not drug: continue # Filter by epoch field (new comet format) e_raw = m.get("epoch") try: e_val = int(e_raw) except (TypeError, ValueError): continue if e_val != target_epoch: continue # Metric value try: value = float(m.get("metricValue")) except (TypeError, ValueError): continue ts = int(m.get("timestamp", 0)) rows.append((drug, value, ts)) if not rows: return pd.DataFrame(columns=["drug", model_name]) # ----------------------- # 3) Keep latest per drug # ----------------------- latest: dict[str, tuple[float, int]] = {} for drug, value, ts in rows: cur = latest.get(drug) if cur is None or ts > cur[1]: latest[drug] = (value, ts) data = [{"drug": d, model_name: vts[0]} for d, vts in latest.items()] df = pd.DataFrame(data).sort_values("drug").reset_index(drop=True) return df def empirical_batches_to_pandas( all_empirical_batches: Dict[str, List["AICMECompartmentsDataBatch"]], model: Any, model_name: str, metric_name: str, repo_filter: Optional[Collection[str]] = None, ) -> pd.DataFrame: """ Aggregate per-drug metrics computed from all_empirical_batches into a DataFrame with columns ["drug", model_name], analogous to metrics_list_to_pandas. Parameters ---------- all_empirical_batches : Dict[str, List[AICMECompartmentsDataBatch]] Mapping repo_id -> list of batches. model : Any Model instance exposing `_compute_metrics_from_batch_list(batch_list, repo_id)`. model_name : str Name of the model; becomes the metric column name in the DataFrame. metric_name : str Which metric to extract ("rmse", "log_rmse", "r2", "log_r2", ...). repo_filter : Optional[Collection[str]] If given, only these repo_ids are processed. Returns ------- pd.DataFrame Columns: ["drug", model_name], sorted by drug. """ rows: list[tuple[str, float, str]] = [] for repo_id, batch_list in all_empirical_batches.items(): if repo_filter is not None and repo_id not in repo_filter: continue # metrics: dict[raw_drug -> dict[metric_name -> value, ...]] metrics, _prediction_cache = model._compute_metrics_from_batch_list(batch_list, repo_id) for raw_drug, metric_dict in metrics.items(): if metric_dict is None: continue if metric_name not in metric_dict: continue value = metric_dict[metric_name] if value is None: continue try: v = float(value) except (TypeError, ValueError): continue drug = normalize_drug_name(raw_drug) rows.append((drug, v, repo_id)) if not rows: return pd.DataFrame(columns=["drug", model_name]) # If a drug appears multiple times (e.g. in several repos), keep the last one. latest_by_drug: Dict[str, float] = {} for drug, value, _repo_id in rows: latest_by_drug[drug] = value data = [{"drug": d, model_name: v} for d, v in latest_by_drug.items()] df = pd.DataFrame(data).sort_values("drug").reset_index(drop=True) return df def reference_dict_to_pandas( reference_data: Dict[str, list], model_name: str, metric_name: str, ) -> pd.DataFrame: """ Convert a reference dictionary with drug-level metrics into a pandas DataFrame. The dictionary must have at least the keys: - "drug": list[str] - : list[float] Applies normalization of drug names to ensure consistency. Parameters ---------- reference_data : dict Dictionary with keys "drug" and metric names (e.g., "log-rmse", "log-r2"). model_name : str Name for the output value column (like "NodePK" or "GP"). metric_name : str Which metric to extract (must be in the dict). Returns ------- pd.DataFrame Two-column DataFrame with: - "drug": standardized drug names - model_name: metric values Sorted by drug name. """ if metric_name not in reference_data: raise ValueError( f"Metric '{metric_name}' not in reference_data keys {list(reference_data.keys())}" ) drugs = [normalize_drug_name(d) for d in reference_data["drug"]] values = reference_data[metric_name] df = pd.DataFrame({"drug": drugs, model_name: values}) return df.sort_values("drug").reset_index(drop=True) def available_epochs_and_metrics(metrics_list: List[Dict[str, Any]]) -> Dict[str, list[str]]: """ Summarize which epochs, metrics and top-level prefixes are available in a comet_ml metrics list. This handles both: - New-style: epoch is in the 'epoch' field and metricName is something like "Empirical/Synthetic/paracetamol glucuronide/r2" - Old-style: epoch encoded in metricName, e.g. "Empirical/epoch_399/r2" Returns ------- Dict[str, list[str]] { "epochs_available": list of unique epoch identifiers (strings), "metrics_available": list of unique metric names (last path segment), "top_levels_available": list of unique top-level prefixes (first path segment) } """ epochs: set[str] = set() metrics: set[str] = set() top_levels: set[str] = set() for m in metrics_list: name = m.get("metricName") or "" if not name: continue parts = name.split("/") if not parts: continue # top-level, e.g. "Empirical" or "Synthetic" top_levels.add(parts[0]) # metric name is always the last segment, e.g. "rmse", "r2" metric = parts[-1] metrics.add(metric) # --- New-style: epoch field present --- e_field = m.get("epoch", None) if e_field is not None: try: epochs.add(str(int(e_field))) except (TypeError, ValueError): pass else: # --- Fallback: old-style epoch encoded in the parent segment --- if len(parts) >= 2: parent = parts[-2] if parent.startswith("epoch_"): epochs.add(parent.replace("epoch_", "")) return { "epochs_available": sorted(epochs, key=lambda x: (x != "last", x)), "metrics_available": sorted(metrics), "top_levels_available": sorted(top_levels), } def count_model_wins( df: pd.DataFrame, model_a: str, model_b: str, *, smaller_is_better: bool = True, ) -> Tuple[int, int, int]: """ Compare two models column-by-column in a merged DataFrame and count wins. Parameters ---------- df : pd.DataFrame Must contain the two columns `model_a` and `model_b` with numeric values. model_a : str Name of the first model column in df. model_b : str Name of the second model column in df. smaller_is_better : bool, default=True If True, smaller values are considered better (e.g. RMSE). If False, larger values are considered better (e.g. R^2). Returns ------- wins_a : int Number of rows where model_a outperforms model_b. wins_b : int Number of rows where model_b outperforms model_a. ties : int Number of rows where they are equal (after dropping NaNs). """ # Select valid rows only valid = df[[model_a, model_b]].dropna() if smaller_is_better: wins_a = (valid[model_a] < valid[model_b]).sum() wins_b = (valid[model_b] < valid[model_a]).sum() else: wins_a = (valid[model_a] > valid[model_b]).sum() wins_b = (valid[model_b] > valid[model_a]).sum() ties = (valid[model_a] == valid[model_b]).sum() return wins_a, wins_b, ties