from pathlib import Path from typing import Literal import matplotlib.pyplot as plt from sympy import sympify from logic import ( DataGenerationOptions, Dataset, PlotData, compute_plot_values, generate_dataset, load_dataset_from_csv, ) class Manager: def __init__(self) -> None: self.dataset = Dataset(x=[], y=[]) self.plots_data: PlotData | None = None def update_dataset( self, dataset_type: Literal["Generate", "CSV"], function: str, data_xmin: float, data_xmax: float, sigma: float, nsample: int, sample_method: Literal["Grid", "Random"], csv_path: str | Path | None, has_header: bool, xcol: int, ycol: int, ) -> None: if dataset_type == "Generate": try: parsed_function = sympify(function) except Exception as exc: raise ValueError(f"Invalid function: {exc}") from exc sampling = sample_method.lower() if sampling not in ["grid", "random"]: raise ValueError(f"Unknown sampling method: {sample_method}") self.dataset = generate_dataset( parsed_function, (data_xmin, data_xmax), DataGenerationOptions( method=sampling, num_samples=nsample, noise=sigma, ), ) return normalized_path = self._normalize_csv_path(csv_path) if normalized_path is None: raise ValueError("Please upload a CSV file.") self.dataset = load_dataset_from_csv( normalized_path, has_header, xcol, ycol, ) def compute_plot_data( self, kernel: str, distribution: Literal["Prior", "Posterior"], plot_xmin: float, plot_xmax: float, ) -> None: self.plots_data = compute_plot_values( self.dataset, kernel, distribution, plot_xmin, plot_xmax, ) def handle_generate_plots( self, dataset_type: Literal["Generate", "CSV"], function: str, data_xmin: float, data_xmax: float, sigma: float, nsample: int, sample_method: Literal["Grid", "Random"], csv_path: str | Path | None, has_header: bool, xcol: int, ycol: int, kernel: str, distribution: Literal["Prior", "Posterior"], plot_xmin: float, plot_xmax: float, ): self.update_dataset( dataset_type, function, data_xmin, data_xmax, sigma, nsample, sample_method, csv_path, has_header, xcol, ycol, ) true_dataset = self._build_true_dataset( dataset_type, function, plot_xmin, plot_xmax, ) self.compute_plot_data( kernel, distribution, plot_xmin, plot_xmax, ) return self.generate_plot(true_dataset) def generate_plot(self, true_dataset: Dataset): if self.plots_data is None: raise ValueError("Plot data has not been computed.") fig, ax = plt.subplots(figsize=(12, 9)) cmap = plt.get_cmap("tab20") ax.scatter(self.dataset.x, self.dataset.y, color=cmap(0), label="Data Points") if true_dataset.y is not None and len(true_dataset.y) > 0: ax.plot(true_dataset.x, true_dataset.y, color=cmap(1), label="True Function") ax.plot(self.plots_data.x, self.plots_data.pred_mean, color=cmap(2), label="Mean Prediction") ax.fill_between( self.plots_data.x, self.plots_data.pred_mean - 1.96 * self.plots_data.pred_std, self.plots_data.pred_mean + 1.96 * self.plots_data.pred_std, color=cmap(3), alpha=0.2, label="95% Confidence Interval", ) ax.legend() return fig def _build_true_dataset( self, dataset_type: Literal["Generate", "CSV"], function: str, xmin: float, xmax: float, ) -> Dataset: if dataset_type == "CSV": return Dataset(x=[], y=[]) try: parsed_function = sympify(function) except Exception as exc: raise ValueError(f"Invalid function: {exc}") from exc return generate_dataset( parsed_function, (xmin, xmax), DataGenerationOptions( method="grid", num_samples=1000, noise=0.0, ), ) def _normalize_csv_path(self, csv_path: str | Path | None) -> str | None: if csv_path is None: return None if isinstance(csv_path, Path): return str(csv_path) if isinstance(csv_path, str): return csv_path name = getattr(csv_path, "name", None) if name: return str(name) return None