| import gradio as gr | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| from ui.helpers import ( | |
| reset_on_task_change, | |
| update_models, | |
| update_graphs, | |
| preview_csv, | |
| reset_metrics_on_file_clear, | |
| toggle_csv_preview | |
| ) | |
| from ui.theme import OrangeRedTheme | |
| from ui.styles import CSS_STYLE | |
| from core.training import train_model | |
| from core.detection import auto_set_task | |
| with gr.Blocks() as app: | |
| with gr.Column(elem_id="container"): | |
| gr.Markdown("## ๐ Supervised Learning Model Trainer") | |
| with gr.Row(equal_height=True): | |
| with gr.Column(): | |
| with gr.Row(equal_height=True): | |
| with gr.Column(): | |
| file_input = gr.File(label="Upload CSV", file_types=[".csv"]) | |
| show_preview = gr.Checkbox( | |
| label="Show CSV Preview", | |
| value=False, | |
| ) | |
| task_type = gr.Dropdown( | |
| ["Regression", "Classification"], | |
| label="Task Type", | |
| value="Regression", | |
| ) | |
| model_group = gr.Dropdown( | |
| label="Model Group", | |
| choices=["Basic", "Bagging", "Boosting", "Stacking"], | |
| value="Basic", | |
| ) | |
| model_name = gr.Dropdown(label="Model") | |
| graph_type = gr.Dropdown(label="Graph Type") | |
| with gr.Row(equal_height=True): | |
| gr.Examples( | |
| examples=[ | |
| ["assets/regression_example.csv", "Regression", "Basic"], | |
| ["assets/classification_example.csv", "Classification", "Basic"], | |
| ], | |
| inputs=[file_input, task_type, model_group], | |
| label="Example Datasets", | |
| examples_per_page=2, | |
| ) | |
| with gr.Row(equal_height=True): | |
| run_btn = gr.Button("Train & Evaluate", variant="primary", size="lg") | |
| with gr.Column(): | |
| with gr.Row(equal_height=True): | |
| with gr.Column(): | |
| csv_preview = gr.Dataframe(label="CSV Preview", interactive=False, visible=False,) | |
| output = gr.Dataframe(label="Evaluation Metrics", interactive=False) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(): | |
| plot = gr.Plot(label="Selected Graph") | |
| file_input.change(preview_csv, file_input, csv_preview) | |
| file_input.change(auto_set_task, file_input, task_type) | |
| file_input.change(reset_metrics_on_file_clear,inputs=file_input,outputs=[output, plot]) | |
| task_type.change( | |
| reset_on_task_change, | |
| inputs=task_type, | |
| outputs=[model_group, model_name], | |
| ) | |
| model_group.change( | |
| update_models, | |
| inputs=[task_type, model_group], | |
| outputs=model_name, | |
| ) | |
| task_type.change(update_graphs, task_type, graph_type) | |
| show_preview.change( | |
| toggle_csv_preview, | |
| inputs=show_preview, | |
| outputs=csv_preview, | |
| ) | |
| app.load( | |
| reset_on_task_change, | |
| inputs=task_type, | |
| outputs=[model_group, model_name], | |
| ) | |
| app.load(update_graphs, task_type, graph_type) | |
| run_btn.click( | |
| train_model, | |
| inputs=[file_input, task_type, model_group, model_name, graph_type], | |
| outputs=[output, plot] | |
| ) | |
| if __name__ == "__main__": | |
| orange_red_theme = OrangeRedTheme() | |
| app.queue().launch( | |
| theme=orange_red_theme, | |
| css=CSS_STYLE, | |
| show_error=True, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| debug=True | |
| ) |