FBAGSTM's picture
STM32 AI Experimentation Hub
747451d
import os
import sys
import dash
from dash import dcc, html, Output, Input, State
from dash import callback_context
import dash_bootstrap_components as dbc
import logging
import base64
import io
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import random
import pandas as pd
from datetime import datetime
import re
import uuid
import threading
import queue
import yaml
import subprocess
import time
import signal
import urllib.parse
import zipfile
import tarfile
import shutil
from PIL import Image
import io as _io
import multiprocessing
from PIL import Image, ImageOps
import re
import matplotlib.pyplot as plt
import numpy as np
import random
import io, base64, os
import matplotlib.image as mpimg
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
ST_MODELZOO_SERVICES_REPO_URL = "https://github.com/STMicroelectronics/stm32ai-modelzoo-services/tree/main"
OFFICIAL_SPACE_OWNER = os.getenv("OFFICIAL_SPACE_OWNER", "STMicroelectronics")
OFFICIAL_SPACE_NAME = os.getenv("OFFICIAL_SPACE_NAME", "stm32-modelzoo-app")
def _host_variants(owner: str, space: str):
owner_l = (owner or "").strip().lower()
space_l = (space or "").strip().lower()
raw = [
f"{owner_l}-{space_l}.hf.space",
f"{owner_l}--{space_l}.hf.space",
f"{owner_l}-{space_l.replace('_','-')}.hf.space",
f"{owner_l}--{space_l.replace('_','-')}.hf.space",
]
# Preserve ordering but drop empties/duplicates
out = []
for h in raw:
if h and h not in out:
out.append(h)
return out
# Only block the official ST space and prompt users to duplicate.
_DEFAULT_ORIGINAL_HOSTS = ",".join(_host_variants(OFFICIAL_SPACE_OWNER, OFFICIAL_SPACE_NAME))
ORIGINAL_SPACE_HOSTS = [h.strip().lower() for h in os.getenv("ORIGINAL_SPACE_HOSTS", _DEFAULT_ORIGINAL_HOSTS).split(",") if h.strip()]
DUPLICATE_URL = f"https://huggingface.co/spaces/{OFFICIAL_SPACE_OWNER}/{OFFICIAL_SPACE_NAME}?duplicate=true"
app = dash.Dash(
__name__,
external_stylesheets=[dbc.themes.BOOTSTRAP],
suppress_callback_exceptions=True,
assets_folder=os.path.join(os.path.dirname(__file__), "assets")
)
app.title = "STM32AI Model Zoo Dashboard"
app.index_string = '''
<!DOCTYPE html>
<html>
<head>
{%metas%}
<title>{%title%}</title>
{%favicon%}
{%css%}
<style>
body { background: #ffffff !important; font-family: Arial, sans-serif; }
.navbar-custom { background-color: #0E2140 !important; box-shadow: 0 2px 8px rgba(50,50,80,0.08); }
.footer-custom { background-color: #0E2140 !important; }
.card { background: #ffffff !important; border: 1px solid #A6ADB5 !important; }
.btn-primary, .navbar-yellow-btn {
background-color: #FFD200 !important;
border-color: #FFD200 !important;
color: #03234B !important;
}
.navbar-gray-btn {
background-color: #A6ADB5 !important;
border-color: #A6ADB5 !important;
color: #03234B !important;
}
.btn-success { background-color: #525A63 !important; border-color: #525A63 !important; color: #ffffff !important; }
h2, h4 { color: #03234B !important; }
hr { border-top: 2px solid #ffffff !important; }
.dash-upload { border: 2px dashed #3CB4E6 !important; background: #ffffff !important; color: #525A63 !important; }
.nav-link-custom { color: #ffffff !important; font-weight: bold; font-size: 1.1rem; border-radius: 20px; padding: 0.5rem 1.2rem; transition: background 0.2s, color 0.2s; }
.nav-link-custom:hover { background: #3CB4E6 !important; color: #ffffff !important; }
.nav-link-active { background: #3CB4E6 !important; color: #ffffff !important; }
.navbar-separator { border-left: 3px solid #3CB4E6; height: 32px; border-radius: 2px; margin: 0 1.2rem; box-shadow: 0 0 4px #3CB4E6; }
.btn-stop-round {
width: 44px;
height: 44px;
border-radius: 50%;
padding: 0;
display: flex;
align-items: center;
justify-content: center;
font-size: 1.5rem;
background-color: #dc3545 !important;
border: none !important;
color: #fff !important;
box-shadow: 0 2px 8px rgba(50,50,80,0.08);
transition: background 0.2s;
}
.btn-stop-round:hover {
background-color: #b52a37 !important;
}
.section-title {
font-size: 1.75rem;
font-weight: 700;
margin: 0 0 0.5rem 0;
color: #03234B;
line-height: 1.2;
}
.section-subtitle {
font-size: 0.85rem;
color: #6c757d;
margin: 0 0 1.25rem 0;
}
.predict-logs {
font-family: monospace;
width: 100%;
height: 100%;
}
.predict-image {
font-family: Arial, sans-serif;
width: 100%;
height: 100%;
display: flex;
align-items: center;
justify-content: center;
overflow: auto;
}
.predict-image img {
max-width: 100%;
height: auto;
display: block;
}
/* Showcase link style */
a.info-link { color: #2167b0; text-decoration: underline; }
a.info-link:hover { color: #114a86; text-decoration: underline; }
/* Placeholder skeleton */
.img-skeleton {
width: 100%;
height: 320px;
background: linear-gradient(90deg, #f0f2f5 25%, #e6e9ef 37%, #f0f2f5 63%);
border-radius: 8px;
}
/* Responsive card heights on small screens */
@media (max-width: 768px) {
.responsive-card { min-height: 520px !important; height: 520px !important; }
}
</style>
</head>
<body>
{%app_entry%}
<footer>
{%config%}
{%scripts%}
{%renderer%}
</footer>
</body>
</html>
'''
MODEL_PATH = "object_detection/pretrained_models/st_yoloxn_d033_w025_320_int8.tflite"
model_exists = os.path.isfile(MODEL_PATH)
if model_exists:
print(f"[INFO] Founded model : {MODEL_PATH}")
else:
print(f"[ERROR] Missing model : {MODEL_PATH}")
model_alert = None
if not model_exists:
model_alert = html.Div(
"[ERROR] onnx model object_detection/pretrained_models/st_yoloxn_d033_w025_320_int8.tflite not found !",
style={"background": "#ffcccc", "color": "#a00", "padding": "1rem", "textAlign": "center", "fontWeight": "bold"}
)
app.layout = html.Div([
model_alert,
html.Div([
dcc.Location(id="app-url", refresh=False),
dcc.Store(id="is-original-host"),
dcc.Store(id="sidebar-mode", data="tab-train"),
dcc.Store(id="stored-image", data=None, storage_type="session"),
dcc.Store(id="input-image-size", data=None, storage_type="session"),
dcc.Store(id="predict-state", data="idle"),
dcc.Interval(id="progress-interval", interval=1000, n_intervals=0, disabled=True),
dcc.Store(id="stm32ai-credentials", data=None),
dcc.Store(id="user-id", storage_type="local"),
dcc.Store(id="train-log-raw", data=""),
dcc.Store(id="navbar-mode-store"),
dcc.Store(id="acc-fig-store"),
dcc.Store(id="loss-fig-store"),
dcc.Store(id="confusion-float-store"),
dcc.Store(id="confusion-quant-store"),
dcc.Store(id="confusion-last-type"),
dcc.Store(id="predict-image-store", data=None),
dcc.Location(id="url-initial", refresh=False),
html.Div(
id="duplicate-blocker",
children=html.Div(
dbc.Card(
dbc.CardBody([
html.Div([
html.Div(
html.Img(
src="/assets/ST_logo.png",
style={
"height": "34px",
"width": "auto",
"display": "inline-block",
"objectFit": "contain",
"objectPosition": "center"
}
),
style={
"width": "100%",
"textAlign": "center",
"lineHeight": 0
}
),
html.Div(
"STM32AI Model Zoo",
style={
"width": "100%",
"color": "#FFFFFF",
"fontWeight": "bold",
"fontSize": "1.02rem",
"marginTop": "0.45rem",
"textAlign": "center"
}
)
], style={
"width": "100%",
"boxSizing": "border-box",
"background": "#03234B",
"borderRadius": "14px",
"padding": "0.95rem 1.0rem",
"display": "flex",
"flexDirection": "column",
"alignItems": "center",
"justifyContent": "center"
}),
html.H2(
"Duplicate this Space to continue",
style={
"color": "#03234B",
"marginTop": "1.25rem",
"marginBottom": "0.5rem",
"textAlign": "center",
"fontWeight": 800,
"fontSize": "1.6rem"
}
),
html.P(
"This is the official ST reference Space. For a stable experience and to enable training and prediction, "
"please duplicate it to your own Hugging Face account.",
style={
"color": "#525A63",
"fontSize": "0.98rem",
"textAlign": "center",
"maxWidth": "720px",
"margin": "0 auto 1.0rem auto",
"lineHeight": "1.5"
}
),
html.Div([
html.Div(
[
html.Div("1", style={
"width": "26px", "height": "26px", "borderRadius": "50%",
"background": "#03234B", "color": "#fff",
"display": "flex", "alignItems": "center", "justifyContent": "center",
"fontWeight": "bold", "marginRight": "0.6rem"
}),
html.Div("Click Duplicate Space", style={"fontWeight": 600, "color": "#03234B"})
],
style={"display": "flex", "alignItems": "center", "marginBottom": "0.5rem"}
),
html.Div(
[
html.Div("2", style={
"width": "26px", "height": "26px", "borderRadius": "50%",
"background": "#03234B", "color": "#fff",
"display": "flex", "alignItems": "center", "justifyContent": "center",
"fontWeight": "bold", "marginRight": "0.6rem"
}),
html.Div("Wait for the build to finish", style={"fontWeight": 600, "color": "#03234B"})
],
style={"display": "flex", "alignItems": "center", "marginBottom": "0.5rem"}
),
html.Div(
[
html.Div("3", style={
"width": "26px", "height": "26px", "borderRadius": "50%",
"background": "#03234B", "color": "#fff",
"display": "flex", "alignItems": "center", "justifyContent": "center",
"fontWeight": "bold", "marginRight": "0.6rem"
}),
html.Div("Open your duplicated Space URL", style={"fontWeight": 600, "color": "#03234B"})
],
style={"display": "flex", "alignItems": "center"}
)
], style={
"background": "#EEEFF1",
"border": "1px solid #A6ADB5",
"borderRadius": "12px",
"padding": "0.9rem 1.0rem",
"marginTop": "0.75rem",
"marginBottom": "1.25rem"
}),
html.Div([
html.A(
"Duplicate Space",
href=DUPLICATE_URL,
target="_blank",
style={
"background": "#FFD200",
"color": "#03234B",
"padding": "0.9rem 1.4rem",
"borderRadius": "10px",
"fontWeight": "bold",
"fontSize": "1.02rem",
"textDecoration": "none",
"display": "inline-block",
"boxShadow": "0 6px 18px rgba(0,0,0,0.18)",
"border": "1px solid rgba(3,35,75,0.18)"
}
)
], style={"display": "flex", "justifyContent": "center", "flexWrap": "wrap", "gap": "0.5rem"}),
html.Div(
[
html.Div("You will be redirected to a new Space under your account.", style={"fontWeight": 600}),
html.Div("Once duplicated, reload your new URL to unlock all features.")
],
style={
"marginTop": "1.25rem",
"color": "#525A63",
"fontSize": "0.9rem",
"textAlign": "center"
}
)
]),
style={
"maxWidth": "780px",
"width": "100%",
"borderRadius": "16px",
"border": "1px solid rgba(255,210,0,0.55)",
"boxShadow": "0 14px 45px rgba(0,0,0,0.35)",
"overflow": "hidden",
},
className="card"
),
style={
"width": "100%",
"height": "100%",
"display": "flex",
"alignItems": "center",
"justifyContent": "center"
}
),
style={
"position": "fixed",
"top": 0,
"left": 0,
"right": 0,
"bottom": 0,
"background": "linear-gradient(135deg, rgba(3,35,75,0.98), rgba(3,35,75,0.90))",
"zIndex": 99999,
"display": "none",
"padding": "2rem"
}
),
html.Div(
id="navbar-main",
children=dbc.Navbar(
dbc.Container([
html.Div(
html.Img(src="/assets/ST_logo.png",
style={"height": "38px", "marginRight": "1.2rem"}),
style={"display": "flex", "alignItems": "center"}
),
dbc.NavbarBrand(
"STM32AI Model Zoo Experimentation Hub",
style={
"color": "#fff",
"fontWeight": "bold",
"fontSize": "1.3rem",
"marginRight": "2rem"
}
),
dbc.Nav([
dbc.Button(
"Training Mode (Image classification)",
id="nav-predict",
color="primary",
outline=False,
className="me-2 navbar-yellow-btn",
n_clicks=0
),
dbc.Button(
"Prediction Mode (Object detection)",
id="nav-train",
color="primary",
outline=False,
className="navbar-gray-btn",
n_clicks=0
),
html.A(
html.Img(
src=app.get_asset_url("github.svg"),
style={"height": "22px", "width": "22px", "display": "block"}
),
id="navbar-github-link",
href=ST_MODELZOO_SERVICES_REPO_URL,
target="_blank",
title="STM32AI Model Zoo Services Repository",
style={
"display": "inline-flex",
"alignItems": "center",
"justifyContent": "center",
"padding": "0.45rem",
"borderRadius": "999px",
"textDecoration": "none",
"background": "transparent",
"border": "1px solid rgba(255,255,255,0.22)",
},
),
], className="ms-auto", style={"gap": "0.5rem"})
], style={"display": "flex", "alignItems": "center"}),
color="#0E2140",
dark=True,
style={"marginBottom": "0.5rem", "padding": "0.3rem 2rem", "height": "56px", "backgroundColor": "#0E2140"},
className="navbar-custom"
)
),
html.Div([
html.Div(
id="sidebar-content",
children=[
html.Div(style={"flex": 1})
],
style={
"display": "flex",
"flexDirection": "column",
"alignItems": "center",
"background": "#EEEFF1",
"paddingTop": "2rem",
"minWidth": "90px",
"maxWidth": "90px",
"width": "90px",
"boxSizing": "border-box"
}
),
html.Div([
dcc.Location(id="netron-url"),
dcc.Store(id="yaml-modal-store"),
dbc.Modal([
dbc.ModalHeader(dbc.ModalTitle("YAML Content")),
dbc.ModalBody(id="yaml-modal-body"),
dbc.ModalFooter(
dbc.Button("Close", id="close-yaml-modal", className="ms-auto", n_clicks=0)
),
], id="yaml-modal", is_open=False, size="lg"),
dbc.Modal([
dbc.ModalHeader(dbc.ModalTitle("Visualize Model with Netron")),
dbc.ModalBody(
html.Iframe(
id="netron-iframe",
src="",
style={"width": "100%", "height": "600px", "border": "none"}
)
),
dbc.ModalFooter(
dbc.Button("Close", id="close-netron-modal", className="ms-auto", n_clicks=0)
),
], id="netron-modal", is_open=False, size="xl"),
dbc.Modal([
dbc.ModalHeader(dbc.ModalTitle("Visualize Flowers Dataset")),
dbc.ModalBody(id="dataset-modal-body"),
dbc.ModalFooter(
dbc.Button("Close", id="close-dataset-modal", className="ms-auto", n_clicks=0)
),
], id="dataset-modal", is_open=False, size="xl"),
dbc.Modal([
dbc.ModalHeader(dbc.ModalTitle("Confusion Matrix")),
dbc.ModalBody(id="confusion-matrix-modal-body"),
dbc.ModalFooter(
dbc.Button("Close", id="close-confusion-matrix-modal", className="ms-auto", n_clicks=0)
),
], id="confusion-matrix-modal", is_open=False, size="xl"),
dcc.Upload(id="upload-image", style={"display": "none"}),
html.Div(id="page-content", style={"marginTop": "2rem"}),
dbc.Button("", id="show-dataset-btn", style={"display": "none"}),
# Hidden retry button to satisfy Dash callback wiring
dbc.Button("", id="retry-predict-btn", style={"display": "none"})
], style={"flex": 1, "minWidth": 0, "overflow": "auto", "marginLeft": "40px"})
], style={
"display": "flex",
"flexDirection": "row",
"flex": "1 1 auto",
"minHeight": 0,
"height": "100%",
"margin": "0"
})
], style={"display": "flex", "flexDirection": "column", "height": "100vh", "minHeight": "100vh"})
])
@app.callback(
Output("sidebar-mode", "data"),
[Input("nav-predict", "n_clicks"), Input("nav-train", "n_clicks")],
prevent_initial_call=True
)
def switch_mode(n_predict, n_train):
ctx = dash.callback_context
if not ctx.triggered:
return dash.no_update
btn_id = ctx.triggered[0]["prop_id"].split(".")[0]
if btn_id == "nav-predict":
return "tab-predict"
elif btn_id == "nav-train":
return "tab-train"
return dash.no_update
############ Side Bar ############
@app.callback(
Output("sidebar-content", "children"),
Input("sidebar-mode", "data")
)
def update_sidebar(tab_value):
sidebar = []
if tab_value == "tab-predict":
sidebar = [
dbc.Button(html.Img(src="/assets/yaml_logo.png", style={"height": "35px"}), id="sidebar-yaml", color="light", style={"margin": "1rem 0", "background": "#EEEFF1", "border": "none"}),
dbc.Tooltip("Visualize YAML file for image_classification", target="sidebar-yaml", placement="right"),
dbc.Button(html.Img(src="/assets/media.png", style={"height": "35px"}), id="sidebar-visualize-dataset", color="light", style={"margin": "1rem 0", "background": "#EEEFF1", "border": "none"}),
dbc.Tooltip("Visualize Dataset", target="sidebar-visualize-dataset", placement="right"),
dbc.Button(html.Img(src="/assets/netron.png", style={"height": "65px", "width": "60px", "objectFit": "contain"}), id="sidebar-netron", color="light", style={"margin": "1rem 0", "background": "#EEEFF1", "border": "none"}),
dbc.Tooltip("Visualize Model with Netron", target="sidebar-netron", placement="right"),
html.Div(style={"flex": 1}),
]
elif tab_value == "tab-train":
sidebar = [
dbc.Button(html.Img(src="/assets/yaml_logo.png", style={"height": "35px"}), id="sidebar-yaml", color="light", style={"margin": "1rem 0", "background": "#EEEFF1", "border": "none"}),
dbc.Button(html.Img(src="/assets/netron.png", style={"height": "65px", "width": "60px", "objectFit": "contain"}), id="sidebar-netron", color="light", style={"margin": "1rem 0", "background": "#EEEFF1", "border": "none"}),
dbc.Tooltip("Visualize Model with Netron", target="sidebar-netron", placement="right"),
html.Div(style={"flex": 1}),
]
return sidebar
############ PREDICTION ############
def render_model_perf_od():
# Define which columns are repeated (grouped) vs varying (by resolution/format)
grouped_cols = [
"Model",
"Hyperparameters (depth_width)",
"Serie",
"Dataset",
"Format",
]
variant_cols = [
"Resolution",
"Internal RAM (KB)",
"Weights Flash (KB)",
"Inference Time (ms)",
]
# Source data (shared fields + variations)
base = {
"Model": "st_yoloxn",
"Hyperparameters (depth_width)": "d033_w025",
"Serie": "STM32N6",
"Dataset": "COCO-Person",
"Format": "Quantized INT8",
}
variants = [
{"Resolution": "192x192x3", "Internal RAM (KB)": "333", "Weights Flash (KB)": "877", "Inference Time (ms)": "6"},
{"Resolution": "256x256x3", "Internal RAM (KB)": "624", "Weights Flash (KB)": "885", "Inference Time (ms)": "9"},
{"Resolution": "320x320x3", "Internal RAM (KB)": "1125", "Weights Flash (KB)": "895", "Inference Time (ms)": "13"},
{"Resolution": "416x416x3", "Internal RAM (KB)": "2676", "Weights Flash (KB)": "904", "Inference Time (ms)": "21"},
]
rows = [{**base, **v} for v in variants]
# Build header once (grouped + variant columns)
columns = grouped_cols + variant_cols
header = html.Thead(html.Tr([html.Th(col) for col in columns]))
# Group rows by repeated (grouped) columns to use rowSpan
groups = {}
for r in rows:
key = tuple(r.get(c, "-") for c in grouped_cols)
groups.setdefault(key, []).append(r)
body_rows = []
for key, group_rows in groups.items():
span = len(group_rows)
for i, r in enumerate(group_rows):
tds = []
# Emit grouped cells once with rowSpan
if i == 0:
for j, c in enumerate(grouped_cols):
tds.append(html.Td(r.get(c, "-"), rowSpan=span))
# Emit variant cells for every row
for c in variant_cols:
tds.append(html.Td(r.get(c, "-")))
body_rows.append(html.Tr(tds))
body = html.Tbody(body_rows)
return dbc.Card([
dbc.CardHeader(
html.Div([
html.Span("ST Yolo X Quantized Model Performances overview"),
dbc.Badge("STM32 Reference", color="secondary", pill=True, className="ms-2")
], style={"display": "flex", "alignItems": "center", "gap": "8px"}),
style={"fontWeight": "bold", "fontSize": "1.1rem", "background": "#f8f9fa"}
),
dbc.CardBody([
html.P(
" ST Yolo X Quantized Performances on STM32N6. Values depend on resolution and quantization. Fill with STM32Cube.AI 3.0.0 results.",
style={"color": "#525A63", "fontSize": "0.95rem", "marginBottom": "0.75rem"}
),
dbc.Table([header, body], bordered=True, hover=True, responsive=True, className="table-sm"),
])
], style={
"boxShadow": "0 2px 8px rgba(50,50,80,0.08)",
"border": "1px solid #A6ADB5",
"marginTop": "0.5rem",
"marginBottom": "0.5rem",
})
def render_object_detection_page():
card_height = "min(760px, calc(100vh - 240px))"
return [
html.H2("Object Detection Prediction", className="section-title"),
html.P("Upload an image to run object detection.", className="section-subtitle"),
render_showcase_info_card(),
dbc.Row([
dbc.Col([
dbc.Card([
dbc.CardHeader("Image Upload & Preview", style={
"fontWeight": "bold", "fontSize": "1.1rem", "background": "#f8f9fa"
}),
dbc.CardBody([
html.Div(
id="upload-preview-box",
style={
"flex": 1,
"display": "flex",
"flexDirection": "column",
"justifyContent": "center",
"alignItems": "center",
"overflow": "hidden",
"padding": "0.25rem"
}
),
html.Div([
dbc.Button(
"Launch Prediction",
id="predict-btn",
color="primary",
style={
"marginBottom": "1rem",
"boxShadow": "0 4px 16px rgba(60,180,230,0.18), 0 1.5px 4px rgba(50,50,80,0.10)",
"marginRight": "1rem"
}
),
dbc.Button(
"Reset image",
id="reset-image-btn",
color="secondary",
style={"display": "none", "marginBottom": "1rem"}
)
], style={
"display": "flex",
"flexDirection": "row",
"justifyContent": "center",
"alignItems": "center"
}),
html.Div([
html.Div(id="conf-thresh-value", style={"fontWeight": 600, "marginBottom": "0.25rem", "textAlign": "center"}),
dcc.Slider(
id="conf-thresh-slider",
min=0.0,
max=1.0,
step=0.01,
value=0.5,
tooltip={"placement": "bottom", "always_visible": False},
marks={0.0: "0.0", 0.5: "0.5", 1.0: "1.0"}
),
html.Div("Adjust confidence threshold; changing it re-runs prediction.", style={"fontSize": "0.85rem", "color": "#6c757d", "marginTop": "0.25rem", "textAlign": "center"})
], style={"marginTop": "0.5rem", "padding": "0.25rem 0.25rem 0 0.25rem"})
], style={
"overflowY": "auto",
"padding": "1rem",
"height": "100%",
"display": "flex",
"flexDirection": "column",
"boxSizing": "border-box"
}),
], style={
"boxShadow": "0 2px 8px rgba(50,50,80,0.08)",
"border": "1px solid #A6ADB5",
"height": "800px",
"minHeight": "800px",
"maxHeight": "800px",
"display": "flex",
"flexDirection": "column"
})
], width=6, style={
"display": "flex",
"flexDirection": "column",
"height": card_height,
"minHeight": card_height,
"maxHeight": card_height,
"boxSizing": "border-box"
}),
dbc.Col([
dbc.Card([
dbc.CardHeader("Prediction Logs", style={
"fontWeight": "bold", "fontSize": "1.1rem", "background": "#f8f9fa"
}),
dbc.CardBody([
html.Div(id="progress-bar-box"),
dcc.Loading(
id="predict-loading",
type="circle",
color="#3CB4E6",
fullscreen=False,
children=html.Div(
id="predict-output",
style={
"flex": 1,
"height": "100%",
"marginBottom": "1rem",
"background": "#f8f9fa",
"borderRadius": "8px",
"padding": "0.5rem",
"boxSizing": "border-box",
"display": "flex",
"flexDirection": "column",
"alignItems": "center",
"justifyContent": "center",
"overflowY": "auto",
"overflowX": "hidden"
}
),
style={"flex": 1}
),
dbc.Button(
"Download Prediction",
id="download-prediction-btn",
color="secondary",
style={"display":"none", "marginTop": "0.25rem", "marginBottom": "0.5rem"}
),
dcc.Download(id="download-prediction"),
dcc.Interval(id="predict-interval", interval=300, n_intervals=0, disabled=True)
], style={
"overflow": "hidden",
"padding": "1rem",
"height": "100%",
"display": "flex",
"flexDirection": "column"
}),
], style={
"boxShadow": "0 2px 8px rgba(50,50,80,0.08)",
"border": "1px solid #A6ADB5",
"height": card_height,
"minHeight": card_height,
"display": "flex",
"flexDirection": "column"
})
], width=6, style={
"display": "flex",
"flexDirection": "column",
"height": card_height,
"minHeight": card_height
})
], justify="center", align="stretch", style={"marginTop": "2rem", "alignItems": "stretch"})
]
PREDICT_JOB = {"process": None, "result": None, "tail": "", "finished": False, "image": None, "orig_size": None, "target_basename": None}
PREDICT_RESIZE_MODE = os.getenv("PREDICT_RESIZE_MODE", "stretch")
ACC_LOSS_BASE_LAYOUT = {
'uirevision': 'metrics-static',
'xaxis': {
'title': 'Epochs',
'showgrid': True,
'gridcolor': '#EEEFF1',
'linecolor': '#A6ADB5',
'linewidth': 1,
'fixedrange': True
},
'plot_bgcolor': '#fff',
'paper_bgcolor': '#fff',
'margin': {'l': 60, 'r': 40, 't': 10, 'b': 60},
'hovermode': 'x unified',
'showlegend': True,
'legend': {'bgcolor': 'rgba(0,0,0,0)'}
}
def launch_prediction_async(image_path, target_basename=None):
env = os.environ.copy()
env["OMP_NUM_THREADS"] = "1"
env["MKL_NUM_THREADS"] = "1"
env["NUMEXPR_NUM_THREADS"] = "1"
env.setdefault("HYDRA_FULL_ERROR", "1")
od_cwd = os.path.join(os.path.dirname(os.path.abspath(__file__)), "object_detection")
process = subprocess.Popen(
[sys.executable, "stm32ai_main.py", "--config-name", "user_config"],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
env=env,
cwd=od_cwd
)
output_lines = []
for line in process.stdout:
output_lines.append(line)
# Live tail update (last 80 lines)
tail = "".join(output_lines[-80:])
PREDICT_JOB["tail"] = tail
process.wait()
PREDICT_JOB["result"] = "".join(output_lines)
pred_image = None
# Locate the latest predictions folder.
# OD config typically uses hydra.run.dir: ./pt/src/experiments_outputs/${now:...}
exp_roots = []
try:
yaml_cfg_path = os.path.join("object_detection", "user_config.yaml")
with open(yaml_cfg_path, "r", encoding="utf-8") as f:
ocfg = yaml.safe_load(f) or {}
run_dir = (ocfg.get("hydra") or {}).get("run", {})
run_dir_str = run_dir.get("dir") if isinstance(run_dir, dict) else None
if isinstance(run_dir_str, str) and run_dir_str:
run_dir_str = run_dir_str.replace("\\", "/")
if "${now:" in run_dir_str:
run_dir_str = os.path.dirname(run_dir_str)
if run_dir_str.startswith("./"):
exp_roots.append(os.path.join("object_detection", run_dir_str[2:]))
else:
exp_roots.append(os.path.join("object_detection", run_dir_str) if not os.path.isabs(run_dir_str) else run_dir_str)
except Exception:
pass
# Fallbacks for common layouts
exp_roots.extend([
os.path.join("object_detection", "tf", "src", "experiments_outputs"),
os.path.join("object_detection", "src", "experiments_outputs"),
])
def _find_latest_prediction_image(exp_base: str):
if not exp_base or not os.path.exists(exp_base):
return None
subs = [d for d in os.listdir(exp_base)
if os.path.isdir(os.path.join(exp_base, d)) and re.match(r"^20\d{2}_\d{2}_\d{2}_\d{2}_\d{2}_\d{2}$", d)]
if not subs:
return None
subs.sort(reverse=True)
last_exp = os.path.join(exp_base, subs[0], "predictions")
if not os.path.exists(last_exp):
return None
if target_basename:
for ext in (".jpg", ".jpeg", ".png"):
candidate = os.path.join(last_exp, f"{target_basename}_predict{ext}")
if os.path.exists(candidate):
return candidate
imgs = [os.path.join(last_exp, f) for f in os.listdir(last_exp)
if f.lower().endswith((".jpg", ".jpeg", ".png"))]
if imgs:
return max(imgs, key=os.path.getmtime)
return None
for root in exp_roots:
pred_image = _find_latest_prediction_image(root)
if pred_image:
break
PREDICT_JOB["image"] = pred_image
PREDICT_JOB["finished"] = True
@app.callback(
[Output("predict-state", "data"),
Output("progress-interval", "disabled"),
Output("progress-interval", "n_intervals"),
Output("predict-output", "children", allow_duplicate=True),
Output("predict-image-store", "data")],
[Input("predict-btn", "n_clicks"), Input("retry-predict-btn", "n_clicks"), Input("conf-thresh-slider", "value")],
State("stored-image", "data"),
State("input-image-size", "data"),
State("is-original-host", "data"),
prevent_initial_call=True
)
def start_prediction(n, n_retry, conf_thresh, stored, input_size, is_original):
ctx = dash.callback_context
triggered = ctx.triggered[0]["prop_id"] if ctx.triggered else ""
n = n or 0
n_retry = n_retry or 0
slider_changed = triggered.startswith("conf-thresh-slider.")
if (n + n_retry) == 0 and not slider_changed:
return dash.no_update, True, dash.no_update, dash.no_update, dash.no_update
if is_original:
return "idle", True, 0, html.Div(
"Prediction disabled on reference Space. Duplicate to enable.",
style={"color": "#dc3545", "fontWeight": "bold"}
), None
if not stored or "contents" not in stored:
return "idle", True, 0, html.Div("Upload an image first.", style={"color": "#dc3545"}), None
_, b64data = stored["contents"].split(",", 1)
raw = base64.b64decode(b64data)
fname = f"{uuid.uuid4().hex}_{os.path.basename(stored.get('filename','input.jpg'))}"
# Sauvegarde copie locale
image_path = os.path.join("object_detection", "datasets", fname)
os.makedirs(os.path.dirname(image_path), exist_ok=True)
with open(image_path, "wb") as f:
f.write(raw)
# Copy into the folder used by the OD prediction pipeline (dataset.prediction_path)
yaml_cfg_path = os.path.join("object_detection", "user_config.yaml")
try:
with open(yaml_cfg_path, "r", encoding="utf-8") as yf:
ocfg = yaml.safe_load(yf) or {}
# Force OD into prediction mode for the UI workflow.
ocfg["operation_mode"] = "prediction"
dataset_cfg = ocfg.setdefault("dataset", {})
pred_dir = dataset_cfg.get("prediction_path")
# In the repo defaults, prediction_path can be absolute (Spaces) or empty.
# For UI-driven prediction, force it to a local relative folder.
if not isinstance(pred_dir, str) or not pred_dir or os.path.isabs(pred_dir):
pred_dir = "./datasets"
dataset_cfg["prediction_path"] = pred_dir
# Optional prediction section (no test_files_path needed in current services)
pred_cfg = ocfg.setdefault("prediction", {})
pred_cfg.setdefault("target", "host")
# Remove any legacy key to avoid Hydra Unknown attribute errors
if "test_files_path" in pred_cfg:
try:
del pred_cfg["test_files_path"]
except Exception:
pred_cfg.pop("test_files_path", None)
# Update postprocessing confidence threshold from UI slider
pp_cfg = ocfg.setdefault("postprocessing", {})
try:
val = float(conf_thresh) if conf_thresh is not None else 0.5
except Exception:
val = 0.5
pp_cfg["confidence_thresh"] = val
# Normalize to an absolute filesystem path for copying.
if pred_dir.startswith("./"):
pred_dir_abs = os.path.join("object_detection", pred_dir[2:])
else:
pred_dir_abs = os.path.join("object_detection", pred_dir) if not os.path.isabs(pred_dir) else pred_dir
os.makedirs(pred_dir_abs, exist_ok=True)
target_copy = os.path.join(pred_dir_abs, fname)
with open(target_copy, "wb") as cf:
cf.write(raw)
# Persist the updated path so the subprocess reads the correct directory.
with open(yaml_cfg_path, "w", encoding="utf-8") as yf:
yaml.dump(ocfg, yf, allow_unicode=True, sort_keys=False, default_flow_style=False)
except Exception as e:
print(f"[WARN] Unable to prepare prediction folder: {e}")
orig_size = None
if input_size and "width" in input_size and "height" in input_size:
orig_size = (int(input_size["width"]), int(input_size["height"]))
else:
try:
with Image.open(io.BytesIO(raw)) as im:
orig_size = im.size
except Exception:
orig_size = None
basename_no_ext = os.path.splitext(fname)[0]
PREDICT_JOB.update({
"process": None,
"result": None,
"tail": "",
"finished": False,
"image": None,
"orig_size": orig_size,
"target_basename": basename_no_ext
})
t = threading.Thread(target=launch_prediction_async, args=(image_path, basename_no_ext), daemon=True)
t.start()
return "running", False, 0, html.Div("Prediction launched...", style={"fontFamily": "monospace"}), None
def _resize_prediction_image(pred_path, orig_size, mode=None):
if not orig_size:
with open(pred_path, "rb") as f:
return base64.b64encode(f.read()).decode()
target_w, target_h = map(int, orig_size)
try:
with Image.open(pred_path) as im:
im = ImageOps.exif_transpose(im).convert("RGB")
out = im.resize((target_w, target_h), Image.Resampling.BILINEAR)
if out.size != (target_w, target_h):
out = out.resize((target_w, target_h), Image.Resampling.NEAREST)
buf = _io.BytesIO()
out.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode()
except Exception:
with open(pred_path, "rb") as f:
return base64.b64encode(f.read()).decode()
@app.callback(
[Output("predict-output", "children", allow_duplicate=True),
Output("predict-state", "data", allow_duplicate=True),
Output("progress-interval", "disabled", allow_duplicate=True),
Output("predict-image-store", "data", allow_duplicate=True)],
Input("progress-interval", "n_intervals"),
State("predict-state", "data"),
prevent_initial_call=True
)
def poll_prediction(n, state):
if state != "running":
return dash.no_update, dash.no_update, dash.no_update, dash.no_update
if not PREDICT_JOB["finished"]:
# Show a placeholder while running; progress bar is updated separately
placeholder = html.Div([
html.Div(className="img-skeleton")
], style={"width": "100%"})
return placeholder, "running", False, dash.no_update
blocks = []
b64img = None
if PREDICT_JOB.get("image") and os.path.exists(PREDICT_JOB["image"]):
b64img = _resize_prediction_image(PREDICT_JOB["image"], PREDICT_JOB.get("orig_size"), PREDICT_RESIZE_MODE)
blocks.append(
html.Div(
className="predict-image",
children=html.Img(
src=f"data:image/png;base64,{b64img}",
style={
"borderRadius": "8px",
"marginBottom": "1rem"
},
alt="Prediction Image"
),
style={"width": "100%", "height": "100%"}
)
)
else:
blocks.append(html.Div("No prediction image produced.", style={"color": "#dc3545", "fontWeight": "bold", "marginBottom": "0.5rem"}))
blocks.append(dbc.Button("Retry", id="retry-predict-btn", color="warning", outline=True, style={"marginBottom": "0.5rem"}))
if PREDICT_JOB.get("result"):
tail = "".join(PREDICT_JOB["result"].splitlines(True)[-80:])
blocks.append(html.Pre(tail, style={
"margin": 0,
"marginTop": "0.75rem",
"whiteSpace": "pre-wrap",
"fontFamily": "monospace",
"background": "#f8f9fa",
"padding": "0.75rem",
"borderRadius": "6px",
"maxHeight": "280px",
"overflowY": "auto"
}))
return html.Div(blocks), "done", True, b64img
@app.callback(
Output("conf-thresh-value", "children"),
Input("conf-thresh-slider", "value"),
prevent_initial_call=False
)
def update_conf_label(v):
try:
val = float(v)
except Exception:
val = 0.50
return f"Confidence threshold: {val:.2f} (NMS 0.5, max boxes 10)"
@app.callback(
[Output("stored-image", "data"),
Output("input-image-size", "data")],
[
Input("upload-image", "contents"),
Input("upload-image", "filename"),
Input("reset-image-btn", "n_clicks")
],
prevent_initial_call=True
)
def update_stored_image(contents, filename, reset_click):
ctx = dash.callback_context
if not ctx.triggered:
return dash.no_update, dash.no_update
trigger = ctx.triggered[0]["prop_id"].split(".")[0]
if trigger == "reset-image-btn" and reset_click:
return None, None
if trigger == "upload-image" and contents and filename:
try:
header, b64data = contents.split(",", 1)
raw = base64.b64decode(b64data)
with Image.open(io.BytesIO(raw)) as img:
width, height = img.size
return {"contents": contents, "filename": filename}, {"width": width, "height": height}
except Exception as e:
return {"contents": contents, "filename": filename}, None
return dash.no_update, dash.no_update
@app.callback(
[Output("predict-image-store", "data", allow_duplicate=True),
Output("predict-output", "children", allow_duplicate=True),
Output("predict-state", "data", allow_duplicate=True)],
Input("reset-image-btn", "n_clicks"),
prevent_initial_call=True
)
def reset_prediction(n):
if not n:
return dash.no_update, dash.no_update, dash.no_update
PREDICT_JOB.update({
"process": None,
"result": None,
"finished": False,
"image": None,
"orig_size": None,
"target_basename": None
})
return None, html.Div(), "idle"
@app.callback(
Output("upload-preview-box", "children"),
Input("stored-image", "data"),
prevent_initial_call=False
)
def update_upload_preview_box(data):
show_upload = not (data and "contents" in data)
children = []
children.append(
dcc.Upload(
id="upload-image",
children=html.Div(["Drag and Drop or ", html.A("Select Files")]),
style={
"width": "100%",
"height": "120px",
"lineHeight": "120px",
"borderWidth": "2px",
"borderStyle": "dashed",
"borderRadius": "8px",
"textAlign": "center",
"marginBottom": "1rem",
"background": "#fff",
"color": "#525A63",
"display": "block" if show_upload else "none"
},
multiple=False
)
)
if data and "contents" in data:
children.append(
html.Img(
src=data["contents"],
style={
"maxWidth": "100%",
"maxHeight": "350px",
"display": "block",
"marginLeft": "auto",
"marginRight": "auto",
"borderRadius": "8px",
"boxShadow": "0 2px 8px rgba(50,50,80,0.08)"
}
)
)
btn_style = {"display": "inline-block", "marginBottom": "1rem"} if data and "contents" in data else {"display": "none", "marginBottom": "1rem"}
children.append(
dbc.Button(
"Reset image",
id="reset-image-btn",
color="secondary",
style=btn_style
)
)
return html.Div(children, style={"width": "100%", "textAlign": "center", "padding": "1rem 0"})
@app.callback(
Output("reset-image-btn", "style"),
Input("stored-image", "data"),
prevent_initial_call=False
)
def toggle_reset_btn(data):
if data and "contents" in data:
return {"display": "inline-block", "marginBottom": "1rem"}
return {"display": "none", "marginBottom": "1rem"}
@app.callback(
Output("predict-btn", "disabled"),
[Input("stored-image", "data"), Input("predict-state", "data")],
prevent_initial_call=False
)
def disable_predict_btn(data, state):
has_image = bool(data and "contents" in data)
return (not has_image) or (state == "running")
@app.callback(
Output("progress-bar-box", "children"),
[Input("predict-state", "data"),
Input("progress-interval", "n_intervals")]
)
def update_progress_bar(state, n):
if state == "running":
percent = min(100, n * 10)
return dbc.Progress(f"{percent}%", value=percent, striped=True, animated=True, color="info", style={"height": "20px"})
elif state == "done":
return dbc.Progress("100%", value=100, striped=False, animated=False, color="success", style={"height": "20px"})
else:
return dbc.Progress("0%", value=0, striped=False, animated=False, color="light", style={"height": "20px"})
@app.callback(
[Output("predict-progress", "value"),
Output("predict-progress", "label"),
Output("predict-progress", "animated"),
Output("predict-progress", "striped"),
Output("predict-live-logs", "children")],
[Input("progress-interval", "n_intervals"),
Input("predict-state", "data")],
prevent_initial_call=False
)
def update_progress_and_logs(n, state):
if state == "running":
percent = min(95, (n + 1) * 5)
tail = PREDICT_JOB.get("tail", "")
return percent, f"Running... {percent}%", True, True, tail
elif state == "done":
tail = PREDICT_JOB.get("result", "")
tail_last = "".join(tail.splitlines(True)[-80:]) if tail else ""
return 100, "Done", False, False, tail_last
else:
return 0, "", False, False, ""
@app.callback(
Output("download-prediction", "data"),
Input("download-prediction-btn", "n_clicks"),
State("predict-image-store", "data"),
prevent_initial_call=True
)
def download_prediction(n, b64img):
if not n or not b64img:
return dash.no_update
try:
binary = base64.b64decode(b64img)
except Exception:
return dash.no_update
filename = f"prediction_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
return dcc.send_bytes(lambda buff: buff.write(binary), filename)
@app.callback(
Output("download-prediction-btn", "style"),
[Input("predict-image-store", "data"),
Input("sidebar-mode", "data")],
prevent_initial_call=False
)
def toggle_download_prediction_btn(data, mode):
style_train = {
"marginTop": "0.75rem",
"alignSelf": "center"
}
style_predict = {
"marginTop": "0.25rem",
"marginBottom": "0.5rem"
}
base = style_train if mode == "tab-train" else style_predict
if data:
return base
hidden = base.copy()
hidden["display"] = "none"
return hidden
@app.callback(
Output("logs-collapse", "is_open"),
Input("toggle-logs-btn", "n_clicks"),
State("logs-collapse", "is_open"),
prevent_initial_call=True
)
def toggle_logs(n, is_open):
if n:
return not is_open
return is_open
##########SIDEBAR##########
@app.callback(
[Output("netron-modal", "is_open"),
Output("netron-iframe", "src")],
[Input("sidebar-netron", "n_clicks"),
Input("close-netron-modal", "n_clicks")],
[State("netron-modal", "is_open"),
State("sidebar-mode", "data")],
prevent_initial_call=True
)
def toggle_netron_modal(n_open, n_close, is_open, tab):
ctx = dash.callback_context
triggered = ctx.triggered[0]["prop_id"] if ctx.triggered else ""
if triggered.startswith("sidebar-netron") and n_open:
def build_netron_url(relative_path: str) -> str:
base_raw = "https://raw.githubusercontent.com/STMicroelectronics/stm32ai-modelzoo/main/"
full = base_raw + relative_path.lstrip('/')
return f"https://netron.app/?url={urllib.parse.quote(full, safe=':/') }"
if tab == "tab-predict":
url = build_netron_url(
"image_classification/mobilenetv2/ST_pretrainedmodel_public_dataset/tf_flowers/mobilenetv2_a035_224_fft/mobilenetv2_a035_224_fft_int8.tflite"
)
elif tab == "tab-train":
url = build_netron_url(
"object_detection/st_yolodv2tiny_pt/ST_pretrainedmodel_public_dataset/coco_person/st_yoloxn_d033_w025_320_int8.tflite"
)
else:
url = ""
return True, url
elif triggered.startswith("close-netron-modal") and n_close:
return False, dash.no_update
return is_open, dash.no_update
@app.callback(
Output("yaml-modal", "is_open"),
Output("yaml-modal-body", "children"),
[Input("sidebar-yaml", "n_clicks"), Input("close-yaml-modal", "n_clicks")],
[State("yaml-modal", "is_open"), State("sidebar-mode", "data")],
prevent_initial_call=True
)
def toggle_yaml_modal(n_yaml, n_close, is_open, mode):
# Only open modal if sidebar-yaml button is actually clicked
ctx = dash.callback_context
triggered = ctx.triggered[0]["prop_id"] if ctx.triggered else ""
if triggered.startswith("sidebar-yaml") and n_yaml:
# Correction du mapping YAML
if mode == "tab-train":
yaml_path = "object_detection/user_config.yaml"
elif mode == "tab-predict":
yaml_path = "image_classification/user_config.yaml"
else:
return is_open, dash.no_update
try:
with open(yaml_path, "r", encoding="utf-8") as f:
content = f.read()
return True, html.Pre(content, style={"whiteSpace": "pre-wrap", "fontFamily": "monospace", "fontSize": "1rem"})
except Exception as e:
return True, html.Div(f"Error reading YAML: {e}", style={"color": "#A6ADB5"})
elif triggered.startswith("close-yaml-modal") and n_close:
return False, dash.no_update
return is_open, dash.no_update
@app.callback(
Output("dataset-modal", "is_open"),
Output("dataset-modal-body", "children"),
[Input("sidebar-visualize-dataset", "n_clicks"), Input("close-dataset-modal", "n_clicks")],
[State("dataset-modal", "is_open")],
prevent_initial_call=True
)
def toggle_dataset_modal(n_open, n_close, is_open):
ctx = dash.callback_context
triggered = ctx.triggered[0]["prop_id"] if ctx.triggered else ""
if triggered.startswith("sidebar-visualize-dataset") and n_open:
path = 'image_classification/datasets/flower_photos'
if not os.path.exists(path):
return True, html.Div("The dataset folder does not exist. Please check the Dockerfile.")
class_names = sorted([name for name in os.listdir(path) if os.path.isdir(os.path.join(path, name))])
num_classes = len(class_names)
if num_classes == 0:
return True, html.Div("No class found in the dataset.")
fig, axs = plt.subplots(1, num_classes, figsize=(4*num_classes, 4))
for i, class_name in enumerate(class_names):
class_path = os.path.join(path, class_name)
image_files = [f for f in os.listdir(class_path) if os.path.isfile(os.path.join(class_path, f)) and f.endswith('.jpg')]
if len(image_files) == 0:
img = np.zeros((224, 224, 3))
else:
img_path = os.path.join(class_path, random.choice(image_files))
img = mpimg.imread(img_path)
axs[i].imshow(img)
axs[i].set_title(class_name)
axs[i].axis('off')
buf = io.BytesIO()
plt.tight_layout()
plt.savefig(buf, format='png')
plt.close(fig)
buf.seek(0)
img_base64 = base64.b64encode(buf.read()).decode('utf-8')
img_html = html.Img(src=f'data:image/png;base64,{img_base64}', style={"maxWidth": "100%", "height": "auto"})
return True, img_html
elif triggered.startswith("close-dataset-modal") and n_close:
return False, dash.no_update
return is_open, dash.no_update
############ TRAINING ############
# --- Training page content (simple, unique model/dataset, real-time logs) ---
def render_model_perf_ic():
rows = [
{
"Series": "STM32H7",
"Format": "INT8",
"Resolution": "224x224",
"Activation RAM (KB)": "699",
"Weights Flash (KB)": "407",
"Inference Time (ms)": "309"
},
{
"Series": "STM32N6570-DK",
"Format": "INT8",
"Resolution": "224x224",
"Activation RAM (KB)": "913",
"Weights Flash (KB)": "436",
"Inference Time (ms)": "5.3"
},
]
header = html.Thead(
html.Tr([
html.Th("Series"),
html.Th("Format"),
html.Th("Resolution"),
html.Th("Activation RAM (KB)"),
html.Th("Weights Flash (KB)"),
html.Th("Inference Time (ms)")
])
)
body = html.Tbody([
html.Tr([
html.Td(r["Series"]),
html.Td(r["Format"]),
html.Td(r["Resolution"]),
html.Td(r["Activation RAM (KB)"]),
html.Td(r["Weights Flash (KB)"]),
html.Td(r["Inference Time (ms)"])
]) for r in rows
])
return dbc.Card([
dbc.CardHeader(
html.Div([
html.Span("MobileNetv2 Model Performances overview"),
dbc.Badge("STM32 Reference", color="secondary", pill=True, className="ms-2")
], style={"display": "flex", "alignItems": "center", "gap": "8px"}),
style={"fontWeight": "bold", "fontSize": "1.1rem", "background": "#f8f9fa"}
),
dbc.CardBody([
html.P(
"MobileNetv2 Performances on STM32N6. Values depend on resolution and quantization. Fill with STM32Cube.AI 3.0.0 results.",
style={"color": "#525A63", "fontSize": "0.95rem", "marginBottom": "0.75rem"}
),
dbc.Table([header, body], bordered=True, hover=True, responsive=True, className="table-sm"),
])
], style={
"boxShadow": "0 2px 8px rgba(50,50,80,0.08)",
"border": "1px solid #A6ADB5",
"marginTop": "0.5rem",
"marginBottom": "0.5rem",
})
def render_showcase_info_card():
return html.Div(
[
html.Span([
"This is a showcase of the STM32AI Model Zoo. "
"To go further (training, deployment on STM32, scripts and tools), ",
html.A(
"visit the STM32AI Model Zoo Services repository",
href=ST_MODELZOO_SERVICES_REPO_URL,
target="_blank",
className="info-link"
),
"."
])
],
style={
"background": "#f5f8fc",
"border": "1px solid #e5edf5",
"color": "#425466",
"fontSize": "0.95rem",
"padding": "0.6rem 0.9rem",
"borderRadius": "10px",
"marginTop": "0.5rem",
"marginBottom": "0.5rem",
}
)
TRAIN_LOG_QUEUE = queue.Queue()
TRAIN_PROCESS = {"thread": None, "finished": False}
def _load_ic_training_params(yaml_path: str, *, default_epochs: int = 20, default_batch_size: int = 32, default_learning_rate: float = 0.001):
epochs = default_epochs
batch_size = default_batch_size
learning_rate = default_learning_rate
try:
with open(yaml_path, "r", encoding="utf-8") as f:
config = yaml.safe_load(f) or {}
training = config.get("training") or {}
epochs = int(training.get("epochs", epochs))
batch_size = int(training.get("batch_size", batch_size))
lr_nested = (training.get("optimizer") or {}).get("Adam", {}).get("learning_rate")
if lr_nested is not None:
learning_rate = float(lr_nested)
else:
learning_rate = float(training.get("learning_rate", learning_rate))
except Exception:
pass
return epochs, batch_size, learning_rate
def _save_ic_training_params(yaml_path: str, epochs, batch_size, learning_rate):
with open(yaml_path, "r", encoding="utf-8") as f:
config = yaml.safe_load(f) or {}
training = config.setdefault('training', {})
training['epochs'] = int(epochs)
training['batch_size'] = int(batch_size)
if 'learning_rate' in training:
del training['learning_rate']
optimizer = training.setdefault('optimizer', {})
adam = optimizer.setdefault('Adam', {})
adam['learning_rate'] = float(learning_rate)
if 'model' in config and config['model']:
if 'name' in config['model'] and 'model_name' not in config['model']:
config['model']['model_name'] = config['model'].pop('name')
with open(yaml_path, "w", encoding="utf-8") as f:
yaml.dump(config, f, allow_unicode=True, sort_keys=False, default_flow_style=False)
def _env_int(name: str, default: int, *, min_value: int | None = None, max_value: int | None = None) -> int:
raw = os.getenv(name)
try:
val = int(str(raw).strip()) if raw is not None and str(raw).strip() != "" else int(default)
except Exception:
val = int(default)
if min_value is not None:
val = max(min_value, val)
if max_value is not None:
val = min(max_value, val)
return val
# Dataset upload guards (configurable via env vars).
IC_MAX_ARCHIVE_MB = _env_int("IC_MAX_ARCHIVE_MB", 250, min_value=1, max_value=50_000)
IC_MAX_EXTRACT_MB = _env_int("IC_MAX_EXTRACT_MB", 1000, min_value=1, max_value=200_000)
IC_MAX_VERIFY_IMAGES = _env_int("IC_MAX_VERIFY_IMAGES", 250, min_value=0, max_value=200_000)
def _safe_extract_zip_to_dir(zf: zipfile.ZipFile, dst_dir: str) -> None:
os.makedirs(dst_dir, exist_ok=True)
for info in zf.infolist():
name = info.filename
if not name or name.endswith("/"):
continue
# Normalize and block Zip Slip / absolute paths.
norm = os.path.normpath(name).replace("\\", "/")
if norm.startswith("../") or norm.startswith("..\\"):
raise ValueError("Invalid zip entry (path traversal)")
if os.path.isabs(norm) or re.match(r"^[A-Za-z]:", norm):
raise ValueError("Invalid zip entry (absolute path)")
out_path = os.path.join(dst_dir, norm)
out_path_dir = os.path.dirname(out_path)
if out_path_dir:
os.makedirs(out_path_dir, exist_ok=True)
with zf.open(info) as src, open(out_path, "wb") as dst:
shutil.copyfileobj(src, dst)
def _safe_extract_tar_to_dir(tf: tarfile.TarFile, dst_dir: str) -> None:
os.makedirs(dst_dir, exist_ok=True)
for member in tf.getmembers():
name = member.name
if not name:
continue
norm = os.path.normpath(name).replace("\\", "/")
if norm.startswith("../") or norm.startswith("..\\"):
raise ValueError("Invalid tar entry (path traversal)")
if os.path.isabs(norm) or re.match(r"^[A-Za-z]:", norm):
raise ValueError("Invalid tar entry (absolute path)")
# Only extract regular files and directories.
if not (member.isfile() or member.isdir()):
continue
out_path = os.path.join(dst_dir, norm)
out_path_dir = os.path.dirname(out_path)
if out_path_dir:
os.makedirs(out_path_dir, exist_ok=True)
if member.isdir():
os.makedirs(out_path, exist_ok=True)
continue
fobj = tf.extractfile(member)
if fobj is None:
continue
with fobj as src, open(out_path, "wb") as dst:
shutil.copyfileobj(src, dst)
def _prepare_ic_uploaded_dataset(contents: str, filename: str, *, user_id: str, dataset_name: str) -> tuple[str, list[str]]:
if not contents:
raise ValueError("No upload contents")
if "," not in contents:
raise ValueError("Unexpected upload encoding")
header, b64 = contents.split(",", 1)
fname_l = (filename or "").lower()
is_zip = fname_l.endswith(".zip")
is_tar = fname_l.endswith(".tar") or fname_l.endswith(".tar.gz") or fname_l.endswith(".tgz")
if not (is_zip or is_tar):
raise ValueError("Please upload a .zip, .tar, .tar.gz or .tgz dataset")
raw = base64.b64decode(b64)
# Basic size guard (avoid decompression bombs / huge uploads).
max_archive_bytes = IC_MAX_ARCHIVE_MB * 1024 * 1024
if len(raw) > max_archive_bytes:
raise ValueError(f"Archive too large (max {IC_MAX_ARCHIVE_MB}MB)")
safe_name = re.sub(r"[^a-zA-Z0-9_-]+", "_", (dataset_name or "custom_dataset").strip())[:64] or "custom_dataset"
# Store directly under image_classification/datasets/<dataset_name>
base_dir = os.path.join("image_classification", "datasets", safe_name)
if os.path.isdir(base_dir):
shutil.rmtree(base_dir, ignore_errors=True)
os.makedirs(base_dir, exist_ok=True)
# Guard extracted total size (best-effort).
max_total = IC_MAX_EXTRACT_MB * 1024 * 1024
if is_zip:
with zipfile.ZipFile(io.BytesIO(raw)) as zf:
total = sum((i.file_size or 0) for i in zf.infolist())
if total > max_total:
raise ValueError(f"Archive expands to > {IC_MAX_EXTRACT_MB}MB")
_safe_extract_zip_to_dir(zf, base_dir)
else:
with tarfile.open(fileobj=io.BytesIO(raw), mode="r:*") as tf:
total = sum((m.size or 0) for m in tf.getmembers() if m.isfile())
if total > max_total:
raise ValueError(f"Archive expands to > {IC_MAX_EXTRACT_MB}MB")
_safe_extract_tar_to_dir(tf, base_dir)
# Ignore macOS metadata folder if present.
macosx_dir = os.path.join(base_dir, "__MACOSX")
if os.path.isdir(macosx_dir):
shutil.rmtree(macosx_dir, ignore_errors=True)
# Detect if there is a single root folder (e.g. photos_test/roses, ...)
entries = [e for e in os.listdir(base_dir) if not e.startswith('.')]
dirs = [e for e in entries if os.path.isdir(os.path.join(base_dir, e))]
files = [e for e in entries if os.path.isfile(os.path.join(base_dir, e))]
# If there is a single directory and no files, and that directory contains at least 2 subfolders, move its contents up
if len(dirs) == 1 and not files:
single_dir = os.path.join(base_dir, dirs[0])
subdirs = [d for d in os.listdir(single_dir) if os.path.isdir(os.path.join(single_dir, d)) and not d.startswith('.')]
if len(subdirs) >= 2:
# Move all subdirs and files up
for name in os.listdir(single_dir):
src = os.path.join(single_dir, name)
dst = os.path.join(base_dir, name)
shutil.move(src, dst)
shutil.rmtree(single_dir, ignore_errors=True)
dataset_root = base_dir
class_names = [d for d in os.listdir(dataset_root)
if os.path.isdir(os.path.join(dataset_root, d)) and not d.startswith('.')]
class_names = sorted(class_names)
if len(class_names) < 2:
raise ValueError("Dataset must contain at least 2 class folders")
training_path = dataset_root.replace("\\", "/")
return training_path, class_names
def _analyze_ic_dataset(dataset_root: str) -> dict:
root = (dataset_root or "").replace("/", os.sep)
if not root or not os.path.isdir(root):
raise ValueError("Dataset path not found")
allowed_ext = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"}
class_dirs = [d for d in os.listdir(root)
if os.path.isdir(os.path.join(root, d)) and not d.startswith('.')]
class_dirs = sorted(class_dirs)
per_class = []
total = 0
ext_counts = {}
sample_paths = []
corrupted = 0
checked = 0
corrupted_examples = []
max_verify = IC_MAX_VERIFY_IMAGES
max_examples = 5
for class_name in class_dirs:
class_path = os.path.join(root, class_name)
count = 0
for dirpath, _, filenames in os.walk(class_path):
for fn in filenames:
ext = os.path.splitext(fn)[1].lower()
if ext in allowed_ext:
count += 1
ext_counts[ext] = ext_counts.get(ext, 0) + 1
fpath = os.path.join(dirpath, fn)
if len(sample_paths) < 2000:
sample_paths.append(fpath)
if checked < max_verify:
try:
with Image.open(fpath) as im:
im.verify()
except Exception:
corrupted += 1
if len(corrupted_examples) < max_examples:
corrupted_examples.append(os.path.relpath(fpath, root).replace("\\", "/"))
checked += 1
per_class.append((class_name, count))
total += count
counts = [c for _, c in per_class]
stats = {
"num_classes": len(class_dirs),
"total_images": total,
"per_class": per_class,
"ext_counts": dict(sorted(ext_counts.items(), key=lambda kv: (-kv[1], kv[0]))),
"corrupted_checked": checked,
"corrupted_count": corrupted,
"corrupted_examples": corrupted_examples,
"sample_paths": sample_paths,
"min_images": min(counts) if counts else 0,
"max_images": max(counts) if counts else 0,
"avg_images": (float(total) / len(counts)) if counts else 0.0,
}
return stats
def _make_thumbnail_b64(image_path: str, *, size: int = 128) -> str | None:
try:
with Image.open(image_path) as im:
im = im.convert("RGB")
im = ImageOps.fit(im, (size, size), method=Image.Resampling.LANCZOS)
buf = io.BytesIO()
im.save(buf, format="JPEG", quality=85)
return base64.b64encode(buf.getvalue()).decode("utf-8")
except Exception:
return None
@app.callback(
Output("page-content", "children"),
Input("sidebar-mode", "data"),
State("stm32ai-credentials", "data"),
State("page-content", "children"),
State("acc-fig-store", "data"),
State("loss-fig-store", "data"),
State("predict-image-store", "data")
)
def render_page_content(tab_value, creds, current_content, acc_fig_store, loss_fig_store, predict_image_store):
if tab_value == "tab-train":
# Object detection (prediction) page
card_height = "650px"
pred_child = None
if predict_image_store:
pred_child = html.Img(
src=f"data:image/png;base64,{predict_image_store}",
style={"maxWidth": "100%", "borderRadius": "8px"},
alt="Prediction image"
)
return [
html.H2("Object Detection - ST Yolo X", className="section-title"),
html.P(
[
html.A(
"ST Yolo X ",
href="https://github.com/STMicroelectronics/stm32ai-modelzoo/tree/main/object_detection/st_yoloxn",
target="_blank",
className="info-link"
),
"is a real-time object detection model targeted for real-time processing implemented in Tensorflow. This is an optimized ST version of the well known yolo x, quantized in int8 format using tensorflow lite converter. Upload an image and run a prediction.",
],
className="section-subtitle"
),
render_showcase_info_card(),
render_model_perf_od(),
dbc.Row([
dbc.Col([
dbc.Card([
dbc.CardHeader("Image Upload & Preview", style={"fontWeight": "bold", "fontSize": "1.1rem", "background": "#f8f9fa"}),
dbc.CardBody([
html.Div(id="upload-preview-box", style={
"flex": 1, "display": "flex", "flexDirection": "column",
"justifyContent": "flex-start", "alignItems": "center",
"overflow": "hidden", "padding": "0.5rem", "width": "100%", "minHeight": 0
}),
dbc.Button("Predict", id="predict-btn", color="info",
style={"marginTop": "0.5rem", "marginBottom": "0.5rem", "alignSelf": "center"}),
html.Div([
html.Div(id="conf-thresh-value", style={"fontWeight": 600, "marginBottom": "0.25rem", "textAlign": "center"}),
dcc.Slider(
id="conf-thresh-slider",
min=0.0,
max=1.0,
step=0.01,
value=0.5,
tooltip={"placement": "bottom", "always_visible": False},
marks={0.0: "0.0", 0.5: "0.5", 1.0: "1.0"}
),
html.Div("Adjust confidence threshold; changing it re-runs prediction.", style={"fontSize": "0.85rem", "color": "#6c757d", "marginTop": "0.25rem", "textAlign": "center"})
], style={"marginTop": "0.5rem", "padding": "0.25rem 0.25rem 0 0.25rem"})
], style={"display": "flex", "flexDirection": "column", "flex": 1, "padding": "1rem", "minHeight": 0})
], style={
"boxShadow": "0 2px 8px rgba(50,50,80,0.08)",
"border": "1px solid #A6ADB5",
"height": "100%", "minHeight": card_height, "maxHeight": card_height,
"display": "flex", "flexDirection": "column"
})
], width=6, style={"display": "flex", "flexDirection": "column", "height": card_height, "minHeight": card_height, "maxHeight": card_height}),
dbc.Col([
dbc.Card([
dbc.CardHeader("Prediction Output", style={"fontWeight": "bold", "fontSize": "1.1rem", "background": "#f8f9fa"}),
dbc.CardBody([
html.Div([
dbc.Progress(id="predict-progress", value=0, max=100, striped=False, animated=False, color="info", style={"height": "20px", "width": "100%"}),
dbc.Button("View Logs", id="toggle-logs-btn", color="link", style={"padding": 0, "marginLeft": "0.5rem"})
], style={"display": "flex", "alignItems": "center", "gap": "8px", "marginBottom": "0.5rem"}),
dbc.Collapse(
html.Pre(id="predict-live-logs", style={
"margin": 0,
"whiteSpace": "pre-wrap",
"fontFamily": "monospace",
"background": "#f8f9fa",
"padding": "0.75rem",
"borderRadius": "6px",
"maxHeight": "240px",
"overflowY": "auto"
}),
id="logs-collapse", is_open=False
),
dcc.Loading(
id="predict-loading-train",
type="circle",
color="#3CB4E6",
fullscreen=False,
children=html.Div(
id="predict-output",
children=pred_child,
style={
"flex": 1, "whiteSpace": "pre-wrap", "fontFamily": "monospace",
"background": "#f8f9fa", "padding": "0.75rem", "borderRadius": "6px",
"overflowY": "auto", "overflowX": "hidden", "height": "100%", "margin": 0, "minHeight": 0,
"display": "flex", "alignItems": "center", "justifyContent": "center"
}
),
style={"flex": 1}
),
dbc.Button("Retry", id="retry-predict-btn", color="warning", outline=True, style={"display": "none"}),
dbc.Button(
"Download Prediction",
id="download-prediction-btn",
color="secondary",
style={"marginTop": "0.75rem", "alignSelf": "center"}
),
dcc.Download(id="download-prediction")
], style={"display": "flex", "flexDirection": "column", "flex": 1, "padding": "1rem", "minHeight": 0})
], style={
"boxShadow": "0 2px 8px rgba(50,50,80,0.08)",
"border": "1px solid #A6ADB5",
"height": "100%", "minHeight": card_height, "maxHeight": card_height,
"display": "flex", "flexDirection": "column"
}, className="responsive-card")
], width=6, style={"display": "flex", "flexDirection": "column", "height": card_height, "minHeight": card_height, "maxHeight": card_height})
], justify="center", align="stretch", style={"marginTop": "2rem", "alignItems": "stretch"})
]
if tab_value == "tab-predict":
yaml_path = "image_classification/user_config.yaml"
epochs, batch_size, learning_rate = _load_ic_training_params(yaml_path)
acc_initial = acc_fig_store if acc_fig_store else {"data": [], "layout": ACC_LOSS_BASE_LAYOUT}
loss_initial = loss_fig_store if loss_fig_store else {"data": [], "layout": ACC_LOSS_BASE_LAYOUT}
card_height = "650px"
return [
html.H2("Image Classification - MobileNet v2", className="section-title"),
html.P([
html.A(
"MobileNet v2",
href="https://github.com/STMicroelectronics/stm32ai-modelzoo/tree/main/image_classification/mobilenetv2",
target="_blank",
className="info-link"
),
" is very similar to the original MobileNet, except that it uses inverted residual blocks with bottlenecking features. It has a drastically lower parameter count than the original MobileNet.",
html.Span("(", style={"marginRight": "2px"}),
html.A(
"Flowers Dataset used here",
href="https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz",
target="_blank",
style={"textDecoration": "underline"}
),
html.Span(")")
], className="section-subtitle"),
render_showcase_info_card(),
render_model_perf_ic(),
html.H5(
"Step 1 — Configure and Launch Training",
style={
"color": "#03234B",
"fontWeight": "700",
"margin": "0 0 0.75rem 0",
"textAlign": "left"
}
),
dbc.Row([
dbc.Col([
dbc.Card([
dbc.CardHeader("Training Parameters", style={"fontWeight": "bold", "fontSize": "1.1rem", "background": "#f8f9fa"}),
dbc.CardBody([
dbc.Form([
dbc.Label("Number of Epochs"),
dbc.Input(id="epochs", type="number", min=1, max=200, value=epochs, style={"marginBottom": "1rem"}),
dbc.Label("Batch Size"),
dbc.Input(id="batch-size", type="number", min=1, max=256, value=batch_size, style={"marginBottom": "1rem"}),
dbc.Label("Learning Rate"),
dbc.Input(id="learning-rate", type="number", value=learning_rate, style={"marginBottom": "1rem"}),
html.Div(
"Parameters are saved automatically.",
style={"color": "#525A63", "fontSize": "0.95rem", "marginBottom": "0.5rem"}
),
html.Div(id="autosave-yaml-alert"),
html.Hr(style={"margin": "0.75rem 0"}),
dbc.Label("Upload your own dataset (optional)"),
dcc.Upload(
id="ic-dataset-upload",
accept=".zip,.tar,.tar.gz,.tgz",
multiple=False,
children=html.Div([
html.Div("Drop a .zip/.tar(.gz) here or click to upload", style={"fontWeight": "bold"}),
html.Div("Expected: dataset_root/class_name/images...", style={"fontSize": "0.85rem"})
]),
className="dash-upload",
style={
"width": "100%",
"padding": "12px",
"borderRadius": "8px",
"textAlign": "center",
"cursor": "pointer",
"marginBottom": "0.5rem"
},
),
dbc.Button(
"Dataset format help",
id="ic-dataset-info-btn",
color="secondary",
size="sm",
style={"marginBottom": "0.5rem"}
),
dcc.Loading(
id="ic-dataset-upload-loading",
type="circle",
color="#3CB4E6",
fullscreen=False,
children=html.Div(id="ic-dataset-upload-alert")
),
dbc.Modal(
[
dbc.ModalHeader(dbc.ModalTitle("How to format your dataset")),
dbc.ModalBody(
[
html.Div("Upload a .zip or .tar(.gz/.tgz) that contains class folders (like Flowers):"),
html.Pre(
"""my_dataset/\n class1/\n img1.jpg\n img2.jpg\n class2/\n imgA.jpg\n imgB.jpg\n""",
style={"background": "#f8f9fa", "padding": "0.75rem", "borderRadius": "6px"}
),
html.Div("Notes:"),
html.Ul([
html.Li("Each class is a folder name (used as label)."),
html.Li("Images can be .jpg/.jpeg/.png."),
html.Li("The zip can either include a top folder (my_dataset/...) or directly the class folders.")
]),
html.Div("Current limits (configurable via environment variables):", style={"marginTop": "0.5rem"}),
html.Ul([
html.Li(f"Max archive size: {IC_MAX_ARCHIVE_MB}MB (env: IC_MAX_ARCHIVE_MB)"),
html.Li(f"Max extracted size (estimated): {IC_MAX_EXTRACT_MB}MB (env: IC_MAX_EXTRACT_MB)"),
html.Li(f"Corruption check sample size: {IC_MAX_VERIFY_IMAGES} images (env: IC_MAX_VERIFY_IMAGES)"),
]),
html.Div("After upload, the app extracts into image_classification/datasets/<dataset_name> and writes dataset.training_path in image_classification/user_config.yaml.")
]
),
dbc.ModalFooter(
dbc.Button("Close", id="ic-dataset-info-close", className="ms-auto", n_clicks=0)
),
],
id="ic-dataset-info-modal",
is_open=False,
size="lg",
),
])
], style={"overflowY": "auto", "padding": "1rem"})
], style={
"boxShadow": "0 2px 8px rgba(50,50,80,0.08)",
"border": "1px solid #A6ADB5",
"height": card_height, "minHeight": card_height,
"display": "flex", "flexDirection": "column"
})
], width=6),
dbc.Col([
dbc.Card([
dbc.CardHeader("Training & Logs", style={"fontWeight": "bold", "fontSize": "1.1rem", "background": "#f8f9fa"}),
dbc.CardBody([
html.Div([
dbc.Button("Start", id="train-btn", color="warning", style={
"backgroundColor": "#FFD200", "borderColor": "#FFD200", "color": "#03234B",
"borderRadius": "8px", "height": "44px", "minWidth": "90px",
"fontWeight": "bold", "fontSize": "1.05rem", "padding": "0 16px",
"display": "inline-flex", "alignItems": "center", "justifyContent": "center",
"gap": "6px", "whiteSpace": "nowrap", "marginRight": "20px"
}),
dbc.Button(html.Span("\u25A0", style={"fontWeight": "bold", "animation": "pulse 1s infinite alternate"}),
id="stop-train-btn", color="danger", className="btn-stop-round",
style={"marginBottom": "1rem", "boxShadow": "0 0 8px #dc3545", "display": "none"},
title="Stop Training")
], style={"display": "flex", "alignItems": "center"}),
html.Div(id="train-launch-alert"),
dcc.Interval(id="train-log-interval", interval=2000, n_intervals=0, disabled=True),
dbc.Progress(id="train-progress", value=0, max=100, striped=True, animated=True,
style={"height": "30px", "marginTop": "1rem", "marginBottom": "1rem"}),
html.Div(id="train-output", style={
"marginTop": "1rem", "whiteSpace": "pre-wrap", "fontFamily": "monospace",
"background": "#f8f9fa", "padding": "1rem", "borderRadius": "8px",
"minHeight": "200px", "maxHeight": "350px", "overflowY": "auto"
}),
html.Div([
html.Div([
dbc.Label("Tail lines"),
dbc.Input(id="log-tail-lines", type="number", min=50, max=10000, step=50, value=500,
style={"width": "110px", "marginRight": "0.75rem"}),
dbc.Checklist(
options=[{"label": "Visualize all", "value": "all"}],
value=[], id="log-show-all", switch=True, style={"marginRight": "1rem"}
),
dbc.Button("Download Logs", id="download-train-logs-btn",
color="secondary", size="sm", style={"marginRight": "0.5rem"}),
dbc.Button("Copy Logs", id="copy-train-logs-btn",
color="secondary", size="sm", style={"marginRight": "0.5rem"}),
dcc.Download(id="download-train-logs"),
html.Span(id="copy-logs-feedback", style={
"fontSize": "0.8rem", "color": "#03234B", "marginLeft": "0.5rem"
})
], style={"display": "flex", "alignItems": "center", "flexWrap": "wrap", "gap": "0.5rem"})
], style={"marginTop": "0.75rem"}),
dcc.Store(id="train-finished", data=False),
])
], style={
"boxShadow": "0 2px 8px rgba(50,50,80,0.08)",
"border": "1px solid #A6ADB5",
"height": card_height, "minHeight": card_height,
"display": "flex", "flexDirection": "column"
})
], width=6)
], justify="center", align="start", style={"marginTop": "2rem", "alignItems": "stretch"}),
html.Div(
id="train-metrics-box",
children=[
html.H5(
"Step 2 — Monitor Training Metrics",
style={
"color": "#03234B",
"fontWeight": "700",
"margin": "0 0 0.75rem 0",
"textAlign": "left"
}
),
html.Div([
html.Div([
html.H4("Accuracy", style={"textAlign": "center", "color": "#03234B", "marginBottom": "0.5rem"}),
dcc.Graph(
id="acc-visualization",
figure=acc_initial,
style={"height": "320px", "background": "#fff", "borderRadius": "10px",
"boxShadow": "0 2px 8px rgba(50,50,80,0.08)", "padding": "1rem"}
)
], style={"flex": 1, "marginRight": "1rem"}),
html.Div([
html.H4("Loss", style={"textAlign": "center", "color": "#03234B", "marginBottom": "0.5rem"}),
dcc.Graph(
id="loss-visualization",
figure=loss_initial,
style={"height": "320px", "background": "#fff", "borderRadius": "10px",
"boxShadow": "0 2px 8px rgba(50,50,80,0.08)", "padding": "1rem"}
)
], style={"flex": 1})
], style={"display": "flex", "flexDirection": "row", "gap": "1rem", "marginBottom": "2rem"}),
html.Div([
html.Div([
html.Div("Float Model Accuracy", style={
"textAlign": "center", "color": "#03234B", "fontWeight": "bold",
"fontSize": "1.05rem", "marginBottom": "0.4rem"
}),
dcc.Graph(id="gauge-float-acc", figure={"data": [], "layout": {"height": 190}}, style={"height": "190px"})
], style={"flex": 1, "marginRight": "1rem"}),
html.Div([
html.Div("Quantized Model Accuracy", style={
"textAlign": "center", "color": "#03234B", "fontWeight": "bold",
"fontSize": "1.05rem", "marginBottom": "0.4rem"
}),
dcc.Graph(id="gauge-quant-acc", figure={"data": [], "layout": {"height": 190}}, style={"height": "190px"})
], style={"flex": 1})
], style={"display": "flex", "flexDirection": "row", "gap": "1rem",
"marginBottom": "2rem", "justifyContent": "center"}),
html.Div([
dbc.Button("Display Confusion Matrix (Float Model)", id="display-confusion-matrix-btn",
color="warning", style={
"marginBottom": "1rem", "fontWeight": "bold",
"backgroundColor": "#FFD200", "borderColor": "#FFD200",
"color": "#03234B", "marginRight": "1rem"
}),
dbc.Button("Display Confusion Matrix (Quantized Model)", id="display-confusion-matrix-quant-btn",
color="warning", style={
"marginBottom": "1rem", "fontWeight": "bold",
"backgroundColor": "#FFD200", "borderColor": "#FFD200",
"color": "#03234B"
})
], style={"textAlign": "center", "marginBottom": "1.5rem",
"display": "flex", "justifyContent": "center"}),
html.Div(id="metrics-error-message", style={
"color": "#dc3545", "marginTop": "1rem", "fontWeight": "bold", "textAlign": "center"
})
],
style={"background": "#f8f9fa", "borderRadius": "12px", "padding": "2rem 1.5rem",
"boxShadow": "0 2px 8px rgba(50,50,80,0.08)", "marginBottom": "2rem", "width": "100%"}
)
]
return None
@app.callback(
Output("autosave-yaml-alert", "children"),
[Input("epochs", "value"),
Input("batch-size", "value"),
Input("learning-rate", "value")],
prevent_initial_call=True
)
def autosave_ic_yaml_params(epochs, batch_size, learning_rate):
yaml_path = "image_classification/user_config.yaml"
if epochs is None or batch_size is None or learning_rate is None:
return dash.no_update
try:
_save_ic_training_params(yaml_path, epochs, batch_size, learning_rate)
ts = datetime.now().strftime("%H:%M:%S")
return dbc.Alert(f"Autosaved at {ts}", color="success", dismissable=True, duration=1500)
except Exception as e:
return dbc.Alert(f"Erreur autosave : {e}", color="danger", dismissable=True)
@app.callback(
Output("ic-dataset-info-modal", "is_open"),
[Input("ic-dataset-info-btn", "n_clicks"),
Input("ic-dataset-info-close", "n_clicks")],
State("ic-dataset-info-modal", "is_open"),
prevent_initial_call=True
)
def toggle_ic_dataset_info_modal(n_open, n_close, is_open):
ctx = dash.callback_context
triggered = ctx.triggered[0]["prop_id"] if ctx.triggered else ""
if triggered.startswith("ic-dataset-info-btn"):
return True
if triggered.startswith("ic-dataset-info-close"):
return False
return is_open
@app.callback(
Output("ic-dataset-upload-alert", "children"),
Input("ic-dataset-upload", "contents"),
[State("ic-dataset-upload", "filename"),
State("user-id", "data")],
prevent_initial_call=True
)
def handle_ic_dataset_upload(contents, filename, user_id):
if not contents:
return dash.no_update
# Note: visual feedback handled by surrounding dcc.Loading in layout
yaml_path = "image_classification/user_config.yaml"
try:
# Infer dataset name from filename (without extension)
fname = (filename or "dataset.zip")
user_dataset_name = os.path.splitext(os.path.basename(fname))[0]
safe_name = re.sub(r"[^a-zA-Z0-9_-]+", "_", user_dataset_name)[:64] or "custom_dataset"
training_path, class_names = _prepare_ic_uploaded_dataset(
contents,
filename or "dataset.zip",
user_id=str(user_id or "anonymous"),
dataset_name=safe_name,
)
with open(yaml_path, "r", encoding="utf-8") as f:
config = yaml.safe_load(f) or {}
dataset_cfg = config.setdefault("dataset", {})
dataset_cfg["training_path"] = training_path
dataset_cfg["dataset_name"] = "custom_dataset"
if class_names:
dataset_cfg["class_names"] = class_names
with open(yaml_path, "w", encoding="utf-8") as f:
yaml.dump(config, f, allow_unicode=True, sort_keys=False, default_flow_style=False)
stats = _analyze_ic_dataset(training_path)
table_rows = [
html.Tr([html.Td(name), html.Td(str(cnt))])
for name, cnt in stats.get("per_class", [])
]
# Bar chart (images per class) with per-class colors
cls_names = [n for n, _ in stats.get("per_class", [])]
cls_counts = [c for _, c in stats.get("per_class", [])]
palette = [
"#3CB4E6", "#FFD200", "#E6007E", "#04572F", "#8191a5",
"#03234b", "#fff4bf", "#ceecf9"
]
colors = [palette[i % len(palette)] for i in range(len(cls_names))]
bar_fig = {
"data": [
{
"type": "bar",
"x": cls_names,
"y": cls_counts,
"marker": {"color": colors},
}
],
"layout": {
"height": 280,
"margin": {"l": 40, "r": 10, "t": 10, "b": 120},
"xaxis": {"tickangle": -45, "automargin": True},
"yaxis": {"title": "#Images"},
"paper_bgcolor": "white",
"plot_bgcolor": "white",
},
}
# Extension summary
ext_counts = stats.get("ext_counts", {})
ext_summary = ", ".join([f"{k}:{v}" for k, v in list(ext_counts.items())[:6]])
if len(ext_counts) > 6:
ext_summary += ", ..."
# Imbalance hint
imbalance_hint = None
if stats["min_images"] > 0 and stats["max_images"] / max(1, stats["min_images"]) >= 5:
imbalance_hint = f"Class imbalance detected (min={stats['min_images']}, max={stats['max_images']})."
# Thumbnail previews (random sample)
previews = []
sample_paths = stats.get("sample_paths", [])
if sample_paths:
# Try a few random images; skip ones that fail.
for p in random.sample(sample_paths, k=min(12, len(sample_paths))):
b64 = _make_thumbnail_b64(p, size=120)
if not b64:
continue
previews.append(
html.Div(
html.Img(
src=f"data:image/jpeg;base64,{b64}",
style={"width": "120px", "height": "120px", "borderRadius": "8px", "objectFit": "cover"},
),
style={"display": "inline-block", "margin": "6px"},
)
)
if len(previews) >= 8:
break
return html.Div([
dbc.Alert(
f"Dataset uploaded. training_path set to: {training_path}",
color="success",
dismissable=True,
),
html.H4("Dataset Statistics Report", style={"marginTop": "1rem", "marginBottom": "0.5rem", "color":"#425a78"}),
html.Div(
f"Classes: {stats['num_classes']} | Total images: {stats['total_images']} | "
f"Min/Max per class: {stats['min_images']}/{stats['max_images']} | "
f"Avg per class: {stats['avg_images']:.1f}",
style={"fontFamily": "monospace", "fontSize": "0.9rem", "marginBottom": "0.5rem"}
),
html.H5("1. Per-Class Distribution", style={"color": "#03234B", "margin": "0.5rem 0"}),
html.Div(
f"File types: {ext_summary}",
style={"fontFamily": "monospace", "fontSize": "0.85rem", "color": "#525A63", "marginBottom": "0.25rem"}
),
html.Div(
f"Corrupted check: {stats.get('corrupted_count', 0)} corrupted in first {stats.get('corrupted_checked', 0)} images checked",
style={"fontFamily": "monospace", "fontSize": "0.85rem", "color": "#525A63", "marginBottom": "0.5rem"}
),
(dbc.Alert(imbalance_hint, color="warning", dismissable=True) if imbalance_hint else None),
dcc.Graph(figure=bar_fig, config={"displayModeBar": False}),
html.H5("2. Classes Table", style={"color": "#03234B", "margin": "0.75rem 0 0.5rem 0"}),
dbc.Table(
[
html.Thead(html.Tr([html.Th("Class"), html.Th("#Images")])) ,
html.Tbody(table_rows)
],
bordered=True,
hover=True,
size="sm",
style={"maxHeight": "260px", "overflowY": "auto", "display": "block"}
),
(html.Div([
html.H5("3. Sample Previews", style={"color": "#03234B", "margin": "0.75rem 0 0.5rem 0"}),
html.Div(previews, style={"marginTop": "0.25rem"})
]) if previews else None),
(html.Div([
html.H5("4. Corrupted Examples", style={"color": "#03234B", "margin": "0.75rem 0 0.5rem 0"}),
html.Details([
html.Summary("Show list"),
html.Pre("\n".join(stats.get("corrupted_examples", [])), style={"whiteSpace": "pre-wrap"})
])
]) if stats.get("corrupted_examples") else None),
])
except Exception as e:
return html.Div([
dbc.Alert(f"Dataset upload failed: {e}", color="danger", dismissable=True)
])
@app.callback(
Output("stop-train-btn", "style"),
[Input("train-log-interval", "disabled"),
Input("train-finished", "data")],
prevent_initial_call=False
)
def toggle_stop_button(interval_disabled, finished):
base = {
"marginBottom": "1rem",
"boxShadow": "0 0 8px #dc3545"
}
if interval_disabled is False and not finished:
return base
hidden = base.copy()
hidden["display"] = "none"
return hidden
@app.callback(
Output("train-btn", "disabled"),
[Input("train-log-interval", "disabled"),
Input("train-finished", "data")],
prevent_initial_call=False
)
def toggle_start_button(interval_disabled, finished):
if interval_disabled is False and not finished:
return True
return False
@app.callback(
[Output("train-btn", "children"),
Output("train-btn", "style")],
[Input("train-log-interval", "disabled"),
Input("train-finished", "data")],
prevent_initial_call=False
)
def update_start_label(interval_disabled, finished):
base_style = {
"backgroundColor": "#FFD200",
"borderColor": "#FFD200",
"color": "#03234B",
"borderRadius": "8px",
"height": "44px",
"minWidth": "90px",
"fontWeight": "bold",
"fontSize": "1.05rem",
"padding": "0 16px",
"display": "inline-flex",
"alignItems": "center",
"justifyContent": "center",
"gap": "6px",
"whiteSpace": "nowrap",
"marginRight": "20px"
}
if interval_disabled is False and not finished:
return (
[
dbc.Spinner(size="sm", color="dark", spinner_style={"margin": "0"}),
"Running"
],
base_style
)
return "Start", base_style
TRAIN_PROCESS_HANDLE = {"process": None}
@app.callback(
[Output("train-launch-alert", "children"),
Output("train-log-interval", "disabled"),
Output("train-finished", "data"),
Output("train-progress", "value"),
Output("train-progress", "label"),
Output("train-log-raw", "data")],
[Input("train-btn", "n_clicks"),
Input("train-log-interval", "n_intervals"),
Input("stop-train-btn", "n_clicks")],
[State("stm32ai-credentials", "data"),
State("epochs", "value"),
State("batch-size", "value"),
State("learning-rate", "value"),
State("train-finished", "data"),
State("user-id", "data"),
State("is-original-host", "data")],
prevent_initial_call=False
)
def train_and_log(n_clicks, n_intervals, stop_clicks, creds, epochs, batch_size, learning_rate, finished_state, user_id, is_original):
if is_original:
return dbc.Alert("Training disabled on reference Space. Duplicate to enable.",
color="danger", dismissable=True, duration=4000), True, True, 0, "Disabled", ""
ctx = dash.callback_context
triggered = ctx.triggered[0]["prop_id"] if ctx.triggered else ""
alert = None
if not user_id:
user_id = str(uuid.uuid4())
user_tmp_dir = os.path.join("tmp_sessions", user_id)
os.makedirs(user_tmp_dir, exist_ok=True)
logs_path = os.path.join("tf", "src", "experiments_outputs")
log_file_path = None
if os.path.exists(logs_path):
dated_directories = [d for d in os.listdir(logs_path) if os.path.isdir(os.path.join(logs_path, d)) and d.startswith('20')]
if dated_directories:
recent_directory = max(dated_directories, key=lambda d: datetime.strptime(d, '%Y_%m_%d_%H_%M_%S'))
log_file_path = os.path.join(logs_path, recent_directory, "logs", "train.log")
os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
if not log_file_path:
# Fallback to user tmp dir
log_file_path = os.path.join(user_tmp_dir, "train.log")
# For log history per user
if not hasattr(train_and_log, "history_dict"):
train_and_log.history_dict = {}
if user_id not in train_and_log.history_dict:
if os.path.exists(log_file_path):
with open(log_file_path, "r", encoding="utf-8") as f:
train_and_log.history_dict[user_id] = f.readlines()
else:
train_and_log.history_dict[user_id] = []
history = train_and_log.history_dict[user_id]
# If stop button is clicked, terminate training process
if triggered.startswith("stop-train-btn") and stop_clicks:
if TRAIN_PROCESS_HANDLE["process"] is not None:
try:
TRAIN_PROCESS_HANDLE["process"].terminate()
except Exception:
pass
TRAIN_PROCESS["finished"] = True
with open(log_file_path, "w", encoding="utf-8") as f:
f.write("".join(history))
return dbc.Alert("Training has been stopped !", color="danger", dismissable=True, duration=3000), True, True, 0, "Stopped", "".join(history)
# If training button is clicked, start training thread
if triggered.startswith("train-btn") and n_clicks:
yaml_path = "image_classification/user_config.yaml"
try:
_save_ic_training_params(yaml_path, epochs, batch_size, learning_rate)
except Exception as e:
return (
dbc.Alert(f"Error saving parameters: {e}", color="danger", dismissable=True),
True,
True,
0,
"Error",
"".join(history),
)
def run_training(epochs, batch_size, learning_rate):
TRAIN_PROCESS["finished"] = False
TRAIN_LOG_QUEUE.queue.clear()
isolated_env = os.environ.copy()
isolated_env['STATS_TYPE'] = 'HuggingFace_devcloud'
isolated_env['PYTHONUNBUFFERED'] = '1'
with open(log_file_path, "w", encoding="utf-8") as log_f:
process = subprocess.Popen([
sys.executable, "image_classification/stm32ai_main.py"
], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, env=isolated_env)
TRAIN_PROCESS_HANDLE["process"] = process
for line in process.stdout:
TRAIN_LOG_QUEUE.put(line)
log_f.write(line)
process.wait()
TRAIN_PROCESS_HANDLE["process"] = None
TRAIN_PROCESS["finished"] = True
if TRAIN_PROCESS["thread"] is None or not TRAIN_PROCESS["thread"].is_alive():
t = threading.Thread(target=run_training, args=(epochs, batch_size, learning_rate), daemon=True)
t.start()
TRAIN_PROCESS["thread"] = t
alert = dbc.Alert("Parameters autosaved. Launching training !", color="info", dismissable=True, duration=3000)
return alert, False, False, 0, "0%", ""
# If polling logs
log_lines = []
while not TRAIN_LOG_QUEUE.empty():
log_lines.append(TRAIN_LOG_QUEUE.get())
if not log_lines and not TRAIN_PROCESS["thread"]:
return None, True, False, 0, "0%", "".join(history)
history.extend(log_lines)
with open(log_file_path, "w", encoding="utf-8") as f:
f.write("".join(history))
finished = TRAIN_PROCESS["finished"]
def extract_metrics(log_lines):
float_acc = None
quant_acc = None
for line in reversed(log_lines):
m_float = re.search(r"accuracy of float model.*?=\s*([0-9]+\.?[0-9]*)%", line, re.IGNORECASE)
if m_float:
try:
float_acc = float(m_float.group(1))
break
except Exception:
float_acc = None
for line in reversed(log_lines):
m_quant = re.search(r"accuracy of quantized model.*?=\s*([0-9]+\.?[0-9]*)%", line, re.IGNORECASE)
if m_quant:
try:
quant_acc = float(m_quant.group(1))
break
except Exception:
quant_acc = None
return float_acc, quant_acc
float_acc = None
quant_acc = None
logs_path = os.path.join("tf", "src", "experiments_outputs")
log_lines = []
if os.path.exists(logs_path):
dated_directories = [d for d in os.listdir(logs_path) if os.path.isdir(os.path.join(logs_path, d)) and d.startswith('20')]
if dated_directories:
recent_directory = max(dated_directories, key=lambda d: datetime.strptime(d, '%Y_%m_%d_%H_%M_%S'))
log_file = os.path.join(logs_path, recent_directory, "logs", "train.log")
if os.path.exists(log_file):
try:
with open(log_file, "r", encoding="utf-8") as f:
log_lines = f.readlines()
except Exception:
pass
float_acc, quant_acc = extract_metrics(log_lines)
progress = 0
label = "0%"
# Extract epoch and total_epochs from log_lines
epoch = 0
total_epochs = 0
for line in reversed(history):
m_epoch = re.search(r"Epoch (\d+)/(\d+)", line)
if m_epoch:
epoch = int(m_epoch.group(1))
total_epochs = int(m_epoch.group(2))
break
if total_epochs > 0:
progress = int(100 * epoch / total_epochs)
label = f"{epoch}/{total_epochs} ({progress}% Done, {100-progress}% Remaining)"
if finished and TRAIN_LOG_QUEUE.empty():
return None, True, True, 100, "End", "".join(history)
return None, False, False, progress, label, "".join(history)
@app.callback(
Output("train-output", "children"),
[Input("train-log-raw", "data"), Input("log-tail-lines", "value"), Input("log-show-all", "value")],
prevent_initial_call=False
)
def display_train_logs(raw, tail_lines, show_all):
if raw is None:
return html.Pre("")
if not raw:
return html.Pre("", style={"margin": 0})
lines = raw.splitlines()
if show_all and "all" in show_all:
display_lines = lines
else:
try:
n = int(tail_lines or 500)
except Exception:
n = 500
display_lines = lines[-n:]
return html.Pre("\n".join(display_lines), style={"margin": 0})
@app.callback(
Output("download-train-logs", "data"),
Input("download-train-logs-btn", "n_clicks"),
State("train-log-raw", "data"),
prevent_initial_call=True
)
def download_train_logs(n, raw):
if not n:
return dash.no_update
if raw is None:
raw = ""
filename = f"training_logs_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
return dict(content=raw, filename=filename)
app.clientside_callback(
"""
function(n_clicks, raw){
if(!n_clicks){ return window.dash_clientside.no_update || '' }
try {
if(raw === null || raw === undefined){
if (navigator?.clipboard?.writeText) { navigator.clipboard.writeText(''); }
return 'Aucun log';
}
const MAX = 1000000;
let txt = String(raw);
txt = txt.replace(/\\r\\n?/g, '\\n').replace(/\\x1B\\[[0-9;]*m/g, '');
let trunc = false;
if (txt.length > MAX){ txt = txt.slice(0, MAX); trunc = true; }
if (navigator?.clipboard?.writeText){
navigator.clipboard.writeText(txt);
} else {
const ta = document.createElement('textarea');
ta.value = txt;
ta.style.position = 'fixed';
ta.style.top = '-2000px';
document.body.appendChild(ta);
ta.select();
try { document.execCommand('copy'); } catch(e){}
document.body.removeChild(ta);
}
return 'Logs copied' + (trunc ? ' (truncated)' : '') + ' [' + txt.length + ' chars]';
} catch(e){
return 'Copy failed';
}
}
""",
Output("copy-logs-feedback", "children"),
Input("copy-train-logs-btn", "n_clicks"),
State("train-log-raw", "data"),
prevent_initial_call=True
)
app.clientside_callback(
f"""
function(href){{
if(!href) return null;
const host = window.location.host.toLowerCase();
const originals = {ORIGINAL_SPACE_HOSTS};
return Array.isArray(originals) && originals.includes(host);
}}
""",
Output("is-original-host", "data"),
Input("app-url", "href")
)
@app.callback(
Output("duplicate-blocker", "style"),
Input("is-original-host", "data"),
prevent_initial_call=False
)
def show_blocker(is_original):
base = {
"position": "fixed",
"top": 0, "left": 0, "right": 0, "bottom": 0,
"background": "rgba(3,35,75,0.96)",
"zIndex": 99999,
"display": "flex" if is_original else "none",
"padding": "2rem"
}
return base
@app.callback(
[Output("acc-visualization", "figure"),
Output("loss-visualization", "figure"),
Output("acc-fig-store", "data"),
Output("loss-fig-store", "data")],
[Input("train-log-interval", "n_intervals"),
Input("train-finished", "data")],
prevent_initial_call=False
)
def refresh_metrics(n_intervals, finished):
outputs_folder = os.path.join("tf", "src", "experiments_outputs")
if not os.path.exists(outputs_folder):
return dash.no_update, dash.no_update, dash.no_update, dash.no_update
dated_dirs = [d for d in os.listdir(outputs_folder)
if os.path.isdir(os.path.join(outputs_folder, d)) and
re.match(r"^20\d{2}_\d{2}_\d{2}_\d{2}_\d{2}_\d{2}$", d)]
if not dated_dirs:
return dash.no_update, dash.no_update, dash.no_update, dash.no_update
recent = max(dated_dirs, key=lambda d: datetime.strptime(d, "%Y_%m_%d_%H_%M_%S"))
metrics_file = os.path.join(outputs_folder, recent, "logs", "metrics", "train_metrics.csv")
if not os.path.exists(metrics_file):
return dash.no_update, dash.no_update, dash.no_update, dash.no_update
try:
df = pd.read_csv(metrics_file)
except Exception:
return dash.no_update, dash.no_update, dash.no_update, dash.no_update
if df.empty or "epoch" not in df.columns:
return dash.no_update, dash.no_update, dash.no_update, dash.no_update
def normalize_metric(series):
if series.max() <= 1.2:
return series * 100.0, True
return series, False
# Accuracy
traces_acc = []; y_acc_values = []; norm_used = False
mk = {"size": 8, "symbol": "circle", "line": {"width": 1, "color": "#fff"}}
if "accuracy" in df.columns:
acc_vals, n1 = normalize_metric(df["accuracy"]); norm_used |= n1; y_acc_values.append(acc_vals)
traces_acc.append({"x": df["epoch"], "y": acc_vals, "type": "scatter", "mode": "markers+lines",
"name": "Accuracy", "line": {"color": "#1976D2", "width": 2},
"marker": {**mk, "color": "#1976D2"},
"hovertemplate": "Epoch %{x}<br>Acc %{y:.2f}%<extra></extra>"})
if "val_accuracy" in df.columns:
val_acc_vals, n2 = normalize_metric(df["val_accuracy"]); norm_used |= n2; y_acc_values.append(val_acc_vals)
traces_acc.append({"x": df["epoch"], "y": val_acc_vals, "type": "scatter", "mode": "markers+lines",
"name": "Val Accuracy", "line": {"color": "#FF9800", "width": 2},
"marker": {**mk, "color": "#FF9800", "symbol": "diamond"},
"hovertemplate": "Epoch %{x}<br>Val Acc %{y:.2f}%<extra></extra>"})
try:
best_idx = df["val_accuracy"].idxmax()
traces_acc.append({"x": [df.loc[best_idx, "epoch"]], "y": [val_acc_vals.loc[best_idx]],
"type": "scatter", "mode": "markers", "name": "Best Val Acc",
"marker": {"size": 16, "color": "#2E7D32", "symbol": "star",
"line": {"width": 2, "color": "#1B5E20"}},
"hovertemplate": "Best Val Acc<br>Epoch %{x}<br>%{y:.2f}%<extra></extra>"})
except Exception:
pass
xmin = int(df["epoch"].min()); xmax = int(df["epoch"].max())
acc_layout = ACC_LOSS_BASE_LAYOUT | {
"xaxis": {**ACC_LOSS_BASE_LAYOUT["xaxis"], "range": [max(0, xmin - 0.15), xmax + 0.35]},
"yaxis": {"title": "Accuracy (%)", "showgrid": True, "gridcolor": "#EEEFF1", "fixedrange": True},
"legend": {"orientation": "h", "y": -0.25},
"hovermode": "x unified",
"dragmode": False,
"uirevision": "persist-metrics",
"annotations": ([{
"text": "Normalized Accuracy (%)",
"showarrow": False, "xref": "paper", "yref": "paper", "x": 0, "y": 1.12,
"font": {"size": 11, "color": "#525A63"}
}] if norm_used else [])
}
if y_acc_values:
all_acc = pd.concat(y_acc_values)
ymin, ymax = float(all_acc.min()), float(all_acc.max())
if ymin == ymax:
ymin -= 1; ymax += 1
pad = (ymax - ymin) * 0.08
acc_layout["yaxis"]["range"] = [ymin - pad, ymax + pad]
acc_fig = {"data": traces_acc, "layout": acc_layout}
# Loss
traces_loss = []; y_loss_values = []
if "loss" in df.columns:
y_loss_values.append(df["loss"])
traces_loss.append({"x": df["epoch"], "y": df["loss"], "type": "scatter", "mode": "markers+lines",
"name": "Loss", "line": {"color": "#1976D2", "width": 2},
"marker": {"size": 8, "symbol": "circle", "line": {"width": 1, "color": "#fff"}, "color": "#1976D2"},
"hovertemplate": "Epoch %{x}<br>Loss %{y:.4f}<extra></extra>"})
if "val_loss" in df.columns:
y_loss_values.append(df["val_loss"])
traces_loss.append({"x": df["epoch"], "y": df["val_loss"], "type": "scatter", "mode": "markers+lines",
"name": "Val Loss", "line": {"color": "#FF9800", "width": 2},
"marker": {"size": 8, "symbol": "diamond", "line": {"width": 1, "color": "#fff"}, "color": "#FF9800"},
"hovertemplate": "Epoch %{x}<br>Val Loss %{y:.4f}<extra></extra>"})
try:
best_loss_idx = df["val_loss"].idxmin()
traces_loss.append({"x": [df.loc[best_loss_idx, "epoch"]],
"y": [df.loc[best_loss_idx, "val_loss"]],
"type": "scatter", "mode": "markers", "name": "Best Val Loss",
"marker": {"size": 16, "color": "#D32F2F", "symbol": "star",
"line": {"width": 2, "color": "#B71C1C"}},
"hovertemplate": "Best Val Loss<br>Epoch %{x}<br>%{y:.4f}<extra></extra>"})
except Exception:
pass
loss_layout = ACC_LOSS_BASE_LAYOUT | {
"xaxis": {**ACC_LOSS_BASE_LAYOUT["xaxis"], "range": [max(0, xmin - 0.15), xmax + 0.35]},
"yaxis": {"title": "Loss", "showgrid": True, "gridcolor": "#EEEFF1", "fixedrange": True},
"legend": {"orientation": "h", "y": -0.25},
"hovermode": "x unified",
"dragmode": False,
"uirevision": "persist-metrics"
}
if y_loss_values:
all_loss = pd.concat(y_loss_values)
ymin, ymax = float(all_loss.min()), float(all_loss.max())
if ymin == ymax:
ymin -= 0.05 * (abs(ymin) + 1); ymax += 0.05 * (abs(ymax) + 1)
pad = (ymax - ymin) * 0.08
loss_layout["yaxis"]["range"] = [ymin - pad, ymax + pad]
loss_fig = {"data": traces_loss, "layout": loss_layout}
return acc_fig, loss_fig, acc_fig, loss_fig
@app.callback(
[Output("gauge-float-acc", "figure"),
Output("gauge-quant-acc", "figure")],
[Input("train-log-interval", "n_intervals"),
Input("train-finished", "data")],
prevent_initial_call=False
)
def update_gauges(n, finished):
outputs_folder = os.path.join("tf", "src", "experiments_outputs")
float_acc = None
quant_acc = None
if os.path.exists(outputs_folder):
dated_dirs = [d for d in os.listdir(outputs_folder)
if os.path.isdir(os.path.join(outputs_folder, d)) and d.startswith("20")]
if dated_dirs:
recent = max(dated_dirs, key=lambda d: datetime.strptime(d, "%Y_%m_%d_%H_%M_%S"))
log_file = os.path.join(outputs_folder, recent, "logs", "train.log")
if os.path.exists(log_file):
try:
with open(log_file, "r", encoding="utf-8") as f:
lines = f.readlines()
for line in reversed(lines):
m = re.search(r"accuracy of float model.*?=\s*([0-9]+\.?[0-9]*)%", line, re.IGNORECASE)
if m:
float_acc = float(m.group(1))
break
for line in reversed(lines):
m = re.search(r"accuracy of quantized model.*?=\s*([0-9]+\.?[0-9]*)%", line, re.IGNORECASE)
if m:
quant_acc = float(m.group(1))
break
except Exception:
pass
def color(v):
if v is None:
return "#c0c8d2"
if v < 40:
return "#FFD200"
if v < 70:
return "#6dc7ec"
if v < 85:
return "#49B170"
if v < 92:
return "#E6007E"
return "#04572F"
def gauge_figure(val):
col = color(val)
return {
"data": [{
"type": "indicator",
"mode": "gauge+number",
"value": float(val) if val is not None else 0,
"gauge": {
"axis": {"range": [0, 100], "tickwidth": 2, "tickcolor": "#03234B"},
"bar": {"color": col},
"bgcolor": "#EEEFF1",
"borderwidth": 2,
"bordercolor": "#A6ADB5"
},
"number": {"suffix": "%" if val is not None else "", "font": {"size": 24, "color": col}},
"domain": {"x": [0, 1], "y": [0, 1]}
}],
"layout": {
"margin": {"l": 15, "r": 15, "t": 10, "b": 10},
"paper_bgcolor": "#EEEFF1",
"plot_bgcolor": "#EEEFF1",
"height": 190
}
}
return gauge_figure(float_acc), gauge_figure(quant_acc)
@app.callback(
Output("metrics-error-message", "children"),
[Input("acc-visualization", "figure"),
Input("loss-visualization", "figure")],
prevent_initial_call=False
)
def show_metrics_error(acc_fig, loss_fig):
return ""
@app.callback(
Output("user-id", "data"),
Input("user-id", "data"),
prevent_initial_call=False
)
def ensure_user_id(user_id):
if not user_id:
return str(uuid.uuid4())
return user_id
@app.callback(
[Output("confusion-matrix-modal", "is_open"),
Output("confusion-matrix-modal-body", "children"),
Output("confusion-float-store", "data"),
Output("confusion-quant-store", "data"),
Output("confusion-last-type", "data")],
[Input("display-confusion-matrix-btn", "n_clicks"),
Input("display-confusion-matrix-quant-btn", "n_clicks"),
Input("close-confusion-matrix-modal", "n_clicks")],
[State("confusion-matrix-modal", "is_open"),
State("confusion-float-store", "data"),
State("confusion-quant-store", "data"),
State("confusion-last-type", "data")],
prevent_initial_call=True
)
def toggle_confusion_matrix_modal(n_float, n_quant, n_close, is_open,
store_float, store_quant, last_type):
ctx = dash.callback_context
triggered = ctx.triggered[0]["prop_id"] if ctx.triggered else ""
if not triggered:
return is_open, dash.no_update, dash.no_update, dash.no_update, last_type
if triggered.startswith("close-confusion-matrix-modal") and n_close:
return False, dash.no_update, dash.no_update, dash.no_update, last_type
matrix_type = None
if triggered.startswith("display-confusion-matrix-btn") and n_float:
matrix_type = "float"
elif triggered.startswith("display-confusion-matrix-quant-btn") and n_quant:
matrix_type = "quant"
else:
return is_open, dash.no_update, dash.no_update, dash.no_update, last_type
if matrix_type == "float" and store_float:
img = html.Img(src=f"data:image/png;base64,{store_float}",
style={"maxWidth": "100%", "height": "auto"})
return True, img, dash.no_update, dash.no_update, "float"
if matrix_type == "quant" and store_quant:
img = html.Img(src=f"data:image/png;base64,{store_quant}",
style={"maxWidth": "100%", "height": "auto"})
return True, img, dash.no_update, dash.no_update, "quant"
outputs_folder = os.path.join("tf", "src", "experiments_outputs")
if not os.path.exists(outputs_folder):
msg = html.Div("Outputs folder not found.", style={"color": "#dc3545", "textAlign": "center"})
return True, msg, dash.no_update, dash.no_update, last_type
try:
pattern = re.compile(r'^20\d{2}_\d{2}_\d{2}_\d{2}_\d{2}_\d{2}$')
candidates = [d for d in os.listdir(outputs_folder)
if os.path.isdir(os.path.join(outputs_folder, d)) and pattern.match(d)]
if not candidates:
return True, html.Div("No experiment directory found.", style={"color": "#dc3545"}), dash.no_update, dash.no_update, last_type
recent = max(candidates, key=lambda d: datetime.strptime(d, "%Y_%m_%d_%H_%M_%S"))
except Exception as e:
return True, html.Div(f"Scan error: {e}", style={"color": "#dc3545"}), dash.no_update, dash.no_update, last_type
if matrix_type == "float":
file_path = os.path.join(outputs_folder, recent, "float_model_confusion_matrix_validation_set.png")
else:
file_path = os.path.join(outputs_folder, recent, "quantized_model_confusion_matrix_validation_set.png")
if not os.path.exists(file_path):
return True, html.Div(f"Confusion matrix file not found: {file_path}",
style={"color": "#dc3545", "wordBreak": "break-all"}), dash.no_update, dash.no_update, last_type
try:
with Image.open(file_path) as im:
if os.path.getsize(file_path) > 5 * 1024 * 1024:
im.thumbnail((1400, 1400))
buf = _io.BytesIO()
im.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode()
img = html.Img(src=f"data:image/png;base64,{b64}", style={"maxWidth": "100%", "height": "auto"})
if matrix_type == "float":
return True, img, b64, dash.no_update, "float"
else:
return True, img, dash.no_update, b64, "quant"
except Exception as e:
return True, html.Div(f"Error loading image: {e}", style={"color": "#dc3545"}), dash.no_update, dash.no_update, last_type
@app.callback(
Output("navbar-main", "children"),
Input("sidebar-mode", "data")
)
def render_navbar(tab):
pred_class = "me-2 navbar-yellow-btn"
train_class = "navbar-yellow-btn"
if tab == "tab-predict":
pred_class = "me-2 navbar-gray-btn"
train_class = "navbar-yellow-btn"
elif tab == "tab-train":
pred_class = "me-2 navbar-yellow-btn"
train_class = "navbar-gray-btn"
return dbc.Navbar(
dbc.Container([
html.Div(
html.Img(src="/assets/ST_logo.png",
style={"height": "38px", "marginRight": "1.2rem"}),
style={"display": "flex", "alignItems": "center"}
),
dbc.NavbarBrand(
"STM32AI Experimentation Hub",
style={
"color": "#fff",
"fontWeight": "bold",
"fontSize": "1.3rem",
"marginRight": "2rem"
}
),
dbc.Nav([
dbc.Button(
"Training Mode (Image classification)",
id="nav-predict",
color="primary",
outline=False,
className=pred_class,
n_clicks=0
),
dbc.Button(
"Prediction Mode (Object detection)",
id="nav-train",
color="primary",
outline=False,
className=train_class,
n_clicks=0
),
html.A(
html.Img(
src=app.get_asset_url("github.svg"),
style={"height": "22px", "width": "22px", "display": "block"}
),
id="navbar-github-link",
href=ST_MODELZOO_SERVICES_REPO_URL,
target="_blank",
title="STM32AI Model Zoo Services Repository",
style={
"display": "inline-flex",
"alignItems": "center",
"justifyContent": "center",
"padding": "0.45rem",
"borderRadius": "999px",
"textDecoration": "none",
"background": "transparent",
"border": "1px solid rgba(255,255,255,0.22)",
},
),
], className="ms-auto", style={"gap": "0.5rem"})
], style={"display": "flex", "alignItems": "center"}),
color="#0E2140",
dark=True,
style={"marginBottom": "0.5rem", "padding": "0.3rem 2rem", "height": "56px", "backgroundColor": "#0E2140"},
className="navbar-custom"
)
if __name__ == "__main__":
os.makedirs("tmp_sessions", exist_ok=True)
app.run(host="0.0.0.0", port=7860, debug=True,
dev_tools_ui=True, dev_tools_hot_reload=True, threaded=True)