import gradio as gr import warnings warnings.filterwarnings("ignore") from ui.helpers import ( reset_on_task_change, update_models, update_graphs, preview_csv, reset_metrics_on_file_clear, toggle_csv_preview ) from ui.theme import OrangeRedTheme from ui.styles import CSS_STYLE from core.training import train_model from core.detection import auto_set_task with gr.Blocks() as app: with gr.Column(elem_id="container"): gr.Markdown("## 👉 Supervised Learning Model Trainer") with gr.Row(equal_height=True): with gr.Column(): with gr.Row(equal_height=True): with gr.Column(): file_input = gr.File(label="Upload CSV", file_types=[".csv"]) show_preview = gr.Checkbox( label="Show CSV Preview", value=False, ) task_type = gr.Dropdown( ["Regression", "Classification"], label="Task Type", value="Regression", ) model_group = gr.Dropdown( label="Model Group", choices=["Basic", "Bagging", "Boosting", "Stacking"], value="Basic", ) model_name = gr.Dropdown(label="Model") graph_type = gr.Dropdown(label="Graph Type") with gr.Row(equal_height=True): gr.Examples( examples=[ ["assets/regression_example.csv", "Regression", "Basic"], ["assets/classification_example.csv", "Classification", "Basic"], ], inputs=[file_input, task_type, model_group], label="Example Datasets", examples_per_page=2, ) with gr.Row(equal_height=True): run_btn = gr.Button("Train & Evaluate", variant="primary", size="lg") with gr.Column(): with gr.Row(equal_height=True): with gr.Column(): csv_preview = gr.Dataframe(label="CSV Preview", interactive=False, visible=False,) output = gr.Dataframe(label="Evaluation Metrics", interactive=False) with gr.Row(equal_height=True): with gr.Column(): plot = gr.Plot(label="Selected Graph") file_input.change(preview_csv, file_input, csv_preview) file_input.change(auto_set_task, file_input, task_type) file_input.change(reset_metrics_on_file_clear,inputs=file_input,outputs=[output, plot]) task_type.change( reset_on_task_change, inputs=task_type, outputs=[model_group, model_name], ) model_group.change( update_models, inputs=[task_type, model_group], outputs=model_name, ) task_type.change(update_graphs, task_type, graph_type) show_preview.change( toggle_csv_preview, inputs=show_preview, outputs=csv_preview, ) app.load( reset_on_task_change, inputs=task_type, outputs=[model_group, model_name], ) app.load(update_graphs, task_type, graph_type) run_btn.click( train_model, inputs=[file_input, task_type, model_group, model_name, graph_type], outputs=[output, plot] ) if __name__ == "__main__": orange_red_theme = OrangeRedTheme() app.queue().launch( theme=orange_red_theme, css=CSS_STYLE, show_error=True, server_name="0.0.0.0", server_port=7860, debug=True )