from __future__ import annotations import matplotlib.lines as mlines import matplotlib.pyplot as plt import numpy as np from matplotlib.figure import Figure from sympy import sympify, symbols, sin, cos, exp from sympy.parsing.sympy_parser import ( standard_transformations, implicit_multiplication_application, parse_expr, ) from logic import ( DataGenerationOptions, Dataset, PlotsData, compute_plot_values, compute_suggested_settings, generate_dataset, load_dataset_from_csv, ) class Manager: def __init__(self, dataset: Dataset | None = None, plots_data: PlotsData | None = None): self.dataset = dataset self.plots_data = plots_data def update_dataset( self, dataset_type: str, function: str, x1_range_input: str, x2_range_input: str, x_selection_method: str, sigma: float, nsample: int, csv_file: str, has_header: bool, x1_col: int, x2_col: int, y_col: int, ) -> None: dataset = self._compute_dataset( dataset_type, function, x1_range_input, x2_range_input, x_selection_method, sigma, nsample, csv_file, has_header, x1_col, x2_col, y_col, ) if len(dataset.x1) == 0: raise ValueError("Dataset cannot be empty") elif len(dataset.x1) == 1: # todo - remove this condition after fixing weird cases raise ValueError("Dataset must contain at least 2 points") self.dataset = dataset def _compute_dataset( self, dataset_type: str, function: str, x1_range_input: str, x2_range_input: str, x_selection_method: str, sigma: float, nsample: int, csv_file: str, has_header: bool, x1_col: int, x2_col: int, y_col: int, ) -> Dataset: if dataset_type == "Generate": x1, x2 = symbols("x1 x2") allowed_locals = { "x1": x1, "x2": x2, "sin": sin, "cos": cos, "exp": exp, } if not function.strip(): raise ValueError("Function cannot be empty") try: parsed_function = parse_expr( function, local_dict=allowed_locals, transformations=standard_transformations + (implicit_multiplication_application,), evaluate=True, ) except Exception as e: raise ValueError(f"Invalid function: {e}") unknown_symbols = parsed_function.free_symbols - {x1, x2} if unknown_symbols: unknown_names = ", ".join(sorted(str(s) for s in unknown_symbols)) raise ValueError(f"Unknown variable(s): {unknown_names}. Allowed: x1, x2") if not x1_range_input.strip(): raise ValueError("x1 range cannot be empty") if not x2_range_input.strip(): raise ValueError("x2 range cannot be empty") try: x1_range = self._parse_range(x1_range_input) except Exception as e: raise ValueError(f"Invalid x1 range: {e}") try: x2_range = self._parse_range(x2_range_input) except Exception as e: raise ValueError(f"Invalid x2 range: {e}") method = x_selection_method.lower() if method not in ("grid", "random"): raise ValueError(f"Invalid x_selection_method: {x_selection_method}") return generate_dataset( parsed_function, x1_range, x2_range, DataGenerationOptions(method, int(nsample), float(sigma)), ) elif dataset_type == "CSV": csv_path = self._resolve_csv_path(csv_file) try: return load_dataset_from_csv( csv_path, bool(has_header), int(x1_col), int(x2_col), int(y_col), ) except Exception as e: raise ValueError(f"Failed to load dataset from CSV: {e}") else: raise ValueError(f"Invalid dataset_type: {dataset_type}") def compute_plots_data( self, loss_type: str, regularizer_type: str, resolution: int, ) -> None: if self.dataset is None: raise ValueError("Dataset is not initialized") if loss_type not in ("l1", "l2"): raise ValueError(f"Invalid loss_type: {loss_type}") if regularizer_type not in ("l1", "l2"): raise ValueError(f"Invalid regularizer_type: {regularizer_type}") w1_range, w2_range, reg_levels = compute_suggested_settings(self.dataset) self.plots_data = compute_plot_values( self.dataset, loss_type, regularizer_type, reg_levels, w1_range, w2_range, int(resolution), ) def handle_generate_plots( self, dataset_type: str, function: str, x1_range_input: str, x2_range_input: str, x_selection_method: str, sigma: float, nsample: int, csv_file: str, has_header: bool, x1_col: int, x2_col: int, y_col: int, loss_type: str, regularizer_type: str, resolution: int, ) -> tuple[Manager, Figure, Figure, Figure]: self.update_dataset( dataset_type, function, x1_range_input, x2_range_input, x_selection_method, sigma, nsample, csv_file, has_header, x1_col, x2_col, y_col, ) self.compute_plots_data( loss_type, regularizer_type, resolution, ) if self.dataset is None or self.plots_data is None: raise ValueError("Failed to generate plot data") contour_plot = self._generate_contour_plot(self.plots_data) data_plot = self._generate_data_plot(self.dataset) strength_plot = self._generate_strength_plot(self.plots_data.path) return self, contour_plot, data_plot, strength_plot @staticmethod def _generate_contour_plot(plots_data: PlotsData) -> Figure: fig, ax = plt.subplots(figsize=(8, 8)) ax.set_xlabel("w1") ax.set_ylabel("w2") cmap = plt.get_cmap("viridis") n_levels = len(plots_data.reg_levels) if n_levels == 1: colors = [cmap(0.5)] else: colors = [cmap(i / (n_levels - 1)) for i in range(n_levels)] cs1 = ax.contour( plots_data.W1, plots_data.W2, plots_data.norms, levels=plots_data.reg_levels, colors=colors, linestyles="dashed", ) ax.clabel(cs1, inline=True, fontsize=8) cs2 = ax.contour( plots_data.W1, plots_data.W2, plots_data.loss_values, levels=plots_data.loss_levels, colors=colors[::-1], ) ax.clabel(cs2, inline=True, fontsize=8) if plots_data.unreg_solution.ndim == 1: ax.plot( plots_data.unreg_solution[0], plots_data.unreg_solution[1], "bx", markersize=5, label="unregularized solution", ) else: ax.plot( plots_data.unreg_solution[:, 0], plots_data.unreg_solution[:, 1], "b-", label="unregularized solution", ) ax.plot(plots_data.path[:, 0], plots_data.path[:, 1], "r-", label="regularization path") handles = [ mlines.Line2D([], [], color="black", linestyle="-", label="loss"), mlines.Line2D([], [], color="black", linestyle="--", label="regularization"), mlines.Line2D([], [], color="red", linestyle="-", label="regularization path"), ] if plots_data.unreg_solution.ndim == 1: handles.append( mlines.Line2D([], [], color="blue", marker="x", linestyle="None", label="unregularized solution") ) else: handles.append(mlines.Line2D([], [], color="blue", linestyle="-", label="unregularized solution")) ax.legend(handles=handles) ax.grid(True) return fig @staticmethod def _generate_data_plot(dataset: Dataset) -> Figure: fig, ax = plt.subplots(figsize=(8, 8)) ax.set_xlabel("x1") ax.set_ylabel("x2") scatter = ax.scatter(dataset.x1, dataset.x2, c=dataset.y, cmap="viridis") ax.grid(True) fig.colorbar(scatter, ax=ax) return fig @staticmethod def _generate_strength_plot(path: np.ndarray) -> Figure: reg_levels = np.logspace(-4, 4, path.shape[0]) fig, ax = plt.subplots(figsize=(8, 6)) ax.set_xlabel("Regularization Strength") ax.set_ylabel("Weight") ax.plot(reg_levels, path[:, 0], "r-", label="w1") ax.plot(reg_levels, path[:, 1], "b-", label="w2") ax.set_xscale("log") ax.legend() ax.grid(True) return fig @staticmethod def _parse_range(range_input: str) -> tuple[float, float]: values = tuple(x.strip() for x in range_input.split(",")) if len(values) != 2: raise ValueError("Range must contain exactly two comma-separated values") low = values[0] high = values[1] if low == "": raise ValueError("Range lower bound cannot be empty") if high == "": raise ValueError("Range upper bound cannot be empty") try: low = float(low) high = float(high) except ValueError: raise ValueError("Range values must be valid numbers") if low >= high: raise ValueError("Range lower bound must be less than upper bound") return low, high @staticmethod def _parse_levels(levels_input: str) -> list[float]: values = [x.strip() for x in levels_input.split(",")] if not values or all(x == "" for x in values): raise ValueError("At least one regularization level is required") if any(x == "" for x in values): raise ValueError("Regularization levels cannot contain empty values") try: values = [float(x) for x in values] except ValueError: raise ValueError("Level values must be valid numbers") return values @staticmethod def _resolve_csv_path(csv_file: str) -> str: if csv_file is None: raise ValueError("CSV file is required") if isinstance(csv_file, str): return csv_file if isinstance(csv_file, dict) and "name" in csv_file: return csv_file["name"] if hasattr(csv_file, "name"): return csv_file.name raise ValueError("Unsupported CSV file input")