import json
import os
import pandas as pd
import requests
import streamlit as st
API_URL = os.getenv("API_URL", "http://127.0.0.1:8001/api")
LABEL_DISPLAY_MAP = {
"ang": "anger",
"hap": "happiness/excitement",
"sad": "sadness",
"neu": "neutral",
"exc": "happiness/excitement",
0: "anger",
1: "happiness/excitement",
2: "sadness",
3: "neutral",
"0": "anger",
"1": "happiness/excitement",
"2": "sadness",
"3": "neutral",
}
def _map_label_display(label):
if label in LABEL_DISPLAY_MAP:
return LABEL_DISPLAY_MAP[label]
key = str(label)
return LABEL_DISPLAY_MAP.get(key, key)
# ==========================================
# 1. PAGE SETTINGS
# ==========================================
st.set_page_config(
page_title="Federated Learning Admin",
layout="wide",
page_icon="settings"
)
# ==========================================
# 2. SIDEBAR (GLOBAL ACTIONS)
# ==========================================
st.sidebar.title("Server")
st.sidebar.markdown("---")
if st.sidebar.button("Refresh Application"):
st.rerun()
# ==========================================
# 3. MAIN CONTENT ROUTING
# ==========================================
def simple_block_header(text):
st.markdown(
f"""
{text}
""",
unsafe_allow_html=True
)
st.markdown(
"""
""",
unsafe_allow_html=True
)
def _parse_int_field(name: str, raw: str, min_value: int = None) -> int:
try:
value = int(str(raw).strip())
except Exception:
raise ValueError(f"{name} must be an integer.")
if min_value is not None and value < min_value:
raise ValueError(f"{name} must be >= {min_value}.")
return value
def _parse_float_field(name: str, raw: str, min_value: float = None, max_value: float = None) -> float:
try:
value = float(str(raw).strip())
except Exception:
raise ValueError(f"{name} must be a number.")
if min_value is not None and value < min_value:
raise ValueError(f"{name} must be >= {min_value}.")
if max_value is not None and value > max_value:
raise ValueError(f"{name} must be <= {max_value}.")
return value
def render_global_configuration():
st.subheader("Global Configuration")
current_setup = {}
try:
res = requests.get(f"{API_URL}/config")
res_data = res.json()
if res_data.get("status") == "success":
current_setup = (res_data.get("data") or {}).get("current_setup", {})
except Exception as e:
st.warning(f"Cannot fetch global configuration: {e}")
with st.form("server_config_form_shared"):
col_cfg1, col_cfg2 = st.columns(2)
with col_cfg1:
rounds_pretrain_raw = st.text_input("Rounds Pretrain", value=str(int(current_setup.get("rounds_pretrain", 1))))
rounds_ssl_raw = st.text_input("Rounds SSL", value=str(int(current_setup.get("rounds_ssl", 1))))
tau_raw = st.text_input("Tau (Threshold)", value=str(float(current_setup.get("tau", 0.5))))
ssl_weight_raw = st.text_input("SSL Weight", value=str(float(current_setup.get("ssl_weight", 1.0))))
with col_cfg2:
seed_raw = st.text_input("Random Seed", value=str(int(current_setup.get("seed", 42))))
agg_options = ["fedavg", "fedprox"]
current_agg = current_setup.get("aggregator", "fedavg")
agg_index = agg_options.index(current_agg) if current_agg in agg_options else 0
aggregator = st.selectbox("Aggregator", agg_options, index=agg_index)
fedprox_mu_raw = st.text_input("FedProx Mu", value=str(float(current_setup.get("fedprox_mu", 0.01))))
weight_options = ["num_samples", "equal"]
current_weight = current_setup.get("weight_by", "num_samples")
weight_index = weight_options.index(current_weight) if current_weight in weight_options else 0
weight_by = st.selectbox("Weight By", weight_options, index=weight_index)
submit_cfg = st.form_submit_button("Save Configuration", use_container_width=True)
if submit_cfg:
try:
rounds_pretrain = _parse_int_field("Rounds Pretrain", rounds_pretrain_raw, min_value=1)
rounds_ssl = _parse_int_field("Rounds SSL", rounds_ssl_raw, min_value=0)
tau = _parse_float_field("Tau (Threshold)", tau_raw, min_value=0.0, max_value=1.0)
ssl_weight = _parse_float_field("SSL Weight", ssl_weight_raw, min_value=0.0)
seed = _parse_int_field("Random Seed", seed_raw)
fedprox_mu = _parse_float_field("FedProx Mu", fedprox_mu_raw, min_value=0.0)
except ValueError as ve:
st.error(str(ve))
return
payload = {
"rounds_pretrain": rounds_pretrain,
"rounds_ssl": rounds_ssl,
"tau": tau,
"ssl_weight": ssl_weight,
"seed": seed,
"aggregator": aggregator,
"fedprox_mu": fedprox_mu,
"weight_by": weight_by,
}
with st.spinner("Saving configuration..."):
try:
cfg_res = requests.post(f"{API_URL}/config", json=payload)
cfg_data = cfg_res.json()
if cfg_data.get("status") == "success":
st.success("Configuration updated successfully.")
st.rerun()
else:
st.error(f"Update failed: {cfg_data.get('message')}")
except Exception as e:
st.error(f"Request failed: {str(e)}")
def render_server_controls():
"""
Renders the Server Control panel.
Handles training execution and global server configuration updates.
"""
# st.subheader("Server Control")
st.info("Manage training sessions.")
# 1. Fetch current system state from backend
try:
res = requests.get(f"{API_URL}/config")
res_data = res.json()
system_state = res_data.get("data", {}) if res_data.get("status") == "success" else {}
except Exception as e:
st.error(f"Cannot connect to the backend server. Is it running? Details: {str(e)}")
return
profiles_data = []
profiles_error = None
try:
res_profiles = requests.get(f"{API_URL}/profiles")
profiles_json = res_profiles.json() if res_profiles.status_code == 200 else {}
if profiles_json.get("status") == "success":
profiles_data = profiles_json.get("data", [])
else:
profiles_error = profiles_json.get("message", "Cannot fetch client profiles.")
except Exception as e:
profiles_error = str(e)
active_run = system_state.get("active_run")
# ==========================================
# EXECUTION CONTROLS
# ==========================================
st.subheader("Execution Controls")
col1, col2 = st.columns(2)
with col1:
if st.button("Start Training", type="primary", use_container_width=True):
with st.spinner("Initializing training session..."):
try:
train_res = requests.post(f"{API_URL}/train")
data = train_res.json()
if data.get("status") == "success":
st.success(data.get("message", "Training started successfully."))
st.rerun()
else:
st.error(f"Failed to start: {data.get('message')}")
except Exception as e:
st.error(f"Request error: {str(e)}")
with col2:
if st.button("Stop Training", type="secondary", use_container_width=True):
with st.spinner("Stopping current session..."):
try:
stop_res = requests.delete(f"{API_URL}/train")
data = stop_res.json()
if data.get("status") == "success":
st.warning(data.get("message", "Training stopped."))
st.rerun()
else:
st.error(f"Failed to stop: {data.get('message')}")
except Exception as e:
st.error(f"Request error: {str(e)}")
# Display active run status if it exists
if active_run and active_run.get("status") not in ["completed", "failed"]:
st.success(f"**Active Session:** {active_run.get('run_id')} **Status:** {active_run.get('status').upper()}")
else:
st.info("No active training session currently running.")
st.subheader("Registered Client Profiles")
if profiles_error:
st.warning(f"Profiles API is not available right now: {profiles_error}")
elif not profiles_data:
st.info("No client profiles found. Start client scripts to register profile metadata.")
else:
client_ids = [str(item.get("client_id", "N/A")) for item in profiles_data]
st.write("**Detailed Status:**")
cols = st.columns(len(client_ids))
for i, cid in enumerate(client_ids):
with cols[i]:
st.success(f"Registered: {cid}")
# st.markdown("---")
def render_server_controls_manual():
"""
Renders the Manual Server Control panel.
Allows for step-by-step orchestration of the FL process.
"""
# st.subheader("Server Control (Manual Mode)")
st.info("Human-in-the-loop: start rounds, monitor progress, and aggregate models manually.")
# 1. Fetch current system state
try:
res = requests.get(f"{API_URL}/config")
system_state = res.json().get("data", {})
active_run = system_state.get("active_run")
except Exception as e:
st.error(f"Backend connection error: {str(e)}")
return
profiles_data = []
profiles_error = None
try:
res_profiles = requests.get(f"{API_URL}/profiles")
profiles_json = res_profiles.json() if res_profiles.status_code == 200 else {}
if profiles_json.get("status") == "success":
profiles_data = profiles_json.get("data", [])
else:
profiles_error = profiles_json.get("message", "Cannot fetch client profiles.")
except Exception as e:
profiles_error = str(e)
# ==========================================
# 1. EXECUTION CONTROLS (Manual Start/Stop)
# ==========================================
st.subheader("Orchestration")
col1, col2 = st.columns(2)
with col1:
# Start manual mode and broadcast NEW_ROUND_MANUAL
if st.button("Start / Broadcast Next Round", type="primary", use_container_width=True):
with st.spinner("Broadcasting command..."):
try:
res = requests.post(f"{API_URL}/train_manual")
data = res.json()
if data.get("status") == "success":
st.success(data.get("message"))
st.rerun()
else:
st.error(f"Error: {data.get('message')}")
except Exception as e:
st.error(str(e))
with col2:
# Stop manual mode
if st.button("Stop Manual Session", type="secondary", use_container_width=True):
with st.spinner("Stopping..."):
try:
res = requests.delete(f"{API_URL}/train_manual")
if res.json().get("status") == "success":
st.warning("Manual session stopped.")
st.rerun()
except Exception as e:
st.error(str(e))
st.markdown("---")
st.subheader("Registered Client Profiles")
if profiles_error:
st.warning(f"Profiles API is not available right now: {profiles_error}")
elif not profiles_data:
st.info("No client profiles found. Start client scripts to register profile metadata.")
else:
client_ids = [str(item.get("client_id", "N/A")) for item in profiles_data]
st.write("**Detailed Status:**")
cols = st.columns(len(client_ids))
for i, cid in enumerate(client_ids):
with cols[i]:
st.success(f"Registered: {cid}")
st.markdown("---")
# ==========================================
# 2. SUBMISSION MONITORING
# ==========================================
st.subheader("Submission Monitor")
try:
sub_res = requests.get(f"{API_URL}/submission_status")
res_json = sub_res.json()
if res_json.get("status") == "error":
# If no run is active, backend returns an error state.
st.warning(f"Warning: {res_json.get('message')}")
st.info("Click 'Start Training' to begin a new session.")
else:
# Data exists only when status == "success".
sub_status = res_json.get("data", {})
curr = sub_status.get("current_count", 0)
total = sub_status.get("total_required", 0)
progress_str = sub_status.get("progress", "0/0")
is_ready = sub_status.get("is_ready", False)
all_ids = sub_status.get("all_client_ids", [])
submitted_ids = sub_status.get("submitted_clients", [])
# --- Progress dashboard ---
m_col1, m_col2 = st.columns([1, 3])
with m_col1:
st.metric(label="Clients Ready", value=progress_str)
with m_col2:
st.write(f"**Pending round:** {sub_status.get('pending_round')}")
st.progress(curr / total if total > 0 else 0)
st.write("**Detailed Status:**")
if all_ids:
cols = st.columns(len(all_ids))
for i, cid in enumerate(all_ids):
with cols[i]:
if cid in submitted_ids:
st.success(f"Submitted: {cid}")
else:
st.error(f"Missing: {cid}")
st.markdown("---")
# ==========================================
# 4. Aggregation action area
# ==========================================
with st.container():
if is_ready:
st.success(f"Ready: received all {progress_str} submissions.")
if st.button("Aggregate and Start Next Round", type="primary", use_container_width=True):
with st.spinner("Running aggregation..."):
try:
agg_res = requests.post(f"{API_URL}/aggregate_manual")
agg_data = agg_res.json()
if agg_data.get("status") == "success":
st.success(f"{agg_data.get('message')}")
import time
time.sleep(1.5)
st.rerun()
else:
st.error(f"Error: {agg_data.get('message')}")
except Exception as e:
st.error(str(e))
else:
st.warning(f"Waiting for {total - curr} more client submissions...")
st.button("Aggregate Submitted Models (Waiting...)",
type="secondary", use_container_width=True, disabled=True)
except Exception as e:
st.error(f"Submission monitor connection error: {e}")
st.markdown("---")
def render_server_validation_test_upload():
"""
Renders the Server Validation/Test Upload panel.
Handles uploading and deleting server-side evaluation datasets (.pkl).
"""
simple_block_header("Server Validation/Test Upload")
st.info("Upload `.pkl` datasets for server-side global model evaluation.")
# 1. Fetch current system state to get upload_status
try:
res = requests.get(f"{API_URL}/config")
res_data = res.json()
system_state = res_data.get("data", {}) if res_data.get("status") == "success" else {}
except Exception as e:
st.error(f"Cannot connect to the backend server. Details: {str(e)}")
return
upload_status = system_state.get("upload_status", {})
valid_exists = upload_status.get("valid", {}).get("uploaded", False)
test_exists = upload_status.get("test", {}).get("uploaded", False)
# ==========================================
# DATASET STATUS CARDS
# ==========================================
st.subheader("Current Data Status")
col_stat1, col_stat2 = st.columns(2)
with col_stat1:
if valid_exists:
st.success("**Validation Dataset:** Available")
# st.caption(f"Details: {upload_status.get('valid')}")
else:
st.error("**Validation Dataset:** Missing")
with col_stat2:
if test_exists:
st.success("**Test Dataset:** Available")
# st.caption(f"Details: {upload_status.get('test')}")
else:
st.error("**Test Dataset:** Missing")
st.markdown("---")
# ==========================================
# UPLOAD FORM
# ==========================================
st.subheader("Upload Dataset")
split_type = st.selectbox("Select Split Type", ["valid", "test"],
help="Choose whether this data is for validation or testing.")
uploaded_file = st.file_uploader(f"Choose a .pkl file for the '{split_type}' split", type=["pkl"])
if st.button("Upload File", type="primary"):
if uploaded_file is not None:
with st.spinner(f"Uploading {uploaded_file.name}..."):
try:
# Prepare the file payload for FastAPI (multipart/form-data)
files = {"file": (uploaded_file.name, uploaded_file.getvalue(), "application/octet-stream")}
upload_res = requests.post(f"{API_URL}/data/{split_type}", files=files)
data = upload_res.json()
if data.get("status") == "success":
st.success(data.get("message", "File uploaded successfully."))
st.rerun()
else:
st.error(f"Upload failed: {data.get('message')}")
except Exception as e:
st.error(f"Request error: {str(e)}")
else:
st.warning("Please select a file to upload first.")
st.markdown("---")
# ==========================================
# DANGER ZONE (DELETION)
# ==========================================
st.subheader("Danger Zone")
del_col1, del_col2, del_col3 = st.columns(3)
with del_col1:
if st.button("Delete Valid Data", disabled=not valid_exists, use_container_width=True):
try:
del_res = requests.delete(f"{API_URL}/data/valid")
if del_res.json().get("status") == "success":
st.toast("Validation data deleted.")
st.rerun()
else:
st.error(del_res.json().get("message"))
except Exception as e:
st.error(str(e))
with del_col2:
if st.button("Delete Test Data", disabled=not test_exists, use_container_width=True):
try:
del_res = requests.delete(f"{API_URL}/data/test")
if del_res.json().get("status") == "success":
st.toast("Test data deleted.")
st.rerun()
else:
st.error(del_res.json().get("message"))
except Exception as e:
st.error(str(e))
with del_col3:
if st.button("Clear ALL Data", type="primary", disabled=not (valid_exists or test_exists),
use_container_width=True):
try:
del_all_res = requests.delete(f"{API_URL}/data")
if del_all_res.json().get("status") == "success":
st.toast("All server data cleared.")
st.rerun()
else:
st.error(del_all_res.json().get("message"))
except Exception as e:
st.error(str(e))
def render_dataset_client_summary():
"""
Renders the Dataset and Client Summary panel.
Displays registered clients, data volumes, and class distributions.
"""
simple_block_header("Dataset / Client Summary")
st.info("Overview of registered clients, their data volume, and label distribution.")
# 1. Fetch data from backend
try:
# Fetch profiles from API (returns List[ClientProfileResponse])
res_profiles = requests.get(f"{API_URL}/profiles")
# Keep config fetch if other global metadata is needed
res_config = requests.get(f"{API_URL}/config")
profiles_data = res_profiles.json().get("data", []) if res_profiles.status_code == 200 else []
except Exception as e:
st.error(f"Cannot connect to the backend server. Details: {str(e)}")
return
# Check whether any client has registered
if not profiles_data:
st.warning("No clients have registered their profiles yet. Start your client scripts to register them.")
return
# 2. Process data for visualization (profile data is a list of objects)
total_clients = len(profiles_data)
total_labeled = 0
total_unlabeled = 0
profile_list = []
class_hist_list = []
# Iterate through List[ClientProfileResponse]
for info in profiles_data:
cid = info.get("client_id")
labeled = info.get("labeled_count", 0)
unlabeled = info.get("unlabeled_count", 0)
total_labeled += labeled
total_unlabeled += unlabeled
# Build table row using fields from ClientProfileResponse
profile_list.append({
"Client ID": cid,
"Labeled Data": labeled,
"Unlabeled Data": unlabeled,
"Feature Dim": info.get("feature_dim", "N/A"),
"Model Name": info.get("model_name", "N/A"),
# Keep N/A when a field is unavailable in response
"Batch Size": info.get("batch_size", "N/A")
})
# Build class distribution chart data
hist = info.get("class_hist", {})
if hist:
for label, count in hist.items():
class_hist_list.append({
"Client": cid,
"Class": _map_label_display(label),
"Count": count
})
# ==========================================
# OVERALL METRICS
# ==========================================
st.subheader("Global Metrics")
col_m1, col_m2, col_m3 = st.columns(3)
col_m1.metric("Registered Clients", total_clients)
col_m2.metric("Total Labeled Samples", total_labeled)
col_m3.metric("Total Unlabeled Samples", total_unlabeled)
st.markdown("---")
# ==========================================
# CLIENT PROFILE TABLE
# ==========================================
st.subheader("Client Data Overview")
df_profiles = pd.DataFrame(profile_list)
st.dataframe(df_profiles, use_container_width=True, hide_index=True)
st.markdown("---")
# ==========================================
# DATA DISTRIBUTION CHARTS
# ==========================================
st.subheader("Data Distribution Analysis")
col_chart1, col_chart2 = st.columns(2)
with col_chart1:
st.markdown("**Data Volume per Client**")
if not df_profiles.empty:
# Bar chart showing Labeled vs Unlabeled side by side
df_volume = df_profiles.set_index("Client ID")[["Labeled Data", "Unlabeled Data"]]
st.bar_chart(df_volume)
with col_chart2:
st.markdown("**Class Label Distribution**")
if class_hist_list:
df_hist = pd.DataFrame(class_hist_list)
# Convert Class to string so chart X-axis is categorical
df_hist['Class'] = df_hist['Class'].astype(str)
# Pivot by class and client for grouped class count visualization
df_pivot = df_hist.pivot(index='Class', columns='Client', values='Count').fillna(0)
st.bar_chart(df_pivot)
else:
st.info("No class history (class_hist) provided by clients.")
st.markdown("---")
if st.button("Clear All Profiles", type="secondary"):
# Optional: add a confirmation dialog before destructive action
try:
res = requests.delete(f"{API_URL}/profiles")
if res.status_code == 200:
st.success("All client profiles have been cleared.")
st.rerun()
else:
st.error("Cannot clear profiles.")
except Exception as e:
st.error(f"Error: {e}")
def render_client_panels():
"""
Renders the Client Panels section.
Allows monitoring of individual clients and manual submission of model updates.
"""
simple_block_header("Client Panels")
st.info("Monitor specific clients and manually submit model weights (.pt) on their behalf.")
# 1. Fetch system state to get registered clients
try:
res = requests.get(f"{API_URL}/config")
res_data = res.json()
system_state = res_data.get("data", {}) if res_data.get("status") == "success" else {}
except Exception as e:
st.error(f"Cannot connect to the backend server. Details: {str(e)}")
return
profiles = system_state.get("profiles", {})
active_run = system_state.get("active_run")
if not profiles:
st.warning("No clients are currently registered. Please register clients first.")
return
# 2. Select a Client
client_ids = list(profiles.keys())
selected_client = st.selectbox("Select a Client to View/Manage", client_ids)
st.markdown("---")
client_info = profiles[selected_client]
# ==========================================
# CLIENT INFORMATION
# ==========================================
st.subheader(f"Profile: {selected_client}")
col_info1, col_info2, col_info3 = st.columns(3)
col_info1.metric("Labeled Data", client_info.get("labeled_count", 0))
col_info2.metric("Unlabeled Data", client_info.get("unlabeled_count", 0))
col_info3.metric("Batch Size", client_info.get("batch_size", "N/A"))
with st.expander("View Raw Profile Data"):
st.json(client_info)
st.markdown("---")
# ==========================================
# MANUAL MODEL SUBMISSION FORM
# ==========================================
st.subheader("Manual Model Submission")
if active_run and active_run.get("status") == "waiting_responses":
st.success(
f"Server is currently waiting for responses for Run: **{active_run.get('run_id')}** (Round {active_run.get('pending_round')})")
with st.form(f"submit_model_form_{selected_client}"):
st.write("Submit weights directly to the server for this client.")
# Form Inputs
num_samples = st.number_input(
"Number of Samples (Weight for FedAvg)",
min_value=1,
value=int(client_info.get("labeled_count", 100)),
help="The number of data samples used to train this model (used for weighting in FedAvg)."
)
metrics_input = st.text_area(
"Training Metrics (JSON format)",
value='{\n "loss": 0.45,\n "accuracy": 0.82\n}',
help="Provide metrics like loss and accuracy as a JSON string."
)
uploaded_model = st.file_uploader("Upload Model Weights (.pt)", type=["pt", "bin", "pth"])
submit_btn = st.form_submit_button("Submit Model Update", type="primary", use_container_width=True)
if submit_btn:
# 1. Validate JSON metrics
try:
parsed_metrics = json.loads(metrics_input)
valid_json_str = json.dumps(parsed_metrics)
except json.JSONDecodeError:
st.error("Invalid JSON format in Training Metrics.")
st.stop()
# 2. Check File
if uploaded_model is None:
st.error("Please upload a model weights file (.pt).")
else:
with st.spinner("Submitting model to server..."):
try:
# Prepare form data
data = {
"client_id": selected_client,
"num_samples": num_samples,
"metrics": valid_json_str
}
# Prepare file
files = {
"file": (uploaded_model.name, uploaded_model.getvalue(), "application/octet-stream")
}
# Send request
submit_res = requests.post(f"{API_URL}/model", data=data, files=files)
res_json = submit_res.json()
if res_json.get("status") == "success":
st.success(res_json.get("message"))
else:
st.error(f"Submission failed: {res_json.get('message')}")
except Exception as e:
st.error(f"Request error: {str(e)}")
else:
st.info(
"Manual submission is disabled. The server is not currently waiting for client responses (no active pending round).")
def render_global_server_panel():
"""
Renders the Global / Server Panel.
Displays active run details, global model status, and aggregation history.
"""
simple_block_header("Global / Server Panel")
st.info("Monitor the central orchestrator state, active run details, and aggregation history.")
# 1. Fetch system state from backend
try:
res = requests.get(f"{API_URL}/config")
res_data = res.json()
system_state = res_data.get("data", {}) if res_data.get("status") == "success" else {}
except Exception as e:
st.error(f"Cannot connect to the backend server. Details: {str(e)}")
return
active_run = system_state.get("active_run")
history_data = system_state.get("history", {})
# ==========================================
# ACTIVE RUN DETAILS
# ==========================================
st.subheader("Current Active Run")
if active_run:
status_color = "" if active_run.get("status") not in ["failed", "stopped"] else ""
st.markdown(f"**Run ID:** `{active_run.get('run_id', 'N/A')}`")
col1, col2, col3, col4 = st.columns(4)
col1.metric("Status", f"{status_color} {active_run.get('status', 'N/A').upper()}")
col2.metric("Current Stage", active_run.get("current_stage", active_run.get("stage", "N/A")).upper())
col3.metric("Current Round", active_run.get("current_round", 0))
col4.metric("Pending Round", active_run.get("pending_round", "N/A"))
with st.expander("View Raw Active Run State"):
st.json(active_run)
else:
st.info("No active run is currently in progress. Start training in the 'Server Control' panel.")
st.markdown("---")
# ==========================================
# GLOBAL MODEL DOWNLOAD
# ==========================================
st.subheader("Global Model Status")
st.write("Fetch and download the latest aggregated global model weights (`.pt` file).")
if st.button("Fetch Latest Global Model"):
with st.spinner("Connecting to server to fetch model weights..."):
try:
model_res = requests.get(f"{API_URL}/model")
content_type = model_res.headers.get('Content-Type', '')
if 'application/json' in content_type:
err_data = model_res.json()
st.error(f"Cannot fetch model: {err_data.get('message', 'Unknown error')}")
elif model_res.status_code == 200:
st.success("Model weights fetched successfully! Click below to save to disk.")
st.download_button(
label="Save global_model.pt",
data=model_res.content,
file_name="global_model.pt",
mime="application/octet-stream",
type="primary"
)
else:
st.error(f"Unexpected error: HTTP {model_res.status_code}")
except Exception as e:
st.error(f"Request error: {str(e)}")
st.markdown("---")
# ==========================================
# AGGREGATION HISTORY TABLE
# ==========================================
st.subheader("Aggregation History")
# Read round records from history_data
rounds_list = history_data.get("rounds", [])
if rounds_list:
try:
df_history = pd.DataFrame(rounds_list)
# Hide client_metrics because long JSON blobs reduce table readability.
if "client_metrics" in df_history.columns:
df_history = df_history.drop(columns=["client_metrics"])
# Keep 'round' as the first column.
cols = ['round'] + [c for c in df_history.columns if c != 'round']
df_history = df_history[cols]
st.dataframe(df_history, use_container_width=True, hide_index=True)
except Exception as e:
st.error(f"Could not parse history table: {e}")
else:
st.info("No aggregation history available yet. The table will populate once rounds are completed.")
def render_charts():
simple_block_header("Training Performance Analytics")
try:
# Fetch chart data from API
res = requests.get(f"{API_URL}/charts")
history_data = res.json().get("data", {})
except Exception as e:
st.error(f"Cannot connect to backend: {e}")
return
# Read list of rounds from 'rounds' key
rounds_list = history_data.get("rounds", [])
if not rounds_list:
st.warning("No training rounds recorded yet. Start training to see results.")
return
# 1. Convert list to DataFrame
df = pd.DataFrame(rounds_list)
# Use round as chart X-axis index
df.set_index("round", inplace=True)
# 2. Tabs
tab1, tab2, tab3 = st.tabs(["Global Accuracy", "Local Loss (Clients)", "Global F1 Score"])
# ==========================================
# TAB 1: GLOBAL ACCURACY
# ==========================================
with tab1:
st.subheader("Global Accuracy (Validation vs Test)")
# Check columns before plotting to avoid missing-key errors
cols_to_plot = [c for c in ["val_accuracy", "test_accuracy"] if c in df.columns]
if cols_to_plot:
st.line_chart(df[cols_to_plot])
else:
st.info("Accuracy data is not available yet.")
# ==========================================
# TAB 2: LOCAL LOSS (multi-client)
# ==========================================
with tab2:
st.subheader("Local Training Loss")
all_clients = set()
for r in rounds_list:
c_metrics = r.get("client_metrics", {})
if isinstance(c_metrics, dict):
all_clients.update(c_metrics.keys())
all_clients = sorted(list(all_clients))
if not all_clients:
st.info("No client metrics found in history.")
else:
# Select client to inspect
selected_client = st.selectbox("Select Client to inspect:", all_clients)
# Extract selected client loss per round
loss_data = {}
for r_idx, row in df.iterrows():
c_metrics = row.get("client_metrics", {}).get(selected_client, {})
if c_metrics:
p_loss = c_metrics.get("pretrain_loss", 0)
s_loss = c_metrics.get("ssl_labeled_loss", 0)
# Convert 0 to None so missing phase points are not connected
loss_data[r_idx] = {
"Pretrain Loss": p_loss if p_loss > 0 else None,
"SSL Loss": s_loss if s_loss > 0 else None
}
if loss_data:
df_loss = pd.DataFrame.from_dict(loss_data, orient='index')
# Render as two separate charts
st.markdown("#### Phase 1: Pre-training Loss")
st.caption("Phase where the model learns from local labeled data.")
# Plot only Pretrain Loss column
st.line_chart(df_loss[["Pretrain Loss"]])
st.markdown("#### Phase 2: SSL Labeled Loss")
st.caption("Semi-supervised learning phase.")
# Plot only SSL Loss column
st.line_chart(df_loss[["SSL Loss"]])
else:
st.warning(f"No loss data available for {selected_client} in the selected rounds.")
# ==========================================
# TAB 3: GLOBAL F1 SCORE
# ==========================================
with tab3:
st.subheader("Global Weighted F1 Score")
cols_to_plot_f1 = [c for c in ["val_WF1", "test_WF1"] if c in df.columns]
if cols_to_plot_f1:
st.line_chart(df[cols_to_plot_f1])
else:
st.info("F1 Score data is not available yet.")
# ==========================================
# 3. Show detailed raw metrics table
# ==========================================
st.markdown("---")
with st.expander("View Full Raw Metrics Table"):
# Copy DataFrame and hide verbose client_metrics column for readability
df_display = df.copy()
if "client_metrics" in df_display.columns:
df_display.drop(columns=["client_metrics"], inplace=True)
st.dataframe(df_display, use_container_width=True)
def render_server_controls_combined():
simple_block_header("Server Control")
st.info("Manage automatic and manual orchestration in one section.")
render_global_configuration()
st.divider()
tab_auto, tab_manual = st.tabs(["Server Control", "Server Control Manual"])
with tab_auto:
render_server_controls()
with tab_manual:
render_server_controls_manual()
def render_single_page_layout():
# simple_block_header("Federated Learning Admin")
# Thay thế st.title bằng dòng này
st.markdown(
"""
Federated Learning Server
""",
unsafe_allow_html=True
)
render_server_validation_test_upload()
st.divider()
render_dataset_client_summary()
st.divider()
render_server_controls_combined()
# st.divider()
render_global_server_panel()
st.divider()
render_charts()
def render_inference_demo():
simple_block_header("Emotion Recognition Inference")
st.markdown("Upload an audio file to predict emotion using the trained global AI model.")
col_config, col_main = st.columns([1, 2], gap="large")
with col_config:
st.subheader("Model Configuration")
with st.container(border=True):
fl_method = st.selectbox("FL Method", ["FedAvg", "FedProx"])
# Available label-ratio tokens used by checkpoints.
label_ratio = st.selectbox("Label Ratio", ["0p3", "0p5", "0p7", "1"], index=3)
num_clients = st.number_input("Num Clients", min_value=1, max_value=100, value=5)
model_name = st.selectbox("Model Architecture", ["fedalmer", "memocmt", "fleser", "threemser", "cemobam"])
device_pref = st.radio("Processing Device", ["cpu", "cuda"], horizontal=True)
with col_main:
st.subheader("Audio Upload")
uploaded_file = st.file_uploader("Drag and drop an audio file here (.wav, .mp3, .flac)",
type=["wav", "mp3", "flac"])
# Preview player appears only after upload.
if uploaded_file is not None:
st.audio(uploaded_file, format="audio/wav")
st.markdown("
", unsafe_allow_html=True)
if st.button("Analyze and Predict Emotion", type="primary", use_container_width=True):
if uploaded_file is None:
st.warning("Please upload an audio file before running prediction.")
else:
with st.spinner("Analyzing audio... (first model load may take longer)"):
try:
# 1. Build form payload.
data_payload = {
"fl_method": fl_method,
"label_ratio_token": label_ratio,
"num_clients": num_clients,
"model_name": model_name,
"device_pref": device_pref
}
# 2. Build file payload.
files_payload = {
"file": (uploaded_file.name, uploaded_file.getvalue(), "audio/wav")
}
# 3. Send request to FastAPI.
res = requests.post(f"{API_URL}/predict", data=data_payload, files=files_payload, timeout=30.00)
# 4. Handle response.
if res.status_code == 200:
res_json = res.json()
if res_json.get("status") == "success":
result = res_json.get("data", {})
st.success("Prediction completed.")
st.divider()
# Show result metrics.
m1, m2 = st.columns(2)
emotion_str = str(result.get('emotion', 'Unknown')).title()
confidence = result.get('confidence', 0)
m1.metric(label="Predicted Emotion", value=emotion_str)
m2.metric(label="Confidence", value=f"{confidence:.2%}")
with st.expander("ASR Transcript", expanded=True):
st.write(f"*{result.get('transcript', 'Speech not recognized')}*")
with st.expander("Probability Distribution"):
probs = result.get("probabilities", {})
if probs:
# Render simple probability bar chart.
import pandas as pd
df_probs = pd.DataFrame(list(probs.items()), columns=["Emotion", "Probability"])
df_probs.set_index("Emotion", inplace=True)
st.bar_chart(df_probs)
else:
st.json(probs)
with st.expander("System Metadata"):
st.json({
"Loaded Checkpoint": result.get("checkpoint"),
"Audio Duration": f"{result.get('audio_seconds', 0):.2f} seconds",
"Model": result.get("model_name")
})
else:
st.error(f"Server Error: {res_json.get('message')}")
else:
st.error(f"HTTP Error {res.status_code}: {res.text}")
except Exception as e:
st.error(f"System connection error: {str(e)}")
render_single_page_layout()