from dataclasses import dataclass from typing import Any, Literal, cast import torch.nn as nn import torch.optim as optim from sympy import sympify from logic import * @dataclass class DatasetOptions: dataset_type: str function: str xmin: float xmax: float sigma: float nsample: int sample_method: str csv_path: str has_header: bool xcol: int ycol: int @dataclass class OptimizerOptions: optimizer_type: str learning_rate: float | None beta1: float | None beta2: float | None momentum: float | None weight_decay: float | None batch_size: int | None class Manager: def __init__(self) -> None: self._dataset: Dataset | None = None self._true_dataset: Dataset | None = None self._architecture: str | None = None self._model: nn.Module | None = None self._optimizer: optim.Optimizer | None = None self._optimizer_options: OptimizerOptions | None = None self._batch_size: int | None = None def handle_set_dataset(self, options: dict[str, Any]) -> PlotData: if self._model is not None: self.handle_reset_model() try: parsed_options = DatasetOptions(**options) except TypeError as e: raise ValueError(f"Invalid dataset options: {e}") if parsed_options.dataset_type == "Generate": try: function_expr = sympify(parsed_options.function) except Exception as e: raise ValueError(f"Invalid function expression: {e}") if parsed_options.sample_method not in ["Grid", "Random"]: raise ValueError(f"Invalid sample method: {parsed_options.sample_method}") parsed_options.sample_method = cast(Literal["Grid", "Random"], parsed_options.sample_method) dataset = generate_dataset( function_expr, (parsed_options.xmin, parsed_options.xmax), DataGenerationOptions( parsed_options.sample_method, parsed_options.nsample, parsed_options.sigma, ), ) true_dataset = generate_dataset( function_expr, (parsed_options.xmin - 1, parsed_options.xmax + 1), DataGenerationOptions( "Grid", 1000, 0.0, ), ) elif parsed_options.dataset_type == "CSV": dataset = load_dataset_from_csv( parsed_options.csv_path, parsed_options.has_header, parsed_options.xcol, parsed_options.ycol, ) true_dataset = Dataset(x=[], y=[]) else: raise ValueError(f"Unknown dataset type: {parsed_options.dataset_type}") self._dataset = dataset self._true_dataset = true_dataset return self.get_plot_data() def handle_set_architecture(self, architecture_str: str) -> PlotData: self._architecture = architecture_str self._model = build_model_from_architecture(architecture_str) # important! must reset optimizer if self._optimizer_options is not None: self._optimizer = self._build_optimizer() return self.get_plot_data() def handle_set_optimizer(self, options: dict[str, Any]) -> PlotData: try: parsed_options = OptimizerOptions(**options) except TypeError as e: raise ValueError(f"Invalid optimizer options: {e}") self._optimizer_options = parsed_options if self._model is None: raise ValueError("Model must be set before configuring the optimizer.") self._optimizer = self._build_optimizer() self._batch_size = self._optimizer_options.batch_size or 32 return self.get_plot_data() def _build_optimizer(self) -> optim.Optimizer: if self._model is None: raise ValueError("Model must be set before configuring the optimizer.") if self._optimizer_options is None: raise ValueError("Optimizer options must be set before configuring the optimizer.") options = self._optimizer_options if options.optimizer_type == "SGD": return optim.SGD( self._model.parameters(), lr=options.learning_rate or 0.01, momentum=options.momentum or 0.0, weight_decay=options.weight_decay or 0.0, ) elif options.optimizer_type == "Adam": return optim.Adam( self._model.parameters(), lr=options.learning_rate or 0.001, betas=( options.beta1 or 0.9, options.beta2 or 0.999, ), weight_decay=options.weight_decay or 0.0, ) else: raise ValueError(f"Unknown optimizer type: {options.optimizer_type}") def handle_train_step(self, num_steps: int = 1) -> PlotData: if self._model is None or self._optimizer is None or self._dataset is None: raise ValueError("Model, optimizer, and dataset must be set before training.") train_step( self._model, self._optimizer, self._dataset, batch_size=self._batch_size or 32, num_steps=num_steps, ) return self.get_plot_data() def handle_reset_model(self) -> PlotData: if self._architecture is None: raise ValueError("Architecture must be set before resetting the model.") self._model = build_model_from_architecture(self._architecture) self._optimizer = self._build_optimizer() return self.get_plot_data() def get_plot_data(self) -> PlotData: if self._dataset is None: dataset = Dataset(x=[], y=[]) else: dataset = self._dataset if self._true_dataset is None: test_dataset = Dataset(x=[], y=[]) else: test_dataset = self._true_dataset if test_dataset.x and test_dataset.y and self._model is not None: test_predictions = generate_test_predictions( test_dataset, self._model, ) else: test_predictions = None return PlotData( dataset=dataset, test_dataset=test_dataset, test_predictions=test_predictions, )