File size: 11,462 Bytes
484caec
 
 
 
 
 
0c4e50c
 
 
 
 
 
484caec
 
 
 
 
 
85ac76b
484caec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b43e0ab
c5b45d7
 
 
 
 
 
 
 
 
 
 
 
 
 
b43e0ab
 
 
 
 
 
 
c5b45d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484caec
0c4e50c
 
 
 
 
 
 
 
 
 
 
484caec
0c4e50c
 
 
 
 
 
484caec
0c4e50c
 
 
 
 
 
484caec
97d2547
 
 
 
 
 
 
 
 
 
 
 
 
484caec
 
 
 
 
c5b45d7
484caec
 
 
 
 
 
 
 
 
c5b45d7
484caec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19aa7d3
484caec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19aa7d3
484caec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97d2547
 
484caec
 
97d2547
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484caec
 
 
97d2547
 
484caec
97d2547
 
 
 
 
 
 
 
 
484caec
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
from __future__ import annotations

import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.figure import Figure
from sympy import sympify, symbols, sin, cos, exp
from sympy.parsing.sympy_parser import (
  standard_transformations, 
  implicit_multiplication_application, 
  parse_expr,
)

from logic import (
    DataGenerationOptions,
    Dataset,
    PlotsData,
    compute_plot_values,
    compute_suggested_settings,
    generate_dataset,
    load_dataset_from_csv,
)


