Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| from typing import Literal | |
| import matplotlib.pyplot as plt | |
| from sympy import sympify | |
| from logic import ( | |
| DataGenerationOptions, | |
| Dataset, | |
| PlotData, | |
| compute_plot_values, | |
| generate_dataset, | |
| load_dataset_from_csv, | |
| ) | |
| class Manager: | |
| def __init__(self) -> None: | |
| self.dataset = Dataset(x=[], y=[]) | |
| self.plots_data: PlotData | None = None | |
| def update_dataset( | |
| self, | |
| dataset_type: Literal["Generate", "CSV"], | |
| function: str, | |
| data_xmin: float, | |
| data_xmax: float, | |
| sigma: float, | |
| nsample: int, | |
| sample_method: Literal["Grid", "Random"], | |
| csv_path: str | Path | None, | |
| has_header: bool, | |
| xcol: int, | |
| ycol: int, | |
| ) -> None: | |
| if dataset_type == "Generate": | |
| try: | |
| parsed_function = sympify(function) | |
| except Exception as exc: | |
| raise ValueError(f"Invalid function: {exc}") from exc | |
| sampling = sample_method.lower() | |
| if sampling not in ["grid", "random"]: | |
| raise ValueError(f"Unknown sampling method: {sample_method}") | |
| self.dataset = generate_dataset( | |
| parsed_function, | |
| (data_xmin, data_xmax), | |
| DataGenerationOptions( | |
| method=sampling, | |
| num_samples=nsample, | |
| noise=sigma, | |
| ), | |
| ) | |
| return | |
| normalized_path = self._normalize_csv_path(csv_path) | |
| if normalized_path is None: | |
| raise ValueError("Please upload a CSV file.") | |
| self.dataset = load_dataset_from_csv( | |
| normalized_path, | |
| has_header, | |
| xcol, | |
| ycol, | |
| ) | |
| def compute_plot_data( | |
| self, | |
| kernel: str, | |
| distribution: Literal["Prior", "Posterior"], | |
| plot_xmin: float, | |
| plot_xmax: float, | |
| ) -> None: | |
| self.plots_data = compute_plot_values( | |
| self.dataset, | |
| kernel, | |
| distribution, | |
| plot_xmin, | |
| plot_xmax, | |
| ) | |
| def handle_generate_plots( | |
| self, | |
| dataset_type: Literal["Generate", "CSV"], | |
| function: str, | |
| data_xmin: float, | |
| data_xmax: float, | |
| sigma: float, | |
| nsample: int, | |
| sample_method: Literal["Grid", "Random"], | |
| csv_path: str | Path | None, | |
| has_header: bool, | |
| xcol: int, | |
| ycol: int, | |
| kernel: str, | |
| distribution: Literal["Prior", "Posterior"], | |
| plot_xmin: float, | |
| plot_xmax: float, | |
| ): | |
| self.update_dataset( | |
| dataset_type, | |
| function, | |
| data_xmin, | |
| data_xmax, | |
| sigma, | |
| nsample, | |
| sample_method, | |
| csv_path, | |
| has_header, | |
| xcol, | |
| ycol, | |
| ) | |
| true_dataset = self._build_true_dataset( | |
| dataset_type, | |
| function, | |
| plot_xmin, | |
| plot_xmax, | |
| ) | |
| self.compute_plot_data( | |
| kernel, | |
| distribution, | |
| plot_xmin, | |
| plot_xmax, | |
| ) | |
| return self.generate_plot(true_dataset) | |
| def generate_plot(self, true_dataset: Dataset): | |
| if self.plots_data is None: | |
| raise ValueError("Plot data has not been computed.") | |
| fig, ax = plt.subplots(figsize=(12, 9)) | |
| cmap = plt.get_cmap("tab20") | |
| ax.scatter(self.dataset.x, self.dataset.y, color=cmap(0), label="Data Points") | |
| if true_dataset.y is not None and len(true_dataset.y) > 0: | |
| ax.plot(true_dataset.x, true_dataset.y, color=cmap(1), label="True Function") | |
| ax.plot(self.plots_data.x, self.plots_data.pred_mean, color=cmap(2), label="Mean Prediction") | |
| ax.fill_between( | |
| self.plots_data.x, | |
| self.plots_data.pred_mean - 1.96 * self.plots_data.pred_std, | |
| self.plots_data.pred_mean + 1.96 * self.plots_data.pred_std, | |
| color=cmap(3), | |
| alpha=0.2, | |
| label="95% Confidence Interval", | |
| ) | |
| ax.legend() | |
| return fig | |
| def _build_true_dataset( | |
| self, | |
| dataset_type: Literal["Generate", "CSV"], | |
| function: str, | |
| xmin: float, | |
| xmax: float, | |
| ) -> Dataset: | |
| if dataset_type == "CSV": | |
| return Dataset(x=[], y=[]) | |
| try: | |
| parsed_function = sympify(function) | |
| except Exception as exc: | |
| raise ValueError(f"Invalid function: {exc}") from exc | |
| return generate_dataset( | |
| parsed_function, | |
| (xmin, xmax), | |
| DataGenerationOptions( | |
| method="grid", | |
| num_samples=1000, | |
| noise=0.0, | |
| ), | |
| ) | |
| def _normalize_csv_path(self, csv_path: str | Path | None) -> str | None: | |
| if csv_path is None: | |
| return None | |
| if isinstance(csv_path, Path): | |
| return str(csv_path) | |
| if isinstance(csv_path, str): | |
| return csv_path | |
| name = getattr(csv_path, "name", None) | |
| if name: | |
| return str(name) | |
| return None | |