regularization / backend /src /manager.py
joel-woodfield's picture
Automatically set the suggested settings
19aa7d3
from __future__ import annotations
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.figure import Figure
from sympy import sympify, symbols, sin, cos, exp
from sympy.parsing.sympy_parser import (
standard_transformations,
implicit_multiplication_application,
parse_expr,
)
from logic import (
DataGenerationOptions,
Dataset,
PlotsData,
compute_plot_values,
compute_suggested_settings,
generate_dataset,
load_dataset_from_csv,
)
class Manager:
def __init__(self, dataset: Dataset | None = None, plots_data: PlotsData | None = None):
self.dataset = dataset
self.plots_data = plots_data
def update_dataset(
self,
dataset_type: str,
function: str,
x1_range_input: str,
x2_range_input: str,
x_selection_method: str,
sigma: float,
nsample: int,
csv_file: str,
has_header: bool,
x1_col: int,
x2_col: int,
y_col: int,
) -> None:
dataset = self._compute_dataset(
dataset_type,
function,
x1_range_input,
x2_range_input,
x_selection_method,
sigma,
nsample,
csv_file,
has_header,
x1_col,
x2_col,
y_col,
)
if len(dataset.x1) == 0:
raise ValueError("Dataset cannot be empty")
elif len(dataset.x1) == 1:
# todo - remove this condition after fixing weird cases
raise ValueError("Dataset must contain at least 2 points")
self.dataset = dataset
def _compute_dataset(
self,
dataset_type: str,
function: str,
x1_range_input: str,
x2_range_input: str,
x_selection_method: str,
sigma: float,
nsample: int,
csv_file: str,
has_header: bool,
x1_col: int,
x2_col: int,
y_col: int,
) -> Dataset:
if dataset_type == "Generate":
x1, x2 = symbols("x1 x2")
allowed_locals = {
"x1": x1,
"x2": x2,
"sin": sin,
"cos": cos,
"exp": exp,
}
if not function.strip():
raise ValueError("Function cannot be empty")
try:
parsed_function = parse_expr(
function,
local_dict=allowed_locals,
transformations=standard_transformations + (implicit_multiplication_application,),
evaluate=True,
)
except Exception as e:
raise ValueError(f"Invalid function: {e}")
unknown_symbols = parsed_function.free_symbols - {x1, x2}
if unknown_symbols:
unknown_names = ", ".join(sorted(str(s) for s in unknown_symbols))
raise ValueError(f"Unknown variable(s): {unknown_names}. Allowed: x1, x2")
if not x1_range_input.strip():
raise ValueError("x1 range cannot be empty")
if not x2_range_input.strip():
raise ValueError("x2 range cannot be empty")
try:
x1_range = self._parse_range(x1_range_input)
except Exception as e:
raise ValueError(f"Invalid x1 range: {e}")
try:
x2_range = self._parse_range(x2_range_input)
except Exception as e:
raise ValueError(f"Invalid x2 range: {e}")
method = x_selection_method.lower()
if method not in ("grid", "random"):
raise ValueError(f"Invalid x_selection_method: {x_selection_method}")
return generate_dataset(
parsed_function,
x1_range,
x2_range,
DataGenerationOptions(method, int(nsample), float(sigma)),
)
elif dataset_type == "CSV":
csv_path = self._resolve_csv_path(csv_file)
try:
return load_dataset_from_csv(
csv_path,
bool(has_header),
int(x1_col),
int(x2_col),
int(y_col),
)
except Exception as e:
raise ValueError(f"Failed to load dataset from CSV: {e}")
else:
raise ValueError(f"Invalid dataset_type: {dataset_type}")
def compute_plots_data(
self,
loss_type: str,
regularizer_type: str,
resolution: int,
) -> None:
if self.dataset is None:
raise ValueError("Dataset is not initialized")
if loss_type not in ("l1", "l2"):
raise ValueError(f"Invalid loss_type: {loss_type}")
if regularizer_type not in ("l1", "l2"):
raise ValueError(f"Invalid regularizer_type: {regularizer_type}")
w1_range, w2_range, reg_levels = compute_suggested_settings(self.dataset)
self.plots_data = compute_plot_values(
self.dataset,
loss_type,
regularizer_type,
reg_levels,
w1_range,
w2_range,
int(resolution),
)
def handle_generate_plots(
self,
dataset_type: str,
function: str,
x1_range_input: str,
x2_range_input: str,
x_selection_method: str,
sigma: float,
nsample: int,
csv_file: str,
has_header: bool,
x1_col: int,
x2_col: int,
y_col: int,
loss_type: str,
regularizer_type: str,
resolution: int,
) -> tuple[Manager, Figure, Figure, Figure]:
self.update_dataset(
dataset_type,
function,
x1_range_input,
x2_range_input,
x_selection_method,
sigma,
nsample,
csv_file,
has_header,
x1_col,
x2_col,
y_col,
)
self.compute_plots_data(
loss_type,
regularizer_type,
resolution,
)
if self.dataset is None or self.plots_data is None:
raise ValueError("Failed to generate plot data")
contour_plot = self._generate_contour_plot(self.plots_data)
data_plot = self._generate_data_plot(self.dataset)
strength_plot = self._generate_strength_plot(self.plots_data.path)
return self, contour_plot, data_plot, strength_plot
@staticmethod
def _generate_contour_plot(plots_data: PlotsData) -> Figure:
fig, ax = plt.subplots(figsize=(8, 8))
ax.set_xlabel("w1")
ax.set_ylabel("w2")
cmap = plt.get_cmap("viridis")
n_levels = len(plots_data.reg_levels)
if n_levels == 1:
colors = [cmap(0.5)]
else:
colors = [cmap(i / (n_levels - 1)) for i in range(n_levels)]
cs1 = ax.contour(
plots_data.W1,
plots_data.W2,
plots_data.norms,
levels=plots_data.reg_levels,
colors=colors,
linestyles="dashed",
)
ax.clabel(cs1, inline=True, fontsize=8)
cs2 = ax.contour(
plots_data.W1,
plots_data.W2,
plots_data.loss_values,
levels=plots_data.loss_levels,
colors=colors[::-1],
)
ax.clabel(cs2, inline=True, fontsize=8)
if plots_data.unreg_solution.ndim == 1:
ax.plot(
plots_data.unreg_solution[0],
plots_data.unreg_solution[1],
"bx",
markersize=5,
label="unregularized solution",
)
else:
ax.plot(
plots_data.unreg_solution[:, 0],
plots_data.unreg_solution[:, 1],
"b-",
label="unregularized solution",
)
ax.plot(plots_data.path[:, 0], plots_data.path[:, 1], "r-", label="regularization path")
handles = [
mlines.Line2D([], [], color="black", linestyle="-", label="loss"),
mlines.Line2D([], [], color="black", linestyle="--", label="regularization"),
mlines.Line2D([], [], color="red", linestyle="-", label="regularization path"),
]
if plots_data.unreg_solution.ndim == 1:
handles.append(
mlines.Line2D([], [], color="blue", marker="x", linestyle="None", label="unregularized solution")
)
else:
handles.append(mlines.Line2D([], [], color="blue", linestyle="-", label="unregularized solution"))
ax.legend(handles=handles)
ax.grid(True)
return fig
@staticmethod
def _generate_data_plot(dataset: Dataset) -> Figure:
fig, ax = plt.subplots(figsize=(8, 8))
ax.set_xlabel("x1")
ax.set_ylabel("x2")
scatter = ax.scatter(dataset.x1, dataset.x2, c=dataset.y, cmap="viridis")
ax.grid(True)
fig.colorbar(scatter, ax=ax)
return fig
@staticmethod
def _generate_strength_plot(path: np.ndarray) -> Figure:
reg_levels = np.logspace(-4, 4, path.shape[0])
fig, ax = plt.subplots(figsize=(8, 6))
ax.set_xlabel("Regularization Strength")
ax.set_ylabel("Weight")
ax.plot(reg_levels, path[:, 0], "r-", label="w1")
ax.plot(reg_levels, path[:, 1], "b-", label="w2")
ax.set_xscale("log")
ax.legend()
ax.grid(True)
return fig
@staticmethod
def _parse_range(range_input: str) -> tuple[float, float]:
values = tuple(x.strip() for x in range_input.split(","))
if len(values) != 2:
raise ValueError("Range must contain exactly two comma-separated values")
low = values[0]
high = values[1]
if low == "":
raise ValueError("Range lower bound cannot be empty")
if high == "":
raise ValueError("Range upper bound cannot be empty")
try:
low = float(low)
high = float(high)
except ValueError:
raise ValueError("Range values must be valid numbers")
if low >= high:
raise ValueError("Range lower bound must be less than upper bound")
return low, high
@staticmethod
def _parse_levels(levels_input: str) -> list[float]:
values = [x.strip() for x in levels_input.split(",")]
if not values or all(x == "" for x in values):
raise ValueError("At least one regularization level is required")
if any(x == "" for x in values):
raise ValueError("Regularization levels cannot contain empty values")
try:
values = [float(x) for x in values]
except ValueError:
raise ValueError("Level values must be valid numbers")
return values
@staticmethod
def _resolve_csv_path(csv_file: str) -> str:
if csv_file is None:
raise ValueError("CSV file is required")
if isinstance(csv_file, str):
return csv_file
if isinstance(csv_file, dict) and "name" in csv_file:
return csv_file["name"]
if hasattr(csv_file, "name"):
return csv_file.name
raise ValueError("Unsupported CSV file input")