hkayabilisim commited on
Commit
e00a48f
·
1 Parent(s): c6ff61b

Added testing page

Browse files
Files changed (2) hide show
  1. .gitignore +10 -0
  2. agent/dashboard/testing.py +61 -0
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ agent/backend/__pycache__/data.cpython-310.pyc
2
+ agent/__pycache__/__init__.cpython-310.pyc
3
+ agent/backend/__pycache__/data.cpython-310.pyc
4
+ agent/backend/__pycache__/loss.cpython-310.pyc
5
+ agent/backend/__pycache__/models.cpython-310.pyc
6
+ agent/backend/__pycache__/utils.cpython-310.pyc
7
+ agent/dashboard/__pycache__/__init__.cpython-310.pyc
8
+ agent/dashboard/__pycache__/data.cpython-310.pyc
9
+ agent/dashboard/__pycache__/testing.cpython-310.pyc
10
+ agent/dashboard/__pycache__/training.cpython-310.pyc
agent/dashboard/testing.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import solara
2
+ import torch
3
+ import solara.express as solara_px
4
+ from .training import local_state as training_state
5
+ from .data import state as data_state
6
+ from ..backend.utils import predict
7
+
8
+
9
+
10
+ local_state = solara.reactive(
11
+ {
12
+ 'predictions': solara.reactive({}),
13
+ 'render_count': solara.reactive(0),
14
+ }
15
+ )
16
+
17
+ def force_render():
18
+ local_state.value['render_count'].set(1 + local_state.value['render_count'].value)
19
+
20
+ @solara.component
21
+ def ScatterPlot(predictions, render_count):
22
+ for col in predictions.keys():
23
+ with solara.Row():
24
+ solara_px.scatter(
25
+ predictions[col]['training'],
26
+ x = 'prediction',
27
+ y = 'target',
28
+ title=f'{col}'
29
+ )
30
+
31
+
32
+ @solara.component
33
+ def Page():
34
+ df = data_state.value['data']
35
+
36
+ filter, set_filter = solara.use_cross_filter(id(df))
37
+
38
+ dff = df
39
+ if filter is not None:
40
+ dff = df[filter]
41
+
42
+ def make_predictions():
43
+ model = training_state.value['model'].value
44
+ if model is None:
45
+ print('There is no pre-trained model! Please train your model.')
46
+ else:
47
+ print('There is a pre-trained model')
48
+ input_cols = training_state.value['input_cols'].value
49
+ output_cols = training_state.value['output_cols'].value
50
+ trn_ratio = training_state.value['trn_ratio'].value
51
+ batch_size_trn = training_state.value['batch_size_trn'].value
52
+ batch_size_val = training_state.value['batch_size_val'].value
53
+ seed = training_state.value['seed'].value
54
+ predictions = predict(model, dff, input_cols, output_cols, trn_ratio,
55
+ batch_size_trn, batch_size_val, seed)
56
+
57
+ local_state.value['predictions'].set(predictions)
58
+ force_render()
59
+
60
+ solara.Button(label='Output Predictions', on_click=make_predictions)
61
+ ScatterPlot(local_state.value['predictions'].value, local_state.value['render_count'].value)