gp_visualizer / backend /src /manager.py
Joel Woodfield
Refactor to use manager in backend
de9ce02
from pathlib import Path
from typing import Literal
import matplotlib.pyplot as plt
from sympy import sympify
from logic import (
DataGenerationOptions,
Dataset,
PlotData,
compute_plot_values,
generate_dataset,
load_dataset_from_csv,
)
class Manager:
def __init__(self) -> None:
self.dataset = Dataset(x=[], y=[])
self.plots_data: PlotData | None = None
def update_dataset(
self,
dataset_type: Literal["Generate", "CSV"],
function: str,
data_xmin: float,
data_xmax: float,
sigma: float,
nsample: int,
sample_method: Literal["Grid", "Random"],
csv_path: str | Path | None,
has_header: bool,
xcol: int,
ycol: int,
) -> None:
if dataset_type == "Generate":
try:
parsed_function = sympify(function)
except Exception as exc:
raise ValueError(f"Invalid function: {exc}") from exc
sampling = sample_method.lower()
if sampling not in ["grid", "random"]:
raise ValueError(f"Unknown sampling method: {sample_method}")
self.dataset = generate_dataset(
parsed_function,
(data_xmin, data_xmax),
DataGenerationOptions(
method=sampling,
num_samples=nsample,
noise=sigma,
),
)
return
normalized_path = self._normalize_csv_path(csv_path)
if normalized_path is None:
raise ValueError("Please upload a CSV file.")
self.dataset = load_dataset_from_csv(
normalized_path,
has_header,
xcol,
ycol,
)
def compute_plot_data(
self,
kernel: str,
distribution: Literal["Prior", "Posterior"],
plot_xmin: float,
plot_xmax: float,
) -> None:
self.plots_data = compute_plot_values(
self.dataset,
kernel,
distribution,
plot_xmin,
plot_xmax,
)
def handle_generate_plots(
self,
dataset_type: Literal["Generate", "CSV"],
function: str,
data_xmin: float,
data_xmax: float,
sigma: float,
nsample: int,
sample_method: Literal["Grid", "Random"],
csv_path: str | Path | None,
has_header: bool,
xcol: int,
ycol: int,
kernel: str,
distribution: Literal["Prior", "Posterior"],
plot_xmin: float,
plot_xmax: float,
):
self.update_dataset(
dataset_type,
function,
data_xmin,
data_xmax,
sigma,
nsample,
sample_method,
csv_path,
has_header,
xcol,
ycol,
)
true_dataset = self._build_true_dataset(
dataset_type,
function,
plot_xmin,
plot_xmax,
)
self.compute_plot_data(
kernel,
distribution,
plot_xmin,
plot_xmax,
)
return self.generate_plot(true_dataset)
def generate_plot(self, true_dataset: Dataset):
if self.plots_data is None:
raise ValueError("Plot data has not been computed.")
fig, ax = plt.subplots(figsize=(12, 9))
cmap = plt.get_cmap("tab20")
ax.scatter(self.dataset.x, self.dataset.y, color=cmap(0), label="Data Points")
if true_dataset.y is not None and len(true_dataset.y) > 0:
ax.plot(true_dataset.x, true_dataset.y, color=cmap(1), label="True Function")
ax.plot(self.plots_data.x, self.plots_data.pred_mean, color=cmap(2), label="Mean Prediction")
ax.fill_between(
self.plots_data.x,
self.plots_data.pred_mean - 1.96 * self.plots_data.pred_std,
self.plots_data.pred_mean + 1.96 * self.plots_data.pred_std,
color=cmap(3),
alpha=0.2,
label="95% Confidence Interval",
)
ax.legend()
return fig
def _build_true_dataset(
self,
dataset_type: Literal["Generate", "CSV"],
function: str,
xmin: float,
xmax: float,
) -> Dataset:
if dataset_type == "CSV":
return Dataset(x=[], y=[])
try:
parsed_function = sympify(function)
except Exception as exc:
raise ValueError(f"Invalid function: {exc}") from exc
return generate_dataset(
parsed_function,
(xmin, xmax),
DataGenerationOptions(
method="grid",
num_samples=1000,
noise=0.0,
),
)
def _normalize_csv_path(self, csv_path: str | Path | None) -> str | None:
if csv_path is None:
return None
if isinstance(csv_path, Path):
return str(csv_path)
if isinstance(csv_path, str):
return csv_path
name = getattr(csv_path, "name", None)
if name:
return str(name)
return None