class Manager:
    def __init__(self, dataset: Dataset | None = None, plots_data: PlotsData | None = None):
        self.dataset = dataset
        self.plots_data = plots_data

    def update_dataset(
        self,
        dataset_type: str,
        function: str,
        x1_range_input: str,
        x2_range_input: str,
        x_selection_method: str,
        sigma: float,
        nsample: int,
        csv_file: str,
        has_header: bool,
        x1_col: int,
        x2_col: int,
        y_col: int,
    ) -> None:
        dataset = self._compute_dataset(
            dataset_type,
            function,
            x1_range_input,
            x2_range_input,
            x_selection_method,
            sigma,
            nsample,
            csv_file,
            has_header,
            x1_col,
            x2_col,
            y_col,
        )

        if len(dataset.x1) == 0:
            raise ValueError("Dataset cannot be empty")
        elif len(dataset.x1) == 1:
            # todo - remove this condition after fixing weird cases
            raise ValueError("Dataset must contain at least 2 points")
        self.dataset = dataset

    def _compute_dataset(
        self,
        dataset_type: str,
        function: str,
        x1_range_input: str,
        x2_range_input: str,
        x_selection_method: str,
        sigma: float,
        nsample: int,
        csv_file: str,
        has_header: bool,
        x1_col: int,
        x2_col: int,
        y_col: int,
    ) -> Dataset:
        if dataset_type == "Generate":
            x1, x2 = symbols("x1 x2")
            allowed_locals = {
                "x1": x1,
                "x2": x2,
                "sin": sin,
                "cos": cos,
                "exp": exp,
            }
            if not function.strip():
                raise ValueError("Function cannot be empty")

            try:
                parsed_function = parse_expr(
                    function, 
                    local_dict=allowed_locals,
                    transformations=standard_transformations + (implicit_multiplication_application,),
                    evaluate=True,
                )
            except Exception as e:
                raise ValueError(f"Invalid function: {e}")

            unknown_symbols = parsed_function.free_symbols - {x1, x2}
            if unknown_symbols:
                unknown_names = ", ".join(sorted(str(s) for s in unknown_symbols))
                raise ValueError(f"Unknown variable(s): {unknown_names}. Allowed: x1, x2")

            if not x1_range_input.strip():
                raise ValueError("x1 range cannot be empty")
            if not x2_range_input.strip():
                raise ValueError("x2 range cannot be empty")

            try:
                x1_range = self._parse_range(x1_range_input)
            except Exception as e:
                raise ValueError(f"Invalid x1 range: {e}")
            try:
                x2_range = self._parse_range(x2_range_input)
            except Exception as e:
                raise ValueError(f"Invalid x2 range: {e}")

            method = x_selection_method.lower()
            if method not in ("grid", "random"):
                raise ValueError(f"Invalid x_selection_method: {x_selection_method}")

            return generate_dataset(
                parsed_function,
                x1_range,
                x2_range,
                DataGenerationOptions(method, int(nsample), float(sigma)),
            )

        elif dataset_type == "CSV":
            csv_path = self._resolve_csv_path(csv_file)
            try:
                return load_dataset_from_csv(
                    csv_path,
                    bool(has_header),
                    int(x1_col),
                    int(x2_col),
                    int(y_col),
                )
            except Exception as e:
                raise ValueError(f"Failed to load dataset from CSV: {e}")

        else:
            raise ValueError(f"Invalid dataset_type: {dataset_type}")


    def compute_plots_data(
        self,
        loss_type: str,
        regularizer_type: str,
        resolution: int,
    ) -> None:
        if self.dataset is None:
            raise ValueError("Dataset is not initialized")

        if loss_type not in ("l1", "l2"):
            raise ValueError(f"Invalid loss_type: {loss_type}")
        if regularizer_type not in ("l1", "l2"):
            raise ValueError(f"Invalid regularizer_type: {regularizer_type}")

        w1_range, w2_range, reg_levels = compute_suggested_settings(self.dataset)

        self.plots_data = compute_plot_values(
            self.dataset,
            loss_type,
            regularizer_type,
            reg_levels,
            w1_range,
            w2_range,
            int(resolution),
        )

    def handle_generate_plots(
        self,
        dataset_type: str,
        function: str,
        x1_range_input: str,
        x2_range_input: str,
        x_selection_method: str,
        sigma: float,
        nsample: int,
        csv_file: str,
        has_header: bool,
        x1_col: int,
        x2_col: int,
        y_col: int,
        loss_type: str,
        regularizer_type: str,
        resolution: int,
    ) -> tuple[Manager, Figure, Figure, Figure]:
        self.update_dataset(
            dataset_type,
            function,
            x1_range_input,
            x2_range_input,
            x_selection_method,
            sigma,
            nsample,
            csv_file,
            has_header,
            x1_col,
            x2_col,
            y_col,
        )

        self.compute_plots_data(
            loss_type,
            regularizer_type,
            resolution,
        )

        if self.dataset is None or self.plots_data is None:
            raise ValueError("Failed to generate plot data")

        contour_plot = self._generate_contour_plot(self.plots_data)
        data_plot = self._generate_data_plot(self.dataset)
        strength_plot = self._generate_strength_plot(self.plots_data.path)
        return self, contour_plot, data_plot, strength_plot

    @staticmethod
    def _generate_contour_plot(plots_data: PlotsData) -> Figure:
        fig, ax = plt.subplots(figsize=(8, 8))
        ax.set_xlabel("w1")
        ax.set_ylabel("w2")

        cmap = plt.get_cmap("viridis")
        n_levels = len(plots_data.reg_levels)
        if n_levels == 1:
            colors = [cmap(0.5)]
        else:
            colors = [cmap(i / (n_levels - 1)) for i in range(n_levels)]

        cs1 = ax.contour(
            plots_data.W1,
            plots_data.W2,
            plots_data.norms,
            levels=plots_data.reg_levels,
            colors=colors,
            linestyles="dashed",
        )
        ax.clabel(cs1, inline=True, fontsize=8)

        cs2 = ax.contour(
            plots_data.W1,
            plots_data.W2,
            plots_data.loss_values,
            levels=plots_data.loss_levels,
            colors=colors[::-1],
        )
        ax.clabel(cs2, inline=True, fontsize=8)

        if plots_data.unreg_solution.ndim == 1:
            ax.plot(
                plots_data.unreg_solution[0],
                plots_data.unreg_solution[1],
                "bx",
                markersize=5,
                label="unregularized solution",
            )
        else:
            ax.plot(
                plots_data.unreg_solution[:, 0],
                plots_data.unreg_solution[:, 1],
                "b-",
                label="unregularized solution",
            )

        ax.plot(plots_data.path[:, 0], plots_data.path[:, 1], "r-", label="regularization path")

        handles = [
            mlines.Line2D([], [], color="black", linestyle="-", label="loss"),
            mlines.Line2D([], [], color="black", linestyle="--", label="regularization"),
            mlines.Line2D([], [], color="red", linestyle="-", label="regularization path"),
        ]
        if plots_data.unreg_solution.ndim == 1:
            handles.append(
                mlines.Line2D([], [], color="blue", marker="x", linestyle="None", label="unregularized solution")
            )
        else:
            handles.append(mlines.Line2D([], [], color="blue", linestyle="-", label="unregularized solution"))

        ax.legend(handles=handles)
        ax.grid(True)
        return fig

    @staticmethod
    def _generate_data_plot(dataset: Dataset) -> Figure:
        fig, ax = plt.subplots(figsize=(8, 8))
        ax.set_xlabel("x1")
        ax.set_ylabel("x2")

        scatter = ax.scatter(dataset.x1, dataset.x2, c=dataset.y, cmap="viridis")
        ax.grid(True)
        fig.colorbar(scatter, ax=ax)
        return fig

    @staticmethod
    def _generate_strength_plot(path: np.ndarray) -> Figure:
        reg_levels = np.logspace(-4, 4, path.shape[0])

        fig, ax = plt.subplots(figsize=(8, 6))
        ax.set_xlabel("Regularization Strength")
        ax.set_ylabel("Weight")

        ax.plot(reg_levels, path[:, 0], "r-", label="w1")
        ax.plot(reg_levels, path[:, 1], "b-", label="w2")
        ax.set_xscale("log")
        ax.legend()
        ax.grid(True)
        return fig

    @staticmethod
    def _parse_range(range_input: str) -> tuple[float, float]:
        values = tuple(x.strip() for x in range_input.split(","))

        if len(values) != 2:
            raise ValueError("Range must contain exactly two comma-separated values")

        low = values[0]
        high = values[1]

        if low == "":
            raise ValueError("Range lower bound cannot be empty")
        if high == "":
            raise ValueError("Range upper bound cannot be empty")

        try:
            low = float(low)
            high = float(high)
        except ValueError:
            raise ValueError("Range values must be valid numbers")

        if low >= high:
            raise ValueError("Range lower bound must be less than upper bound")

        return low, high

    @staticmethod
    def _parse_levels(levels_input: str) -> list[float]:
        values = [x.strip() for x in levels_input.split(",")]
        if not values or all(x == "" for x in values):
            raise ValueError("At least one regularization level is required")

        if any(x == "" for x in values):
            raise ValueError("Regularization levels cannot contain empty values")

        try:
            values = [float(x) for x in values]
        except ValueError:
            raise ValueError("Level values must be valid numbers")

        return values

    @staticmethod
    def _resolve_csv_path(csv_file: str) -> str:
        if csv_file is None:
            raise ValueError("CSV file is required")
        if isinstance(csv_file, str):
            return csv_file
        if isinstance(csv_file, dict) and "name" in csv_file:
            return csv_file["name"]
        if hasattr(csv_file, "name"):
            return csv_file.name
        raise ValueError("Unsupported CSV file input")