Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| from typing import Any, Literal, cast | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from sympy import sympify | |
| from logic import * | |
| class DatasetOptions: | |
| dataset_type: str | |
| function: str | |
| xmin: float | |
| xmax: float | |
| sigma: float | |
| nsample: int | |
| sample_method: str | |
| csv_path: str | |
| has_header: bool | |
| xcol: int | |
| ycol: int | |
| class OptimizerOptions: | |
| optimizer_type: str | |
| learning_rate: float | None | |
| beta1: float | None | |
| beta2: float | None | |
| momentum: float | None | |
| weight_decay: float | None | |
| batch_size: int | None | |
| class Manager: | |
| def __init__(self) -> None: | |
| self._dataset: Dataset | None = None | |
| self._true_dataset: Dataset | None = None | |
| self._architecture: str | None = None | |
| self._model: nn.Module | None = None | |
| self._optimizer: optim.Optimizer | None = None | |
| self._optimizer_options: OptimizerOptions | None = None | |
| self._batch_size: int | None = None | |
| def handle_set_dataset(self, options: dict[str, Any]) -> PlotData: | |
| if self._model is not None: | |
| self.handle_reset_model() | |
| try: | |
| parsed_options = DatasetOptions(**options) | |
| except TypeError as e: | |
| raise ValueError(f"Invalid dataset options: {e}") | |
| if parsed_options.dataset_type == "Generate": | |
| try: | |
| function_expr = sympify(parsed_options.function) | |
| except Exception as e: | |
| raise ValueError(f"Invalid function expression: {e}") | |
| if parsed_options.sample_method not in ["Grid", "Random"]: | |
| raise ValueError(f"Invalid sample method: {parsed_options.sample_method}") | |
| parsed_options.sample_method = cast(Literal["Grid", "Random"], parsed_options.sample_method) | |
| dataset = generate_dataset( | |
| function_expr, | |
| (parsed_options.xmin, parsed_options.xmax), | |
| DataGenerationOptions( | |
| parsed_options.sample_method, | |
| parsed_options.nsample, | |
| parsed_options.sigma, | |
| ), | |
| ) | |
| true_dataset = generate_dataset( | |
| function_expr, | |
| (parsed_options.xmin - 1, parsed_options.xmax + 1), | |
| DataGenerationOptions( | |
| "Grid", | |
| 1000, | |
| 0.0, | |
| ), | |
| ) | |
| elif parsed_options.dataset_type == "CSV": | |
| dataset = load_dataset_from_csv( | |
| parsed_options.csv_path, | |
| parsed_options.has_header, | |
| parsed_options.xcol, | |
| parsed_options.ycol, | |
| ) | |
| true_dataset = Dataset(x=[], y=[]) | |
| else: | |
| raise ValueError(f"Unknown dataset type: {parsed_options.dataset_type}") | |
| self._dataset = dataset | |
| self._true_dataset = true_dataset | |
| return self.get_plot_data() | |
| def handle_set_architecture(self, architecture_str: str) -> PlotData: | |
| self._architecture = architecture_str | |
| self._model = build_model_from_architecture(architecture_str) | |
| # important! must reset optimizer | |
| if self._optimizer_options is not None: | |
| self._optimizer = self._build_optimizer() | |
| return self.get_plot_data() | |
| def handle_set_optimizer(self, options: dict[str, Any]) -> PlotData: | |
| try: | |
| parsed_options = OptimizerOptions(**options) | |
| except TypeError as e: | |
| raise ValueError(f"Invalid optimizer options: {e}") | |
| self._optimizer_options = parsed_options | |
| if self._model is None: | |
| raise ValueError("Model must be set before configuring the optimizer.") | |
| self._optimizer = self._build_optimizer() | |
| self._batch_size = self._optimizer_options.batch_size or 32 | |
| return self.get_plot_data() | |
| def _build_optimizer(self) -> optim.Optimizer: | |
| if self._model is None: | |
| raise ValueError("Model must be set before configuring the optimizer.") | |
| if self._optimizer_options is None: | |
| raise ValueError("Optimizer options must be set before configuring the optimizer.") | |
| options = self._optimizer_options | |
| if options.optimizer_type == "SGD": | |
| return optim.SGD( | |
| self._model.parameters(), | |
| lr=options.learning_rate or 0.01, | |
| momentum=options.momentum or 0.0, | |
| weight_decay=options.weight_decay or 0.0, | |
| ) | |
| elif options.optimizer_type == "Adam": | |
| return optim.Adam( | |
| self._model.parameters(), | |
| lr=options.learning_rate or 0.001, | |
| betas=( | |
| options.beta1 or 0.9, | |
| options.beta2 or 0.999, | |
| ), | |
| weight_decay=options.weight_decay or 0.0, | |
| ) | |
| else: | |
| raise ValueError(f"Unknown optimizer type: {options.optimizer_type}") | |
| def handle_train_step(self, num_steps: int = 1) -> PlotData: | |
| if self._model is None or self._optimizer is None or self._dataset is None: | |
| raise ValueError("Model, optimizer, and dataset must be set before training.") | |
| train_step( | |
| self._model, | |
| self._optimizer, | |
| self._dataset, | |
| batch_size=self._batch_size or 32, | |
| num_steps=num_steps, | |
| ) | |
| return self.get_plot_data() | |
| def handle_reset_model(self) -> PlotData: | |
| if self._architecture is None: | |
| raise ValueError("Architecture must be set before resetting the model.") | |
| self._model = build_model_from_architecture(self._architecture) | |
| self._optimizer = self._build_optimizer() | |
| return self.get_plot_data() | |
| def get_plot_data(self) -> PlotData: | |
| if self._dataset is None: | |
| dataset = Dataset(x=[], y=[]) | |
| else: | |
| dataset = self._dataset | |
| if self._true_dataset is None: | |
| test_dataset = Dataset(x=[], y=[]) | |
| else: | |
| test_dataset = self._true_dataset | |
| if test_dataset.x and test_dataset.y and self._model is not None: | |
| test_predictions = generate_test_predictions( | |
| test_dataset, | |
| self._model, | |
| ) | |
| else: | |
| test_predictions = None | |
| return PlotData( | |
| dataset=dataset, | |
| test_dataset=test_dataset, | |
| test_predictions=test_predictions, | |
| ) | |