Spaces:
Running
Running
| """ | |
| Plotting functions for the ODE research platform. | |
| This module provides functions to create matplotlib figures | |
| for phase portraits, time series, and shadowing analysis. | |
| """ | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from typing import List, Dict, Tuple, Optional | |
| def make_phase_portrait_figure( | |
| solutions: List[Dict], | |
| selected_indices: List[int], | |
| solver_type: str = "DOP853", | |
| t_train_end: float = 1.0, | |
| t_full_end: float = 3.0, | |
| t_number: int = 100, | |
| show_connections: bool = False, | |
| connection_stride: int = 5, | |
| ) -> plt.Figure: | |
| """ | |
| Create phase portrait figure showing trajectories in phase space. | |
| Args: | |
| solutions: List of solution dictionaries with 'x', 'y', 't_full' | |
| selected_indices: List of indices to display | |
| solver_type: Type of solver used | |
| t_train_end: Training interval end time | |
| t_full_end: Full integration end time | |
| t_number: Number of time points | |
| show_connections: Whether to show connections between DOP853 and NN | |
| connection_stride: Stride for connection lines | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| selected_indices = [int(i) for i in selected_indices] if selected_indices else [] | |
| if not selected_indices: | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| ax.text(0.5, 0.5, "No trajectories selected", ha="center", va="center", transform=ax.transAxes) | |
| ax.set_axis_off() | |
| return fig | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| styles = ['-', '--', '-.', ':'] | |
| colors = plt.cm.tab20.colors | |
| for solution_data in solutions: | |
| traj_idx = solution_data["idx"] | |
| if traj_idx not in selected_indices: | |
| continue | |
| color = colors[traj_idx % len(colors)] | |
| x_dop = solution_data["x"] | |
| y_dop = solution_data["y"] | |
| ax.plot(x_dop, y_dop, linestyle='-', color=color, linewidth=1.2, | |
| label=f'DOP853 traj {traj_idx}' if traj_idx == selected_indices[0] else "") | |
| ax.plot(x_dop[0], y_dop[0], 'o', color=color, markersize=4) | |
| ax.plot(x_dop[-1], y_dop[-1], 'x', color=color, markersize=6) | |
| ax.text(x_dop[-1] + 0.01, y_dop[-1] + 0.01, f"{traj_idx}", fontsize=8, color=color) | |
| if "x_nn" in solution_data and solution_data["x_nn"] is not None: | |
| x_nn = solution_data["x_nn"] | |
| y_nn = solution_data["y_nn"] | |
| t_full = solution_data["t_full"] | |
| ax.plot(x_nn, y_nn, linestyle='--', color=color, linewidth=1.0, alpha=0.7, | |
| label=f'NN traj {traj_idx}' if traj_idx == selected_indices[0] else "") | |
| ax.plot(x_dop[0], y_dop[0], '^', color=color, markersize=8, markeredgecolor='black') | |
| if x_nn is not None and y_nn is not None: | |
| train_idx = np.searchsorted(t_full, t_train_end) | |
| if train_idx < len(x_dop): | |
| ax.plot(x_nn[train_idx], y_nn[train_idx], '^', color=color, markersize=8, | |
| markeredgecolor='black', markerfacecolor='none') | |
| ax.plot([x_dop[train_idx], x_nn[train_idx]], | |
| [y_dop[train_idx], y_nn[train_idx]], | |
| color=color, linewidth=1.0, alpha=0.7, linestyle=':') | |
| if show_connections: | |
| for i in range(0, len(x_dop), connection_stride): | |
| ax.plot([x_dop[i], x_nn[i]], [y_dop[i], y_nn[i]], | |
| color=color, linewidth=0.5, alpha=0.3, linestyle='-') | |
| ax.text(x_dop[-1] + 0.01, y_dop[-1] + 0.01, f"{traj_idx}", fontsize=8, color=color) | |
| if "x_nn" in solution_data and solution_data["x_nn"] is not None: | |
| if solution_data["x_nn"] is not None: | |
| ax.text(solution_data["x_nn"][-1] + 0.01, solution_data["y_nn"][-1] + 0.01, | |
| f"{traj_idx}", fontsize=8, color=color) | |
| ax.set_title(f"Gene regulatory trajectories ({solver_type}) " | |
| f"— t_train_end={t_train_end:.2f}, t_full_end={t_full_end:.2f}, t_points={t_number}") | |
| ax.set_xlabel("x(t)") | |
| ax.set_ylabel("y(t)") | |
| ax.grid(True) | |
| if len(selected_indices) <= 3: | |
| ax.legend() | |
| else: | |
| from matplotlib.lines import Line2D | |
| legend_elements = [ | |
| Line2D([0], [0], color='gray', lw=2, linestyle='-', label='DOP853'), | |
| Line2D([0], [0], color='gray', lw=2, linestyle='--', label='NN') | |
| ] | |
| ax.legend(handles=legend_elements) | |
| return fig | |
| def make_time_series_figure( | |
| solutions: List[Dict], | |
| selected_indices: List[int], | |
| solver_type: str = "DOP853", | |
| t_train_end: float = 1.0, | |
| t_full_end: float = 3.0, | |
| t_number: int = 100, | |
| ) -> plt.Figure: | |
| """ | |
| Create time series figure showing x(t) and y(t) over time. | |
| Args: | |
| solutions: List of solution dictionaries | |
| selected_indices: List of indices to display | |
| solver_type: Type of solver used | |
| t_train_end: Training interval end time | |
| t_full_end: Full integration end time | |
| t_number: Number of time points | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| selected_indices = [int(i) for i in selected_indices] if selected_indices else [] | |
| if not selected_indices: | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| ax.text(0.5, 0.5, "No trajectories selected", ha="center", va="center", transform=ax.transAxes) | |
| ax.set_axis_off() | |
| return fig | |
| fig, ax_ts = plt.subplots(figsize=(8, 6)) | |
| colors = plt.cm.tab20.colors | |
| for solution_data in solutions: | |
| traj_idx = solution_data["idx"] | |
| if traj_idx not in selected_indices: | |
| continue | |
| color = colors[traj_idx % len(colors)] | |
| t_full = solution_data["t_full"] | |
| x_dop = solution_data["x"] | |
| y_dop = solution_data["y"] | |
| ax_ts.plot(t_full, x_dop, linestyle='-', color=color, linewidth=1.2, | |
| label=f'x(t) traj {traj_idx}' if traj_idx == selected_indices[0] else "") | |
| ax_ts.plot(t_full, y_dop, linestyle='--', color=color, linewidth=1.2, | |
| label=f'y(t) traj {traj_idx}' if traj_idx == selected_indices[0] else "") | |
| ax_ts.set_title(f"Time series ({solver_type}) — x(t) and y(t) — " | |
| f"t_train_end={t_train_end:.2f}, t_full_end={t_full_end:.2f}, t_points={t_number}") | |
| ax_ts.set_xlabel("t") | |
| ax_ts.set_ylabel("x(t), y(t)") | |
| ax_ts.grid(True) | |
| ax_ts.legend() | |
| return fig | |
| def make_shadowing_figure( | |
| solutions: List[Dict], | |
| selected_indices: List[int], | |
| solver_type: str = "DOP853", | |
| t_train_end: float = 1.0, | |
| t_full_end: float = 3.0, | |
| t_number: int = 100, | |
| epsilon_threshold: float = 1e-3, | |
| rhs_func=None, | |
| ) -> Optional[plt.Figure]: | |
| """ | |
| Create shadowing figure showing epsilon(t) vs time. | |
| Args: | |
| solutions: List of solution dictionaries | |
| selected_indices: List of indices to display | |
| solver_type: Type of solver used | |
| t_train_end: Training interval end time | |
| t_full_end: Full integration end time | |
| t_number: Number of time points | |
| epsilon_threshold: Threshold for shadowing breakdown | |
| rhs_func: Right-hand side function for ODE system (for correct shadowing) | |
| Returns: | |
| matplotlib Figure object or None if no shadowing data available | |
| """ | |
| selected_indices = [int(i) for i in selected_indices] if selected_indices else [] | |
| if not selected_indices: | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| ax.text(0.5, 0.5, "No trajectories selected", ha="center", va="center", transform=ax.transAxes) | |
| ax.set_axis_off() | |
| return fig | |
| fig, ax_shad = plt.subplots(figsize=(8, 6)) | |
| colors = plt.cm.tab20.colors | |
| annotated = False | |
| for solution_data in solutions: | |
| traj_idx = solution_data["idx"] | |
| if traj_idx not in selected_indices: | |
| continue | |
| color = colors[traj_idx % len(colors)] | |
| t_full = solution_data["t_full"] | |
| x_dop = solution_data["x"] | |
| y_dop = solution_data["y"] | |
| eps = 1e-6 * (1.0 + abs(x_dop[0]) + abs(y_dop[0])) | |
| xp0, yp0 = x_dop[0] + eps, y_dop[0] + 0.5 * eps | |
| if rhs_func is not None: | |
| try: | |
| sol_p = solve_ivp( | |
| rhs_func, | |
| (t_full[0], t_full[-1]), | |
| (xp0, yp0), | |
| method='DOP853', | |
| t_eval=t_full | |
| ) | |
| if sol_p.success: | |
| xp, yp = sol_p.y | |
| dist = np.sqrt((x_dop - xp)**2 + (y_dop - yp)**2) | |
| epsilon_t = np.maximum.accumulate(dist) | |
| ax_shad.plot(epsilon_t, t_full, color=color, linewidth=1.2, | |
| label=f'ε(t) traj {traj_idx}' if traj_idx == selected_indices[0] else "") | |
| exceed_indices = np.where(epsilon_t > epsilon_threshold)[0] | |
| if len(exceed_indices) > 0: | |
| first_exceed_idx = exceed_indices[0] | |
| shadowing_time = t_full[first_exceed_idx] | |
| ax_shad.axhline(y=shadowing_time, color=color, linestyle=':', alpha=0.7, | |
| label=f't*={shadowing_time:.2f}' if not annotated else "") | |
| annotated = True | |
| except Exception: | |
| pass | |
| else: | |
| xp0, yp0 = x_dop[0] + eps, y_dop[0] + 0.5 * eps | |
| try: | |
| from scipy.integrate import solve_ivp | |
| sol_p = solve_ivp( | |
| lambda t, s: [0, 0], | |
| (t_full[0], t_full[-1]), | |
| (xp0, yp0), | |
| method='DOP853', | |
| t_eval=t_full | |
| ) | |
| if sol_p.success: | |
| xp, yp = sol_p.y | |
| dist = np.sqrt((x_dop - xp)**2 + (y_dop - yp)**2) | |
| epsilon_t = np.maximum.accumulate(dist) | |
| ax_shad.plot(epsilon_t, t_full, color=color, linewidth=1.2, | |
| label=f'ε(t) traj {traj_idx}' if traj_idx == selected_indices[0] else "") | |
| exceed_indices = np.where(epsilon_t > epsilon_threshold)[0] | |
| if len(exceed_indices) > 0: | |
| first_exceed_idx = exceed_indices[0] | |
| shadowing_time = t_full[first_exceed_idx] | |
| ax_shad.axhline(y=shadowing_time, color=color, linestyle=':', alpha=0.7, | |
| label=f't*={shadowing_time:.2f}' if not annotated else "") | |
| annotated = True | |
| except Exception: | |
| pass | |
| ax_shad.axhline(y=t_train_end, color='red', linestyle='--', alpha=0.7, | |
| label=f't_train_end={t_train_end}') | |
| ax_shad.set_title(f"Shadowing analysis ({solver_type}) — ε(t) vs t — Perturbed trajectory vs primary trajectory — " | |
| f"t_train_end={t_train_end:.2f}, t_full_end={t_full_end:.2f}, t_points={t_number}") | |
| ax_shad.set_xlabel("ε(t)") | |
| ax_shad.set_ylabel("t") | |
| ax_shad.grid(True) | |
| if len(selected_indices) <= 3: | |
| ax_shad.legend() | |
| else: | |
| from matplotlib.lines import Line2D | |
| legend_elements = [ | |
| Line2D([0], [0], color='gray', lw=2, linestyle='-', label='ε(t)'), | |
| Line2D([0], [0], color='red', lw=1, linestyle='--', label='train/extrap boundary') | |
| ] | |
| ax_shad.legend(handles=legend_elements) | |
| return fig | |
| def make_metrics_table_figure( | |
| metrics_df, | |
| show_extremes: bool = True, | |
| ) -> plt.Figure: | |
| """ | |
| Create a figure showing the metrics table. | |
| Args: | |
| metrics_df: DataFrame with metrics | |
| show_extremes: Whether to highlight extreme values | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| if metrics_df.empty: | |
| fig, ax = plt.subplots(figsize=(10, 4)) | |
| ax.text(0.5, 0.5, "No metrics to display", ha='center', va='center', fontsize=14) | |
| ax.set_axis_off() | |
| return fig | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| column_rename_map = { | |
| 'curv_radius_mean': 'curv_rad_mn', | |
| 'curv_radius_median': 'curv_rad_med', | |
| 'curv_radius_std': 'curv_rad_std', | |
| 'curv_radius_local_zscore': 'curv_rad_lcl_z', | |
| 'curv_count_finite': 'curv_ct_fin' | |
| } | |
| df_display = metrics_df.rename(columns=column_rename_map).copy() | |
| df_display = df_display.reset_index(drop=True) | |
| numeric_columns = df_display.select_dtypes(include=['number']).columns.tolist() | |
| if show_extremes: | |
| max_cols = ['idx', 'ftle', 'ftle_r2', 'amp', 'final_dist', 'hurst', 'curv_ct_fin', | |
| 'path_len', 'max_kappa', 'frac_high_curv', 'anomaly_score'] | |
| min_cols = ['curv_rad_mn', 'curv_rad_med', 'curv_rad_std', 'curv_rad_lcl_z', 'curv_p10', 'curv_p90'] | |
| styles = [['' for _ in range(len(df_display.columns))] for _ in range(len(df_display))] | |
| for i, col in enumerate(df_display.columns): | |
| if col == 'idx': | |
| continue | |
| if col in max_cols and df_display[col].dtype in ['float64', 'int64', 'float32', 'int32']: | |
| if not df_display[col].isna().all(): | |
| valid = df_display[col].dropna() | |
| if len(valid) >= 2: | |
| top2 = valid.nlargest(2) | |
| for idx in top2.index: | |
| styles[idx][i] = 'background-color: #D2691E' | |
| elif len(valid) == 1: | |
| styles[valid.index[0]][i] = 'background-color: #D2691E' | |
| elif col in min_cols and df_display[col].dtype in ['float64', 'int64', 'float32', 'int32']: | |
| if not df_display[col].isna().all(): | |
| valid = df_display[col].dropna() | |
| if len(valid) >= 2: | |
| bot2 = valid.nsmallest(2) | |
| for idx in bot2.index: | |
| styles[idx][i] = 'background-color: #00CED1' | |
| elif len(valid) == 1: | |
| styles[valid.index[0]][i] = 'background-color: #00CED1' | |
| df_display = df_display.style.apply(lambda x: styles, axis=None) | |
| ax.axis('off') | |
| table_str = df_display.format("{:.3f}")._repr_html_() | |
| ax.text(0.01, 0.99, table_str, transform=ax.transAxes, ha='left', va='top', fontsize=8) | |
| return fig | |