differ-hug / differ_hug /plotting.py
componavt's picture
fix of step 1
2e27876
Raw
History Blame Contribute Delete
14.8 kB
"""
Plotting functions for the ODE research platform.
This module provides functions to create matplotlib figures
for phase portraits, time series, and shadowing analysis.
"""
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Optional
def make_phase_portrait_figure(
solutions: List[Dict],
selected_indices: List[int],
solver_type: str = "DOP853",
t_train_end: float = 1.0,
t_full_end: float = 3.0,
t_number: int = 100,
show_connections: bool = False,
connection_stride: int = 5,
) -> plt.Figure:
"""
Create phase portrait figure showing trajectories in phase space.
Args:
solutions: List of solution dictionaries with 'x', 'y', 't_full'
selected_indices: List of indices to display
solver_type: Type of solver used
t_train_end: Training interval end time
t_full_end: Full integration end time
t_number: Number of time points
show_connections: Whether to show connections between DOP853 and NN
connection_stride: Stride for connection lines
Returns:
matplotlib Figure object
"""
selected_indices = [int(i) for i in selected_indices] if selected_indices else []
if not selected_indices:
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, "No trajectories selected", ha="center", va="center", transform=ax.transAxes)
ax.set_axis_off()
return fig
fig, ax = plt.subplots(figsize=(8, 6))
styles = ['-', '--', '-.', ':']
colors = plt.cm.tab20.colors
for solution_data in solutions:
traj_idx = solution_data["idx"]
if traj_idx not in selected_indices:
continue
color = colors[traj_idx % len(colors)]
x_dop = solution_data["x"]
y_dop = solution_data["y"]
ax.plot(x_dop, y_dop, linestyle='-', color=color, linewidth=1.2,
label=f'DOP853 traj {traj_idx}' if traj_idx == selected_indices[0] else "")
ax.plot(x_dop[0], y_dop[0], 'o', color=color, markersize=4)
ax.plot(x_dop[-1], y_dop[-1], 'x', color=color, markersize=6)
ax.text(x_dop[-1] + 0.01, y_dop[-1] + 0.01, f"{traj_idx}", fontsize=8, color=color)
if "x_nn" in solution_data and solution_data["x_nn"] is not None:
x_nn = solution_data["x_nn"]
y_nn = solution_data["y_nn"]
t_full = solution_data["t_full"]
ax.plot(x_nn, y_nn, linestyle='--', color=color, linewidth=1.0, alpha=0.7,
label=f'NN traj {traj_idx}' if traj_idx == selected_indices[0] else "")
ax.plot(x_dop[0], y_dop[0], '^', color=color, markersize=8, markeredgecolor='black')
if x_nn is not None and y_nn is not None:
train_idx = np.searchsorted(t_full, t_train_end)
if train_idx < len(x_dop):
ax.plot(x_nn[train_idx], y_nn[train_idx], '^', color=color, markersize=8,
markeredgecolor='black', markerfacecolor='none')
ax.plot([x_dop[train_idx], x_nn[train_idx]],
[y_dop[train_idx], y_nn[train_idx]],
color=color, linewidth=1.0, alpha=0.7, linestyle=':')
if show_connections:
for i in range(0, len(x_dop), connection_stride):
ax.plot([x_dop[i], x_nn[i]], [y_dop[i], y_nn[i]],
color=color, linewidth=0.5, alpha=0.3, linestyle='-')
ax.text(x_dop[-1] + 0.01, y_dop[-1] + 0.01, f"{traj_idx}", fontsize=8, color=color)
if "x_nn" in solution_data and solution_data["x_nn"] is not None:
if solution_data["x_nn"] is not None:
ax.text(solution_data["x_nn"][-1] + 0.01, solution_data["y_nn"][-1] + 0.01,
f"{traj_idx}", fontsize=8, color=color)
ax.set_title(f"Gene regulatory trajectories ({solver_type}) "
f"— t_train_end={t_train_end:.2f}, t_full_end={t_full_end:.2f}, t_points={t_number}")
ax.set_xlabel("x(t)")
ax.set_ylabel("y(t)")
ax.grid(True)
if len(selected_indices) <= 3:
ax.legend()
else:
from matplotlib.lines import Line2D
legend_elements = [
Line2D([0], [0], color='gray', lw=2, linestyle='-', label='DOP853'),
Line2D([0], [0], color='gray', lw=2, linestyle='--', label='NN')
]
ax.legend(handles=legend_elements)
return fig
def make_time_series_figure(
solutions: List[Dict],
selected_indices: List[int],
solver_type: str = "DOP853",
t_train_end: float = 1.0,
t_full_end: float = 3.0,
t_number: int = 100,
) -> plt.Figure:
"""
Create time series figure showing x(t) and y(t) over time.
Args:
solutions: List of solution dictionaries
selected_indices: List of indices to display
solver_type: Type of solver used
t_train_end: Training interval end time
t_full_end: Full integration end time
t_number: Number of time points
Returns:
matplotlib Figure object
"""
selected_indices = [int(i) for i in selected_indices] if selected_indices else []
if not selected_indices:
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, "No trajectories selected", ha="center", va="center", transform=ax.transAxes)
ax.set_axis_off()
return fig
fig, ax_ts = plt.subplots(figsize=(8, 6))
colors = plt.cm.tab20.colors
for solution_data in solutions:
traj_idx = solution_data["idx"]
if traj_idx not in selected_indices:
continue
color = colors[traj_idx % len(colors)]
t_full = solution_data["t_full"]
x_dop = solution_data["x"]
y_dop = solution_data["y"]
ax_ts.plot(t_full, x_dop, linestyle='-', color=color, linewidth=1.2,
label=f'x(t) traj {traj_idx}' if traj_idx == selected_indices[0] else "")
ax_ts.plot(t_full, y_dop, linestyle='--', color=color, linewidth=1.2,
label=f'y(t) traj {traj_idx}' if traj_idx == selected_indices[0] else "")
ax_ts.set_title(f"Time series ({solver_type}) — x(t) and y(t) — "
f"t_train_end={t_train_end:.2f}, t_full_end={t_full_end:.2f}, t_points={t_number}")
ax_ts.set_xlabel("t")
ax_ts.set_ylabel("x(t), y(t)")
ax_ts.grid(True)
ax_ts.legend()
return fig
def make_shadowing_figure(
solutions: List[Dict],
selected_indices: List[int],
solver_type: str = "DOP853",
t_train_end: float = 1.0,
t_full_end: float = 3.0,
t_number: int = 100,
epsilon_threshold: float = 1e-3,
rhs_func=None,
) -> Optional[plt.Figure]:
"""
Create shadowing figure showing epsilon(t) vs time.
Args:
solutions: List of solution dictionaries
selected_indices: List of indices to display
solver_type: Type of solver used
t_train_end: Training interval end time
t_full_end: Full integration end time
t_number: Number of time points
epsilon_threshold: Threshold for shadowing breakdown
rhs_func: Right-hand side function for ODE system (for correct shadowing)
Returns:
matplotlib Figure object or None if no shadowing data available
"""
selected_indices = [int(i) for i in selected_indices] if selected_indices else []
if not selected_indices:
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, "No trajectories selected", ha="center", va="center", transform=ax.transAxes)
ax.set_axis_off()
return fig
fig, ax_shad = plt.subplots(figsize=(8, 6))
colors = plt.cm.tab20.colors
annotated = False
for solution_data in solutions:
traj_idx = solution_data["idx"]
if traj_idx not in selected_indices:
continue
color = colors[traj_idx % len(colors)]
t_full = solution_data["t_full"]
x_dop = solution_data["x"]
y_dop = solution_data["y"]
eps = 1e-6 * (1.0 + abs(x_dop[0]) + abs(y_dop[0]))
xp0, yp0 = x_dop[0] + eps, y_dop[0] + 0.5 * eps
if rhs_func is not None:
try:
sol_p = solve_ivp(
rhs_func,
(t_full[0], t_full[-1]),
(xp0, yp0),
method='DOP853',
t_eval=t_full
)
if sol_p.success:
xp, yp = sol_p.y
dist = np.sqrt((x_dop - xp)**2 + (y_dop - yp)**2)
epsilon_t = np.maximum.accumulate(dist)
ax_shad.plot(epsilon_t, t_full, color=color, linewidth=1.2,
label=f'ε(t) traj {traj_idx}' if traj_idx == selected_indices[0] else "")
exceed_indices = np.where(epsilon_t > epsilon_threshold)[0]
if len(exceed_indices) > 0:
first_exceed_idx = exceed_indices[0]
shadowing_time = t_full[first_exceed_idx]
ax_shad.axhline(y=shadowing_time, color=color, linestyle=':', alpha=0.7,
label=f't*={shadowing_time:.2f}' if not annotated else "")
annotated = True
except Exception:
pass
else:
xp0, yp0 = x_dop[0] + eps, y_dop[0] + 0.5 * eps
try:
from scipy.integrate import solve_ivp
sol_p = solve_ivp(
lambda t, s: [0, 0],
(t_full[0], t_full[-1]),
(xp0, yp0),
method='DOP853',
t_eval=t_full
)
if sol_p.success:
xp, yp = sol_p.y
dist = np.sqrt((x_dop - xp)**2 + (y_dop - yp)**2)
epsilon_t = np.maximum.accumulate(dist)
ax_shad.plot(epsilon_t, t_full, color=color, linewidth=1.2,
label=f'ε(t) traj {traj_idx}' if traj_idx == selected_indices[0] else "")
exceed_indices = np.where(epsilon_t > epsilon_threshold)[0]
if len(exceed_indices) > 0:
first_exceed_idx = exceed_indices[0]
shadowing_time = t_full[first_exceed_idx]
ax_shad.axhline(y=shadowing_time, color=color, linestyle=':', alpha=0.7,
label=f't*={shadowing_time:.2f}' if not annotated else "")
annotated = True
except Exception:
pass
ax_shad.axhline(y=t_train_end, color='red', linestyle='--', alpha=0.7,
label=f't_train_end={t_train_end}')
ax_shad.set_title(f"Shadowing analysis ({solver_type}) — ε(t) vs t — Perturbed trajectory vs primary trajectory — "
f"t_train_end={t_train_end:.2f}, t_full_end={t_full_end:.2f}, t_points={t_number}")
ax_shad.set_xlabel("ε(t)")
ax_shad.set_ylabel("t")
ax_shad.grid(True)
if len(selected_indices) <= 3:
ax_shad.legend()
else:
from matplotlib.lines import Line2D
legend_elements = [
Line2D([0], [0], color='gray', lw=2, linestyle='-', label='ε(t)'),
Line2D([0], [0], color='red', lw=1, linestyle='--', label='train/extrap boundary')
]
ax_shad.legend(handles=legend_elements)
return fig
def make_metrics_table_figure(
metrics_df,
show_extremes: bool = True,
) -> plt.Figure:
"""
Create a figure showing the metrics table.
Args:
metrics_df: DataFrame with metrics
show_extremes: Whether to highlight extreme values
Returns:
matplotlib Figure object
"""
if metrics_df.empty:
fig, ax = plt.subplots(figsize=(10, 4))
ax.text(0.5, 0.5, "No metrics to display", ha='center', va='center', fontsize=14)
ax.set_axis_off()
return fig
fig, ax = plt.subplots(figsize=(12, 6))
column_rename_map = {
'curv_radius_mean': 'curv_rad_mn',
'curv_radius_median': 'curv_rad_med',
'curv_radius_std': 'curv_rad_std',
'curv_radius_local_zscore': 'curv_rad_lcl_z',
'curv_count_finite': 'curv_ct_fin'
}
df_display = metrics_df.rename(columns=column_rename_map).copy()
df_display = df_display.reset_index(drop=True)
numeric_columns = df_display.select_dtypes(include=['number']).columns.tolist()
if show_extremes:
max_cols = ['idx', 'ftle', 'ftle_r2', 'amp', 'final_dist', 'hurst', 'curv_ct_fin',
'path_len', 'max_kappa', 'frac_high_curv', 'anomaly_score']
min_cols = ['curv_rad_mn', 'curv_rad_med', 'curv_rad_std', 'curv_rad_lcl_z', 'curv_p10', 'curv_p90']
styles = [['' for _ in range(len(df_display.columns))] for _ in range(len(df_display))]
for i, col in enumerate(df_display.columns):
if col == 'idx':
continue
if col in max_cols and df_display[col].dtype in ['float64', 'int64', 'float32', 'int32']:
if not df_display[col].isna().all():
valid = df_display[col].dropna()
if len(valid) >= 2:
top2 = valid.nlargest(2)
for idx in top2.index:
styles[idx][i] = 'background-color: #D2691E'
elif len(valid) == 1:
styles[valid.index[0]][i] = 'background-color: #D2691E'
elif col in min_cols and df_display[col].dtype in ['float64', 'int64', 'float32', 'int32']:
if not df_display[col].isna().all():
valid = df_display[col].dropna()
if len(valid) >= 2:
bot2 = valid.nsmallest(2)
for idx in bot2.index:
styles[idx][i] = 'background-color: #00CED1'
elif len(valid) == 1:
styles[valid.index[0]][i] = 'background-color: #00CED1'
df_display = df_display.style.apply(lambda x: styles, axis=None)
ax.axis('off')
table_str = df_display.format("{:.3f}")._repr_html_()
ax.text(0.01, 0.99, table_str, transform=ax.transAxes, ha='left', va='top', fontsize=8)
return fig