Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import matplotlib.lines as mlines | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from matplotlib.figure import Figure | |
| from sympy import sympify, symbols, sin, cos, exp | |
| from sympy.parsing.sympy_parser import ( | |
| standard_transformations, | |
| implicit_multiplication_application, | |
| parse_expr, | |
| ) | |
| from logic import ( | |
| DataGenerationOptions, | |
| Dataset, | |
| PlotsData, | |
| compute_plot_values, | |
| compute_suggested_settings, | |
| generate_dataset, | |
| load_dataset_from_csv, | |
| ) | |
| class Manager: | |
| def __init__(self, dataset: Dataset | None = None, plots_data: PlotsData | None = None): | |
| self.dataset = dataset | |
| self.plots_data = plots_data | |
| def update_dataset( | |
| self, | |
| dataset_type: str, | |
| function: str, | |
| x1_range_input: str, | |
| x2_range_input: str, | |
| x_selection_method: str, | |
| sigma: float, | |
| nsample: int, | |
| csv_file: str, | |
| has_header: bool, | |
| x1_col: int, | |
| x2_col: int, | |
| y_col: int, | |
| ) -> None: | |
| dataset = self._compute_dataset( | |
| dataset_type, | |
| function, | |
| x1_range_input, | |
| x2_range_input, | |
| x_selection_method, | |
| sigma, | |
| nsample, | |
| csv_file, | |
| has_header, | |
| x1_col, | |
| x2_col, | |
| y_col, | |
| ) | |
| if len(dataset.x1) == 0: | |
| raise ValueError("Dataset cannot be empty") | |
| elif len(dataset.x1) == 1: | |
| # todo - remove this condition after fixing weird cases | |
| raise ValueError("Dataset must contain at least 2 points") | |
| self.dataset = dataset | |
| def _compute_dataset( | |
| self, | |
| dataset_type: str, | |
| function: str, | |
| x1_range_input: str, | |
| x2_range_input: str, | |
| x_selection_method: str, | |
| sigma: float, | |
| nsample: int, | |
| csv_file: str, | |
| has_header: bool, | |
| x1_col: int, | |
| x2_col: int, | |
| y_col: int, | |
| ) -> Dataset: | |
| if dataset_type == "Generate": | |
| x1, x2 = symbols("x1 x2") | |
| allowed_locals = { | |
| "x1": x1, | |
| "x2": x2, | |
| "sin": sin, | |
| "cos": cos, | |
| "exp": exp, | |
| } | |
| if not function.strip(): | |
| raise ValueError("Function cannot be empty") | |
| try: | |
| parsed_function = parse_expr( | |
| function, | |
| local_dict=allowed_locals, | |
| transformations=standard_transformations + (implicit_multiplication_application,), | |
| evaluate=True, | |
| ) | |
| except Exception as e: | |
| raise ValueError(f"Invalid function: {e}") | |
| unknown_symbols = parsed_function.free_symbols - {x1, x2} | |
| if unknown_symbols: | |
| unknown_names = ", ".join(sorted(str(s) for s in unknown_symbols)) | |
| raise ValueError(f"Unknown variable(s): {unknown_names}. Allowed: x1, x2") | |
| if not x1_range_input.strip(): | |
| raise ValueError("x1 range cannot be empty") | |
| if not x2_range_input.strip(): | |
| raise ValueError("x2 range cannot be empty") | |
| try: | |
| x1_range = self._parse_range(x1_range_input) | |
| except Exception as e: | |
| raise ValueError(f"Invalid x1 range: {e}") | |
| try: | |
| x2_range = self._parse_range(x2_range_input) | |
| except Exception as e: | |
| raise ValueError(f"Invalid x2 range: {e}") | |
| method = x_selection_method.lower() | |
| if method not in ("grid", "random"): | |
| raise ValueError(f"Invalid x_selection_method: {x_selection_method}") | |
| return generate_dataset( | |
| parsed_function, | |
| x1_range, | |
| x2_range, | |
| DataGenerationOptions(method, int(nsample), float(sigma)), | |
| ) | |
| elif dataset_type == "CSV": | |
| csv_path = self._resolve_csv_path(csv_file) | |
| try: | |
| return load_dataset_from_csv( | |
| csv_path, | |
| bool(has_header), | |
| int(x1_col), | |
| int(x2_col), | |
| int(y_col), | |
| ) | |
| except Exception as e: | |
| raise ValueError(f"Failed to load dataset from CSV: {e}") | |
| else: | |
| raise ValueError(f"Invalid dataset_type: {dataset_type}") | |
| def compute_plots_data( | |
| self, | |
| loss_type: str, | |
| regularizer_type: str, | |
| resolution: int, | |
| ) -> None: | |
| if self.dataset is None: | |
| raise ValueError("Dataset is not initialized") | |
| if loss_type not in ("l1", "l2"): | |
| raise ValueError(f"Invalid loss_type: {loss_type}") | |
| if regularizer_type not in ("l1", "l2"): | |
| raise ValueError(f"Invalid regularizer_type: {regularizer_type}") | |
| w1_range, w2_range, reg_levels = compute_suggested_settings(self.dataset) | |
| self.plots_data = compute_plot_values( | |
| self.dataset, | |
| loss_type, | |
| regularizer_type, | |
| reg_levels, | |
| w1_range, | |
| w2_range, | |
| int(resolution), | |
| ) | |
| def handle_generate_plots( | |
| self, | |
| dataset_type: str, | |
| function: str, | |
| x1_range_input: str, | |
| x2_range_input: str, | |
| x_selection_method: str, | |
| sigma: float, | |
| nsample: int, | |
| csv_file: str, | |
| has_header: bool, | |
| x1_col: int, | |
| x2_col: int, | |
| y_col: int, | |
| loss_type: str, | |
| regularizer_type: str, | |
| resolution: int, | |
| ) -> tuple[Manager, Figure, Figure, Figure]: | |
| self.update_dataset( | |
| dataset_type, | |
| function, | |
| x1_range_input, | |
| x2_range_input, | |
| x_selection_method, | |
| sigma, | |
| nsample, | |
| csv_file, | |
| has_header, | |
| x1_col, | |
| x2_col, | |
| y_col, | |
| ) | |
| self.compute_plots_data( | |
| loss_type, | |
| regularizer_type, | |
| resolution, | |
| ) | |
| if self.dataset is None or self.plots_data is None: | |
| raise ValueError("Failed to generate plot data") | |
| contour_plot = self._generate_contour_plot(self.plots_data) | |
| data_plot = self._generate_data_plot(self.dataset) | |
| strength_plot = self._generate_strength_plot(self.plots_data.path) | |
| return self, contour_plot, data_plot, strength_plot | |
| def _generate_contour_plot(plots_data: PlotsData) -> Figure: | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| ax.set_xlabel("w1") | |
| ax.set_ylabel("w2") | |
| cmap = plt.get_cmap("viridis") | |
| n_levels = len(plots_data.reg_levels) | |
| if n_levels == 1: | |
| colors = [cmap(0.5)] | |
| else: | |
| colors = [cmap(i / (n_levels - 1)) for i in range(n_levels)] | |
| cs1 = ax.contour( | |
| plots_data.W1, | |
| plots_data.W2, | |
| plots_data.norms, | |
| levels=plots_data.reg_levels, | |
| colors=colors, | |
| linestyles="dashed", | |
| ) | |
| ax.clabel(cs1, inline=True, fontsize=8) | |
| cs2 = ax.contour( | |
| plots_data.W1, | |
| plots_data.W2, | |
| plots_data.loss_values, | |
| levels=plots_data.loss_levels, | |
| colors=colors[::-1], | |
| ) | |
| ax.clabel(cs2, inline=True, fontsize=8) | |
| if plots_data.unreg_solution.ndim == 1: | |
| ax.plot( | |
| plots_data.unreg_solution[0], | |
| plots_data.unreg_solution[1], | |
| "bx", | |
| markersize=5, | |
| label="unregularized solution", | |
| ) | |
| else: | |
| ax.plot( | |
| plots_data.unreg_solution[:, 0], | |
| plots_data.unreg_solution[:, 1], | |
| "b-", | |
| label="unregularized solution", | |
| ) | |
| ax.plot(plots_data.path[:, 0], plots_data.path[:, 1], "r-", label="regularization path") | |
| handles = [ | |
| mlines.Line2D([], [], color="black", linestyle="-", label="loss"), | |
| mlines.Line2D([], [], color="black", linestyle="--", label="regularization"), | |
| mlines.Line2D([], [], color="red", linestyle="-", label="regularization path"), | |
| ] | |
| if plots_data.unreg_solution.ndim == 1: | |
| handles.append( | |
| mlines.Line2D([], [], color="blue", marker="x", linestyle="None", label="unregularized solution") | |
| ) | |
| else: | |
| handles.append(mlines.Line2D([], [], color="blue", linestyle="-", label="unregularized solution")) | |
| ax.legend(handles=handles) | |
| ax.grid(True) | |
| return fig | |
| def _generate_data_plot(dataset: Dataset) -> Figure: | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| ax.set_xlabel("x1") | |
| ax.set_ylabel("x2") | |
| scatter = ax.scatter(dataset.x1, dataset.x2, c=dataset.y, cmap="viridis") | |
| ax.grid(True) | |
| fig.colorbar(scatter, ax=ax) | |
| return fig | |
| def _generate_strength_plot(path: np.ndarray) -> Figure: | |
| reg_levels = np.logspace(-4, 4, path.shape[0]) | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| ax.set_xlabel("Regularization Strength") | |
| ax.set_ylabel("Weight") | |
| ax.plot(reg_levels, path[:, 0], "r-", label="w1") | |
| ax.plot(reg_levels, path[:, 1], "b-", label="w2") | |
| ax.set_xscale("log") | |
| ax.legend() | |
| ax.grid(True) | |
| return fig | |
| def _parse_range(range_input: str) -> tuple[float, float]: | |
| values = tuple(x.strip() for x in range_input.split(",")) | |
| if len(values) != 2: | |
| raise ValueError("Range must contain exactly two comma-separated values") | |
| low = values[0] | |
| high = values[1] | |
| if low == "": | |
| raise ValueError("Range lower bound cannot be empty") | |
| if high == "": | |
| raise ValueError("Range upper bound cannot be empty") | |
| try: | |
| low = float(low) | |
| high = float(high) | |
| except ValueError: | |
| raise ValueError("Range values must be valid numbers") | |
| if low >= high: | |
| raise ValueError("Range lower bound must be less than upper bound") | |
| return low, high | |
| def _parse_levels(levels_input: str) -> list[float]: | |
| values = [x.strip() for x in levels_input.split(",")] | |
| if not values or all(x == "" for x in values): | |
| raise ValueError("At least one regularization level is required") | |
| if any(x == "" for x in values): | |
| raise ValueError("Regularization levels cannot contain empty values") | |
| try: | |
| values = [float(x) for x in values] | |
| except ValueError: | |
| raise ValueError("Level values must be valid numbers") | |
| return values | |
| def _resolve_csv_path(csv_file: str) -> str: | |
| if csv_file is None: | |
| raise ValueError("CSV file is required") | |
| if isinstance(csv_file, str): | |
| return csv_file | |
| if isinstance(csv_file, dict) and "name" in csv_file: | |
| return csv_file["name"] | |
| if hasattr(csv_file, "name"): | |
| return csv_file.name | |
| raise ValueError("Unsupported CSV file input") | |