File size: 4,010 Bytes
d7e53e8
 
 
 
 
4928a1a
d7e53e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0c822a
d7e53e8
 
 
c1d1768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7e53e8
c1d1768
 
 
 
 
 
 
 
 
 
 
 
d7e53e8
 
c1d1768
 
 
 
d7e53e8
c1d1768
 
 
d7e53e8
 
 
 
 
4928a1a
 
 
 
 
 
 
 
 
 
 
 
d7e53e8
 
 
 
 
 
 
 
4928a1a
 
 
 
 
 
d7e53e8
 
 
 
4928a1a
d7e53e8
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import warnings
warnings.filterwarnings("ignore")

from ui.helpers import (
    reset_on_task_change,
    update_models, 
    update_graphs, 
    preview_csv, 
    reset_metrics_on_file_clear, 
    toggle_csv_preview
)

from ui.theme import OrangeRedTheme
from ui.styles import CSS_STYLE

from core.training import train_model
from core.detection import auto_set_task

with gr.Blocks() as app:
    with gr.Column(elem_id="container"):
        gr.Markdown("## 👉 Supervised Learning Model Trainer")

        with gr.Row(equal_height=True):
            with gr.Column():
                with gr.Row(equal_height=True):
                    with gr.Column():
                        file_input = gr.File(label="Upload CSV", file_types=[".csv"])
                        show_preview = gr.Checkbox(
                            label="Show CSV Preview",
                            value=False,
                        )
                        task_type = gr.Dropdown(
                            ["Regression", "Classification"],
                            label="Task Type",
                            value="Regression",
                        )
                        model_group = gr.Dropdown(
                            label="Model Group",
                            choices=["Basic", "Bagging", "Boosting", "Stacking"],
                            value="Basic",
                        )
                        model_name = gr.Dropdown(label="Model")
                        graph_type = gr.Dropdown(label="Graph Type")

                with gr.Row(equal_height=True):
                    gr.Examples(
                        examples=[
                            ["assets/regression_example.csv", "Regression", "Basic"],
                            ["assets/classification_example.csv", "Classification", "Basic"],
                        ],
                        inputs=[file_input, task_type, model_group],
                        label="Example Datasets",
                        examples_per_page=2,
                    )
                with gr.Row(equal_height=True):
                    run_btn = gr.Button("Train & Evaluate", variant="primary", size="lg")

            with gr.Column():
                with gr.Row(equal_height=True):
                    with gr.Column():
                        csv_preview = gr.Dataframe(label="CSV Preview", interactive=False, visible=False,)
                        output = gr.Dataframe(label="Evaluation Metrics", interactive=False)

                with gr.Row(equal_height=True):
                    with gr.Column():
                        plot = gr.Plot(label="Selected Graph")

        file_input.change(preview_csv, file_input, csv_preview)
        file_input.change(auto_set_task, file_input, task_type)
        file_input.change(reset_metrics_on_file_clear,inputs=file_input,outputs=[output, plot])

        task_type.change(
            reset_on_task_change,
            inputs=task_type,
            outputs=[model_group, model_name],
        )

        model_group.change(
            update_models,
            inputs=[task_type, model_group],
            outputs=model_name,
        )

        task_type.change(update_graphs, task_type, graph_type)

        show_preview.change(
            toggle_csv_preview,
            inputs=show_preview,
            outputs=csv_preview,
        )

        app.load(
            reset_on_task_change,
            inputs=task_type,
            outputs=[model_group, model_name],
        )

        app.load(update_graphs, task_type, graph_type)

        run_btn.click(
            train_model,
            inputs=[file_input, task_type, model_group, model_name, graph_type],
            outputs=[output, plot]
        )


if __name__ == "__main__":
    orange_red_theme = OrangeRedTheme()
    app.queue().launch(
        theme=orange_red_theme,
        css=CSS_STYLE,
        show_error=True,
        server_name="0.0.0.0",
        server_port=7860,
        debug=True
    )