| import os |
| import sys |
| import dash |
| from dash import dcc, html, Output, Input, State |
| from dash import callback_context |
| import dash_bootstrap_components as dbc |
| import logging |
| import base64 |
| import io |
| import matplotlib.pyplot as plt |
| import matplotlib.image as mpimg |
| import numpy as np |
| import random |
| import pandas as pd |
| from datetime import datetime |
| import re |
| import uuid |
| import threading |
| import queue |
| import yaml |
| import subprocess |
| import time |
| import signal |
| import urllib.parse |
| import zipfile |
| import tarfile |
| import shutil |
| from PIL import Image |
| import io as _io |
| |
| import multiprocessing |
| from PIL import Image, ImageOps |
| import re |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import random |
| import io, base64, os |
| import matplotlib.image as mpimg |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| ST_MODELZOO_SERVICES_REPO_URL = "https://github.com/STMicroelectronics/stm32ai-modelzoo-services/tree/main" |
|
|
| OFFICIAL_SPACE_OWNER = os.getenv("OFFICIAL_SPACE_OWNER", "STMicroelectronics") |
| OFFICIAL_SPACE_NAME = os.getenv("OFFICIAL_SPACE_NAME", "stm32-modelzoo-app") |
|
|
|
|
| def _host_variants(owner: str, space: str): |
| owner_l = (owner or "").strip().lower() |
| space_l = (space or "").strip().lower() |
|
|
| raw = [ |
| f"{owner_l}-{space_l}.hf.space", |
| f"{owner_l}--{space_l}.hf.space", |
| f"{owner_l}-{space_l.replace('_','-')}.hf.space", |
| f"{owner_l}--{space_l.replace('_','-')}.hf.space", |
| ] |
| |
| out = [] |
| for h in raw: |
| if h and h not in out: |
| out.append(h) |
| return out |
|
|
| |
| _DEFAULT_ORIGINAL_HOSTS = ",".join(_host_variants(OFFICIAL_SPACE_OWNER, OFFICIAL_SPACE_NAME)) |
| ORIGINAL_SPACE_HOSTS = [h.strip().lower() for h in os.getenv("ORIGINAL_SPACE_HOSTS", _DEFAULT_ORIGINAL_HOSTS).split(",") if h.strip()] |
|
|
| DUPLICATE_URL = f"https://huggingface.co/spaces/{OFFICIAL_SPACE_OWNER}/{OFFICIAL_SPACE_NAME}?duplicate=true" |
|
|
| app = dash.Dash( |
| __name__, |
| external_stylesheets=[dbc.themes.BOOTSTRAP], |
| suppress_callback_exceptions=True, |
| assets_folder=os.path.join(os.path.dirname(__file__), "assets") |
| ) |
| app.title = "STM32AI Model Zoo Dashboard" |
|
|
| app.index_string = ''' |
| <!DOCTYPE html> |
| <html> |
| <head> |
| {%metas%} |
| <title>{%title%}</title> |
| {%favicon%} |
| {%css%} |
| <style> |
| body { background: #ffffff !important; font-family: Arial, sans-serif; } |
| .navbar-custom { background-color: #0E2140 !important; box-shadow: 0 2px 8px rgba(50,50,80,0.08); } |
| .footer-custom { background-color: #0E2140 !important; } |
| .card { background: #ffffff !important; border: 1px solid #A6ADB5 !important; } |
| .btn-primary, .navbar-yellow-btn { |
| background-color: #FFD200 !important; |
| border-color: #FFD200 !important; |
| color: #03234B !important; |
| } |
| .navbar-gray-btn { |
| background-color: #A6ADB5 !important; |
| border-color: #A6ADB5 !important; |
| color: #03234B !important; |
| } |
| .btn-success { background-color: #525A63 !important; border-color: #525A63 !important; color: #ffffff !important; } |
| h2, h4 { color: #03234B !important; } |
| hr { border-top: 2px solid #ffffff !important; } |
| .dash-upload { border: 2px dashed #3CB4E6 !important; background: #ffffff !important; color: #525A63 !important; } |
| .nav-link-custom { color: #ffffff !important; font-weight: bold; font-size: 1.1rem; border-radius: 20px; padding: 0.5rem 1.2rem; transition: background 0.2s, color 0.2s; } |
| .nav-link-custom:hover { background: #3CB4E6 !important; color: #ffffff !important; } |
| .nav-link-active { background: #3CB4E6 !important; color: #ffffff !important; } |
| .navbar-separator { border-left: 3px solid #3CB4E6; height: 32px; border-radius: 2px; margin: 0 1.2rem; box-shadow: 0 0 4px #3CB4E6; } |
| .btn-stop-round { |
| width: 44px; |
| height: 44px; |
| border-radius: 50%; |
| padding: 0; |
| display: flex; |
| align-items: center; |
| justify-content: center; |
| font-size: 1.5rem; |
| background-color: #dc3545 !important; |
| border: none !important; |
| color: #fff !important; |
| box-shadow: 0 2px 8px rgba(50,50,80,0.08); |
| transition: background 0.2s; |
| } |
| .btn-stop-round:hover { |
| background-color: #b52a37 !important; |
| } |
| .section-title { |
| font-size: 1.75rem; |
| font-weight: 700; |
| margin: 0 0 0.5rem 0; |
| color: #03234B; |
| line-height: 1.2; |
| } |
| .section-subtitle { |
| font-size: 0.85rem; |
| color: #6c757d; |
| margin: 0 0 1.25rem 0; |
| } |
| .predict-logs { |
| font-family: monospace; |
| width: 100%; |
| height: 100%; |
| } |
| .predict-image { |
| font-family: Arial, sans-serif; |
| width: 100%; |
| height: 100%; |
| display: flex; |
| align-items: center; |
| justify-content: center; |
| overflow: auto; |
| } |
| .predict-image img { |
| max-width: 100%; |
| height: auto; |
| display: block; |
| } |
| /* Showcase link style */ |
| a.info-link { color: #2167b0; text-decoration: underline; } |
| a.info-link:hover { color: #114a86; text-decoration: underline; } |
| /* Placeholder skeleton */ |
| .img-skeleton { |
| width: 100%; |
| height: 320px; |
| background: linear-gradient(90deg, #f0f2f5 25%, #e6e9ef 37%, #f0f2f5 63%); |
| border-radius: 8px; |
| } |
| /* Responsive card heights on small screens */ |
| @media (max-width: 768px) { |
| .responsive-card { min-height: 520px !important; height: 520px !important; } |
| } |
| </style> |
| </head> |
| <body> |
| {%app_entry%} |
| <footer> |
| {%config%} |
| {%scripts%} |
| {%renderer%} |
| </footer> |
| </body> |
| </html> |
| ''' |
|
|
| MODEL_PATH = "object_detection/pretrained_models/st_yoloxn_d033_w025_320_int8.tflite" |
| model_exists = os.path.isfile(MODEL_PATH) |
| if model_exists: |
| print(f"[INFO] Founded model : {MODEL_PATH}") |
| else: |
| print(f"[ERROR] Missing model : {MODEL_PATH}") |
|
|
| model_alert = None |
| if not model_exists: |
| model_alert = html.Div( |
| "[ERROR] onnx model object_detection/pretrained_models/st_yoloxn_d033_w025_320_int8.tflite not found !", |
| style={"background": "#ffcccc", "color": "#a00", "padding": "1rem", "textAlign": "center", "fontWeight": "bold"} |
| ) |
|
|
| app.layout = html.Div([ |
| model_alert, |
| html.Div([ |
| dcc.Location(id="app-url", refresh=False), |
| dcc.Store(id="is-original-host"), |
| dcc.Store(id="sidebar-mode", data="tab-train"), |
| dcc.Store(id="stored-image", data=None, storage_type="session"), |
| dcc.Store(id="input-image-size", data=None, storage_type="session"), |
| dcc.Store(id="predict-state", data="idle"), |
| dcc.Interval(id="progress-interval", interval=1000, n_intervals=0, disabled=True), |
| dcc.Store(id="stm32ai-credentials", data=None), |
| dcc.Store(id="user-id", storage_type="local"), |
| dcc.Store(id="train-log-raw", data=""), |
| dcc.Store(id="navbar-mode-store"), |
| dcc.Store(id="acc-fig-store"), |
| dcc.Store(id="loss-fig-store"), |
| dcc.Store(id="confusion-float-store"), |
| dcc.Store(id="confusion-quant-store"), |
| dcc.Store(id="confusion-last-type"), |
| dcc.Store(id="predict-image-store", data=None), |
| dcc.Location(id="url-initial", refresh=False), |
| html.Div( |
| id="duplicate-blocker", |
| children=html.Div( |
| dbc.Card( |
| dbc.CardBody([ |
| html.Div([ |
| html.Div( |
| html.Img( |
| src="/assets/ST_logo.png", |
| style={ |
| "height": "34px", |
| "width": "auto", |
| "display": "inline-block", |
| "objectFit": "contain", |
| "objectPosition": "center" |
| } |
| ), |
| style={ |
| "width": "100%", |
| "textAlign": "center", |
| "lineHeight": 0 |
| } |
| ), |
| html.Div( |
| "STM32AI Model Zoo", |
| style={ |
| "width": "100%", |
| "color": "#FFFFFF", |
| "fontWeight": "bold", |
| "fontSize": "1.02rem", |
| "marginTop": "0.45rem", |
| "textAlign": "center" |
| } |
| ) |
| ], style={ |
| "width": "100%", |
| "boxSizing": "border-box", |
| "background": "#03234B", |
| "borderRadius": "14px", |
| "padding": "0.95rem 1.0rem", |
| "display": "flex", |
| "flexDirection": "column", |
| "alignItems": "center", |
| "justifyContent": "center" |
| }), |
|
|
| html.H2( |
| "Duplicate this Space to continue", |
| style={ |
| "color": "#03234B", |
| "marginTop": "1.25rem", |
| "marginBottom": "0.5rem", |
| "textAlign": "center", |
| "fontWeight": 800, |
| "fontSize": "1.6rem" |
| } |
| ), |
| html.P( |
| "This is the official ST reference Space. For a stable experience and to enable training and prediction, " |
| "please duplicate it to your own Hugging Face account.", |
| style={ |
| "color": "#525A63", |
| "fontSize": "0.98rem", |
| "textAlign": "center", |
| "maxWidth": "720px", |
| "margin": "0 auto 1.0rem auto", |
| "lineHeight": "1.5" |
| } |
| ), |
|
|
| html.Div([ |
| html.Div( |
| [ |
| html.Div("1", style={ |
| "width": "26px", "height": "26px", "borderRadius": "50%", |
| "background": "#03234B", "color": "#fff", |
| "display": "flex", "alignItems": "center", "justifyContent": "center", |
| "fontWeight": "bold", "marginRight": "0.6rem" |
| }), |
| html.Div("Click Duplicate Space", style={"fontWeight": 600, "color": "#03234B"}) |
| ], |
| style={"display": "flex", "alignItems": "center", "marginBottom": "0.5rem"} |
| ), |
| html.Div( |
| [ |
| html.Div("2", style={ |
| "width": "26px", "height": "26px", "borderRadius": "50%", |
| "background": "#03234B", "color": "#fff", |
| "display": "flex", "alignItems": "center", "justifyContent": "center", |
| "fontWeight": "bold", "marginRight": "0.6rem" |
| }), |
| html.Div("Wait for the build to finish", style={"fontWeight": 600, "color": "#03234B"}) |
| ], |
| style={"display": "flex", "alignItems": "center", "marginBottom": "0.5rem"} |
| ), |
| html.Div( |
| [ |
| html.Div("3", style={ |
| "width": "26px", "height": "26px", "borderRadius": "50%", |
| "background": "#03234B", "color": "#fff", |
| "display": "flex", "alignItems": "center", "justifyContent": "center", |
| "fontWeight": "bold", "marginRight": "0.6rem" |
| }), |
| html.Div("Open your duplicated Space URL", style={"fontWeight": 600, "color": "#03234B"}) |
| ], |
| style={"display": "flex", "alignItems": "center"} |
| ) |
| ], style={ |
| "background": "#EEEFF1", |
| "border": "1px solid #A6ADB5", |
| "borderRadius": "12px", |
| "padding": "0.9rem 1.0rem", |
| "marginTop": "0.75rem", |
| "marginBottom": "1.25rem" |
| }), |
|
|
| html.Div([ |
| html.A( |
| "Duplicate Space", |
| href=DUPLICATE_URL, |
| target="_blank", |
| style={ |
| "background": "#FFD200", |
| "color": "#03234B", |
| "padding": "0.9rem 1.4rem", |
| "borderRadius": "10px", |
| "fontWeight": "bold", |
| "fontSize": "1.02rem", |
| "textDecoration": "none", |
| "display": "inline-block", |
| "boxShadow": "0 6px 18px rgba(0,0,0,0.18)", |
| "border": "1px solid rgba(3,35,75,0.18)" |
| } |
| ) |
| ], style={"display": "flex", "justifyContent": "center", "flexWrap": "wrap", "gap": "0.5rem"}), |
|
|
| html.Div( |
| [ |
| html.Div("You will be redirected to a new Space under your account.", style={"fontWeight": 600}), |
| html.Div("Once duplicated, reload your new URL to unlock all features.") |
| ], |
| style={ |
| "marginTop": "1.25rem", |
| "color": "#525A63", |
| "fontSize": "0.9rem", |
| "textAlign": "center" |
| } |
| ) |
| ]), |
| style={ |
| "maxWidth": "780px", |
| "width": "100%", |
| "borderRadius": "16px", |
| "border": "1px solid rgba(255,210,0,0.55)", |
| "boxShadow": "0 14px 45px rgba(0,0,0,0.35)", |
| "overflow": "hidden", |
| }, |
| className="card" |
| ), |
| style={ |
| "width": "100%", |
| "height": "100%", |
| "display": "flex", |
| "alignItems": "center", |
| "justifyContent": "center" |
| } |
| ), |
| style={ |
| "position": "fixed", |
| "top": 0, |
| "left": 0, |
| "right": 0, |
| "bottom": 0, |
| "background": "linear-gradient(135deg, rgba(3,35,75,0.98), rgba(3,35,75,0.90))", |
| "zIndex": 99999, |
| "display": "none", |
| "padding": "2rem" |
| } |
| ), |
| html.Div( |
| id="navbar-main", |
| children=dbc.Navbar( |
| dbc.Container([ |
| html.Div( |
| html.Img(src="/assets/ST_logo.png", |
| style={"height": "38px", "marginRight": "1.2rem"}), |
| style={"display": "flex", "alignItems": "center"} |
| ), |
| dbc.NavbarBrand( |
| "STM32AI Model Zoo Experimentation Hub", |
| style={ |
| "color": "#fff", |
| "fontWeight": "bold", |
| "fontSize": "1.3rem", |
| "marginRight": "2rem" |
| } |
| ), |
| dbc.Nav([ |
| dbc.Button( |
| "Training Mode (Image classification)", |
| id="nav-predict", |
| color="primary", |
| outline=False, |
| className="me-2 navbar-yellow-btn", |
| n_clicks=0 |
| ), |
| dbc.Button( |
| "Prediction Mode (Object detection)", |
| id="nav-train", |
| color="primary", |
| outline=False, |
| className="navbar-gray-btn", |
| n_clicks=0 |
| ), |
| html.A( |
| html.Img( |
| src=app.get_asset_url("github.svg"), |
| style={"height": "22px", "width": "22px", "display": "block"} |
| ), |
| id="navbar-github-link", |
| href=ST_MODELZOO_SERVICES_REPO_URL, |
| target="_blank", |
| title="STM32AI Model Zoo Services Repository", |
| style={ |
| "display": "inline-flex", |
| "alignItems": "center", |
| "justifyContent": "center", |
| "padding": "0.45rem", |
| "borderRadius": "999px", |
| "textDecoration": "none", |
| "background": "transparent", |
| "border": "1px solid rgba(255,255,255,0.22)", |
| }, |
| ), |
| ], className="ms-auto", style={"gap": "0.5rem"}) |
| ], style={"display": "flex", "alignItems": "center"}), |
| color="#0E2140", |
| dark=True, |
| style={"marginBottom": "0.5rem", "padding": "0.3rem 2rem", "height": "56px", "backgroundColor": "#0E2140"}, |
| className="navbar-custom" |
| ) |
| ), |
| html.Div([ |
| html.Div( |
| id="sidebar-content", |
| children=[ |
| html.Div(style={"flex": 1}) |
| ], |
| style={ |
| "display": "flex", |
| "flexDirection": "column", |
| "alignItems": "center", |
| "background": "#EEEFF1", |
| "paddingTop": "2rem", |
| "minWidth": "90px", |
| "maxWidth": "90px", |
| "width": "90px", |
| "boxSizing": "border-box" |
| } |
| ), |
| html.Div([ |
| dcc.Location(id="netron-url"), |
| dcc.Store(id="yaml-modal-store"), |
| dbc.Modal([ |
| dbc.ModalHeader(dbc.ModalTitle("YAML Content")), |
| dbc.ModalBody(id="yaml-modal-body"), |
| dbc.ModalFooter( |
| dbc.Button("Close", id="close-yaml-modal", className="ms-auto", n_clicks=0) |
| ), |
| ], id="yaml-modal", is_open=False, size="lg"), |
| dbc.Modal([ |
| dbc.ModalHeader(dbc.ModalTitle("Visualize Model with Netron")), |
| dbc.ModalBody( |
| html.Iframe( |
| id="netron-iframe", |
| src="", |
| style={"width": "100%", "height": "600px", "border": "none"} |
| ) |
| ), |
| dbc.ModalFooter( |
| dbc.Button("Close", id="close-netron-modal", className="ms-auto", n_clicks=0) |
| ), |
| ], id="netron-modal", is_open=False, size="xl"), |
| dbc.Modal([ |
| dbc.ModalHeader(dbc.ModalTitle("Visualize Flowers Dataset")), |
| dbc.ModalBody(id="dataset-modal-body"), |
| dbc.ModalFooter( |
| dbc.Button("Close", id="close-dataset-modal", className="ms-auto", n_clicks=0) |
| ), |
| ], id="dataset-modal", is_open=False, size="xl"), |
| dbc.Modal([ |
| dbc.ModalHeader(dbc.ModalTitle("Confusion Matrix")), |
| dbc.ModalBody(id="confusion-matrix-modal-body"), |
| dbc.ModalFooter( |
| dbc.Button("Close", id="close-confusion-matrix-modal", className="ms-auto", n_clicks=0) |
| ), |
| ], id="confusion-matrix-modal", is_open=False, size="xl"), |
| dcc.Upload(id="upload-image", style={"display": "none"}), |
| html.Div(id="page-content", style={"marginTop": "2rem"}), |
| dbc.Button("", id="show-dataset-btn", style={"display": "none"}), |
| |
| dbc.Button("", id="retry-predict-btn", style={"display": "none"}) |
| ], style={"flex": 1, "minWidth": 0, "overflow": "auto", "marginLeft": "40px"}) |
| ], style={ |
| "display": "flex", |
| "flexDirection": "row", |
| "flex": "1 1 auto", |
| "minHeight": 0, |
| "height": "100%", |
| "margin": "0" |
| }) |
| ], style={"display": "flex", "flexDirection": "column", "height": "100vh", "minHeight": "100vh"}) |
| ]) |
|
|
| @app.callback( |
| Output("sidebar-mode", "data"), |
| [Input("nav-predict", "n_clicks"), Input("nav-train", "n_clicks")], |
| prevent_initial_call=True |
| ) |
|
|
| def switch_mode(n_predict, n_train): |
| ctx = dash.callback_context |
| if not ctx.triggered: |
| return dash.no_update |
| btn_id = ctx.triggered[0]["prop_id"].split(".")[0] |
| if btn_id == "nav-predict": |
| return "tab-predict" |
| elif btn_id == "nav-train": |
| return "tab-train" |
| return dash.no_update |
|
|
| |
| @app.callback( |
| Output("sidebar-content", "children"), |
| Input("sidebar-mode", "data") |
| ) |
| def update_sidebar(tab_value): |
| sidebar = [] |
| if tab_value == "tab-predict": |
| sidebar = [ |
| dbc.Button(html.Img(src="/assets/yaml_logo.png", style={"height": "35px"}), id="sidebar-yaml", color="light", style={"margin": "1rem 0", "background": "#EEEFF1", "border": "none"}), |
| dbc.Tooltip("Visualize YAML file for image_classification", target="sidebar-yaml", placement="right"), |
| dbc.Button(html.Img(src="/assets/media.png", style={"height": "35px"}), id="sidebar-visualize-dataset", color="light", style={"margin": "1rem 0", "background": "#EEEFF1", "border": "none"}), |
| dbc.Tooltip("Visualize Dataset", target="sidebar-visualize-dataset", placement="right"), |
| dbc.Button(html.Img(src="/assets/netron.png", style={"height": "65px", "width": "60px", "objectFit": "contain"}), id="sidebar-netron", color="light", style={"margin": "1rem 0", "background": "#EEEFF1", "border": "none"}), |
| dbc.Tooltip("Visualize Model with Netron", target="sidebar-netron", placement="right"), |
| html.Div(style={"flex": 1}), |
| ] |
| elif tab_value == "tab-train": |
| sidebar = [ |
| dbc.Button(html.Img(src="/assets/yaml_logo.png", style={"height": "35px"}), id="sidebar-yaml", color="light", style={"margin": "1rem 0", "background": "#EEEFF1", "border": "none"}), |
| dbc.Button(html.Img(src="/assets/netron.png", style={"height": "65px", "width": "60px", "objectFit": "contain"}), id="sidebar-netron", color="light", style={"margin": "1rem 0", "background": "#EEEFF1", "border": "none"}), |
| dbc.Tooltip("Visualize Model with Netron", target="sidebar-netron", placement="right"), |
| html.Div(style={"flex": 1}), |
| ] |
| return sidebar |
|
|
|
|
| |
| def render_model_perf_od(): |
| |
| grouped_cols = [ |
| "Model", |
| "Hyperparameters (depth_width)", |
| "Serie", |
| "Dataset", |
| "Format", |
| ] |
| variant_cols = [ |
| "Resolution", |
| "Internal RAM (KB)", |
| "Weights Flash (KB)", |
| "Inference Time (ms)", |
| ] |
|
|
| |
| base = { |
| "Model": "st_yoloxn", |
| "Hyperparameters (depth_width)": "d033_w025", |
| "Serie": "STM32N6", |
| "Dataset": "COCO-Person", |
| "Format": "Quantized INT8", |
| } |
| variants = [ |
| {"Resolution": "192x192x3", "Internal RAM (KB)": "333", "Weights Flash (KB)": "877", "Inference Time (ms)": "6"}, |
| {"Resolution": "256x256x3", "Internal RAM (KB)": "624", "Weights Flash (KB)": "885", "Inference Time (ms)": "9"}, |
| {"Resolution": "320x320x3", "Internal RAM (KB)": "1125", "Weights Flash (KB)": "895", "Inference Time (ms)": "13"}, |
| {"Resolution": "416x416x3", "Internal RAM (KB)": "2676", "Weights Flash (KB)": "904", "Inference Time (ms)": "21"}, |
| ] |
| rows = [{**base, **v} for v in variants] |
|
|
| |
| columns = grouped_cols + variant_cols |
| header = html.Thead(html.Tr([html.Th(col) for col in columns])) |
|
|
| |
| groups = {} |
| for r in rows: |
| key = tuple(r.get(c, "-") for c in grouped_cols) |
| groups.setdefault(key, []).append(r) |
|
|
| body_rows = [] |
| for key, group_rows in groups.items(): |
| span = len(group_rows) |
| for i, r in enumerate(group_rows): |
| tds = [] |
| |
| if i == 0: |
| for j, c in enumerate(grouped_cols): |
| tds.append(html.Td(r.get(c, "-"), rowSpan=span)) |
| |
| for c in variant_cols: |
| tds.append(html.Td(r.get(c, "-"))) |
| body_rows.append(html.Tr(tds)) |
| body = html.Tbody(body_rows) |
|
|
| return dbc.Card([ |
| dbc.CardHeader( |
| html.Div([ |
| html.Span("ST Yolo X Quantized Model Performances overview"), |
| dbc.Badge("STM32 Reference", color="secondary", pill=True, className="ms-2") |
| ], style={"display": "flex", "alignItems": "center", "gap": "8px"}), |
| style={"fontWeight": "bold", "fontSize": "1.1rem", "background": "#f8f9fa"} |
| ), |
| dbc.CardBody([ |
| html.P( |
| " ST Yolo X Quantized Performances on STM32N6. Values depend on resolution and quantization. Fill with STM32Cube.AI 3.0.0 results.", |
| style={"color": "#525A63", "fontSize": "0.95rem", "marginBottom": "0.75rem"} |
| ), |
| dbc.Table([header, body], bordered=True, hover=True, responsive=True, className="table-sm"), |
| ]) |
| ], style={ |
| "boxShadow": "0 2px 8px rgba(50,50,80,0.08)", |
| "border": "1px solid #A6ADB5", |
| "marginTop": "0.5rem", |
| "marginBottom": "0.5rem", |
| }) |
|
|
| def render_object_detection_page(): |
| card_height = "min(760px, calc(100vh - 240px))" |
| return [ |
| html.H2("Object Detection Prediction", className="section-title"), |
| html.P("Upload an image to run object detection.", className="section-subtitle"), |
| render_showcase_info_card(), |
| dbc.Row([ |
| dbc.Col([ |
| dbc.Card([ |
| dbc.CardHeader("Image Upload & Preview", style={ |
| "fontWeight": "bold", "fontSize": "1.1rem", "background": "#f8f9fa" |
| }), |
| dbc.CardBody([ |
| html.Div( |
| id="upload-preview-box", |
| style={ |
| "flex": 1, |
| "display": "flex", |
| "flexDirection": "column", |
| "justifyContent": "center", |
| "alignItems": "center", |
| "overflow": "hidden", |
| "padding": "0.25rem" |
| } |
| ), |
| html.Div([ |
| dbc.Button( |
| "Launch Prediction", |
| id="predict-btn", |
| color="primary", |
| style={ |
| "marginBottom": "1rem", |
| "boxShadow": "0 4px 16px rgba(60,180,230,0.18), 0 1.5px 4px rgba(50,50,80,0.10)", |
| "marginRight": "1rem" |
| } |
| ), |
| dbc.Button( |
| "Reset image", |
| id="reset-image-btn", |
| color="secondary", |
| style={"display": "none", "marginBottom": "1rem"} |
| ) |
| ], style={ |
| "display": "flex", |
| "flexDirection": "row", |
| "justifyContent": "center", |
| "alignItems": "center" |
| }), |
| html.Div([ |
| html.Div(id="conf-thresh-value", style={"fontWeight": 600, "marginBottom": "0.25rem", "textAlign": "center"}), |
| dcc.Slider( |
| id="conf-thresh-slider", |
| min=0.0, |
| max=1.0, |
| step=0.01, |
| value=0.5, |
| tooltip={"placement": "bottom", "always_visible": False}, |
| marks={0.0: "0.0", 0.5: "0.5", 1.0: "1.0"} |
| ), |
| html.Div("Adjust confidence threshold; changing it re-runs prediction.", style={"fontSize": "0.85rem", "color": "#6c757d", "marginTop": "0.25rem", "textAlign": "center"}) |
| ], style={"marginTop": "0.5rem", "padding": "0.25rem 0.25rem 0 0.25rem"}) |
| ], style={ |
| "overflowY": "auto", |
| "padding": "1rem", |
| "height": "100%", |
| "display": "flex", |
| "flexDirection": "column", |
| "boxSizing": "border-box" |
| }), |
| ], style={ |
| "boxShadow": "0 2px 8px rgba(50,50,80,0.08)", |
| "border": "1px solid #A6ADB5", |
| "height": "800px", |
| "minHeight": "800px", |
| "maxHeight": "800px", |
| "display": "flex", |
| "flexDirection": "column" |
| }) |
| ], width=6, style={ |
| "display": "flex", |
| "flexDirection": "column", |
| "height": card_height, |
| "minHeight": card_height, |
| "maxHeight": card_height, |
| "boxSizing": "border-box" |
| }), |
| dbc.Col([ |
| dbc.Card([ |
| dbc.CardHeader("Prediction Logs", style={ |
| "fontWeight": "bold", "fontSize": "1.1rem", "background": "#f8f9fa" |
| }), |
| dbc.CardBody([ |
| html.Div(id="progress-bar-box"), |
| dcc.Loading( |
| id="predict-loading", |
| type="circle", |
| color="#3CB4E6", |
| fullscreen=False, |
| children=html.Div( |
| id="predict-output", |
| style={ |
| "flex": 1, |
| "height": "100%", |
| "marginBottom": "1rem", |
| "background": "#f8f9fa", |
| "borderRadius": "8px", |
| "padding": "0.5rem", |
| "boxSizing": "border-box", |
| "display": "flex", |
| "flexDirection": "column", |
| "alignItems": "center", |
| "justifyContent": "center", |
| "overflowY": "auto", |
| "overflowX": "hidden" |
| } |
| ), |
| style={"flex": 1} |
| ), |
| dbc.Button( |
| "Download Prediction", |
| id="download-prediction-btn", |
| color="secondary", |
| style={"display":"none", "marginTop": "0.25rem", "marginBottom": "0.5rem"} |
| ), |
| dcc.Download(id="download-prediction"), |
| dcc.Interval(id="predict-interval", interval=300, n_intervals=0, disabled=True) |
| ], style={ |
| "overflow": "hidden", |
| "padding": "1rem", |
| "height": "100%", |
| "display": "flex", |
| "flexDirection": "column" |
| }), |
| ], style={ |
| "boxShadow": "0 2px 8px rgba(50,50,80,0.08)", |
| "border": "1px solid #A6ADB5", |
| "height": card_height, |
| "minHeight": card_height, |
| "display": "flex", |
| "flexDirection": "column" |
| }) |
| ], width=6, style={ |
| "display": "flex", |
| "flexDirection": "column", |
| "height": card_height, |
| "minHeight": card_height |
| }) |
| ], justify="center", align="stretch", style={"marginTop": "2rem", "alignItems": "stretch"}) |
| ] |
|
|
| PREDICT_JOB = {"process": None, "result": None, "tail": "", "finished": False, "image": None, "orig_size": None, "target_basename": None} |
| PREDICT_RESIZE_MODE = os.getenv("PREDICT_RESIZE_MODE", "stretch") |
| ACC_LOSS_BASE_LAYOUT = { |
| 'uirevision': 'metrics-static', |
| 'xaxis': { |
| 'title': 'Epochs', |
| 'showgrid': True, |
| 'gridcolor': '#EEEFF1', |
| 'linecolor': '#A6ADB5', |
| 'linewidth': 1, |
| 'fixedrange': True |
| }, |
| 'plot_bgcolor': '#fff', |
| 'paper_bgcolor': '#fff', |
| 'margin': {'l': 60, 'r': 40, 't': 10, 'b': 60}, |
| 'hovermode': 'x unified', |
| 'showlegend': True, |
| 'legend': {'bgcolor': 'rgba(0,0,0,0)'} |
| } |
|
|
| def launch_prediction_async(image_path, target_basename=None): |
| env = os.environ.copy() |
| env["OMP_NUM_THREADS"] = "1" |
| env["MKL_NUM_THREADS"] = "1" |
| env["NUMEXPR_NUM_THREADS"] = "1" |
| env.setdefault("HYDRA_FULL_ERROR", "1") |
|
|
| od_cwd = os.path.join(os.path.dirname(os.path.abspath(__file__)), "object_detection") |
| process = subprocess.Popen( |
| [sys.executable, "stm32ai_main.py", "--config-name", "user_config"], |
| stdout=subprocess.PIPE, |
| stderr=subprocess.STDOUT, |
| text=True, |
| bufsize=1, |
| env=env, |
| cwd=od_cwd |
| ) |
| output_lines = [] |
| for line in process.stdout: |
| output_lines.append(line) |
| |
| tail = "".join(output_lines[-80:]) |
| PREDICT_JOB["tail"] = tail |
| process.wait() |
| PREDICT_JOB["result"] = "".join(output_lines) |
| pred_image = None |
|
|
| |
| |
| exp_roots = [] |
| try: |
| yaml_cfg_path = os.path.join("object_detection", "user_config.yaml") |
| with open(yaml_cfg_path, "r", encoding="utf-8") as f: |
| ocfg = yaml.safe_load(f) or {} |
| run_dir = (ocfg.get("hydra") or {}).get("run", {}) |
| run_dir_str = run_dir.get("dir") if isinstance(run_dir, dict) else None |
| if isinstance(run_dir_str, str) and run_dir_str: |
| run_dir_str = run_dir_str.replace("\\", "/") |
| if "${now:" in run_dir_str: |
| run_dir_str = os.path.dirname(run_dir_str) |
| if run_dir_str.startswith("./"): |
| exp_roots.append(os.path.join("object_detection", run_dir_str[2:])) |
| else: |
| exp_roots.append(os.path.join("object_detection", run_dir_str) if not os.path.isabs(run_dir_str) else run_dir_str) |
| except Exception: |
| pass |
|
|
| |
| exp_roots.extend([ |
| os.path.join("object_detection", "tf", "src", "experiments_outputs"), |
| os.path.join("object_detection", "src", "experiments_outputs"), |
| ]) |
|
|
| def _find_latest_prediction_image(exp_base: str): |
| if not exp_base or not os.path.exists(exp_base): |
| return None |
| subs = [d for d in os.listdir(exp_base) |
| if os.path.isdir(os.path.join(exp_base, d)) and re.match(r"^20\d{2}_\d{2}_\d{2}_\d{2}_\d{2}_\d{2}$", d)] |
| if not subs: |
| return None |
| subs.sort(reverse=True) |
| last_exp = os.path.join(exp_base, subs[0], "predictions") |
| if not os.path.exists(last_exp): |
| return None |
| if target_basename: |
| for ext in (".jpg", ".jpeg", ".png"): |
| candidate = os.path.join(last_exp, f"{target_basename}_predict{ext}") |
| if os.path.exists(candidate): |
| return candidate |
| imgs = [os.path.join(last_exp, f) for f in os.listdir(last_exp) |
| if f.lower().endswith((".jpg", ".jpeg", ".png"))] |
| if imgs: |
| return max(imgs, key=os.path.getmtime) |
| return None |
|
|
| for root in exp_roots: |
| pred_image = _find_latest_prediction_image(root) |
| if pred_image: |
| break |
| PREDICT_JOB["image"] = pred_image |
| PREDICT_JOB["finished"] = True |
|
|
| @app.callback( |
| [Output("predict-state", "data"), |
| Output("progress-interval", "disabled"), |
| Output("progress-interval", "n_intervals"), |
| Output("predict-output", "children", allow_duplicate=True), |
| Output("predict-image-store", "data")], |
| [Input("predict-btn", "n_clicks"), Input("retry-predict-btn", "n_clicks"), Input("conf-thresh-slider", "value")], |
| State("stored-image", "data"), |
| State("input-image-size", "data"), |
| State("is-original-host", "data"), |
| prevent_initial_call=True |
| ) |
| def start_prediction(n, n_retry, conf_thresh, stored, input_size, is_original): |
| ctx = dash.callback_context |
| triggered = ctx.triggered[0]["prop_id"] if ctx.triggered else "" |
| n = n or 0 |
| n_retry = n_retry or 0 |
| slider_changed = triggered.startswith("conf-thresh-slider.") |
| if (n + n_retry) == 0 and not slider_changed: |
| return dash.no_update, True, dash.no_update, dash.no_update, dash.no_update |
| if is_original: |
| return "idle", True, 0, html.Div( |
| "Prediction disabled on reference Space. Duplicate to enable.", |
| style={"color": "#dc3545", "fontWeight": "bold"} |
| ), None |
| if not stored or "contents" not in stored: |
| return "idle", True, 0, html.Div("Upload an image first.", style={"color": "#dc3545"}), None |
| _, b64data = stored["contents"].split(",", 1) |
| raw = base64.b64decode(b64data) |
| fname = f"{uuid.uuid4().hex}_{os.path.basename(stored.get('filename','input.jpg'))}" |
| |
| image_path = os.path.join("object_detection", "datasets", fname) |
| os.makedirs(os.path.dirname(image_path), exist_ok=True) |
| with open(image_path, "wb") as f: |
| f.write(raw) |
| |
| yaml_cfg_path = os.path.join("object_detection", "user_config.yaml") |
| try: |
| with open(yaml_cfg_path, "r", encoding="utf-8") as yf: |
| ocfg = yaml.safe_load(yf) or {} |
|
|
| |
| ocfg["operation_mode"] = "prediction" |
|
|
| dataset_cfg = ocfg.setdefault("dataset", {}) |
| pred_dir = dataset_cfg.get("prediction_path") |
|
|
| |
| |
| if not isinstance(pred_dir, str) or not pred_dir or os.path.isabs(pred_dir): |
| pred_dir = "./datasets" |
| dataset_cfg["prediction_path"] = pred_dir |
|
|
| |
| pred_cfg = ocfg.setdefault("prediction", {}) |
| pred_cfg.setdefault("target", "host") |
| |
| if "test_files_path" in pred_cfg: |
| try: |
| del pred_cfg["test_files_path"] |
| except Exception: |
| pred_cfg.pop("test_files_path", None) |
|
|
| |
| pp_cfg = ocfg.setdefault("postprocessing", {}) |
| try: |
| val = float(conf_thresh) if conf_thresh is not None else 0.5 |
| except Exception: |
| val = 0.5 |
| pp_cfg["confidence_thresh"] = val |
|
|
| |
| if pred_dir.startswith("./"): |
| pred_dir_abs = os.path.join("object_detection", pred_dir[2:]) |
| else: |
| pred_dir_abs = os.path.join("object_detection", pred_dir) if not os.path.isabs(pred_dir) else pred_dir |
|
|
| os.makedirs(pred_dir_abs, exist_ok=True) |
| target_copy = os.path.join(pred_dir_abs, fname) |
| with open(target_copy, "wb") as cf: |
| cf.write(raw) |
|
|
| |
| with open(yaml_cfg_path, "w", encoding="utf-8") as yf: |
| yaml.dump(ocfg, yf, allow_unicode=True, sort_keys=False, default_flow_style=False) |
| except Exception as e: |
| print(f"[WARN] Unable to prepare prediction folder: {e}") |
| orig_size = None |
| if input_size and "width" in input_size and "height" in input_size: |
| orig_size = (int(input_size["width"]), int(input_size["height"])) |
| else: |
| try: |
| with Image.open(io.BytesIO(raw)) as im: |
| orig_size = im.size |
| except Exception: |
| orig_size = None |
| basename_no_ext = os.path.splitext(fname)[0] |
| PREDICT_JOB.update({ |
| "process": None, |
| "result": None, |
| "tail": "", |
| "finished": False, |
| "image": None, |
| "orig_size": orig_size, |
| "target_basename": basename_no_ext |
| }) |
| t = threading.Thread(target=launch_prediction_async, args=(image_path, basename_no_ext), daemon=True) |
| t.start() |
| return "running", False, 0, html.Div("Prediction launched...", style={"fontFamily": "monospace"}), None |
|
|
| def _resize_prediction_image(pred_path, orig_size, mode=None): |
| if not orig_size: |
| with open(pred_path, "rb") as f: |
| return base64.b64encode(f.read()).decode() |
| target_w, target_h = map(int, orig_size) |
| try: |
| with Image.open(pred_path) as im: |
| im = ImageOps.exif_transpose(im).convert("RGB") |
| out = im.resize((target_w, target_h), Image.Resampling.BILINEAR) |
| if out.size != (target_w, target_h): |
| out = out.resize((target_w, target_h), Image.Resampling.NEAREST) |
| buf = _io.BytesIO() |
| out.save(buf, format="PNG") |
| return base64.b64encode(buf.getvalue()).decode() |
| except Exception: |
| with open(pred_path, "rb") as f: |
| return base64.b64encode(f.read()).decode() |
|
|
| @app.callback( |
| [Output("predict-output", "children", allow_duplicate=True), |
| Output("predict-state", "data", allow_duplicate=True), |
| Output("progress-interval", "disabled", allow_duplicate=True), |
| Output("predict-image-store", "data", allow_duplicate=True)], |
| Input("progress-interval", "n_intervals"), |
| State("predict-state", "data"), |
| prevent_initial_call=True |
| ) |
| def poll_prediction(n, state): |
| if state != "running": |
| return dash.no_update, dash.no_update, dash.no_update, dash.no_update |
| if not PREDICT_JOB["finished"]: |
| |
| placeholder = html.Div([ |
| html.Div(className="img-skeleton") |
| ], style={"width": "100%"}) |
| return placeholder, "running", False, dash.no_update |
| blocks = [] |
| b64img = None |
| if PREDICT_JOB.get("image") and os.path.exists(PREDICT_JOB["image"]): |
| b64img = _resize_prediction_image(PREDICT_JOB["image"], PREDICT_JOB.get("orig_size"), PREDICT_RESIZE_MODE) |
| blocks.append( |
| html.Div( |
| className="predict-image", |
| children=html.Img( |
| src=f"data:image/png;base64,{b64img}", |
| style={ |
| "borderRadius": "8px", |
| "marginBottom": "1rem" |
| }, |
| alt="Prediction Image" |
| ), |
| style={"width": "100%", "height": "100%"} |
| ) |
| ) |
| else: |
| blocks.append(html.Div("No prediction image produced.", style={"color": "#dc3545", "fontWeight": "bold", "marginBottom": "0.5rem"})) |
| blocks.append(dbc.Button("Retry", id="retry-predict-btn", color="warning", outline=True, style={"marginBottom": "0.5rem"})) |
| if PREDICT_JOB.get("result"): |
| tail = "".join(PREDICT_JOB["result"].splitlines(True)[-80:]) |
| blocks.append(html.Pre(tail, style={ |
| "margin": 0, |
| "marginTop": "0.75rem", |
| "whiteSpace": "pre-wrap", |
| "fontFamily": "monospace", |
| "background": "#f8f9fa", |
| "padding": "0.75rem", |
| "borderRadius": "6px", |
| "maxHeight": "280px", |
| "overflowY": "auto" |
| })) |
| return html.Div(blocks), "done", True, b64img |
|
|
| @app.callback( |
| Output("conf-thresh-value", "children"), |
| Input("conf-thresh-slider", "value"), |
| prevent_initial_call=False |
| ) |
| def update_conf_label(v): |
| try: |
| val = float(v) |
| except Exception: |
| val = 0.50 |
| return f"Confidence threshold: {val:.2f} (NMS 0.5, max boxes 10)" |
| @app.callback( |
| [Output("stored-image", "data"), |
| Output("input-image-size", "data")], |
| [ |
| Input("upload-image", "contents"), |
| Input("upload-image", "filename"), |
| Input("reset-image-btn", "n_clicks") |
| ], |
| prevent_initial_call=True |
| ) |
| def update_stored_image(contents, filename, reset_click): |
| ctx = dash.callback_context |
| if not ctx.triggered: |
| return dash.no_update, dash.no_update |
| trigger = ctx.triggered[0]["prop_id"].split(".")[0] |
| if trigger == "reset-image-btn" and reset_click: |
| return None, None |
| if trigger == "upload-image" and contents and filename: |
| try: |
| header, b64data = contents.split(",", 1) |
| raw = base64.b64decode(b64data) |
| with Image.open(io.BytesIO(raw)) as img: |
| width, height = img.size |
| return {"contents": contents, "filename": filename}, {"width": width, "height": height} |
| except Exception as e: |
| return {"contents": contents, "filename": filename}, None |
| return dash.no_update, dash.no_update |
|
|
|
|
| @app.callback( |
| [Output("predict-image-store", "data", allow_duplicate=True), |
| Output("predict-output", "children", allow_duplicate=True), |
| Output("predict-state", "data", allow_duplicate=True)], |
| Input("reset-image-btn", "n_clicks"), |
| prevent_initial_call=True |
| ) |
| def reset_prediction(n): |
| if not n: |
| return dash.no_update, dash.no_update, dash.no_update |
| |
| PREDICT_JOB.update({ |
| "process": None, |
| "result": None, |
| "finished": False, |
| "image": None, |
| "orig_size": None, |
| "target_basename": None |
| }) |
| |
| return None, html.Div(), "idle" |
|
|
|
|
| @app.callback( |
| Output("upload-preview-box", "children"), |
| Input("stored-image", "data"), |
| prevent_initial_call=False |
| ) |
| def update_upload_preview_box(data): |
| show_upload = not (data and "contents" in data) |
| children = [] |
| children.append( |
| dcc.Upload( |
| id="upload-image", |
| children=html.Div(["Drag and Drop or ", html.A("Select Files")]), |
| style={ |
| "width": "100%", |
| "height": "120px", |
| "lineHeight": "120px", |
| "borderWidth": "2px", |
| "borderStyle": "dashed", |
| "borderRadius": "8px", |
| "textAlign": "center", |
| "marginBottom": "1rem", |
| "background": "#fff", |
| "color": "#525A63", |
| "display": "block" if show_upload else "none" |
| }, |
| multiple=False |
| ) |
| ) |
| if data and "contents" in data: |
| children.append( |
| html.Img( |
| src=data["contents"], |
| style={ |
| "maxWidth": "100%", |
| "maxHeight": "350px", |
| "display": "block", |
| "marginLeft": "auto", |
| "marginRight": "auto", |
| "borderRadius": "8px", |
| "boxShadow": "0 2px 8px rgba(50,50,80,0.08)" |
| } |
| ) |
| ) |
| |
| btn_style = {"display": "inline-block", "marginBottom": "1rem"} if data and "contents" in data else {"display": "none", "marginBottom": "1rem"} |
| children.append( |
| dbc.Button( |
| "Reset image", |
| id="reset-image-btn", |
| color="secondary", |
| style=btn_style |
| ) |
| ) |
| return html.Div(children, style={"width": "100%", "textAlign": "center", "padding": "1rem 0"}) |
|
|
| @app.callback( |
| Output("reset-image-btn", "style"), |
| Input("stored-image", "data"), |
| prevent_initial_call=False |
| ) |
| def toggle_reset_btn(data): |
| if data and "contents" in data: |
| return {"display": "inline-block", "marginBottom": "1rem"} |
| return {"display": "none", "marginBottom": "1rem"} |
|
|
| @app.callback( |
| Output("predict-btn", "disabled"), |
| [Input("stored-image", "data"), Input("predict-state", "data")], |
| prevent_initial_call=False |
| ) |
| def disable_predict_btn(data, state): |
| has_image = bool(data and "contents" in data) |
| return (not has_image) or (state == "running") |
|
|
| @app.callback( |
| Output("progress-bar-box", "children"), |
| [Input("predict-state", "data"), |
| Input("progress-interval", "n_intervals")] |
| ) |
| def update_progress_bar(state, n): |
| if state == "running": |
| percent = min(100, n * 10) |
| return dbc.Progress(f"{percent}%", value=percent, striped=True, animated=True, color="info", style={"height": "20px"}) |
| elif state == "done": |
| return dbc.Progress("100%", value=100, striped=False, animated=False, color="success", style={"height": "20px"}) |
| else: |
| return dbc.Progress("0%", value=0, striped=False, animated=False, color="light", style={"height": "20px"}) |
|
|
| @app.callback( |
| [Output("predict-progress", "value"), |
| Output("predict-progress", "label"), |
| Output("predict-progress", "animated"), |
| Output("predict-progress", "striped"), |
| Output("predict-live-logs", "children")], |
| [Input("progress-interval", "n_intervals"), |
| Input("predict-state", "data")], |
| prevent_initial_call=False |
| ) |
| def update_progress_and_logs(n, state): |
| if state == "running": |
| percent = min(95, (n + 1) * 5) |
| tail = PREDICT_JOB.get("tail", "") |
| return percent, f"Running... {percent}%", True, True, tail |
| elif state == "done": |
| tail = PREDICT_JOB.get("result", "") |
| tail_last = "".join(tail.splitlines(True)[-80:]) if tail else "" |
| return 100, "Done", False, False, tail_last |
| else: |
| return 0, "", False, False, "" |
|
|
| @app.callback( |
| Output("download-prediction", "data"), |
| Input("download-prediction-btn", "n_clicks"), |
| State("predict-image-store", "data"), |
| prevent_initial_call=True |
| ) |
| def download_prediction(n, b64img): |
| if not n or not b64img: |
| return dash.no_update |
| try: |
| binary = base64.b64decode(b64img) |
| except Exception: |
| return dash.no_update |
| filename = f"prediction_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png" |
| return dcc.send_bytes(lambda buff: buff.write(binary), filename) |
|
|
|
|
| @app.callback( |
| Output("download-prediction-btn", "style"), |
| [Input("predict-image-store", "data"), |
| Input("sidebar-mode", "data")], |
| prevent_initial_call=False |
| ) |
| def toggle_download_prediction_btn(data, mode): |
| style_train = { |
| "marginTop": "0.75rem", |
| "alignSelf": "center" |
| } |
| style_predict = { |
| "marginTop": "0.25rem", |
| "marginBottom": "0.5rem" |
| } |
| base = style_train if mode == "tab-train" else style_predict |
| if data: |
| return base |
| hidden = base.copy() |
| hidden["display"] = "none" |
| return hidden |
|
|
| @app.callback( |
| Output("logs-collapse", "is_open"), |
| Input("toggle-logs-btn", "n_clicks"), |
| State("logs-collapse", "is_open"), |
| prevent_initial_call=True |
| ) |
| def toggle_logs(n, is_open): |
| if n: |
| return not is_open |
| return is_open |
|
|
| |
| @app.callback( |
| [Output("netron-modal", "is_open"), |
| Output("netron-iframe", "src")], |
| [Input("sidebar-netron", "n_clicks"), |
| Input("close-netron-modal", "n_clicks")], |
| [State("netron-modal", "is_open"), |
| State("sidebar-mode", "data")], |
| prevent_initial_call=True |
| ) |
| def toggle_netron_modal(n_open, n_close, is_open, tab): |
| ctx = dash.callback_context |
| triggered = ctx.triggered[0]["prop_id"] if ctx.triggered else "" |
| if triggered.startswith("sidebar-netron") and n_open: |
| def build_netron_url(relative_path: str) -> str: |
| base_raw = "https://raw.githubusercontent.com/STMicroelectronics/stm32ai-modelzoo/main/" |
| full = base_raw + relative_path.lstrip('/') |
| return f"https://netron.app/?url={urllib.parse.quote(full, safe=':/') }" |
|
|
| if tab == "tab-predict": |
| url = build_netron_url( |
| "image_classification/mobilenetv2/ST_pretrainedmodel_public_dataset/tf_flowers/mobilenetv2_a035_224_fft/mobilenetv2_a035_224_fft_int8.tflite" |
| ) |
| elif tab == "tab-train": |
| url = build_netron_url( |
| "object_detection/st_yolodv2tiny_pt/ST_pretrainedmodel_public_dataset/coco_person/st_yoloxn_d033_w025_320_int8.tflite" |
| ) |
| else: |
| url = "" |
| return True, url |
| elif triggered.startswith("close-netron-modal") and n_close: |
| return False, dash.no_update |
| return is_open, dash.no_update |
|
|
| @app.callback( |
| Output("yaml-modal", "is_open"), |
| Output("yaml-modal-body", "children"), |
| [Input("sidebar-yaml", "n_clicks"), Input("close-yaml-modal", "n_clicks")], |
| [State("yaml-modal", "is_open"), State("sidebar-mode", "data")], |
| prevent_initial_call=True |
| ) |
| def toggle_yaml_modal(n_yaml, n_close, is_open, mode): |
| |
| ctx = dash.callback_context |
| triggered = ctx.triggered[0]["prop_id"] if ctx.triggered else "" |
| if triggered.startswith("sidebar-yaml") and n_yaml: |
| |
| if mode == "tab-train": |
| yaml_path = "object_detection/user_config.yaml" |
| elif mode == "tab-predict": |
| yaml_path = "image_classification/user_config.yaml" |
| else: |
| return is_open, dash.no_update |
| try: |
| with open(yaml_path, "r", encoding="utf-8") as f: |
| content = f.read() |
| return True, html.Pre(content, style={"whiteSpace": "pre-wrap", "fontFamily": "monospace", "fontSize": "1rem"}) |
| except Exception as e: |
| return True, html.Div(f"Error reading YAML: {e}", style={"color": "#A6ADB5"}) |
| elif triggered.startswith("close-yaml-modal") and n_close: |
| return False, dash.no_update |
| return is_open, dash.no_update |
|
|
|
|
| @app.callback( |
| Output("dataset-modal", "is_open"), |
| Output("dataset-modal-body", "children"), |
| [Input("sidebar-visualize-dataset", "n_clicks"), Input("close-dataset-modal", "n_clicks")], |
| [State("dataset-modal", "is_open")], |
| prevent_initial_call=True |
| ) |
| def toggle_dataset_modal(n_open, n_close, is_open): |
| ctx = dash.callback_context |
| triggered = ctx.triggered[0]["prop_id"] if ctx.triggered else "" |
| if triggered.startswith("sidebar-visualize-dataset") and n_open: |
| path = 'image_classification/datasets/flower_photos' |
| if not os.path.exists(path): |
| return True, html.Div("The dataset folder does not exist. Please check the Dockerfile.") |
| class_names = sorted([name for name in os.listdir(path) if os.path.isdir(os.path.join(path, name))]) |
| num_classes = len(class_names) |
| if num_classes == 0: |
| return True, html.Div("No class found in the dataset.") |
| fig, axs = plt.subplots(1, num_classes, figsize=(4*num_classes, 4)) |
| for i, class_name in enumerate(class_names): |
| class_path = os.path.join(path, class_name) |
| image_files = [f for f in os.listdir(class_path) if os.path.isfile(os.path.join(class_path, f)) and f.endswith('.jpg')] |
| if len(image_files) == 0: |
| img = np.zeros((224, 224, 3)) |
| else: |
| img_path = os.path.join(class_path, random.choice(image_files)) |
| img = mpimg.imread(img_path) |
| axs[i].imshow(img) |
| axs[i].set_title(class_name) |
| axs[i].axis('off') |
| buf = io.BytesIO() |
| plt.tight_layout() |
| plt.savefig(buf, format='png') |
| plt.close(fig) |
| buf.seek(0) |
| img_base64 = base64.b64encode(buf.read()).decode('utf-8') |
| img_html = html.Img(src=f'data:image/png;base64,{img_base64}', style={"maxWidth": "100%", "height": "auto"}) |
| return True, img_html |
| elif triggered.startswith("close-dataset-modal") and n_close: |
| return False, dash.no_update |
| return is_open, dash.no_update |
|
|
| |
| |
| def render_model_perf_ic(): |
| rows = [ |
| { |
| "Series": "STM32H7", |
| "Format": "INT8", |
| "Resolution": "224x224", |
| "Activation RAM (KB)": "699", |
| "Weights Flash (KB)": "407", |
| "Inference Time (ms)": "309" |
| }, |
| { |
| "Series": "STM32N6570-DK", |
| "Format": "INT8", |
| "Resolution": "224x224", |
| "Activation RAM (KB)": "913", |
| "Weights Flash (KB)": "436", |
| "Inference Time (ms)": "5.3" |
| }, |
| ] |
|
|
| header = html.Thead( |
| html.Tr([ |
| html.Th("Series"), |
| html.Th("Format"), |
| html.Th("Resolution"), |
| html.Th("Activation RAM (KB)"), |
| html.Th("Weights Flash (KB)"), |
| html.Th("Inference Time (ms)") |
| ]) |
| ) |
| body = html.Tbody([ |
| html.Tr([ |
| html.Td(r["Series"]), |
| html.Td(r["Format"]), |
| html.Td(r["Resolution"]), |
| html.Td(r["Activation RAM (KB)"]), |
| html.Td(r["Weights Flash (KB)"]), |
| html.Td(r["Inference Time (ms)"]) |
| ]) for r in rows |
| ]) |
|
|
| return dbc.Card([ |
| dbc.CardHeader( |
| html.Div([ |
| html.Span("MobileNetv2 Model Performances overview"), |
| dbc.Badge("STM32 Reference", color="secondary", pill=True, className="ms-2") |
| ], style={"display": "flex", "alignItems": "center", "gap": "8px"}), |
| style={"fontWeight": "bold", "fontSize": "1.1rem", "background": "#f8f9fa"} |
| ), |
| dbc.CardBody([ |
| html.P( |
| "MobileNetv2 Performances on STM32N6. Values depend on resolution and quantization. Fill with STM32Cube.AI 3.0.0 results.", |
| style={"color": "#525A63", "fontSize": "0.95rem", "marginBottom": "0.75rem"} |
| ), |
| dbc.Table([header, body], bordered=True, hover=True, responsive=True, className="table-sm"), |
| ]) |
| ], style={ |
| "boxShadow": "0 2px 8px rgba(50,50,80,0.08)", |
| "border": "1px solid #A6ADB5", |
| "marginTop": "0.5rem", |
| "marginBottom": "0.5rem", |
| }) |
|
|
| def render_showcase_info_card(): |
| return html.Div( |
| [ |
| html.Span([ |
| "This is a showcase of the STM32AI Model Zoo. " |
| "To go further (training, deployment on STM32, scripts and tools), ", |
| html.A( |
| "visit the STM32AI Model Zoo Services repository", |
| href=ST_MODELZOO_SERVICES_REPO_URL, |
| target="_blank", |
| className="info-link" |
| ), |
| "." |
| ]) |
| ], |
| style={ |
| "background": "#f5f8fc", |
| "border": "1px solid #e5edf5", |
| "color": "#425466", |
| "fontSize": "0.95rem", |
| "padding": "0.6rem 0.9rem", |
| "borderRadius": "10px", |
| "marginTop": "0.5rem", |
| "marginBottom": "0.5rem", |
| } |
| ) |
|
|
|
|
| TRAIN_LOG_QUEUE = queue.Queue() |
| TRAIN_PROCESS = {"thread": None, "finished": False} |
|
|
|
|
| def _load_ic_training_params(yaml_path: str, *, default_epochs: int = 20, default_batch_size: int = 32, default_learning_rate: float = 0.001): |
| epochs = default_epochs |
| batch_size = default_batch_size |
| learning_rate = default_learning_rate |
| try: |
| with open(yaml_path, "r", encoding="utf-8") as f: |
| config = yaml.safe_load(f) or {} |
| training = config.get("training") or {} |
| epochs = int(training.get("epochs", epochs)) |
| batch_size = int(training.get("batch_size", batch_size)) |
| lr_nested = (training.get("optimizer") or {}).get("Adam", {}).get("learning_rate") |
| if lr_nested is not None: |
| learning_rate = float(lr_nested) |
| else: |
| learning_rate = float(training.get("learning_rate", learning_rate)) |
| except Exception: |
| pass |
| return epochs, batch_size, learning_rate |
|
|
|
|
| def _save_ic_training_params(yaml_path: str, epochs, batch_size, learning_rate): |
| with open(yaml_path, "r", encoding="utf-8") as f: |
| config = yaml.safe_load(f) or {} |
|
|
| training = config.setdefault('training', {}) |
| training['epochs'] = int(epochs) |
| training['batch_size'] = int(batch_size) |
|
|
| if 'learning_rate' in training: |
| del training['learning_rate'] |
| optimizer = training.setdefault('optimizer', {}) |
| adam = optimizer.setdefault('Adam', {}) |
| adam['learning_rate'] = float(learning_rate) |
|
|
| if 'model' in config and config['model']: |
| if 'name' in config['model'] and 'model_name' not in config['model']: |
| config['model']['model_name'] = config['model'].pop('name') |
|
|
| with open(yaml_path, "w", encoding="utf-8") as f: |
| yaml.dump(config, f, allow_unicode=True, sort_keys=False, default_flow_style=False) |
|
|
|
|
| def _env_int(name: str, default: int, *, min_value: int | None = None, max_value: int | None = None) -> int: |
| raw = os.getenv(name) |
| try: |
| val = int(str(raw).strip()) if raw is not None and str(raw).strip() != "" else int(default) |
| except Exception: |
| val = int(default) |
| if min_value is not None: |
| val = max(min_value, val) |
| if max_value is not None: |
| val = min(max_value, val) |
| return val |
|
|
|
|
| |
| IC_MAX_ARCHIVE_MB = _env_int("IC_MAX_ARCHIVE_MB", 250, min_value=1, max_value=50_000) |
| IC_MAX_EXTRACT_MB = _env_int("IC_MAX_EXTRACT_MB", 1000, min_value=1, max_value=200_000) |
| IC_MAX_VERIFY_IMAGES = _env_int("IC_MAX_VERIFY_IMAGES", 250, min_value=0, max_value=200_000) |
|
|
|
|
| def _safe_extract_zip_to_dir(zf: zipfile.ZipFile, dst_dir: str) -> None: |
| os.makedirs(dst_dir, exist_ok=True) |
| for info in zf.infolist(): |
| name = info.filename |
| if not name or name.endswith("/"): |
| continue |
|
|
| |
| norm = os.path.normpath(name).replace("\\", "/") |
| if norm.startswith("../") or norm.startswith("..\\"): |
| raise ValueError("Invalid zip entry (path traversal)") |
| if os.path.isabs(norm) or re.match(r"^[A-Za-z]:", norm): |
| raise ValueError("Invalid zip entry (absolute path)") |
|
|
| out_path = os.path.join(dst_dir, norm) |
| out_path_dir = os.path.dirname(out_path) |
| if out_path_dir: |
| os.makedirs(out_path_dir, exist_ok=True) |
|
|
| with zf.open(info) as src, open(out_path, "wb") as dst: |
| shutil.copyfileobj(src, dst) |
|
|
|
|
| def _safe_extract_tar_to_dir(tf: tarfile.TarFile, dst_dir: str) -> None: |
| os.makedirs(dst_dir, exist_ok=True) |
| for member in tf.getmembers(): |
| name = member.name |
| if not name: |
| continue |
|
|
| norm = os.path.normpath(name).replace("\\", "/") |
| if norm.startswith("../") or norm.startswith("..\\"): |
| raise ValueError("Invalid tar entry (path traversal)") |
| if os.path.isabs(norm) or re.match(r"^[A-Za-z]:", norm): |
| raise ValueError("Invalid tar entry (absolute path)") |
|
|
| |
| if not (member.isfile() or member.isdir()): |
| continue |
|
|
| out_path = os.path.join(dst_dir, norm) |
| out_path_dir = os.path.dirname(out_path) |
| if out_path_dir: |
| os.makedirs(out_path_dir, exist_ok=True) |
|
|
| if member.isdir(): |
| os.makedirs(out_path, exist_ok=True) |
| continue |
|
|
| fobj = tf.extractfile(member) |
| if fobj is None: |
| continue |
| with fobj as src, open(out_path, "wb") as dst: |
| shutil.copyfileobj(src, dst) |
|
|
|
|
| def _prepare_ic_uploaded_dataset(contents: str, filename: str, *, user_id: str, dataset_name: str) -> tuple[str, list[str]]: |
| if not contents: |
| raise ValueError("No upload contents") |
| if "," not in contents: |
| raise ValueError("Unexpected upload encoding") |
|
|
| header, b64 = contents.split(",", 1) |
| fname_l = (filename or "").lower() |
| is_zip = fname_l.endswith(".zip") |
| is_tar = fname_l.endswith(".tar") or fname_l.endswith(".tar.gz") or fname_l.endswith(".tgz") |
| if not (is_zip or is_tar): |
| raise ValueError("Please upload a .zip, .tar, .tar.gz or .tgz dataset") |
|
|
| raw = base64.b64decode(b64) |
| |
| max_archive_bytes = IC_MAX_ARCHIVE_MB * 1024 * 1024 |
| if len(raw) > max_archive_bytes: |
| raise ValueError(f"Archive too large (max {IC_MAX_ARCHIVE_MB}MB)") |
|
|
| safe_name = re.sub(r"[^a-zA-Z0-9_-]+", "_", (dataset_name or "custom_dataset").strip())[:64] or "custom_dataset" |
|
|
| |
| base_dir = os.path.join("image_classification", "datasets", safe_name) |
| if os.path.isdir(base_dir): |
| shutil.rmtree(base_dir, ignore_errors=True) |
| os.makedirs(base_dir, exist_ok=True) |
|
|
| |
| max_total = IC_MAX_EXTRACT_MB * 1024 * 1024 |
| if is_zip: |
| with zipfile.ZipFile(io.BytesIO(raw)) as zf: |
| total = sum((i.file_size or 0) for i in zf.infolist()) |
| if total > max_total: |
| raise ValueError(f"Archive expands to > {IC_MAX_EXTRACT_MB}MB") |
| _safe_extract_zip_to_dir(zf, base_dir) |
| else: |
| with tarfile.open(fileobj=io.BytesIO(raw), mode="r:*") as tf: |
| total = sum((m.size or 0) for m in tf.getmembers() if m.isfile()) |
| if total > max_total: |
| raise ValueError(f"Archive expands to > {IC_MAX_EXTRACT_MB}MB") |
| _safe_extract_tar_to_dir(tf, base_dir) |
|
|
| |
| macosx_dir = os.path.join(base_dir, "__MACOSX") |
| if os.path.isdir(macosx_dir): |
| shutil.rmtree(macosx_dir, ignore_errors=True) |
|
|
| |
| entries = [e for e in os.listdir(base_dir) if not e.startswith('.')] |
| dirs = [e for e in entries if os.path.isdir(os.path.join(base_dir, e))] |
| files = [e for e in entries if os.path.isfile(os.path.join(base_dir, e))] |
|
|
| |
| if len(dirs) == 1 and not files: |
| single_dir = os.path.join(base_dir, dirs[0]) |
| subdirs = [d for d in os.listdir(single_dir) if os.path.isdir(os.path.join(single_dir, d)) and not d.startswith('.')] |
| if len(subdirs) >= 2: |
| |
| for name in os.listdir(single_dir): |
| src = os.path.join(single_dir, name) |
| dst = os.path.join(base_dir, name) |
| shutil.move(src, dst) |
| shutil.rmtree(single_dir, ignore_errors=True) |
|
|
| dataset_root = base_dir |
|
|
| class_names = [d for d in os.listdir(dataset_root) |
| if os.path.isdir(os.path.join(dataset_root, d)) and not d.startswith('.')] |
| class_names = sorted(class_names) |
| if len(class_names) < 2: |
| raise ValueError("Dataset must contain at least 2 class folders") |
|
|
| training_path = dataset_root.replace("\\", "/") |
| return training_path, class_names |
|
|
|
|
| def _analyze_ic_dataset(dataset_root: str) -> dict: |
| root = (dataset_root or "").replace("/", os.sep) |
| if not root or not os.path.isdir(root): |
| raise ValueError("Dataset path not found") |
|
|
| allowed_ext = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"} |
| class_dirs = [d for d in os.listdir(root) |
| if os.path.isdir(os.path.join(root, d)) and not d.startswith('.')] |
| class_dirs = sorted(class_dirs) |
| per_class = [] |
| total = 0 |
| ext_counts = {} |
| sample_paths = [] |
| corrupted = 0 |
| checked = 0 |
| corrupted_examples = [] |
| max_verify = IC_MAX_VERIFY_IMAGES |
| max_examples = 5 |
| for class_name in class_dirs: |
| class_path = os.path.join(root, class_name) |
| count = 0 |
| for dirpath, _, filenames in os.walk(class_path): |
| for fn in filenames: |
| ext = os.path.splitext(fn)[1].lower() |
| if ext in allowed_ext: |
| count += 1 |
| ext_counts[ext] = ext_counts.get(ext, 0) + 1 |
| fpath = os.path.join(dirpath, fn) |
| if len(sample_paths) < 2000: |
| sample_paths.append(fpath) |
| if checked < max_verify: |
| try: |
| with Image.open(fpath) as im: |
| im.verify() |
| except Exception: |
| corrupted += 1 |
| if len(corrupted_examples) < max_examples: |
| corrupted_examples.append(os.path.relpath(fpath, root).replace("\\", "/")) |
| checked += 1 |
| per_class.append((class_name, count)) |
| total += count |
|
|
| counts = [c for _, c in per_class] |
| stats = { |
| "num_classes": len(class_dirs), |
| "total_images": total, |
| "per_class": per_class, |
| "ext_counts": dict(sorted(ext_counts.items(), key=lambda kv: (-kv[1], kv[0]))), |
| "corrupted_checked": checked, |
| "corrupted_count": corrupted, |
| "corrupted_examples": corrupted_examples, |
| "sample_paths": sample_paths, |
| "min_images": min(counts) if counts else 0, |
| "max_images": max(counts) if counts else 0, |
| "avg_images": (float(total) / len(counts)) if counts else 0.0, |
| } |
| return stats |
|
|
|
|
| def _make_thumbnail_b64(image_path: str, *, size: int = 128) -> str | None: |
| try: |
| with Image.open(image_path) as im: |
| im = im.convert("RGB") |
| im = ImageOps.fit(im, (size, size), method=Image.Resampling.LANCZOS) |
| buf = io.BytesIO() |
| im.save(buf, format="JPEG", quality=85) |
| return base64.b64encode(buf.getvalue()).decode("utf-8") |
| except Exception: |
| return None |
|
|
|
|
| @app.callback( |
| Output("page-content", "children"), |
| Input("sidebar-mode", "data"), |
| State("stm32ai-credentials", "data"), |
| State("page-content", "children"), |
| State("acc-fig-store", "data"), |
| State("loss-fig-store", "data"), |
| State("predict-image-store", "data") |
| ) |
| def render_page_content(tab_value, creds, current_content, acc_fig_store, loss_fig_store, predict_image_store): |
| if tab_value == "tab-train": |
| |
| card_height = "650px" |
| pred_child = None |
| if predict_image_store: |
| pred_child = html.Img( |
| src=f"data:image/png;base64,{predict_image_store}", |
| style={"maxWidth": "100%", "borderRadius": "8px"}, |
| alt="Prediction image" |
| ) |
| return [ |
| html.H2("Object Detection - ST Yolo X", className="section-title"), |
| html.P( |
| [ |
| html.A( |
| "ST Yolo X ", |
| href="https://github.com/STMicroelectronics/stm32ai-modelzoo/tree/main/object_detection/st_yoloxn", |
| target="_blank", |
| className="info-link" |
| ), |
| "is a real-time object detection model targeted for real-time processing implemented in Tensorflow. This is an optimized ST version of the well known yolo x, quantized in int8 format using tensorflow lite converter. Upload an image and run a prediction.", |
| ], |
| className="section-subtitle" |
| ), |
| render_showcase_info_card(), |
| render_model_perf_od(), |
| dbc.Row([ |
| dbc.Col([ |
| dbc.Card([ |
| dbc.CardHeader("Image Upload & Preview", style={"fontWeight": "bold", "fontSize": "1.1rem", "background": "#f8f9fa"}), |
| dbc.CardBody([ |
| html.Div(id="upload-preview-box", style={ |
| "flex": 1, "display": "flex", "flexDirection": "column", |
| "justifyContent": "flex-start", "alignItems": "center", |
| "overflow": "hidden", "padding": "0.5rem", "width": "100%", "minHeight": 0 |
| }), |
| dbc.Button("Predict", id="predict-btn", color="info", |
| style={"marginTop": "0.5rem", "marginBottom": "0.5rem", "alignSelf": "center"}), |
| html.Div([ |
| html.Div(id="conf-thresh-value", style={"fontWeight": 600, "marginBottom": "0.25rem", "textAlign": "center"}), |
| dcc.Slider( |
| id="conf-thresh-slider", |
| min=0.0, |
| max=1.0, |
| step=0.01, |
| value=0.5, |
| tooltip={"placement": "bottom", "always_visible": False}, |
| marks={0.0: "0.0", 0.5: "0.5", 1.0: "1.0"} |
| ), |
| html.Div("Adjust confidence threshold; changing it re-runs prediction.", style={"fontSize": "0.85rem", "color": "#6c757d", "marginTop": "0.25rem", "textAlign": "center"}) |
| ], style={"marginTop": "0.5rem", "padding": "0.25rem 0.25rem 0 0.25rem"}) |
| ], style={"display": "flex", "flexDirection": "column", "flex": 1, "padding": "1rem", "minHeight": 0}) |
| ], style={ |
| "boxShadow": "0 2px 8px rgba(50,50,80,0.08)", |
| "border": "1px solid #A6ADB5", |
| "height": "100%", "minHeight": card_height, "maxHeight": card_height, |
| "display": "flex", "flexDirection": "column" |
| }) |
| ], width=6, style={"display": "flex", "flexDirection": "column", "height": card_height, "minHeight": card_height, "maxHeight": card_height}), |
| dbc.Col([ |
| dbc.Card([ |
| dbc.CardHeader("Prediction Output", style={"fontWeight": "bold", "fontSize": "1.1rem", "background": "#f8f9fa"}), |
| dbc.CardBody([ |
| html.Div([ |
| dbc.Progress(id="predict-progress", value=0, max=100, striped=False, animated=False, color="info", style={"height": "20px", "width": "100%"}), |
| dbc.Button("View Logs", id="toggle-logs-btn", color="link", style={"padding": 0, "marginLeft": "0.5rem"}) |
| ], style={"display": "flex", "alignItems": "center", "gap": "8px", "marginBottom": "0.5rem"}), |
| dbc.Collapse( |
| html.Pre(id="predict-live-logs", style={ |
| "margin": 0, |
| "whiteSpace": "pre-wrap", |
| "fontFamily": "monospace", |
| "background": "#f8f9fa", |
| "padding": "0.75rem", |
| "borderRadius": "6px", |
| "maxHeight": "240px", |
| "overflowY": "auto" |
| }), |
| id="logs-collapse", is_open=False |
| ), |
| dcc.Loading( |
| id="predict-loading-train", |
| type="circle", |
| color="#3CB4E6", |
| fullscreen=False, |
| children=html.Div( |
| id="predict-output", |
| children=pred_child, |
| style={ |
| "flex": 1, "whiteSpace": "pre-wrap", "fontFamily": "monospace", |
| "background": "#f8f9fa", "padding": "0.75rem", "borderRadius": "6px", |
| "overflowY": "auto", "overflowX": "hidden", "height": "100%", "margin": 0, "minHeight": 0, |
| "display": "flex", "alignItems": "center", "justifyContent": "center" |
| } |
| ), |
| style={"flex": 1} |
| ), |
| dbc.Button("Retry", id="retry-predict-btn", color="warning", outline=True, style={"display": "none"}), |
| dbc.Button( |
| "Download Prediction", |
| id="download-prediction-btn", |
| color="secondary", |
| style={"marginTop": "0.75rem", "alignSelf": "center"} |
| ), |
| dcc.Download(id="download-prediction") |
| ], style={"display": "flex", "flexDirection": "column", "flex": 1, "padding": "1rem", "minHeight": 0}) |
| ], style={ |
| "boxShadow": "0 2px 8px rgba(50,50,80,0.08)", |
| "border": "1px solid #A6ADB5", |
| "height": "100%", "minHeight": card_height, "maxHeight": card_height, |
| "display": "flex", "flexDirection": "column" |
| }, className="responsive-card") |
| ], width=6, style={"display": "flex", "flexDirection": "column", "height": card_height, "minHeight": card_height, "maxHeight": card_height}) |
| ], justify="center", align="stretch", style={"marginTop": "2rem", "alignItems": "stretch"}) |
| ] |
|
|
| if tab_value == "tab-predict": |
| yaml_path = "image_classification/user_config.yaml" |
| epochs, batch_size, learning_rate = _load_ic_training_params(yaml_path) |
|
|
| acc_initial = acc_fig_store if acc_fig_store else {"data": [], "layout": ACC_LOSS_BASE_LAYOUT} |
| loss_initial = loss_fig_store if loss_fig_store else {"data": [], "layout": ACC_LOSS_BASE_LAYOUT} |
| card_height = "650px" |
| return [ |
| html.H2("Image Classification - MobileNet v2", className="section-title"), |
| html.P([ |
| html.A( |
| "MobileNet v2", |
| href="https://github.com/STMicroelectronics/stm32ai-modelzoo/tree/main/image_classification/mobilenetv2", |
| target="_blank", |
| className="info-link" |
| ), |
| " is very similar to the original MobileNet, except that it uses inverted residual blocks with bottlenecking features. It has a drastically lower parameter count than the original MobileNet.", |
| html.Span("(", style={"marginRight": "2px"}), |
| html.A( |
| "Flowers Dataset used here", |
| href="https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz", |
| target="_blank", |
| style={"textDecoration": "underline"} |
| ), |
| html.Span(")") |
| ], className="section-subtitle"), |
| render_showcase_info_card(), |
| render_model_perf_ic(), |
| html.H5( |
| "Step 1 — Configure and Launch Training", |
| style={ |
| "color": "#03234B", |
| "fontWeight": "700", |
| "margin": "0 0 0.75rem 0", |
| "textAlign": "left" |
| } |
| ), |
| dbc.Row([ |
| dbc.Col([ |
| dbc.Card([ |
| dbc.CardHeader("Training Parameters", style={"fontWeight": "bold", "fontSize": "1.1rem", "background": "#f8f9fa"}), |
| dbc.CardBody([ |
| dbc.Form([ |
| dbc.Label("Number of Epochs"), |
| dbc.Input(id="epochs", type="number", min=1, max=200, value=epochs, style={"marginBottom": "1rem"}), |
| dbc.Label("Batch Size"), |
| dbc.Input(id="batch-size", type="number", min=1, max=256, value=batch_size, style={"marginBottom": "1rem"}), |
| dbc.Label("Learning Rate"), |
| dbc.Input(id="learning-rate", type="number", value=learning_rate, style={"marginBottom": "1rem"}), |
| html.Div( |
| "Parameters are saved automatically.", |
| style={"color": "#525A63", "fontSize": "0.95rem", "marginBottom": "0.5rem"} |
| ), |
| html.Div(id="autosave-yaml-alert"), |
| html.Hr(style={"margin": "0.75rem 0"}), |
| dbc.Label("Upload your own dataset (optional)"), |
| dcc.Upload( |
| id="ic-dataset-upload", |
| accept=".zip,.tar,.tar.gz,.tgz", |
| multiple=False, |
| children=html.Div([ |
| html.Div("Drop a .zip/.tar(.gz) here or click to upload", style={"fontWeight": "bold"}), |
| html.Div("Expected: dataset_root/class_name/images...", style={"fontSize": "0.85rem"}) |
| ]), |
| className="dash-upload", |
| style={ |
| "width": "100%", |
| "padding": "12px", |
| "borderRadius": "8px", |
| "textAlign": "center", |
| "cursor": "pointer", |
| "marginBottom": "0.5rem" |
| }, |
| ), |
| dbc.Button( |
| "Dataset format help", |
| id="ic-dataset-info-btn", |
| color="secondary", |
| size="sm", |
| style={"marginBottom": "0.5rem"} |
| ), |
| dcc.Loading( |
| id="ic-dataset-upload-loading", |
| type="circle", |
| color="#3CB4E6", |
| fullscreen=False, |
| children=html.Div(id="ic-dataset-upload-alert") |
| ), |
| dbc.Modal( |
| [ |
| dbc.ModalHeader(dbc.ModalTitle("How to format your dataset")), |
| dbc.ModalBody( |
| [ |
| html.Div("Upload a .zip or .tar(.gz/.tgz) that contains class folders (like Flowers):"), |
| html.Pre( |
| """my_dataset/\n class1/\n img1.jpg\n img2.jpg\n class2/\n imgA.jpg\n imgB.jpg\n""", |
| style={"background": "#f8f9fa", "padding": "0.75rem", "borderRadius": "6px"} |
| ), |
| html.Div("Notes:"), |
| html.Ul([ |
| html.Li("Each class is a folder name (used as label)."), |
| html.Li("Images can be .jpg/.jpeg/.png."), |
| html.Li("The zip can either include a top folder (my_dataset/...) or directly the class folders.") |
| ]), |
| html.Div("Current limits (configurable via environment variables):", style={"marginTop": "0.5rem"}), |
| html.Ul([ |
| html.Li(f"Max archive size: {IC_MAX_ARCHIVE_MB}MB (env: IC_MAX_ARCHIVE_MB)"), |
| html.Li(f"Max extracted size (estimated): {IC_MAX_EXTRACT_MB}MB (env: IC_MAX_EXTRACT_MB)"), |
| html.Li(f"Corruption check sample size: {IC_MAX_VERIFY_IMAGES} images (env: IC_MAX_VERIFY_IMAGES)"), |
| ]), |
| html.Div("After upload, the app extracts into image_classification/datasets/<dataset_name> and writes dataset.training_path in image_classification/user_config.yaml.") |
| ] |
| ), |
| dbc.ModalFooter( |
| dbc.Button("Close", id="ic-dataset-info-close", className="ms-auto", n_clicks=0) |
| ), |
| ], |
| id="ic-dataset-info-modal", |
| is_open=False, |
| size="lg", |
| ), |
| ]) |
| ], style={"overflowY": "auto", "padding": "1rem"}) |
| ], style={ |
| "boxShadow": "0 2px 8px rgba(50,50,80,0.08)", |
| "border": "1px solid #A6ADB5", |
| "height": card_height, "minHeight": card_height, |
| "display": "flex", "flexDirection": "column" |
| }) |
| ], width=6), |
| dbc.Col([ |
| dbc.Card([ |
| dbc.CardHeader("Training & Logs", style={"fontWeight": "bold", "fontSize": "1.1rem", "background": "#f8f9fa"}), |
| dbc.CardBody([ |
| html.Div([ |
| dbc.Button("Start", id="train-btn", color="warning", style={ |
| "backgroundColor": "#FFD200", "borderColor": "#FFD200", "color": "#03234B", |
| "borderRadius": "8px", "height": "44px", "minWidth": "90px", |
| "fontWeight": "bold", "fontSize": "1.05rem", "padding": "0 16px", |
| "display": "inline-flex", "alignItems": "center", "justifyContent": "center", |
| "gap": "6px", "whiteSpace": "nowrap", "marginRight": "20px" |
| }), |
| dbc.Button(html.Span("\u25A0", style={"fontWeight": "bold", "animation": "pulse 1s infinite alternate"}), |
| id="stop-train-btn", color="danger", className="btn-stop-round", |
| style={"marginBottom": "1rem", "boxShadow": "0 0 8px #dc3545", "display": "none"}, |
| title="Stop Training") |
| ], style={"display": "flex", "alignItems": "center"}), |
| html.Div(id="train-launch-alert"), |
| dcc.Interval(id="train-log-interval", interval=2000, n_intervals=0, disabled=True), |
| dbc.Progress(id="train-progress", value=0, max=100, striped=True, animated=True, |
| style={"height": "30px", "marginTop": "1rem", "marginBottom": "1rem"}), |
| html.Div(id="train-output", style={ |
| "marginTop": "1rem", "whiteSpace": "pre-wrap", "fontFamily": "monospace", |
| "background": "#f8f9fa", "padding": "1rem", "borderRadius": "8px", |
| "minHeight": "200px", "maxHeight": "350px", "overflowY": "auto" |
| }), |
| html.Div([ |
| html.Div([ |
| dbc.Label("Tail lines"), |
| dbc.Input(id="log-tail-lines", type="number", min=50, max=10000, step=50, value=500, |
| style={"width": "110px", "marginRight": "0.75rem"}), |
| dbc.Checklist( |
| options=[{"label": "Visualize all", "value": "all"}], |
| value=[], id="log-show-all", switch=True, style={"marginRight": "1rem"} |
| ), |
| dbc.Button("Download Logs", id="download-train-logs-btn", |
| color="secondary", size="sm", style={"marginRight": "0.5rem"}), |
| dbc.Button("Copy Logs", id="copy-train-logs-btn", |
| color="secondary", size="sm", style={"marginRight": "0.5rem"}), |
| dcc.Download(id="download-train-logs"), |
| html.Span(id="copy-logs-feedback", style={ |
| "fontSize": "0.8rem", "color": "#03234B", "marginLeft": "0.5rem" |
| }) |
| ], style={"display": "flex", "alignItems": "center", "flexWrap": "wrap", "gap": "0.5rem"}) |
| ], style={"marginTop": "0.75rem"}), |
| dcc.Store(id="train-finished", data=False), |
| ]) |
| ], style={ |
| "boxShadow": "0 2px 8px rgba(50,50,80,0.08)", |
| "border": "1px solid #A6ADB5", |
| "height": card_height, "minHeight": card_height, |
| "display": "flex", "flexDirection": "column" |
| }) |
| ], width=6) |
| ], justify="center", align="start", style={"marginTop": "2rem", "alignItems": "stretch"}), |
| html.Div( |
| id="train-metrics-box", |
| children=[ |
| html.H5( |
| "Step 2 — Monitor Training Metrics", |
| style={ |
| "color": "#03234B", |
| "fontWeight": "700", |
| "margin": "0 0 0.75rem 0", |
| "textAlign": "left" |
| } |
| ), |
| html.Div([ |
| html.Div([ |
| html.H4("Accuracy", style={"textAlign": "center", "color": "#03234B", "marginBottom": "0.5rem"}), |
| dcc.Graph( |
| id="acc-visualization", |
| figure=acc_initial, |
| style={"height": "320px", "background": "#fff", "borderRadius": "10px", |
| "boxShadow": "0 2px 8px rgba(50,50,80,0.08)", "padding": "1rem"} |
| ) |
| ], style={"flex": 1, "marginRight": "1rem"}), |
| html.Div([ |
| html.H4("Loss", style={"textAlign": "center", "color": "#03234B", "marginBottom": "0.5rem"}), |
| dcc.Graph( |
| id="loss-visualization", |
| figure=loss_initial, |
| style={"height": "320px", "background": "#fff", "borderRadius": "10px", |
| "boxShadow": "0 2px 8px rgba(50,50,80,0.08)", "padding": "1rem"} |
| ) |
| ], style={"flex": 1}) |
| ], style={"display": "flex", "flexDirection": "row", "gap": "1rem", "marginBottom": "2rem"}), |
| html.Div([ |
| html.Div([ |
| html.Div("Float Model Accuracy", style={ |
| "textAlign": "center", "color": "#03234B", "fontWeight": "bold", |
| "fontSize": "1.05rem", "marginBottom": "0.4rem" |
| }), |
| dcc.Graph(id="gauge-float-acc", figure={"data": [], "layout": {"height": 190}}, style={"height": "190px"}) |
| ], style={"flex": 1, "marginRight": "1rem"}), |
| html.Div([ |
| html.Div("Quantized Model Accuracy", style={ |
| "textAlign": "center", "color": "#03234B", "fontWeight": "bold", |
| "fontSize": "1.05rem", "marginBottom": "0.4rem" |
| }), |
| dcc.Graph(id="gauge-quant-acc", figure={"data": [], "layout": {"height": 190}}, style={"height": "190px"}) |
| ], style={"flex": 1}) |
| ], style={"display": "flex", "flexDirection": "row", "gap": "1rem", |
| "marginBottom": "2rem", "justifyContent": "center"}), |
| html.Div([ |
| dbc.Button("Display Confusion Matrix (Float Model)", id="display-confusion-matrix-btn", |
| color="warning", style={ |
| "marginBottom": "1rem", "fontWeight": "bold", |
| "backgroundColor": "#FFD200", "borderColor": "#FFD200", |
| "color": "#03234B", "marginRight": "1rem" |
| }), |
| dbc.Button("Display Confusion Matrix (Quantized Model)", id="display-confusion-matrix-quant-btn", |
| color="warning", style={ |
| "marginBottom": "1rem", "fontWeight": "bold", |
| "backgroundColor": "#FFD200", "borderColor": "#FFD200", |
| "color": "#03234B" |
| }) |
| ], style={"textAlign": "center", "marginBottom": "1.5rem", |
| "display": "flex", "justifyContent": "center"}), |
| html.Div(id="metrics-error-message", style={ |
| "color": "#dc3545", "marginTop": "1rem", "fontWeight": "bold", "textAlign": "center" |
| }) |
| ], |
| style={"background": "#f8f9fa", "borderRadius": "12px", "padding": "2rem 1.5rem", |
| "boxShadow": "0 2px 8px rgba(50,50,80,0.08)", "marginBottom": "2rem", "width": "100%"} |
| ) |
| ] |
| return None |
|
|
|
|
| @app.callback( |
| Output("autosave-yaml-alert", "children"), |
| [Input("epochs", "value"), |
| Input("batch-size", "value"), |
| Input("learning-rate", "value")], |
| prevent_initial_call=True |
| ) |
| def autosave_ic_yaml_params(epochs, batch_size, learning_rate): |
| yaml_path = "image_classification/user_config.yaml" |
| if epochs is None or batch_size is None or learning_rate is None: |
| return dash.no_update |
| try: |
| _save_ic_training_params(yaml_path, epochs, batch_size, learning_rate) |
| ts = datetime.now().strftime("%H:%M:%S") |
| return dbc.Alert(f"Autosaved at {ts}", color="success", dismissable=True, duration=1500) |
| except Exception as e: |
| return dbc.Alert(f"Erreur autosave : {e}", color="danger", dismissable=True) |
|
|
|
|
| @app.callback( |
| Output("ic-dataset-info-modal", "is_open"), |
| [Input("ic-dataset-info-btn", "n_clicks"), |
| Input("ic-dataset-info-close", "n_clicks")], |
| State("ic-dataset-info-modal", "is_open"), |
| prevent_initial_call=True |
| ) |
| def toggle_ic_dataset_info_modal(n_open, n_close, is_open): |
| ctx = dash.callback_context |
| triggered = ctx.triggered[0]["prop_id"] if ctx.triggered else "" |
| if triggered.startswith("ic-dataset-info-btn"): |
| return True |
| if triggered.startswith("ic-dataset-info-close"): |
| return False |
| return is_open |
|
|
|
|
| @app.callback( |
| Output("ic-dataset-upload-alert", "children"), |
| Input("ic-dataset-upload", "contents"), |
| [State("ic-dataset-upload", "filename"), |
| State("user-id", "data")], |
| prevent_initial_call=True |
| ) |
| def handle_ic_dataset_upload(contents, filename, user_id): |
| if not contents: |
| return dash.no_update |
|
|
| |
|
|
| yaml_path = "image_classification/user_config.yaml" |
| try: |
| |
| fname = (filename or "dataset.zip") |
| user_dataset_name = os.path.splitext(os.path.basename(fname))[0] |
| safe_name = re.sub(r"[^a-zA-Z0-9_-]+", "_", user_dataset_name)[:64] or "custom_dataset" |
|
|
| training_path, class_names = _prepare_ic_uploaded_dataset( |
| contents, |
| filename or "dataset.zip", |
| user_id=str(user_id or "anonymous"), |
| dataset_name=safe_name, |
| ) |
|
|
| with open(yaml_path, "r", encoding="utf-8") as f: |
| config = yaml.safe_load(f) or {} |
| dataset_cfg = config.setdefault("dataset", {}) |
| dataset_cfg["training_path"] = training_path |
| dataset_cfg["dataset_name"] = "custom_dataset" |
| if class_names: |
| dataset_cfg["class_names"] = class_names |
|
|
| with open(yaml_path, "w", encoding="utf-8") as f: |
| yaml.dump(config, f, allow_unicode=True, sort_keys=False, default_flow_style=False) |
|
|
| stats = _analyze_ic_dataset(training_path) |
| table_rows = [ |
| html.Tr([html.Td(name), html.Td(str(cnt))]) |
| for name, cnt in stats.get("per_class", []) |
| ] |
|
|
| |
| cls_names = [n for n, _ in stats.get("per_class", [])] |
| cls_counts = [c for _, c in stats.get("per_class", [])] |
| palette = [ |
| "#3CB4E6", "#FFD200", "#E6007E", "#04572F", "#8191a5", |
| "#03234b", "#fff4bf", "#ceecf9" |
| ] |
| colors = [palette[i % len(palette)] for i in range(len(cls_names))] |
| bar_fig = { |
| "data": [ |
| { |
| "type": "bar", |
| "x": cls_names, |
| "y": cls_counts, |
| "marker": {"color": colors}, |
| } |
| ], |
| "layout": { |
| "height": 280, |
| "margin": {"l": 40, "r": 10, "t": 10, "b": 120}, |
| "xaxis": {"tickangle": -45, "automargin": True}, |
| "yaxis": {"title": "#Images"}, |
| "paper_bgcolor": "white", |
| "plot_bgcolor": "white", |
| }, |
| } |
|
|
| |
| ext_counts = stats.get("ext_counts", {}) |
| ext_summary = ", ".join([f"{k}:{v}" for k, v in list(ext_counts.items())[:6]]) |
| if len(ext_counts) > 6: |
| ext_summary += ", ..." |
|
|
| |
| imbalance_hint = None |
| if stats["min_images"] > 0 and stats["max_images"] / max(1, stats["min_images"]) >= 5: |
| imbalance_hint = f"Class imbalance detected (min={stats['min_images']}, max={stats['max_images']})." |
|
|
| |
| previews = [] |
| sample_paths = stats.get("sample_paths", []) |
| if sample_paths: |
| |
| for p in random.sample(sample_paths, k=min(12, len(sample_paths))): |
| b64 = _make_thumbnail_b64(p, size=120) |
| if not b64: |
| continue |
| previews.append( |
| html.Div( |
| html.Img( |
| src=f"data:image/jpeg;base64,{b64}", |
| style={"width": "120px", "height": "120px", "borderRadius": "8px", "objectFit": "cover"}, |
| ), |
| style={"display": "inline-block", "margin": "6px"}, |
| ) |
| ) |
| if len(previews) >= 8: |
| break |
|
|
| return html.Div([ |
| dbc.Alert( |
| f"Dataset uploaded. training_path set to: {training_path}", |
| color="success", |
| dismissable=True, |
| ), |
| html.H4("Dataset Statistics Report", style={"marginTop": "1rem", "marginBottom": "0.5rem", "color":"#425a78"}), |
| html.Div( |
| f"Classes: {stats['num_classes']} | Total images: {stats['total_images']} | " |
| f"Min/Max per class: {stats['min_images']}/{stats['max_images']} | " |
| f"Avg per class: {stats['avg_images']:.1f}", |
| style={"fontFamily": "monospace", "fontSize": "0.9rem", "marginBottom": "0.5rem"} |
| ), |
| html.H5("1. Per-Class Distribution", style={"color": "#03234B", "margin": "0.5rem 0"}), |
| html.Div( |
| f"File types: {ext_summary}", |
| style={"fontFamily": "monospace", "fontSize": "0.85rem", "color": "#525A63", "marginBottom": "0.25rem"} |
| ), |
| html.Div( |
| f"Corrupted check: {stats.get('corrupted_count', 0)} corrupted in first {stats.get('corrupted_checked', 0)} images checked", |
| style={"fontFamily": "monospace", "fontSize": "0.85rem", "color": "#525A63", "marginBottom": "0.5rem"} |
| ), |
| (dbc.Alert(imbalance_hint, color="warning", dismissable=True) if imbalance_hint else None), |
| dcc.Graph(figure=bar_fig, config={"displayModeBar": False}), |
| html.H5("2. Classes Table", style={"color": "#03234B", "margin": "0.75rem 0 0.5rem 0"}), |
| dbc.Table( |
| [ |
| html.Thead(html.Tr([html.Th("Class"), html.Th("#Images")])) , |
| html.Tbody(table_rows) |
| ], |
| bordered=True, |
| hover=True, |
| size="sm", |
| style={"maxHeight": "260px", "overflowY": "auto", "display": "block"} |
| ), |
| (html.Div([ |
| html.H5("3. Sample Previews", style={"color": "#03234B", "margin": "0.75rem 0 0.5rem 0"}), |
| html.Div(previews, style={"marginTop": "0.25rem"}) |
| ]) if previews else None), |
| (html.Div([ |
| html.H5("4. Corrupted Examples", style={"color": "#03234B", "margin": "0.75rem 0 0.5rem 0"}), |
| html.Details([ |
| html.Summary("Show list"), |
| html.Pre("\n".join(stats.get("corrupted_examples", [])), style={"whiteSpace": "pre-wrap"}) |
| ]) |
| ]) if stats.get("corrupted_examples") else None), |
| ]) |
| except Exception as e: |
| return html.Div([ |
| dbc.Alert(f"Dataset upload failed: {e}", color="danger", dismissable=True) |
| ]) |
|
|
| @app.callback( |
| Output("stop-train-btn", "style"), |
| [Input("train-log-interval", "disabled"), |
| Input("train-finished", "data")], |
| prevent_initial_call=False |
| ) |
| def toggle_stop_button(interval_disabled, finished): |
| base = { |
| "marginBottom": "1rem", |
| "boxShadow": "0 0 8px #dc3545" |
| } |
| if interval_disabled is False and not finished: |
| return base |
| hidden = base.copy() |
| hidden["display"] = "none" |
| return hidden |
|
|
| @app.callback( |
| Output("train-btn", "disabled"), |
| [Input("train-log-interval", "disabled"), |
| Input("train-finished", "data")], |
| prevent_initial_call=False |
| ) |
| def toggle_start_button(interval_disabled, finished): |
| if interval_disabled is False and not finished: |
| return True |
| return False |
|
|
| @app.callback( |
| [Output("train-btn", "children"), |
| Output("train-btn", "style")], |
| [Input("train-log-interval", "disabled"), |
| Input("train-finished", "data")], |
| prevent_initial_call=False |
| ) |
| def update_start_label(interval_disabled, finished): |
| base_style = { |
| "backgroundColor": "#FFD200", |
| "borderColor": "#FFD200", |
| "color": "#03234B", |
| "borderRadius": "8px", |
| "height": "44px", |
| "minWidth": "90px", |
| "fontWeight": "bold", |
| "fontSize": "1.05rem", |
| "padding": "0 16px", |
| "display": "inline-flex", |
| "alignItems": "center", |
| "justifyContent": "center", |
| "gap": "6px", |
| "whiteSpace": "nowrap", |
| "marginRight": "20px" |
| } |
| if interval_disabled is False and not finished: |
| return ( |
| [ |
| dbc.Spinner(size="sm", color="dark", spinner_style={"margin": "0"}), |
| "Running" |
| ], |
| base_style |
| ) |
| return "Start", base_style |
|
|
| TRAIN_PROCESS_HANDLE = {"process": None} |
|
|
| @app.callback( |
| [Output("train-launch-alert", "children"), |
| Output("train-log-interval", "disabled"), |
| Output("train-finished", "data"), |
| Output("train-progress", "value"), |
| Output("train-progress", "label"), |
| Output("train-log-raw", "data")], |
| [Input("train-btn", "n_clicks"), |
| Input("train-log-interval", "n_intervals"), |
| Input("stop-train-btn", "n_clicks")], |
| [State("stm32ai-credentials", "data"), |
| State("epochs", "value"), |
| State("batch-size", "value"), |
| State("learning-rate", "value"), |
| State("train-finished", "data"), |
| State("user-id", "data"), |
| State("is-original-host", "data")], |
| prevent_initial_call=False |
| ) |
| def train_and_log(n_clicks, n_intervals, stop_clicks, creds, epochs, batch_size, learning_rate, finished_state, user_id, is_original): |
| if is_original: |
| return dbc.Alert("Training disabled on reference Space. Duplicate to enable.", |
| color="danger", dismissable=True, duration=4000), True, True, 0, "Disabled", "" |
| ctx = dash.callback_context |
| triggered = ctx.triggered[0]["prop_id"] if ctx.triggered else "" |
| alert = None |
|
|
| if not user_id: |
| user_id = str(uuid.uuid4()) |
| |
| user_tmp_dir = os.path.join("tmp_sessions", user_id) |
| os.makedirs(user_tmp_dir, exist_ok=True) |
| logs_path = os.path.join("tf", "src", "experiments_outputs") |
| log_file_path = None |
| if os.path.exists(logs_path): |
| dated_directories = [d for d in os.listdir(logs_path) if os.path.isdir(os.path.join(logs_path, d)) and d.startswith('20')] |
| if dated_directories: |
| recent_directory = max(dated_directories, key=lambda d: datetime.strptime(d, '%Y_%m_%d_%H_%M_%S')) |
| log_file_path = os.path.join(logs_path, recent_directory, "logs", "train.log") |
| os.makedirs(os.path.dirname(log_file_path), exist_ok=True) |
| if not log_file_path: |
| |
| log_file_path = os.path.join(user_tmp_dir, "train.log") |
| |
| if not hasattr(train_and_log, "history_dict"): |
| train_and_log.history_dict = {} |
| if user_id not in train_and_log.history_dict: |
| if os.path.exists(log_file_path): |
| with open(log_file_path, "r", encoding="utf-8") as f: |
| train_and_log.history_dict[user_id] = f.readlines() |
| else: |
| train_and_log.history_dict[user_id] = [] |
| history = train_and_log.history_dict[user_id] |
| |
| if triggered.startswith("stop-train-btn") and stop_clicks: |
| if TRAIN_PROCESS_HANDLE["process"] is not None: |
| try: |
| TRAIN_PROCESS_HANDLE["process"].terminate() |
| except Exception: |
| pass |
| TRAIN_PROCESS["finished"] = True |
| with open(log_file_path, "w", encoding="utf-8") as f: |
| f.write("".join(history)) |
| return dbc.Alert("Training has been stopped !", color="danger", dismissable=True, duration=3000), True, True, 0, "Stopped", "".join(history) |
|
|
| |
| if triggered.startswith("train-btn") and n_clicks: |
| yaml_path = "image_classification/user_config.yaml" |
| try: |
| _save_ic_training_params(yaml_path, epochs, batch_size, learning_rate) |
| except Exception as e: |
| return ( |
| dbc.Alert(f"Error saving parameters: {e}", color="danger", dismissable=True), |
| True, |
| True, |
| 0, |
| "Error", |
| "".join(history), |
| ) |
|
|
| def run_training(epochs, batch_size, learning_rate): |
| TRAIN_PROCESS["finished"] = False |
| TRAIN_LOG_QUEUE.queue.clear() |
| isolated_env = os.environ.copy() |
| isolated_env['STATS_TYPE'] = 'HuggingFace_devcloud' |
| isolated_env['PYTHONUNBUFFERED'] = '1' |
| with open(log_file_path, "w", encoding="utf-8") as log_f: |
| process = subprocess.Popen([ |
| sys.executable, "image_classification/stm32ai_main.py" |
| ], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, env=isolated_env) |
| TRAIN_PROCESS_HANDLE["process"] = process |
| for line in process.stdout: |
| TRAIN_LOG_QUEUE.put(line) |
| log_f.write(line) |
| process.wait() |
| TRAIN_PROCESS_HANDLE["process"] = None |
| TRAIN_PROCESS["finished"] = True |
| if TRAIN_PROCESS["thread"] is None or not TRAIN_PROCESS["thread"].is_alive(): |
| t = threading.Thread(target=run_training, args=(epochs, batch_size, learning_rate), daemon=True) |
| t.start() |
| TRAIN_PROCESS["thread"] = t |
| alert = dbc.Alert("Parameters autosaved. Launching training !", color="info", dismissable=True, duration=3000) |
| return alert, False, False, 0, "0%", "" |
|
|
| |
| log_lines = [] |
| while not TRAIN_LOG_QUEUE.empty(): |
| log_lines.append(TRAIN_LOG_QUEUE.get()) |
| if not log_lines and not TRAIN_PROCESS["thread"]: |
| return None, True, False, 0, "0%", "".join(history) |
| history.extend(log_lines) |
| with open(log_file_path, "w", encoding="utf-8") as f: |
| f.write("".join(history)) |
| finished = TRAIN_PROCESS["finished"] |
|
|
| def extract_metrics(log_lines): |
| float_acc = None |
| quant_acc = None |
| |
| for line in reversed(log_lines): |
| |
| m_float = re.search(r"accuracy of float model.*?=\s*([0-9]+\.?[0-9]*)%", line, re.IGNORECASE) |
| if m_float: |
| try: |
| float_acc = float(m_float.group(1)) |
| break |
| except Exception: |
| float_acc = None |
| |
| for line in reversed(log_lines): |
| m_quant = re.search(r"accuracy of quantized model.*?=\s*([0-9]+\.?[0-9]*)%", line, re.IGNORECASE) |
| if m_quant: |
| try: |
| quant_acc = float(m_quant.group(1)) |
| break |
| except Exception: |
| quant_acc = None |
| return float_acc, quant_acc |
|
|
| float_acc = None |
| quant_acc = None |
| logs_path = os.path.join("tf", "src", "experiments_outputs") |
| log_lines = [] |
| if os.path.exists(logs_path): |
| dated_directories = [d for d in os.listdir(logs_path) if os.path.isdir(os.path.join(logs_path, d)) and d.startswith('20')] |
| if dated_directories: |
| recent_directory = max(dated_directories, key=lambda d: datetime.strptime(d, '%Y_%m_%d_%H_%M_%S')) |
| log_file = os.path.join(logs_path, recent_directory, "logs", "train.log") |
| if os.path.exists(log_file): |
| try: |
| with open(log_file, "r", encoding="utf-8") as f: |
| log_lines = f.readlines() |
| except Exception: |
| pass |
| float_acc, quant_acc = extract_metrics(log_lines) |
|
|
| progress = 0 |
| label = "0%" |
| |
| epoch = 0 |
| total_epochs = 0 |
| for line in reversed(history): |
| m_epoch = re.search(r"Epoch (\d+)/(\d+)", line) |
| if m_epoch: |
| epoch = int(m_epoch.group(1)) |
| total_epochs = int(m_epoch.group(2)) |
| break |
| if total_epochs > 0: |
| progress = int(100 * epoch / total_epochs) |
| label = f"{epoch}/{total_epochs} ({progress}% Done, {100-progress}% Remaining)" |
|
|
| if finished and TRAIN_LOG_QUEUE.empty(): |
| return None, True, True, 100, "End", "".join(history) |
| return None, False, False, progress, label, "".join(history) |
|
|
| @app.callback( |
| Output("train-output", "children"), |
| [Input("train-log-raw", "data"), Input("log-tail-lines", "value"), Input("log-show-all", "value")], |
| prevent_initial_call=False |
| ) |
| def display_train_logs(raw, tail_lines, show_all): |
| if raw is None: |
| return html.Pre("") |
| if not raw: |
| return html.Pre("", style={"margin": 0}) |
| lines = raw.splitlines() |
| if show_all and "all" in show_all: |
| display_lines = lines |
| else: |
| try: |
| n = int(tail_lines or 500) |
| except Exception: |
| n = 500 |
| display_lines = lines[-n:] |
| return html.Pre("\n".join(display_lines), style={"margin": 0}) |
|
|
|
|
| @app.callback( |
| Output("download-train-logs", "data"), |
| Input("download-train-logs-btn", "n_clicks"), |
| State("train-log-raw", "data"), |
| prevent_initial_call=True |
| ) |
| def download_train_logs(n, raw): |
| if not n: |
| return dash.no_update |
| if raw is None: |
| raw = "" |
| filename = f"training_logs_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt" |
| return dict(content=raw, filename=filename) |
|
|
| app.clientside_callback( |
| """ |
| function(n_clicks, raw){ |
| if(!n_clicks){ return window.dash_clientside.no_update || '' } |
| try { |
| if(raw === null || raw === undefined){ |
| if (navigator?.clipboard?.writeText) { navigator.clipboard.writeText(''); } |
| return 'Aucun log'; |
| } |
| const MAX = 1000000; |
| let txt = String(raw); |
| txt = txt.replace(/\\r\\n?/g, '\\n').replace(/\\x1B\\[[0-9;]*m/g, ''); |
| let trunc = false; |
| if (txt.length > MAX){ txt = txt.slice(0, MAX); trunc = true; } |
| if (navigator?.clipboard?.writeText){ |
| navigator.clipboard.writeText(txt); |
| } else { |
| const ta = document.createElement('textarea'); |
| ta.value = txt; |
| ta.style.position = 'fixed'; |
| ta.style.top = '-2000px'; |
| document.body.appendChild(ta); |
| ta.select(); |
| try { document.execCommand('copy'); } catch(e){} |
| document.body.removeChild(ta); |
| } |
| return 'Logs copied' + (trunc ? ' (truncated)' : '') + ' [' + txt.length + ' chars]'; |
| } catch(e){ |
| return 'Copy failed'; |
| } |
| } |
| """, |
| Output("copy-logs-feedback", "children"), |
| Input("copy-train-logs-btn", "n_clicks"), |
| State("train-log-raw", "data"), |
| prevent_initial_call=True |
| ) |
|
|
| app.clientside_callback( |
| f""" |
| function(href){{ |
| if(!href) return null; |
| const host = window.location.host.toLowerCase(); |
| const originals = {ORIGINAL_SPACE_HOSTS}; |
| return Array.isArray(originals) && originals.includes(host); |
| }} |
| """, |
| Output("is-original-host", "data"), |
| Input("app-url", "href") |
| ) |
|
|
| @app.callback( |
| Output("duplicate-blocker", "style"), |
| Input("is-original-host", "data"), |
| prevent_initial_call=False |
| ) |
| def show_blocker(is_original): |
| base = { |
| "position": "fixed", |
| "top": 0, "left": 0, "right": 0, "bottom": 0, |
| "background": "rgba(3,35,75,0.96)", |
| "zIndex": 99999, |
| "display": "flex" if is_original else "none", |
| "padding": "2rem" |
| } |
| return base |
|
|
| @app.callback( |
| [Output("acc-visualization", "figure"), |
| Output("loss-visualization", "figure"), |
| Output("acc-fig-store", "data"), |
| Output("loss-fig-store", "data")], |
| [Input("train-log-interval", "n_intervals"), |
| Input("train-finished", "data")], |
| prevent_initial_call=False |
| ) |
| def refresh_metrics(n_intervals, finished): |
| outputs_folder = os.path.join("tf", "src", "experiments_outputs") |
| if not os.path.exists(outputs_folder): |
| return dash.no_update, dash.no_update, dash.no_update, dash.no_update |
| dated_dirs = [d for d in os.listdir(outputs_folder) |
| if os.path.isdir(os.path.join(outputs_folder, d)) and |
| re.match(r"^20\d{2}_\d{2}_\d{2}_\d{2}_\d{2}_\d{2}$", d)] |
| if not dated_dirs: |
| return dash.no_update, dash.no_update, dash.no_update, dash.no_update |
| recent = max(dated_dirs, key=lambda d: datetime.strptime(d, "%Y_%m_%d_%H_%M_%S")) |
| metrics_file = os.path.join(outputs_folder, recent, "logs", "metrics", "train_metrics.csv") |
| if not os.path.exists(metrics_file): |
| return dash.no_update, dash.no_update, dash.no_update, dash.no_update |
| try: |
| df = pd.read_csv(metrics_file) |
| except Exception: |
| return dash.no_update, dash.no_update, dash.no_update, dash.no_update |
| if df.empty or "epoch" not in df.columns: |
| return dash.no_update, dash.no_update, dash.no_update, dash.no_update |
|
|
| def normalize_metric(series): |
| if series.max() <= 1.2: |
| return series * 100.0, True |
| return series, False |
|
|
| |
| traces_acc = []; y_acc_values = []; norm_used = False |
| mk = {"size": 8, "symbol": "circle", "line": {"width": 1, "color": "#fff"}} |
| if "accuracy" in df.columns: |
| acc_vals, n1 = normalize_metric(df["accuracy"]); norm_used |= n1; y_acc_values.append(acc_vals) |
| traces_acc.append({"x": df["epoch"], "y": acc_vals, "type": "scatter", "mode": "markers+lines", |
| "name": "Accuracy", "line": {"color": "#1976D2", "width": 2}, |
| "marker": {**mk, "color": "#1976D2"}, |
| "hovertemplate": "Epoch %{x}<br>Acc %{y:.2f}%<extra></extra>"}) |
| if "val_accuracy" in df.columns: |
| val_acc_vals, n2 = normalize_metric(df["val_accuracy"]); norm_used |= n2; y_acc_values.append(val_acc_vals) |
| traces_acc.append({"x": df["epoch"], "y": val_acc_vals, "type": "scatter", "mode": "markers+lines", |
| "name": "Val Accuracy", "line": {"color": "#FF9800", "width": 2}, |
| "marker": {**mk, "color": "#FF9800", "symbol": "diamond"}, |
| "hovertemplate": "Epoch %{x}<br>Val Acc %{y:.2f}%<extra></extra>"}) |
| try: |
| best_idx = df["val_accuracy"].idxmax() |
| traces_acc.append({"x": [df.loc[best_idx, "epoch"]], "y": [val_acc_vals.loc[best_idx]], |
| "type": "scatter", "mode": "markers", "name": "Best Val Acc", |
| "marker": {"size": 16, "color": "#2E7D32", "symbol": "star", |
| "line": {"width": 2, "color": "#1B5E20"}}, |
| "hovertemplate": "Best Val Acc<br>Epoch %{x}<br>%{y:.2f}%<extra></extra>"}) |
| except Exception: |
| pass |
| xmin = int(df["epoch"].min()); xmax = int(df["epoch"].max()) |
| acc_layout = ACC_LOSS_BASE_LAYOUT | { |
| "xaxis": {**ACC_LOSS_BASE_LAYOUT["xaxis"], "range": [max(0, xmin - 0.15), xmax + 0.35]}, |
| "yaxis": {"title": "Accuracy (%)", "showgrid": True, "gridcolor": "#EEEFF1", "fixedrange": True}, |
| "legend": {"orientation": "h", "y": -0.25}, |
| "hovermode": "x unified", |
| "dragmode": False, |
| "uirevision": "persist-metrics", |
| "annotations": ([{ |
| "text": "Normalized Accuracy (%)", |
| "showarrow": False, "xref": "paper", "yref": "paper", "x": 0, "y": 1.12, |
| "font": {"size": 11, "color": "#525A63"} |
| }] if norm_used else []) |
| } |
| if y_acc_values: |
| all_acc = pd.concat(y_acc_values) |
| ymin, ymax = float(all_acc.min()), float(all_acc.max()) |
| if ymin == ymax: |
| ymin -= 1; ymax += 1 |
| pad = (ymax - ymin) * 0.08 |
| acc_layout["yaxis"]["range"] = [ymin - pad, ymax + pad] |
| acc_fig = {"data": traces_acc, "layout": acc_layout} |
|
|
| |
| traces_loss = []; y_loss_values = [] |
| if "loss" in df.columns: |
| y_loss_values.append(df["loss"]) |
| traces_loss.append({"x": df["epoch"], "y": df["loss"], "type": "scatter", "mode": "markers+lines", |
| "name": "Loss", "line": {"color": "#1976D2", "width": 2}, |
| "marker": {"size": 8, "symbol": "circle", "line": {"width": 1, "color": "#fff"}, "color": "#1976D2"}, |
| "hovertemplate": "Epoch %{x}<br>Loss %{y:.4f}<extra></extra>"}) |
| if "val_loss" in df.columns: |
| y_loss_values.append(df["val_loss"]) |
| traces_loss.append({"x": df["epoch"], "y": df["val_loss"], "type": "scatter", "mode": "markers+lines", |
| "name": "Val Loss", "line": {"color": "#FF9800", "width": 2}, |
| "marker": {"size": 8, "symbol": "diamond", "line": {"width": 1, "color": "#fff"}, "color": "#FF9800"}, |
| "hovertemplate": "Epoch %{x}<br>Val Loss %{y:.4f}<extra></extra>"}) |
| try: |
| best_loss_idx = df["val_loss"].idxmin() |
| traces_loss.append({"x": [df.loc[best_loss_idx, "epoch"]], |
| "y": [df.loc[best_loss_idx, "val_loss"]], |
| "type": "scatter", "mode": "markers", "name": "Best Val Loss", |
| "marker": {"size": 16, "color": "#D32F2F", "symbol": "star", |
| "line": {"width": 2, "color": "#B71C1C"}}, |
| "hovertemplate": "Best Val Loss<br>Epoch %{x}<br>%{y:.4f}<extra></extra>"}) |
| except Exception: |
| pass |
| loss_layout = ACC_LOSS_BASE_LAYOUT | { |
| "xaxis": {**ACC_LOSS_BASE_LAYOUT["xaxis"], "range": [max(0, xmin - 0.15), xmax + 0.35]}, |
| "yaxis": {"title": "Loss", "showgrid": True, "gridcolor": "#EEEFF1", "fixedrange": True}, |
| "legend": {"orientation": "h", "y": -0.25}, |
| "hovermode": "x unified", |
| "dragmode": False, |
| "uirevision": "persist-metrics" |
| } |
| if y_loss_values: |
| all_loss = pd.concat(y_loss_values) |
| ymin, ymax = float(all_loss.min()), float(all_loss.max()) |
| if ymin == ymax: |
| ymin -= 0.05 * (abs(ymin) + 1); ymax += 0.05 * (abs(ymax) + 1) |
| pad = (ymax - ymin) * 0.08 |
| loss_layout["yaxis"]["range"] = [ymin - pad, ymax + pad] |
| loss_fig = {"data": traces_loss, "layout": loss_layout} |
|
|
| return acc_fig, loss_fig, acc_fig, loss_fig |
|
|
|
|
| @app.callback( |
| [Output("gauge-float-acc", "figure"), |
| Output("gauge-quant-acc", "figure")], |
| [Input("train-log-interval", "n_intervals"), |
| Input("train-finished", "data")], |
| prevent_initial_call=False |
| ) |
| def update_gauges(n, finished): |
| outputs_folder = os.path.join("tf", "src", "experiments_outputs") |
| float_acc = None |
| quant_acc = None |
| if os.path.exists(outputs_folder): |
| dated_dirs = [d for d in os.listdir(outputs_folder) |
| if os.path.isdir(os.path.join(outputs_folder, d)) and d.startswith("20")] |
| if dated_dirs: |
| recent = max(dated_dirs, key=lambda d: datetime.strptime(d, "%Y_%m_%d_%H_%M_%S")) |
| log_file = os.path.join(outputs_folder, recent, "logs", "train.log") |
| if os.path.exists(log_file): |
| try: |
| with open(log_file, "r", encoding="utf-8") as f: |
| lines = f.readlines() |
| for line in reversed(lines): |
| m = re.search(r"accuracy of float model.*?=\s*([0-9]+\.?[0-9]*)%", line, re.IGNORECASE) |
| if m: |
| float_acc = float(m.group(1)) |
| break |
| for line in reversed(lines): |
| m = re.search(r"accuracy of quantized model.*?=\s*([0-9]+\.?[0-9]*)%", line, re.IGNORECASE) |
| if m: |
| quant_acc = float(m.group(1)) |
| break |
| except Exception: |
| pass |
|
|
| def color(v): |
| if v is None: |
| return "#c0c8d2" |
| if v < 40: |
| return "#FFD200" |
| if v < 70: |
| return "#6dc7ec" |
| if v < 85: |
| return "#49B170" |
| if v < 92: |
| return "#E6007E" |
| return "#04572F" |
|
|
| def gauge_figure(val): |
| col = color(val) |
| return { |
| "data": [{ |
| "type": "indicator", |
| "mode": "gauge+number", |
| "value": float(val) if val is not None else 0, |
| "gauge": { |
| "axis": {"range": [0, 100], "tickwidth": 2, "tickcolor": "#03234B"}, |
| "bar": {"color": col}, |
| "bgcolor": "#EEEFF1", |
| "borderwidth": 2, |
| "bordercolor": "#A6ADB5" |
| }, |
| "number": {"suffix": "%" if val is not None else "", "font": {"size": 24, "color": col}}, |
| "domain": {"x": [0, 1], "y": [0, 1]} |
| }], |
| "layout": { |
| "margin": {"l": 15, "r": 15, "t": 10, "b": 10}, |
| "paper_bgcolor": "#EEEFF1", |
| "plot_bgcolor": "#EEEFF1", |
| "height": 190 |
| } |
| } |
|
|
| return gauge_figure(float_acc), gauge_figure(quant_acc) |
|
|
| @app.callback( |
| Output("metrics-error-message", "children"), |
| [Input("acc-visualization", "figure"), |
| Input("loss-visualization", "figure")], |
| prevent_initial_call=False |
| ) |
| def show_metrics_error(acc_fig, loss_fig): |
| return "" |
|
|
| @app.callback( |
| Output("user-id", "data"), |
| Input("user-id", "data"), |
| prevent_initial_call=False |
| ) |
| def ensure_user_id(user_id): |
| if not user_id: |
| return str(uuid.uuid4()) |
| return user_id |
|
|
| |
| @app.callback( |
| [Output("confusion-matrix-modal", "is_open"), |
| Output("confusion-matrix-modal-body", "children"), |
| Output("confusion-float-store", "data"), |
| Output("confusion-quant-store", "data"), |
| Output("confusion-last-type", "data")], |
| [Input("display-confusion-matrix-btn", "n_clicks"), |
| Input("display-confusion-matrix-quant-btn", "n_clicks"), |
| Input("close-confusion-matrix-modal", "n_clicks")], |
| [State("confusion-matrix-modal", "is_open"), |
| State("confusion-float-store", "data"), |
| State("confusion-quant-store", "data"), |
| State("confusion-last-type", "data")], |
| prevent_initial_call=True |
| ) |
| def toggle_confusion_matrix_modal(n_float, n_quant, n_close, is_open, |
| store_float, store_quant, last_type): |
| ctx = dash.callback_context |
| triggered = ctx.triggered[0]["prop_id"] if ctx.triggered else "" |
| if not triggered: |
| return is_open, dash.no_update, dash.no_update, dash.no_update, last_type |
| if triggered.startswith("close-confusion-matrix-modal") and n_close: |
| return False, dash.no_update, dash.no_update, dash.no_update, last_type |
|
|
| matrix_type = None |
| if triggered.startswith("display-confusion-matrix-btn") and n_float: |
| matrix_type = "float" |
| elif triggered.startswith("display-confusion-matrix-quant-btn") and n_quant: |
| matrix_type = "quant" |
| else: |
| return is_open, dash.no_update, dash.no_update, dash.no_update, last_type |
|
|
| if matrix_type == "float" and store_float: |
| img = html.Img(src=f"data:image/png;base64,{store_float}", |
| style={"maxWidth": "100%", "height": "auto"}) |
| return True, img, dash.no_update, dash.no_update, "float" |
| if matrix_type == "quant" and store_quant: |
| img = html.Img(src=f"data:image/png;base64,{store_quant}", |
| style={"maxWidth": "100%", "height": "auto"}) |
| return True, img, dash.no_update, dash.no_update, "quant" |
|
|
| outputs_folder = os.path.join("tf", "src", "experiments_outputs") |
| if not os.path.exists(outputs_folder): |
| msg = html.Div("Outputs folder not found.", style={"color": "#dc3545", "textAlign": "center"}) |
| return True, msg, dash.no_update, dash.no_update, last_type |
| try: |
| pattern = re.compile(r'^20\d{2}_\d{2}_\d{2}_\d{2}_\d{2}_\d{2}$') |
| candidates = [d for d in os.listdir(outputs_folder) |
| if os.path.isdir(os.path.join(outputs_folder, d)) and pattern.match(d)] |
| if not candidates: |
| return True, html.Div("No experiment directory found.", style={"color": "#dc3545"}), dash.no_update, dash.no_update, last_type |
| recent = max(candidates, key=lambda d: datetime.strptime(d, "%Y_%m_%d_%H_%M_%S")) |
| except Exception as e: |
| return True, html.Div(f"Scan error: {e}", style={"color": "#dc3545"}), dash.no_update, dash.no_update, last_type |
|
|
| if matrix_type == "float": |
| file_path = os.path.join(outputs_folder, recent, "float_model_confusion_matrix_validation_set.png") |
| else: |
| file_path = os.path.join(outputs_folder, recent, "quantized_model_confusion_matrix_validation_set.png") |
|
|
| if not os.path.exists(file_path): |
| return True, html.Div(f"Confusion matrix file not found: {file_path}", |
| style={"color": "#dc3545", "wordBreak": "break-all"}), dash.no_update, dash.no_update, last_type |
| try: |
| with Image.open(file_path) as im: |
| if os.path.getsize(file_path) > 5 * 1024 * 1024: |
| im.thumbnail((1400, 1400)) |
| buf = _io.BytesIO() |
| im.save(buf, format="PNG") |
| b64 = base64.b64encode(buf.getvalue()).decode() |
| img = html.Img(src=f"data:image/png;base64,{b64}", style={"maxWidth": "100%", "height": "auto"}) |
| if matrix_type == "float": |
| return True, img, b64, dash.no_update, "float" |
| else: |
| return True, img, dash.no_update, b64, "quant" |
| except Exception as e: |
| return True, html.Div(f"Error loading image: {e}", style={"color": "#dc3545"}), dash.no_update, dash.no_update, last_type |
|
|
|
|
| @app.callback( |
| Output("navbar-main", "children"), |
| Input("sidebar-mode", "data") |
| ) |
| def render_navbar(tab): |
| pred_class = "me-2 navbar-yellow-btn" |
| train_class = "navbar-yellow-btn" |
| if tab == "tab-predict": |
| pred_class = "me-2 navbar-gray-btn" |
| train_class = "navbar-yellow-btn" |
| elif tab == "tab-train": |
| pred_class = "me-2 navbar-yellow-btn" |
| train_class = "navbar-gray-btn" |
| return dbc.Navbar( |
| dbc.Container([ |
| html.Div( |
| html.Img(src="/assets/ST_logo.png", |
| style={"height": "38px", "marginRight": "1.2rem"}), |
| style={"display": "flex", "alignItems": "center"} |
| ), |
| dbc.NavbarBrand( |
| "STM32AI Experimentation Hub", |
| style={ |
| "color": "#fff", |
| "fontWeight": "bold", |
| "fontSize": "1.3rem", |
| "marginRight": "2rem" |
| } |
| ), |
| dbc.Nav([ |
| dbc.Button( |
| "Training Mode (Image classification)", |
| id="nav-predict", |
| color="primary", |
| outline=False, |
| className=pred_class, |
| n_clicks=0 |
| ), |
| dbc.Button( |
| "Prediction Mode (Object detection)", |
| id="nav-train", |
| color="primary", |
| outline=False, |
| className=train_class, |
| n_clicks=0 |
| ), |
| html.A( |
| html.Img( |
| src=app.get_asset_url("github.svg"), |
| style={"height": "22px", "width": "22px", "display": "block"} |
| ), |
| id="navbar-github-link", |
| href=ST_MODELZOO_SERVICES_REPO_URL, |
| target="_blank", |
| title="STM32AI Model Zoo Services Repository", |
| style={ |
| "display": "inline-flex", |
| "alignItems": "center", |
| "justifyContent": "center", |
| "padding": "0.45rem", |
| "borderRadius": "999px", |
| "textDecoration": "none", |
| "background": "transparent", |
| "border": "1px solid rgba(255,255,255,0.22)", |
| }, |
| ), |
| ], className="ms-auto", style={"gap": "0.5rem"}) |
| ], style={"display": "flex", "alignItems": "center"}), |
| color="#0E2140", |
| dark=True, |
| style={"marginBottom": "0.5rem", "padding": "0.3rem 2rem", "height": "56px", "backgroundColor": "#0E2140"}, |
| className="navbar-custom" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| os.makedirs("tmp_sessions", exist_ok=True) |
| app.run(host="0.0.0.0", port=7860, debug=True, |
| dev_tools_ui=True, dev_tools_hot_reload=True, threaded=True) |