Pawan Mane
Examples Changes
c1d1768
import gradio as gr
import warnings
warnings.filterwarnings("ignore")
from ui.helpers import (
reset_on_task_change,
update_models,
update_graphs,
preview_csv,
reset_metrics_on_file_clear,
toggle_csv_preview
)
from ui.theme import OrangeRedTheme
from ui.styles import CSS_STYLE
from core.training import train_model
from core.detection import auto_set_task
with gr.Blocks() as app:
with gr.Column(elem_id="container"):
gr.Markdown("## ๐Ÿ‘‰ Supervised Learning Model Trainer")
with gr.Row(equal_height=True):
with gr.Column():
with gr.Row(equal_height=True):
with gr.Column():
file_input = gr.File(label="Upload CSV", file_types=[".csv"])
show_preview = gr.Checkbox(
label="Show CSV Preview",
value=False,
)
task_type = gr.Dropdown(
["Regression", "Classification"],
label="Task Type",
value="Regression",
)
model_group = gr.Dropdown(
label="Model Group",
choices=["Basic", "Bagging", "Boosting", "Stacking"],
value="Basic",
)
model_name = gr.Dropdown(label="Model")
graph_type = gr.Dropdown(label="Graph Type")
with gr.Row(equal_height=True):
gr.Examples(
examples=[
["assets/regression_example.csv", "Regression", "Basic"],
["assets/classification_example.csv", "Classification", "Basic"],
],
inputs=[file_input, task_type, model_group],
label="Example Datasets",
examples_per_page=2,
)
with gr.Row(equal_height=True):
run_btn = gr.Button("Train & Evaluate", variant="primary", size="lg")
with gr.Column():
with gr.Row(equal_height=True):
with gr.Column():
csv_preview = gr.Dataframe(label="CSV Preview", interactive=False, visible=False,)
output = gr.Dataframe(label="Evaluation Metrics", interactive=False)
with gr.Row(equal_height=True):
with gr.Column():
plot = gr.Plot(label="Selected Graph")
file_input.change(preview_csv, file_input, csv_preview)
file_input.change(auto_set_task, file_input, task_type)
file_input.change(reset_metrics_on_file_clear,inputs=file_input,outputs=[output, plot])
task_type.change(
reset_on_task_change,
inputs=task_type,
outputs=[model_group, model_name],
)
model_group.change(
update_models,
inputs=[task_type, model_group],
outputs=model_name,
)
task_type.change(update_graphs, task_type, graph_type)
show_preview.change(
toggle_csv_preview,
inputs=show_preview,
outputs=csv_preview,
)
app.load(
reset_on_task_change,
inputs=task_type,
outputs=[model_group, model_name],
)
app.load(update_graphs, task_type, graph_type)
run_btn.click(
train_model,
inputs=[file_input, task_type, model_group, model_name, graph_type],
outputs=[output, plot]
)
if __name__ == "__main__":
orange_red_theme = OrangeRedTheme()
app.queue().launch(
theme=orange_red_theme,
css=CSS_STYLE,
show_error=True,
server_name="0.0.0.0",
server_port=7860,
debug=True
)