| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
| from torchtyping import TensorType |
|
|
| import torch |
| from torchtyping import TensorType |
| from typing import List,Tuple,Optional |
| import numpy as np |
|
|
| from sim_priors_pk.data.data_preprocessing.raw_to_tensors_bundles import substance_cvs_to_tensors_bundle,substances_csv_to_tensors |
|
|
| from typing import NamedTuple |
| import torch |
| from torchtyping import TensorType |
|
|
| class SubstanceTensorGroup(NamedTuple): |
| observations: TensorType[1, "I", "T"] |
| times: TensorType[1, "I", "T"] |
| mask: TensorType[1, "I", "T"] |
| subject_mask: TensorType[1, "I"] |
|
|
| def apply_timescale_filter( |
| observations: TensorType["S", "I", "T"], |
| times: TensorType["S", "I", "T"], |
| masks: TensorType["S", "I", "T"], |
| subject_mask: TensorType["S", "I"], |
| *, |
| strategy: str = "log_zscore", |
| max_abs_z: float = 2.0, |
| tau: float = 0.4, |
| ) -> Tuple[ |
| TensorType["S", "I", "T"], |
| TensorType["S", "I", "T"], |
| TensorType["S", "I", "T"], |
| TensorType["S", "I"], |
| ]: |
| """ |
| Zeroesβout and unβmasks subjects whose timeβspan is an outlier |
| w.r.t. other subjects in the *same* substance. |
| |
| β’ strategy="log_zscore": keep subjects with |z| β€ max_abs_z in logβspan |
| β’ strategy="median_fraction": keep subjects within Β±tau of median(logβspan) |
| β’ strategy="none": return inputs unchanged |
| """ |
| if strategy == "none": |
| return observations, times, masks, subject_mask |
|
|
| |
| valid = masks.bool() & subject_mask.unsqueeze(-1) |
|
|
| |
| t_max = times.masked_fill(~valid, float("-inf")).max(dim=2).values |
| t_min = times.masked_fill(~valid, float("inf")).min(dim=2).values |
| span = (t_max - t_min).clamp(min=1e-12) |
| log_span = span.log() |
|
|
| |
| if strategy == "log_zscore": |
| z = (log_span - log_span.mean(dim=1, keepdim=True)) / \ |
| (log_span.std(dim=1, keepdim=True).clamp(min=1e-6)) |
| keep = torch.abs(z) <= max_abs_z |
|
|
| elif strategy == "median_fraction": |
| med = log_span.median(dim=1, keepdim=True).values |
| keep = (log_span >= med - tau) & (log_span <= med + tau) |
|
|
| else: |
| |
| return observations, times, masks, subject_mask |
|
|
| |
| |
| obs_f = observations.clone() |
| times_f = times.clone() |
| masks_f = masks.clone() |
| subj_f = subject_mask.clone() |
|
|
| |
| drop = ~keep & subj_f.bool() |
| subj_f[drop] = False |
| masks_f[drop] = False |
| obs_f[drop] = 0.0 |
| times_f[drop] = 0.0 |
|
|
| return obs_f, times_f, masks_f, subj_f |
|
|
| def plot_subjects_for_substance( |
| drug_data_frame, |
| substance_label: str, |
| *, |
| z_score_normalization: bool = False, |
| normalize_by_max:bool = False, |
| time_strategy:str="log_zscore", |
| max_abs_z:float=2., |
| x_scale: str = "linear", |
| y_scale: str = "linear", |
| alpha: float = 1.0, |
| legend_outside: bool = True, |
| figsize: Tuple[float, float] = (10, 5), |
| save_dir: Optional[str] = None, |
| |
| ) -> None: |
| """ |
| Draw every subjectβtrajectory (points + line) for *one* substance. |
| |
| Parameters |
| ---------- |
| drug_data_frame : pandas.DataFrame |
| substance_label : str |
| z_score_normalization : bool, optional |
| x_scale, y_scale : {"linear", "log"}, optional |
| Axis scaling. If you pick "log", make sure data are strictly >β―0 |
| on that axis or Matplotlib will complain. |
| alpha : float in [0,β―1], optional |
| Transparency applied to both the line and the markers. |
| legend_outside : bool, optional |
| True β’ legend in a separate column to the right; |
| False β’ legend inside plot. |
| """ |
| |
| data_bundle = substance_cvs_to_tensors_bundle(drug_data_frame,normalize_by_max=True) |
|
|
| all_obs = data_bundle.observations |
| all_times = data_bundle.times |
| all_masks = data_bundle.masks |
| all_subj_mask = data_bundle.individuals_mask |
| substance_labels = data_bundle.substance_names |
| mapping = data_bundle.mapping |
| study_names = data_bundle.study_names |
| subject_names = data_bundle.individuals_names |
| empirical_loaded = True |
|
|
| |
| try: |
| s_idx: int = int(np.where(substance_labels == substance_label)[0][0]) |
| except IndexError: |
| raise ValueError(f"Substance '{substance_label}' not found.") |
|
|
| |
| obs: TensorType["I", "T"] = all_obs[s_idx] |
| times: TensorType["I", "T"] = all_times[s_idx] |
| step_mask: TensorType["I", "T"] = all_masks[s_idx].bool() |
| subj_mask: TensorType["I"] = all_subj_mask[s_idx].bool() |
|
|
| |
| |
| obs_b = obs.unsqueeze(0) |
| times_b = times.unsqueeze(0) |
| step_mask_b = step_mask.unsqueeze(0) |
| subj_mask_b = subj_mask.unsqueeze(0) |
|
|
| |
| obs_b, times_b, step_mask_b, subj_mask_b = apply_timescale_filter( |
| observations=obs_b, |
| times=times_b, |
| masks=step_mask_b, |
| subject_mask=subj_mask_b, |
| strategy=time_strategy, |
| max_abs_z=max_abs_z, |
| tau=0.4, |
| ) |
|
|
| |
| obs = obs_b[0] |
| times = times_b[0] |
| step_mask = step_mask_b[0] |
| subj_mask = subj_mask_b[0] |
|
|
|
|
| |
| fig, ax = plt.subplots(figsize=figsize) |
| for i in range(obs.shape[0]): |
| if not subj_mask[i]: |
| continue |
|
|
| valid: TensorType["T"] = step_mask[i] |
| t: TensorType["T"] = times[i][valid].cpu() |
| y: TensorType["T"] = obs[i][valid].cpu() |
|
|
| ax.plot(t, y, marker="o", alpha=alpha, label=f"subjectΒ {i}") |
|
|
| |
| ax.set_title(f"All subjects β {substance_label}") |
| ax.set_xlabel("Time (normalised per substance)") |
| ax.set_ylabel("Observation") |
|
|
| |
| ax.set_xscale(x_scale) |
| ax.set_yscale(y_scale) |
|
|
| |
| if legend_outside: |
| |
| ax.legend( |
| loc="center left", |
| bbox_to_anchor=(1.02, 0.5), |
| borderaxespad=0.0, |
| frameon=False, |
| ) |
| plt.tight_layout(rect=[0, 0, 0.82, 1]) |
| else: |
| ax.legend(frameon=False) |
| plt.tight_layout() |
|
|
| |
| if save_dir is not None: |
| from pathlib import Path |
| study_name = mapping[substance_label]["study_name"] |
| index = mapping[substance_label]["index"] |
| Path(save_dir).mkdir(parents=True, exist_ok=True) |
| filename = f"{study_name}_{substance_label}_{index}.png" |
| filepath = Path(save_dir) / filename |
| fig.savefig(filepath, bbox_inches="tight", dpi=300) |
|
|
| plt.show() |
|
|
| def substances_with_min_timesteps( |
| drug_data_frame, |
| min_timesteps: int = 140, |
| *, |
| z_score_normalization: bool = False, |
| normalize_by_max:bool = False, |
| ) -> List[str]: |
| """ |
| Return the list of substance labels whose **best** subject has |
| β₯ `min_timesteps` valid observations. |
| |
| Parameters |
| ---------- |
| drug_data_frame : pandas.DataFrame |
| Same dataframe you already pass to `substance_cvs_to_tensors_from_list`. |
| min_timesteps : int, default = 140 |
| Threshold on the number of valid (unpadded) timeβpoints. |
| z_score_normalization : bool, default = False |
| Passed straight through to `substance_cvs_to_tensors_from_list`. |
| |
| Returns |
| ------- |
| List[str] |
| Substance strings that satisfy the criterion. |
| """ |
| ( |
| all_observations, |
| all_times, |
| all_masks, |
| all_subjects_mask, |
| substance_labels, |
| mapping |
| ) = substance_cvs_to_tensors_bundle( |
| drug_data_frame, |
| z_score_normalization=z_score_normalization, |
| normalize_by_max=normalize_by_max |
| ) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| valid_masks: TensorType["S", "I", "T"] = all_masks.bool() |
| subj_mask: TensorType["S", "I", 1] = all_subjects_mask.bool().unsqueeze(-1) |
| valid_masks = valid_masks & subj_mask |
|
|
| |
| |
| counts: TensorType["S", "I"] = valid_masks.sum(dim=2) |
|
|
| |
| max_counts: TensorType["S"] = counts.max(dim=1).values |
|
|
| |
| qualifying: TensorType["S"] = max_counts >= min_timesteps |
|
|
| |
| return [label for label, keep in zip(substance_labels.tolist(), qualifying.tolist()) if keep] |
|
|
| def get_substance_tensors_by_label( |
| drug_data_frame, |
| substance_label: str, |
| *, |
| z_score_normalization: bool = False, |
| normalize_by_max: bool = False, |
| ) -> SubstanceTensorGroup: |
| """ |
| Returns tensors for a selected substance, preserving S=1 batch shape. |
| |
| Shapes: |
| observations : [1, I, T] |
| times : [1, I, T] |
| mask : [1, I, T] |
| subject_mask : [1, I] |
| """ |
| data_bundle = substance_cvs_to_tensors_bundle(drug_data_frame, |
| z_score_normalization=z_score_normalization, |
| normalize_by_max=normalize_by_max) |
|
|
| all_observations = data_bundle.observations |
| all_empirical_times = data_bundle.times |
| all_empirical_mask = data_bundle.masks |
| all_subjects_mask = data_bundle.individuals_mask |
| substance_labels = data_bundle.substance_names |
| mapping = data_bundle.mapping |
|
|
| |
| label_to_index = {label: idx for idx, label in enumerate(substance_labels)} |
| if substance_label not in label_to_index: |
| raise ValueError(f"Substance label '{substance_label}' not found.") |
| s_idx = label_to_index[substance_label] |
|
|
| |
| return SubstanceTensorGroup( |
| observations=all_observations[s_idx].unsqueeze(0), |
| times=all_empirical_times[s_idx].unsqueeze(0), |
| mask=all_empirical_mask[s_idx].unsqueeze(0).bool(), |
| subject_mask=all_subjects_mask[s_idx].unsqueeze(0).bool() |
| ) |
|
|
|
|