| """Utilities for plotting :class:`AICMECompartmentsDataBatch` objects.""" |
|
|
| from __future__ import annotations |
|
|
| import re |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
| from matplotlib.lines import Line2D |
| from torchtyping import TensorType |
|
|
| from sim_priors_pk.data.data_empirical.builder import EmpiricalBatchConfig, JSON2AICMEBuilder |
| from sim_priors_pk.data.data_empirical.json_schema import ( |
| StudyJSON, |
| canonicalize_study, |
| prediction_stats, |
| ) |
| from sim_priors_pk.data.datasets.aicme_batch import AICMECompartmentsDataBatch |
| import matplotlib |
|
|
| matplotlib.use("Agg") |
|
|
| |
| IndividualJSON = Dict[str, object] |
| from sim_priors_pk.config_classes.data_config import MetaDosingConfig |
|
|
| |
| CONTEXT_OBS_COLOR = "tab:green" |
| CONTEXT_REM_COLOR = "lightgreen" |
| TARGET_OBS_COLOR = "blue" |
| TARGET_REM_COLOR = "red" |
|
|
|
|
| def _detach_to_cpu(batch: AICMECompartmentsDataBatch) -> AICMECompartmentsDataBatch: |
| """Detach all tensors from computation graph and move to CPU.""" |
| return batch.detach_all().to_device(torch.device("cpu")) |
|
|
|
|
| def plot_aicme_databatch( |
| databatch: AICMECompartmentsDataBatch, |
| *, |
| batch_index: int = 0, |
| ax: Optional[plt.Axes] = None, |
| log_scale: bool = True, |
| file_name: Optional[str] = None, |
| point_size: int = 5, |
| line_width: float = 0.75, |
| point_marker: str = "o", |
| context_obs_color: str = CONTEXT_OBS_COLOR, |
| context_rem_color: str = CONTEXT_REM_COLOR, |
| target_obs_color: str = TARGET_OBS_COLOR, |
| target_rem_color: str = TARGET_REM_COLOR, |
| axis_label_font_size: Optional[float] = None, |
| tick_label_font_size: Optional[float] = None, |
| ) -> plt.Axes: |
| """Plot one batch entry with configurable marker size/style and colors.""" |
|
|
| batch_cpu = _detach_to_cpu(databatch) |
|
|
| if ax is None: |
| fig, ax = plt.subplots() |
| else: |
| fig = ax.figure |
|
|
| context_alpha = 0.5 |
| line_alpha = 0.6 |
|
|
| |
| for ind in range(batch_cpu.context_obs.shape[1]): |
| mask = batch_cpu.context_obs_mask[batch_index, ind].bool() |
| times = batch_cpu.context_obs_time[batch_index, ind, mask, 0] |
| values = batch_cpu.context_obs[batch_index, ind, mask, 0] |
| ax.scatter( |
| times, |
| values, |
| color=context_obs_color, |
| s=point_size, |
| alpha=context_alpha, |
| marker=point_marker, |
| ) |
| ax.plot(times, values, color=context_obs_color, linewidth=line_width, alpha=line_alpha) |
|
|
| |
| for ind in range(batch_cpu.context_rem_sim.shape[1]): |
| mask = batch_cpu.context_rem_sim_mask[batch_index, ind].bool() |
| times = batch_cpu.context_rem_sim_time[batch_index, ind, mask, 0] |
| values = batch_cpu.context_rem_sim[batch_index, ind, mask, 0] |
| ax.scatter( |
| times, |
| values, |
| color=context_rem_color, |
| s=point_size, |
| alpha=context_alpha, |
| marker=point_marker, |
| ) |
| ax.plot(times, values, color=context_rem_color, linewidth=line_width, alpha=line_alpha) |
|
|
| |
| for ind in range(batch_cpu.target_obs.shape[1]): |
| mask = batch_cpu.target_obs_mask[batch_index, ind].bool() |
| times = batch_cpu.target_obs_time[batch_index, ind, mask, 0] |
| values = batch_cpu.target_obs[batch_index, ind, mask, 0] |
| ax.scatter(times, values, color=target_obs_color, s=point_size, marker=point_marker) |
| ax.plot(times, values, color=target_obs_color, linewidth=line_width, alpha=0.8) |
|
|
| |
| rem_mask = batch_cpu.target_rem_sim_mask[batch_index, ind].bool() |
| if rem_mask.any() and mask.any(): |
| last_obs_time = times[-1].item() |
| last_obs_val = values[-1].item() |
| first_rem_time = batch_cpu.target_rem_sim_time[batch_index, ind, rem_mask, 0][0].item() |
| first_rem_val = batch_cpu.target_rem_sim[batch_index, ind, rem_mask, 0][0].item() |
| ax.plot( |
| [last_obs_time, first_rem_time], |
| [last_obs_val, first_rem_val], |
| color="gray", |
| linestyle="--", |
| linewidth=line_width, |
| alpha=0.7, |
| ) |
|
|
| |
| for ind in range(batch_cpu.target_rem_sim.shape[1]): |
| mask = batch_cpu.target_rem_sim_mask[batch_index, ind].bool() |
| times = batch_cpu.target_rem_sim_time[batch_index, ind, mask, 0] |
| values = batch_cpu.target_rem_sim[batch_index, ind, mask, 0] |
| ax.scatter(times, values, color=target_rem_color, s=point_size, marker=point_marker) |
| ax.plot(times, values, color=target_rem_color, linewidth=1, alpha=0.8) |
|
|
| if log_scale: |
| ax.set_yscale("log") |
|
|
| ax.set_xlabel("time") |
| ax.set_ylabel("concentration") |
| if axis_label_font_size is not None: |
| ax.set_xlabel("time", fontsize=float(axis_label_font_size)) |
| ax.set_ylabel("concentration", fontsize=float(axis_label_font_size)) |
| if tick_label_font_size is not None: |
| ax.tick_params(axis="both", labelsize=float(tick_label_font_size)) |
|
|
| if file_name is not None: |
| Path(file_name).parent.mkdir(parents=True, exist_ok=True) |
| fig.savefig(file_name, bbox_inches="tight") |
| plt.close(fig) |
|
|
| return ax |
|
|
|
|
| def plot_list_aicme_databatch( |
| databatch_list: List[AICMECompartmentsDataBatch], |
| file_name: Optional[str] = None, |
| number_of_rows: int = 3, |
| number_of_columns: int = 3, |
| log_scale: bool = True, |
| ) -> Optional[str]: |
| """Plot a grid of :class:`AICMECompartmentsDataBatch` objects. |
| |
| Parameters |
| ---------- |
| databatch_list: |
| List of batches to plot. |
| file_name: |
| Path where the figure should be saved. If ``None`` the plot is not saved. |
| number_of_rows: |
| Maximum number of rows in the grid. |
| number_of_columns: |
| Maximum number of columns in the grid. |
| log_scale: |
| If ``True`` (default) the y-axis is set to logarithmic scale for all subplots. |
| |
| Returns |
| ------- |
| str | None |
| ``file_name`` if provided else ``None``. |
| """ |
| if not databatch_list: |
| return file_name |
|
|
| batch_size = databatch_list[0].target_obs.shape[0] |
| n_rows = min(number_of_rows, batch_size) |
| n_cols = min(number_of_columns, len(databatch_list)) |
|
|
| fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3 * n_rows)) |
| axes = axes.reshape(n_rows, n_cols) |
|
|
| for col in range(n_cols): |
| for row in range(n_rows): |
| ax = axes[row, col] |
| plot_aicme_databatch( |
| databatch_list[col], |
| batch_index=row, |
| ax=ax, |
| log_scale=log_scale, |
| ) |
|
|
| |
| first_batch = databatch_list[0] |
| for row in range(n_rows): |
| if row < len(first_batch.substance_name): |
| label = first_batch.substance_name[row] |
| if isinstance(label, tuple): |
| label = "" |
| axes[row, 0].set_ylabel(f"concentration\n{label}") |
|
|
| |
| for col in range(n_cols, axes.shape[1]): |
| for row in range(axes.shape[0]): |
| axes[row, col].axis("off") |
| for row in range(n_rows, axes.shape[0]): |
| for col in range(axes.shape[1]): |
| axes[row, col].axis("off") |
|
|
| fig.tight_layout() |
|
|
| if file_name is not None: |
| Path(file_name).parent.mkdir(parents=True, exist_ok=True) |
| fig.savefig(file_name, bbox_inches="tight") |
| plt.close(fig) |
| else: |
| plt.show() |
|
|
| return file_name |
|
|
|
|
| def plot_ind_json( |
| individual: IndividualJSON, |
| *, |
| file_name: Optional[str] = None, |
| log_scale: bool = True, |
| ) -> Optional[str]: |
| """Plot a single ``IndividualJSON`` record. |
| |
| Parameters |
| ---------- |
| individual: |
| Mapping describing one subject with keys such as ``"observations"`` and |
| ``"observation_times"``. |
| file_name: |
| Optional path where the resulting figure should be saved. If ``None`` |
| the figure is not written to disk. |
| log_scale: |
| If ``True`` (default) the y-axis of the plot is drawn on a logarithmic scale. |
| |
| Returns |
| ------- |
| str | None |
| ``file_name`` if provided else ``None``. |
| """ |
|
|
| max_obs = len(individual.get("observations", [])) |
| max_rem = len(individual.get("remaining", [])) |
|
|
| study = {"context": [], "target": [individual]} |
|
|
| builder = JSON2AICMEBuilder( |
| EmpiricalBatchConfig(max_observations=max_obs, max_remaining=max_rem, max_individuals=1) |
| ) |
| batch = builder.build_one_aicmebatch([study], MetaDosingConfig()) |
| B = batch.target_obs.shape[0] |
|
|
| return plot_list_aicme_databatch( |
| [batch], |
| file_name=file_name, |
| number_of_rows=B, |
| number_of_columns=1, |
| log_scale=log_scale, |
| ) |
|
|
|
|
| def plot_study_json( |
| study: StudyJSON, |
| *, |
| file_name: Optional[str] = None, |
| log_scale: bool = True, |
| ) -> Optional[str]: |
| """Plot a full ``StudyJSON`` record after canonicalization. |
| |
| The study is canonicalized using :func:`canonicalize_study` to ensure |
| consistent ordering and validation of individuals. The resulting study is |
| converted to an :class:`AICMECompartmentsDataBatch` and displayed using |
| :func:`plot_aicme_databatch`. |
| |
| Parameters |
| ---------- |
| study: |
| Mapping describing a study with ``"context"`` and ``"target"`` |
| individuals. |
| file_name: |
| Optional path where the resulting figure should be saved. If ``None`` |
| the figure is not written to disk. |
| log_scale: |
| If ``True`` (default) the y-axis of the plot is drawn on a logarithmic scale. |
| |
| Returns |
| ------- |
| str | None |
| ``file_name`` if provided else ``None``. |
| """ |
|
|
| canon = canonicalize_study(study) |
| all_inds = canon["context"] + canon["target"] |
|
|
| max_obs = max((len(ind.get("observations", [])) for ind in all_inds), default=0) |
| max_rem = max((len(ind.get("remaining", [])) for ind in all_inds), default=0) |
| max_inds = max(len(canon["context"]), len(canon["target"])) |
|
|
| builder = JSON2AICMEBuilder( |
| EmpiricalBatchConfig( |
| max_observations=max_obs, |
| max_remaining=max_rem, |
| max_individuals=max_inds, |
| ) |
| ) |
| batch = builder.build_one_aicmebatch([canon], MetaDosingConfig()) |
|
|
| plot_aicme_databatch(batch, batch_index=0, log_scale=log_scale, file_name=file_name) |
| return file_name |
|
|
|
|
| def plot_study_json_with_prediction( |
| study: StudyJSON, |
| *, |
| ax: Optional[plt.Axes] = None, |
| file_name: Optional[str] = None, |
| log_scale: bool = True, |
| point_size: int = 5, |
| line_width: float = 0.75, |
| point_marker: str = "o", |
| context_obs_color: str = CONTEXT_OBS_COLOR, |
| context_rem_color: str = CONTEXT_REM_COLOR, |
| target_obs_color: str = TARGET_OBS_COLOR, |
| target_rem_color: str = TARGET_REM_COLOR, |
| prediction_marker: str = "o", |
| prediction_marker_size: float = 4.0, |
| prediction_color: str = "black", |
| prediction_error_color: str = "gray", |
| prediction_line_style: str = "-", |
| figure_size: Optional[Tuple[float, float]] = None, |
| show_legend: bool = False, |
| legend_font_size: float = 10.0, |
| legend_loc: str = "best", |
| axis_label_font_size: Optional[float] = None, |
| tick_label_font_size: Optional[float] = None, |
| ) -> Optional[str]: |
| """Plot a ``StudyJSON`` record including prediction statistics. |
| |
| The function computes prediction means and standard deviations using |
| :func:`prediction_stats` and overlays them on top of the canonicalized study |
| plot. Prediction overlays respect the non‑contiguous |
| ``target_rem_sim_mask`` from the constructed |
| :class:`AICMECompartmentsDataBatch`, so padded/invalid future points are |
| never shown. |
| |
| Parameters |
| ---------- |
| study: |
| Study description to plot. |
| ax: |
| Optional Matplotlib axis to draw on. If ``None`` a new figure and axis |
| are created. |
| file_name: |
| If provided, the figure is stored at this path. |
| log_scale: |
| Whether to draw the y-axis on a logarithmic scale (``True`` by default). |
| point_size: |
| Marker area passed to the underlying observed/target scatter calls. |
| line_width: |
| Width of observed/target connecting lines. |
| point_marker: |
| Marker used for observed/target points (for example ``"o"`` for circles). |
| context_obs_color: |
| Color for context observation trajectories. |
| context_rem_color: |
| Color for context remainder trajectories. |
| target_obs_color: |
| Color for target observation trajectories. |
| target_rem_color: |
| Color for target remainder trajectories. |
| prediction_marker: |
| Marker used for predictive mean points. |
| prediction_marker_size: |
| Marker size used for predictive mean points. |
| prediction_color: |
| Color used for predictive mean markers and connecting line. |
| prediction_error_color: |
| Color used for predictive error bars. |
| prediction_line_style: |
| Linestyle used for predictive mean line. |
| figure_size: |
| Reserved for compatibility with ``plot_kwargs`` forwarding. Figure size |
| is managed by :func:`plot_list_list_study_json`, so this argument is |
| ignored here. |
| show_legend: |
| If ``True``, draws a legend for context/target/prediction elements. |
| legend_font_size: |
| Font size used for the legend. |
| legend_loc: |
| Matplotlib legend location string. |
| axis_label_font_size: |
| Font size for x/y axis labels. |
| tick_label_font_size: |
| Font size for x/y tick labels. |
| |
| Returns |
| ------- |
| str | None |
| ``file_name`` if provided else ``None``. |
| """ |
|
|
| _ = figure_size |
| study = prediction_stats(study) |
| canon = canonicalize_study(study, drop_tgt_too_few=False) |
|
|
| all_inds = canon["context"] + canon["target"] |
| max_obs = max((len(ind.get("observations", [])) for ind in all_inds), default=0) |
| max_rem = max((len(ind.get("remaining", [])) for ind in all_inds), default=0) |
| max_inds = max(len(canon["context"]), len(canon["target"])) |
|
|
| builder = JSON2AICMEBuilder( |
| EmpiricalBatchConfig( |
| max_observations=max_obs, |
| max_remaining=max_rem, |
| max_individuals=max_inds, |
| ) |
| ) |
| batch = builder.build_one_aicmebatch([canon], MetaDosingConfig()) |
| ax = plot_aicme_databatch( |
| batch, |
| batch_index=0, |
| ax=ax, |
| log_scale=log_scale, |
| file_name=None, |
| point_size=point_size, |
| line_width=line_width, |
| point_marker=point_marker, |
| context_obs_color=context_obs_color, |
| context_rem_color=context_rem_color, |
| target_obs_color=target_obs_color, |
| target_rem_color=target_rem_color, |
| axis_label_font_size=axis_label_font_size, |
| tick_label_font_size=tick_label_font_size, |
| ) |
|
|
| for it_idx, ind in enumerate(canon["target"]): |
| has_pred = ( |
| "prediction_times" in ind and "prediction_mean" in ind and "prediction_std" in ind |
| ) |
| if not has_pred: |
| continue |
|
|
| |
| times = torch.as_tensor(ind["prediction_times"], dtype=torch.float32).view(-1) |
| mean = torch.as_tensor(ind["prediction_mean"], dtype=torch.float32).view(-1) |
| std = torch.as_tensor(ind["prediction_std"], dtype=torch.float32).view(-1) |
|
|
| if times.numel() == 0: |
| continue |
|
|
| |
| |
| keep_mask = times != 0 |
| |
|
|
| if not keep_mask.any(): |
| continue |
|
|
| times = times[keep_mask] |
| mean = mean[keep_mask] |
| std = std[keep_mask] |
|
|
| ax.errorbar( |
| times, |
| mean, |
| yerr=std, |
| fmt=prediction_marker, |
| linestyle=prediction_line_style, |
| color=prediction_color, |
| ecolor=prediction_error_color, |
| elinewidth=1, |
| capsize=3, |
| markersize=prediction_marker_size, |
| ) |
|
|
| if show_legend: |
| marker_size = max(3.0, float(point_size) ** 0.5) |
| handles = [ |
| Line2D( |
| [0], |
| [0], |
| color=context_obs_color, |
| marker=point_marker, |
| linewidth=line_width, |
| markersize=marker_size, |
| label="Context Obs", |
| ), |
| Line2D( |
| [0], |
| [0], |
| color=context_rem_color, |
| marker=point_marker, |
| linewidth=line_width, |
| markersize=marker_size, |
| label="Context Remainder", |
| ), |
| Line2D( |
| [0], |
| [0], |
| color=target_obs_color, |
| marker=point_marker, |
| linewidth=line_width, |
| markersize=marker_size, |
| label="Target Obs", |
| ), |
| Line2D( |
| [0], |
| [0], |
| color=target_rem_color, |
| marker=point_marker, |
| linewidth=line_width, |
| markersize=marker_size, |
| label="Target Remainder", |
| ), |
| Line2D( |
| [0], |
| [0], |
| color=prediction_color, |
| marker=prediction_marker, |
| linewidth=1.0, |
| markersize=prediction_marker_size, |
| label="Prediction Mean", |
| ), |
| ] |
| ax.legend(handles=handles, fontsize=legend_font_size, loc=legend_loc) |
| if file_name is not None: |
| ax.figure.savefig(file_name) |
| return file_name |
|
|
|
|
| def _normalise_substance_name(raw_name: object, fallback: str) -> str: |
| """Return a clean substance name extracted from StudyJSON metadata. |
| |
| The empirical StudyJSON metadata occasionally stores the ``"substance_name"`` |
| field as strings, tuples or other containers. This helper converts the |
| value into a readable string and falls back to ``fallback`` whenever the |
| metadata entry is missing or empty. The normalised name is used both for |
| axis labelling and for generating deterministic file names when each study |
| is plotted separately. |
| """ |
|
|
| if raw_name is None: |
| return fallback |
|
|
| if isinstance(raw_name, str): |
| candidate = raw_name.strip() |
| return candidate or fallback |
|
|
| if isinstance(raw_name, (list, tuple)): |
| parts = [str(part).strip() for part in raw_name if str(part).strip()] |
| if parts: |
| return " ".join(parts) |
| return fallback |
|
|
| try: |
| candidate = str(raw_name).strip() |
| except Exception: |
| return fallback |
|
|
| return candidate or fallback |
|
|
|
|
| def _separate_plot_file_name( |
| base_file_name: str, |
| *, |
| substance_name: str, |
| permutation_index: int, |
| ) -> str: |
| """Return a filename for a single-study plot derived from ``base_file_name``. |
| |
| Parameters |
| ---------- |
| base_file_name: |
| Reference file name used when plotting multiple studies in a single figure. |
| substance_name: |
| Name of the simulated substance associated with the plot. The value is |
| sanitised so that it can safely be used inside the file name. |
| permutation_index: |
| Index of the permutation that produced the study. Including the |
| permutation makes every generated file name deterministic and unique. |
| |
| Returns |
| ------- |
| str |
| A new filename that appends ``substance_name`` and ``permutation`` |
| information to ``base_file_name`` while preserving the original suffix. |
| """ |
|
|
| base_path = Path(base_file_name) |
| stem = base_path.stem |
| suffix = base_path.suffix |
|
|
| safe_substance = re.sub(r"[^0-9A-Za-z]+", "_", substance_name).strip("_") |
| if not safe_substance: |
| safe_substance = "substance" |
|
|
| new_stem = f"{stem}_{safe_substance}_permutation_{permutation_index}" |
| return str(base_path.with_name(f"{new_stem}{suffix}")) |
|
|
|
|
| def plot_list_list_study_json( |
| studies: List[List[StudyJSON]], |
| *, |
| file_name: Optional[str] = None, |
| number_of_rows: Optional[int] = 3, |
| number_of_columns: Optional[int] = 3, |
| log_scale: bool = True, |
| plot_all_separately: bool = False, |
| plot_kwargs: Optional[Dict[str, Any]] = None, |
| ) -> Optional[Union[str, List[str]]]: |
| """Plot ``StudyJSON`` records either as a grid or as individual figures. |
| |
| When ``plot_all_separately`` is ``False`` (the default) the function retains |
| the historical behaviour and renders the studies on a single grid. When the |
| flag is ``True`` each study is saved to its own image whose name is derived |
| from ``file_name`` via :func:`_separate_plot_file_name`, and the list of |
| generated filenames is returned. |
| |
| Parameters |
| ---------- |
| number_of_rows: |
| Maximum number of different substances to render per permutation. When |
| ``None`` every available substance is shown. |
| number_of_columns: |
| Maximum number of permutations to render. When ``None`` every |
| permutation is shown. |
| plot_kwargs: |
| Optional keyword arguments forwarded to |
| :func:`plot_study_json_with_prediction` to control visual styling |
| (for example marker size, marker type, or colors). The special key |
| ``"figure_size"`` is consumed by this function to control Matplotlib |
| figure size: |
| - ``plot_all_separately=True``: per-image figure size. |
| - ``plot_all_separately=False``: full grid figure size. |
| The optional key ``"title"`` overrides the plot title text. |
| """ |
|
|
| |
| |
|
|
| if not studies or not studies[0]: |
| return file_name |
|
|
| study_plot_kwargs = dict(plot_kwargs) if plot_kwargs else {} |
| figure_size = study_plot_kwargs.pop("figure_size", None) |
| title_font_size = study_plot_kwargs.pop("title_font_size", None) |
| title_override = study_plot_kwargs.pop("title", None) |
| if title_override is not None and not isinstance(title_override, str): |
| raise ValueError("'plot_kwargs[\"title\"]' must be a string when provided.") |
| if figure_size is None: |
| separate_figsize = (4, 3) |
| grid_figsize = None |
| else: |
| if not isinstance(figure_size, (list, tuple)) or len(figure_size) != 2: |
| raise ValueError("'plot_kwargs[\"figure_size\"]' must be a 2-item tuple/list.") |
| width = float(figure_size[0]) |
| height = float(figure_size[1]) |
| if width <= 0 or height <= 0: |
| raise ValueError("'plot_kwargs[\"figure_size\"]' values must be > 0.") |
| separate_figsize = (width, height) |
| grid_figsize = (width, height) |
|
|
| if plot_all_separately: |
| if file_name is None: |
| raise ValueError("'file_name' must be provided when plotting separately") |
|
|
| separate_files: List[str] = [] |
| total_permutations = len(studies) |
| if number_of_columns is not None and number_of_columns <= 0: |
| raise ValueError("'number_of_columns' must be a positive integer or None") |
| if number_of_rows is not None and number_of_rows <= 0: |
| raise ValueError("'number_of_rows' must be a positive integer or None") |
|
|
| max_permutations = ( |
| total_permutations |
| if number_of_columns is None |
| else min(number_of_columns, total_permutations) |
| ) |
|
|
| for permutation_index in range(max_permutations): |
| permutation_studies = studies[permutation_index] |
| max_rows = ( |
| len(permutation_studies) |
| if number_of_rows is None |
| else min(number_of_rows, len(permutation_studies)) |
| ) |
|
|
| for row, study in enumerate(permutation_studies[:max_rows]): |
| fig, ax = plt.subplots(figsize=separate_figsize) |
| plot_study_json_with_prediction( |
| study, |
| ax=ax, |
| log_scale=log_scale, |
| **study_plot_kwargs, |
| ) |
|
|
| raw_substance_name = study["meta_data"].get("substance_name") |
| substance_name = _normalise_substance_name(raw_substance_name, f"substance_{row}") |
| display_title = title_override if title_override else substance_name |
| if title_font_size is not None: |
| ax.set_title(display_title, fontsize=float(title_font_size)) |
| else: |
| ax.set_title(display_title) |
| study_file_name = _separate_plot_file_name( |
| file_name, |
| substance_name=substance_name, |
| permutation_index=permutation_index, |
| ) |
|
|
| Path(study_file_name).parent.mkdir(parents=True, exist_ok=True) |
| fig.savefig(study_file_name, bbox_inches="tight") |
| plt.close(fig) |
| separate_files.append(study_file_name) |
|
|
| return separate_files |
|
|
| batch_size = len(studies[0]) |
|
|
| if number_of_rows is not None and number_of_rows <= 0: |
| raise ValueError("'number_of_rows' must be a positive integer or None") |
| if number_of_columns is not None and number_of_columns <= 0: |
| raise ValueError("'number_of_columns' must be a positive integer or None") |
|
|
| n_rows = batch_size if number_of_rows is None else min(number_of_rows, batch_size) |
| total_permutations = len(studies) |
| n_cols = ( |
| total_permutations |
| if number_of_columns is None |
| else min(number_of_columns, total_permutations) |
| ) |
|
|
| if grid_figsize is None: |
| grid_figsize = (4 * n_cols, 3 * n_rows) |
| fig, axes = plt.subplots(n_rows, n_cols, figsize=grid_figsize) |
| axes = np.atleast_2d(axes).reshape(n_rows, n_cols) |
|
|
| for col in range(n_cols): |
| permutation_studies = studies[col] |
| for row in range(n_rows): |
| ax = axes[row, col] |
| if row >= len(permutation_studies): |
| ax.axis("off") |
| continue |
| study = permutation_studies[row] |
| plot_study_json_with_prediction( |
| study, |
| ax=ax, |
| log_scale=log_scale, |
| **study_plot_kwargs, |
| ) |
|
|
| |
| if col == 0: |
| raw_substance_name = study["meta_data"].get("substance_name") |
| substance_name = _normalise_substance_name(raw_substance_name, f"substance_{row}") |
| ax.set_ylabel(substance_name, fontsize=10, rotation=90, labelpad=10) |
|
|
| |
| for col in range(n_cols, axes.shape[1]): |
| for row in range(axes.shape[0]): |
| axes[row, col].axis("off") |
| for row in range(n_rows, axes.shape[0]): |
| for col in range(axes.shape[1]): |
| axes[row, col].axis("off") |
|
|
| fig.tight_layout() |
|
|
| if file_name is not None: |
| Path(file_name).parent.mkdir(parents=True, exist_ok=True) |
| fig.savefig(file_name, bbox_inches="tight") |
| plt.close(fig) |
| else: |
| plt.show() |
|
|
| return file_name |
|
|