import json import os import pandas as pd import requests import streamlit as st API_URL = os.getenv("API_URL", "http://localhost:8000/api") # ========================================== # 1. PAGE SETTINGS # ========================================== st.set_page_config( page_title="Federated Learning Admin", layout="wide", page_icon="settings" ) # ========================================== # 2. SIDEBAR NAVIGATION # ========================================== st.sidebar.title("FL Orchestrator") st.sidebar.markdown("---") # Define the exact menu options you requested menu_options = [ "Server Control", "Server Control Manual", "Server Validation/Test Upload", "Dataset / Client Summary", # "Client Panels", "Global / Server Panel", "Charts", # "Inference Demo", ] # Create the radio button menu in the sidebar choice = st.sidebar.radio("Navigation", menu_options) st.sidebar.markdown("---") if st.sidebar.button("Refresh Application"): st.rerun() # ========================================== # 3. MAIN CONTENT ROUTING # ========================================== def render_server_controls(): """ Renders the Server Control panel. Handles training execution and global server configuration updates. """ st.title("Server Control") st.info("Manage training sessions and update global server configuration.") # 1. Fetch current system state from backend try: res = requests.get(f"{API_URL}/config") res_data = res.json() system_state = res_data.get("data", {}) if res_data.get("status") == "success" else {} except Exception as e: st.error(f"Cannot connect to the backend server. Is it running? Details: {str(e)}") return active_run = system_state.get("active_run") current_setup = system_state.get("current_setup", {}) # ========================================== # EXECUTION CONTROLS # ========================================== st.subheader("Execution Controls") col1, col2 = st.columns(2) with col1: if st.button("Start Training", type="primary", use_container_width=True): with st.spinner("Initializing training session..."): try: train_res = requests.post(f"{API_URL}/train") data = train_res.json() if data.get("status") == "success": st.success(data.get("message", "Training started successfully.")) st.rerun() else: st.error(f"Failed to start: {data.get('message')}") except Exception as e: st.error(f"Request error: {str(e)}") with col2: if st.button("Stop Training", type="secondary", use_container_width=True): with st.spinner("Stopping current session..."): try: stop_res = requests.delete(f"{API_URL}/train") data = stop_res.json() if data.get("status") == "success": st.warning(data.get("message", "Training stopped.")) st.rerun() else: st.error(f"Failed to stop: {data.get('message')}") except Exception as e: st.error(f"Request error: {str(e)}") # Display active run status if it exists if active_run and active_run.get("status") not in ["completed", "failed"]: st.success(f"**Active Session:** {active_run.get('run_id')} **Status:** {active_run.get('status').upper()}") else: st.info("No active training session currently running.") st.markdown("---") # ========================================== # GLOBAL CONFIGURATION FORM # ========================================== st.subheader("Global Configuration") with st.form("server_config_form"): col_cfg1, col_cfg2 = st.columns(2) with col_cfg1: rounds_pretrain = st.number_input("Rounds Pretrain", min_value=1, value=int(current_setup.get("rounds_pretrain", 1))) rounds_ssl = st.number_input("Rounds SSL", min_value=0, value=int(current_setup.get("rounds_ssl", 1))) tau = st.number_input("Tau (Threshold)", min_value=0.0, max_value=1.0, value=float(current_setup.get("tau", 0.5)), step=0.05) ssl_weight = st.number_input("SSL Weight", min_value=0.0, value=float(current_setup.get("ssl_weight", 1.0)), step=0.1) with col_cfg2: seed = st.number_input("Random Seed", value=int(current_setup.get("seed", 42))) # Setup Aggregator selectbox agg_options = ["fedavg", "fedprox"] current_agg = current_setup.get("aggregator", "fedavg") agg_index = agg_options.index(current_agg) if current_agg in agg_options else 0 aggregator = st.selectbox("Aggregator", agg_options, index=agg_index) fedprox_mu = st.number_input("FedProx Mu", min_value=0.0, value=float(current_setup.get("fedprox_mu", 0.01)), step=0.01) # Setup Weighting method selectbox weight_options = ["num_samples", "equal"] current_weight = current_setup.get("weight_by", "num_samples") weight_index = weight_options.index(current_weight) if current_weight in weight_options else 0 weight_by = st.selectbox("Weight By", weight_options, index=weight_index) submit_cfg = st.form_submit_button("Save Configuration", use_container_width=True) if submit_cfg: payload = { "rounds_pretrain": rounds_pretrain, "rounds_ssl": rounds_ssl, "tau": tau, "ssl_weight": ssl_weight, "seed": seed, "aggregator": aggregator, "fedprox_mu": fedprox_mu, "weight_by": weight_by } with st.spinner("Saving configuration..."): try: cfg_res = requests.post(f"{API_URL}/config", json=payload) cfg_data = cfg_res.json() if cfg_data.get("status") == "success": st.success("Configuration updated successfully.") st.rerun() else: st.error(f"Update failed: {cfg_data.get('message')}") except Exception as e: st.error(f"Request failed: {str(e)}") def render_server_controls_manual(): """ Renders the Manual Server Control panel. Allows for step-by-step orchestration of the FL process. """ st.title("Server Control (Manual Mode)") st.info("Human-in-the-loop: start rounds, monitor progress, and aggregate models manually.") # 1. Fetch current system state try: res = requests.get(f"{API_URL}/config") system_state = res.json().get("data", {}) active_run = system_state.get("active_run") current_setup = system_state.get("current_setup", {}) except Exception as e: st.error(f"Backend connection error: {str(e)}") return # ========================================== # 1. EXECUTION CONTROLS (Manual Start/Stop) # ========================================== st.subheader("Orchestration") col1, col2 = st.columns(2) with col1: # Start manual mode and broadcast NEW_ROUND_MANUAL if st.button("Start / Broadcast Next Round", type="primary", use_container_width=True): with st.spinner("Broadcasting command..."): try: res = requests.post(f"{API_URL}/train_manual") data = res.json() if data.get("status") == "success": st.success(data.get("message")) st.rerun() else: st.error(f"Error: {data.get('message')}") except Exception as e: st.error(str(e)) with col2: # Stop manual mode if st.button("Stop Manual Session", type="secondary", use_container_width=True): with st.spinner("Stopping..."): try: res = requests.delete(f"{API_URL}/train_manual") if res.json().get("status") == "success": st.warning("Manual session stopped.") st.rerun() except Exception as e: st.error(str(e)) st.markdown("---") # ========================================== # 2. SUBMISSION MONITORING # ========================================== st.subheader("Submission Monitor") try: sub_res = requests.get(f"{API_URL}/submission_status") res_json = sub_res.json() if res_json.get("status") == "error": # If no run is active, backend returns an error state. st.warning(f"Warning: {res_json.get('message')}") st.info("Click 'Start Training' to begin a new session.") else: # Data exists only when status == "success". sub_status = res_json.get("data", {}) curr = sub_status.get("current_count", 0) total = sub_status.get("total_required", 0) progress_str = sub_status.get("progress", "0/0") is_ready = sub_status.get("is_ready", False) all_ids = sub_status.get("all_client_ids", []) submitted_ids = sub_status.get("submitted_clients", []) # --- Progress dashboard --- m_col1, m_col2 = st.columns([1, 3]) with m_col1: st.metric(label="Clients Ready", value=progress_str) with m_col2: st.write(f"**Pending round:** {sub_status.get('pending_round')}") st.progress(curr / total if total > 0 else 0) st.write("**Detailed Status:**") if all_ids: cols = st.columns(len(all_ids)) for i, cid in enumerate(all_ids): with cols[i]: if cid in submitted_ids: st.success(f"Submitted: {cid}") else: st.error(f"Missing: {cid}") st.markdown("---") # ========================================== # 4. Aggregation action area # ========================================== with st.container(): if is_ready: st.success(f"Ready: received all {progress_str} submissions.") if st.button("Aggregate and Start Next Round", type="primary", use_container_width=True): with st.spinner("Running aggregation..."): try: agg_res = requests.post(f"{API_URL}/aggregate_manual") agg_data = agg_res.json() if agg_data.get("status") == "success": st.success(f"{agg_data.get('message')}") import time time.sleep(1.5) st.rerun() else: st.error(f"Error: {agg_data.get('message')}") except Exception as e: st.error(str(e)) else: st.warning(f"Waiting for {total - curr} more client submissions...") st.button("Aggregate Submitted Models (Waiting...)", type="secondary", use_container_width=True, disabled=True) except Exception as e: st.error(f"Submission monitor connection error: {e}") st.markdown("---") # ========================================== # 3. GLOBAL CONFIGURATION (reuse existing form) # ========================================== st.subheader("Global Configuration") with st.form("server_config_form"): col_cfg1, col_cfg2 = st.columns(2) with col_cfg1: rounds_pretrain = st.number_input("Rounds Pretrain", min_value=1, value=int(current_setup.get("rounds_pretrain", 1))) rounds_ssl = st.number_input("Rounds SSL", min_value=0, value=int(current_setup.get("rounds_ssl", 1))) tau = st.number_input("Tau (Threshold)", min_value=0.0, max_value=1.0, value=float(current_setup.get("tau", 0.5)), step=0.05) ssl_weight = st.number_input("SSL Weight", min_value=0.0, value=float(current_setup.get("ssl_weight", 1.0)), step=0.1) with col_cfg2: seed = st.number_input("Random Seed", value=int(current_setup.get("seed", 42))) # Setup Aggregator selectbox agg_options = ["fedavg", "fedprox"] current_agg = current_setup.get("aggregator", "fedavg") agg_index = agg_options.index(current_agg) if current_agg in agg_options else 0 aggregator = st.selectbox("Aggregator", agg_options, index=agg_index) fedprox_mu = st.number_input("FedProx Mu", min_value=0.0, value=float(current_setup.get("fedprox_mu", 0.01)), step=0.01) # Setup Weighting method selectbox weight_options = ["num_samples", "equal"] current_weight = current_setup.get("weight_by", "num_samples") weight_index = weight_options.index(current_weight) if current_weight in weight_options else 0 weight_by = st.selectbox("Weight By", weight_options, index=weight_index) submit_cfg = st.form_submit_button("Save Configuration", use_container_width=True) if submit_cfg: payload = { "rounds_pretrain": rounds_pretrain, "rounds_ssl": rounds_ssl, "tau": tau, "ssl_weight": ssl_weight, "seed": seed, "aggregator": aggregator, "fedprox_mu": fedprox_mu, "weight_by": weight_by } with st.spinner("Saving configuration..."): try: cfg_res = requests.post(f"{API_URL}/config", json=payload) cfg_data = cfg_res.json() if cfg_data.get("status") == "success": st.success("Configuration updated successfully.") st.rerun() else: st.error(f"Update failed: {cfg_data.get('message')}") except Exception as e: st.error(f"Request failed: {str(e)}") def render_server_validation_test_upload(): """ Renders the Server Validation/Test Upload panel. Handles uploading and deleting server-side evaluation datasets (.pkl). """ st.title("Server Validation/Test Upload") st.info("Upload `.pkl` datasets for server-side global model evaluation.") # 1. Fetch current system state to get upload_status try: res = requests.get(f"{API_URL}/config") res_data = res.json() system_state = res_data.get("data", {}) if res_data.get("status") == "success" else {} except Exception as e: st.error(f"Cannot connect to the backend server. Details: {str(e)}") return upload_status = system_state.get("upload_status", {}) valid_exists = upload_status.get("valid", {}).get("uploaded", False) test_exists = upload_status.get("test", {}).get("uploaded", False) # ========================================== # DATASET STATUS CARDS # ========================================== st.subheader("Current Data Status") col_stat1, col_stat2 = st.columns(2) with col_stat1: if valid_exists: st.success("**Validation Dataset:** Available") # st.caption(f"Details: {upload_status.get('valid')}") else: st.error("**Validation Dataset:** Missing") with col_stat2: if test_exists: st.success("**Test Dataset:** Available") # st.caption(f"Details: {upload_status.get('test')}") else: st.error("**Test Dataset:** Missing") st.markdown("---") # ========================================== # UPLOAD FORM # ========================================== st.subheader("Upload Dataset") split_type = st.selectbox("Select Split Type", ["valid", "test"], help="Choose whether this data is for validation or testing.") uploaded_file = st.file_uploader(f"Choose a .pkl file for the '{split_type}' split", type=["pkl"]) if st.button("Upload File", type="primary"): if uploaded_file is not None: with st.spinner(f"Uploading {uploaded_file.name}..."): try: # Prepare the file payload for FastAPI (multipart/form-data) files = {"file": (uploaded_file.name, uploaded_file.getvalue(), "application/octet-stream")} upload_res = requests.post(f"{API_URL}/data/{split_type}", files=files) data = upload_res.json() if data.get("status") == "success": st.success(data.get("message", "File uploaded successfully.")) st.rerun() else: st.error(f"Upload failed: {data.get('message')}") except Exception as e: st.error(f"Request error: {str(e)}") else: st.warning("Please select a file to upload first.") st.markdown("---") # ========================================== # DANGER ZONE (DELETION) # ========================================== st.subheader("Danger Zone") del_col1, del_col2, del_col3 = st.columns(3) with del_col1: if st.button("Delete Valid Data", disabled=not valid_exists, use_container_width=True): try: del_res = requests.delete(f"{API_URL}/data/valid") if del_res.json().get("status") == "success": st.toast("Validation data deleted.") st.rerun() else: st.error(del_res.json().get("message")) except Exception as e: st.error(str(e)) with del_col2: if st.button("Delete Test Data", disabled=not test_exists, use_container_width=True): try: del_res = requests.delete(f"{API_URL}/data/test") if del_res.json().get("status") == "success": st.toast("Test data deleted.") st.rerun() else: st.error(del_res.json().get("message")) except Exception as e: st.error(str(e)) with del_col3: if st.button("Clear ALL Data", type="primary", disabled=not (valid_exists or test_exists), use_container_width=True): try: del_all_res = requests.delete(f"{API_URL}/data") if del_all_res.json().get("status") == "success": st.toast("All server data cleared.") st.rerun() else: st.error(del_all_res.json().get("message")) except Exception as e: st.error(str(e)) def render_dataset_client_summary(): """ Renders the Dataset and Client Summary panel. Displays registered clients, data volumes, and class distributions. """ st.title("Dataset / Client Summary") st.info("Overview of registered clients, their data volume, and label distribution.") # 1. Fetch data from backend try: # Fetch profiles from API (returns List[ClientProfileResponse]) res_profiles = requests.get(f"{API_URL}/profiles") # Keep config fetch if other global metadata is needed res_config = requests.get(f"{API_URL}/config") profiles_data = res_profiles.json().get("data", []) if res_profiles.status_code == 200 else [] except Exception as e: st.error(f"Cannot connect to the backend server. Details: {str(e)}") return # Check whether any client has registered if not profiles_data: st.warning("No clients have registered their profiles yet. Start your client scripts to register them.") return # 2. Process data for visualization (profile data is a list of objects) total_clients = len(profiles_data) total_labeled = 0 total_unlabeled = 0 profile_list = [] class_hist_list = [] # Iterate through List[ClientProfileResponse] for info in profiles_data: cid = info.get("client_id") labeled = info.get("labeled_count", 0) unlabeled = info.get("unlabeled_count", 0) total_labeled += labeled total_unlabeled += unlabeled # Build table row using fields from ClientProfileResponse profile_list.append({ "Client ID": cid, "Labeled Data": labeled, "Unlabeled Data": unlabeled, "Feature Dim": info.get("feature_dim", "N/A"), "Model Name": info.get("model_name", "N/A"), # Keep N/A when a field is unavailable in response "Batch Size": info.get("batch_size", "N/A") }) # Build class distribution chart data hist = info.get("class_hist", {}) if hist: for label, count in hist.items(): class_hist_list.append({ "Client": cid, "Class": str(label), "Count": count }) # ========================================== # OVERALL METRICS # ========================================== st.subheader("Global Metrics") col_m1, col_m2, col_m3 = st.columns(3) col_m1.metric("Registered Clients", total_clients) col_m2.metric("Total Labeled Samples", total_labeled) col_m3.metric("Total Unlabeled Samples", total_unlabeled) st.markdown("---") # ========================================== # CLIENT PROFILE TABLE # ========================================== st.subheader("Client Data Overview") df_profiles = pd.DataFrame(profile_list) st.dataframe(df_profiles, use_container_width=True, hide_index=True) st.markdown("---") # ========================================== # DATA DISTRIBUTION CHARTS # ========================================== st.subheader("Data Distribution Analysis") col_chart1, col_chart2 = st.columns(2) with col_chart1: st.markdown("**Data Volume per Client**") if not df_profiles.empty: # Bar chart showing Labeled vs Unlabeled side by side df_volume = df_profiles.set_index("Client ID")[["Labeled Data", "Unlabeled Data"]] st.bar_chart(df_volume) with col_chart2: st.markdown("**Class Label Distribution**") if class_hist_list: df_hist = pd.DataFrame(class_hist_list) # Convert Class to string so chart X-axis is categorical df_hist['Class'] = df_hist['Class'].astype(str) # Pivot by class and client for grouped class count visualization df_pivot = df_hist.pivot(index='Class', columns='Client', values='Count').fillna(0) st.bar_chart(df_pivot) else: st.info("No class history (class_hist) provided by clients.") st.markdown("---") if st.button("Clear All Profiles", type="secondary"): # Optional: add a confirmation dialog before destructive action try: res = requests.delete(f"{API_URL}/profiles") if res.status_code == 200: st.success("All client profiles have been cleared.") st.rerun() else: st.error("Cannot clear profiles.") except Exception as e: st.error(f"Error: {e}") def render_client_panels(): """ Renders the Client Panels section. Allows monitoring of individual clients and manual submission of model updates. """ st.title("Client Panels") st.info("Monitor specific clients and manually submit model weights (.pt) on their behalf.") # 1. Fetch system state to get registered clients try: res = requests.get(f"{API_URL}/config") res_data = res.json() system_state = res_data.get("data", {}) if res_data.get("status") == "success" else {} except Exception as e: st.error(f"Cannot connect to the backend server. Details: {str(e)}") return profiles = system_state.get("profiles", {}) active_run = system_state.get("active_run") if not profiles: st.warning("No clients are currently registered. Please register clients first.") return # 2. Select a Client client_ids = list(profiles.keys()) selected_client = st.selectbox("Select a Client to View/Manage", client_ids) st.markdown("---") client_info = profiles[selected_client] # ========================================== # CLIENT INFORMATION # ========================================== st.subheader(f"Profile: {selected_client}") col_info1, col_info2, col_info3 = st.columns(3) col_info1.metric("Labeled Data", client_info.get("labeled_count", 0)) col_info2.metric("Unlabeled Data", client_info.get("unlabeled_count", 0)) col_info3.metric("Batch Size", client_info.get("batch_size", "N/A")) with st.expander("View Raw Profile Data"): st.json(client_info) st.markdown("---") # ========================================== # MANUAL MODEL SUBMISSION FORM # ========================================== st.subheader("Manual Model Submission") if active_run and active_run.get("status") == "waiting_responses": st.success( f"Server is currently waiting for responses for Run: **{active_run.get('run_id')}** (Round {active_run.get('pending_round')})") with st.form(f"submit_model_form_{selected_client}"): st.write("Submit weights directly to the server for this client.") # Form Inputs num_samples = st.number_input( "Number of Samples (Weight for FedAvg)", min_value=1, value=int(client_info.get("labeled_count", 100)), help="The number of data samples used to train this model (used for weighting in FedAvg)." ) metrics_input = st.text_area( "Training Metrics (JSON format)", value='{\n "loss": 0.45,\n "accuracy": 0.82\n}', help="Provide metrics like loss and accuracy as a JSON string." ) uploaded_model = st.file_uploader("Upload Model Weights (.pt)", type=["pt", "bin", "pth"]) submit_btn = st.form_submit_button("Submit Model Update", type="primary", use_container_width=True) if submit_btn: # 1. Validate JSON metrics try: parsed_metrics = json.loads(metrics_input) valid_json_str = json.dumps(parsed_metrics) except json.JSONDecodeError: st.error("Invalid JSON format in Training Metrics.") st.stop() # 2. Check File if uploaded_model is None: st.error("Please upload a model weights file (.pt).") else: with st.spinner("Submitting model to server..."): try: # Prepare form data data = { "client_id": selected_client, "num_samples": num_samples, "metrics": valid_json_str } # Prepare file files = { "file": (uploaded_model.name, uploaded_model.getvalue(), "application/octet-stream") } # Send request submit_res = requests.post(f"{API_URL}/model", data=data, files=files) res_json = submit_res.json() if res_json.get("status") == "success": st.success(res_json.get("message")) else: st.error(f"Submission failed: {res_json.get('message')}") except Exception as e: st.error(f"Request error: {str(e)}") else: st.info( "Manual submission is disabled. The server is not currently waiting for client responses (no active pending round).") def render_global_server_panel(): """ Renders the Global / Server Panel. Displays active run details, global model status, and aggregation history. """ st.title("Global / Server Panel") st.info("Monitor the central orchestrator state, active run details, and aggregation history.") # 1. Fetch system state from backend try: res = requests.get(f"{API_URL}/config") res_data = res.json() system_state = res_data.get("data", {}) if res_data.get("status") == "success" else {} except Exception as e: st.error(f"Cannot connect to the backend server. Details: {str(e)}") return active_run = system_state.get("active_run") history_data = system_state.get("history", {}) # ========================================== # ACTIVE RUN DETAILS # ========================================== st.subheader("Current Active Run") if active_run: status_color = "" if active_run.get("status") not in ["failed", "stopped"] else "" st.markdown(f"**Run ID:** `{active_run.get('run_id', 'N/A')}`") col1, col2, col3, col4 = st.columns(4) col1.metric("Status", f"{status_color} {active_run.get('status', 'N/A').upper()}") col2.metric("Current Stage", active_run.get("current_stage", active_run.get("stage", "N/A")).upper()) col3.metric("Current Round", active_run.get("current_round", 0)) col4.metric("Pending Round", active_run.get("pending_round", "N/A")) with st.expander("View Raw Active Run State"): st.json(active_run) else: st.info("No active run is currently in progress. Start training in the 'Server Control' panel.") st.markdown("---") # ========================================== # GLOBAL MODEL DOWNLOAD # ========================================== st.subheader("Global Model Status") st.write("Fetch and download the latest aggregated global model weights (`.pt` file).") if st.button("Fetch Latest Global Model"): with st.spinner("Connecting to server to fetch model weights..."): try: model_res = requests.get(f"{API_URL}/model") content_type = model_res.headers.get('Content-Type', '') if 'application/json' in content_type: err_data = model_res.json() st.error(f"Cannot fetch model: {err_data.get('message', 'Unknown error')}") elif model_res.status_code == 200: st.success("Model weights fetched successfully! Click below to save to disk.") st.download_button( label="Save global_model.pt", data=model_res.content, file_name="global_model.pt", mime="application/octet-stream", type="primary" ) else: st.error(f"Unexpected error: HTTP {model_res.status_code}") except Exception as e: st.error(f"Request error: {str(e)}") st.markdown("---") # ========================================== # AGGREGATION HISTORY TABLE # ========================================== st.subheader("Aggregation History") # Read round records from history_data rounds_list = history_data.get("rounds", []) if rounds_list: try: df_history = pd.DataFrame(rounds_list) # Hide client_metrics because long JSON blobs reduce table readability. if "client_metrics" in df_history.columns: df_history = df_history.drop(columns=["client_metrics"]) # Keep 'round' as the first column. cols = ['round'] + [c for c in df_history.columns if c != 'round'] df_history = df_history[cols] st.dataframe(df_history, use_container_width=True, hide_index=True) except Exception as e: st.error(f"Could not parse history table: {e}") else: st.info("No aggregation history available yet. The table will populate once rounds are completed.") def render_charts(): st.title("Training Performance Analytics") try: # Fetch chart data from API res = requests.get(f"{API_URL}/charts") history_data = res.json().get("data", {}) except Exception as e: st.error(f"Cannot connect to backend: {e}") return # Read list of rounds from 'rounds' key rounds_list = history_data.get("rounds", []) if not rounds_list: st.warning("No training rounds recorded yet. Start training to see results.") return # 1. Convert list to DataFrame df = pd.DataFrame(rounds_list) # Use round as chart X-axis index df.set_index("round", inplace=True) # 2. Tabs tab1, tab2, tab3 = st.tabs(["Global Accuracy", "Local Loss (Clients)", "Global F1 Score"]) # ========================================== # TAB 1: GLOBAL ACCURACY # ========================================== with tab1: st.subheader("Global Accuracy (Validation vs Test)") # Check columns before plotting to avoid missing-key errors cols_to_plot = [c for c in ["val_accuracy", "test_accuracy"] if c in df.columns] if cols_to_plot: st.line_chart(df[cols_to_plot]) else: st.info("Accuracy data is not available yet.") # ========================================== # TAB 2: LOCAL LOSS (multi-client) # ========================================== with tab2: st.subheader("Local Training Loss") all_clients = set() for r in rounds_list: c_metrics = r.get("client_metrics", {}) if isinstance(c_metrics, dict): all_clients.update(c_metrics.keys()) all_clients = sorted(list(all_clients)) if not all_clients: st.info("No client metrics found in history.") else: # Select client to inspect selected_client = st.selectbox("Select Client to inspect:", all_clients) # Extract selected client loss per round loss_data = {} for r_idx, row in df.iterrows(): c_metrics = row.get("client_metrics", {}).get(selected_client, {}) if c_metrics: p_loss = c_metrics.get("pretrain_loss", 0) s_loss = c_metrics.get("ssl_labeled_loss", 0) # Convert 0 to None so missing phase points are not connected loss_data[r_idx] = { "Pretrain Loss": p_loss if p_loss > 0 else None, "SSL Loss": s_loss if s_loss > 0 else None } if loss_data: df_loss = pd.DataFrame.from_dict(loss_data, orient='index') # Render as two separate charts st.markdown("#### Phase 1: Pre-training Loss") st.caption("Phase where the model learns from local labeled data.") # Plot only Pretrain Loss column st.line_chart(df_loss[["Pretrain Loss"]]) st.markdown("#### Phase 2: SSL Labeled Loss") st.caption("Semi-supervised learning phase.") # Plot only SSL Loss column st.line_chart(df_loss[["SSL Loss"]]) else: st.warning(f"No loss data available for {selected_client} in the selected rounds.") # ========================================== # TAB 3: GLOBAL F1 SCORE # ========================================== with tab3: st.subheader("Global Weighted F1 Score") cols_to_plot_f1 = [c for c in ["val_WF1", "test_WF1"] if c in df.columns] if cols_to_plot_f1: st.line_chart(df[cols_to_plot_f1]) else: st.info("F1 Score data is not available yet.") # ========================================== # 3. Show detailed raw metrics table # ========================================== st.markdown("---") with st.expander("View Full Raw Metrics Table"): # Copy DataFrame and hide verbose client_metrics column for readability df_display = df.copy() if "client_metrics" in df_display.columns: df_display.drop(columns=["client_metrics"], inplace=True) st.dataframe(df_display, use_container_width=True) def render_inference_demo(): st.header("Emotion Recognition Inference") st.markdown("Upload an audio file to predict emotion using the trained global AI model.") col_config, col_main = st.columns([1, 2], gap="large") with col_config: st.subheader("Model Configuration") with st.container(border=True): fl_method = st.selectbox("FL Method", ["FedAvg", "FedProx"]) # Available label-ratio tokens used by checkpoints. label_ratio = st.selectbox("Label Ratio", ["0p3", "0p5", "0p7", "1"], index=3) num_clients = st.number_input("Num Clients", min_value=1, max_value=100, value=5) model_name = st.selectbox("Model Architecture", ["fedalmer", "memocmt", "fleser", "threemser", "cemobam"]) device_pref = st.radio("Processing Device", ["cpu", "cuda"], horizontal=True) with col_main: st.subheader("Audio Upload") uploaded_file = st.file_uploader("Drag and drop an audio file here (.wav, .mp3, .flac)", type=["wav", "mp3", "flac"]) # Preview player appears only after upload. if uploaded_file is not None: st.audio(uploaded_file, format="audio/wav") st.markdown("
", unsafe_allow_html=True) if st.button("Analyze and Predict Emotion", type="primary", use_container_width=True): if uploaded_file is None: st.warning("Please upload an audio file before running prediction.") else: with st.spinner("Analyzing audio... (first model load may take longer)"): try: # 1. Build form payload. data_payload = { "fl_method": fl_method, "label_ratio_token": label_ratio, "num_clients": num_clients, "model_name": model_name, "device_pref": device_pref } # 2. Build file payload. files_payload = { "file": (uploaded_file.name, uploaded_file.getvalue(), "audio/wav") } # 3. Send request to FastAPI. res = requests.post(f"{API_URL}/predict", data=data_payload, files=files_payload, timeout=30.00) # 4. Handle response. if res.status_code == 200: res_json = res.json() if res_json.get("status") == "success": result = res_json.get("data", {}) st.success("Prediction completed.") st.divider() # Show result metrics. m1, m2 = st.columns(2) emotion_str = str(result.get('emotion', 'Unknown')).title() confidence = result.get('confidence', 0) m1.metric(label="Predicted Emotion", value=emotion_str) m2.metric(label="Confidence", value=f"{confidence:.2%}") with st.expander("ASR Transcript", expanded=True): st.write(f"*{result.get('transcript', 'Speech not recognized')}*") with st.expander("Probability Distribution"): probs = result.get("probabilities", {}) if probs: # Render simple probability bar chart. import pandas as pd df_probs = pd.DataFrame(list(probs.items()), columns=["Emotion", "Probability"]) df_probs.set_index("Emotion", inplace=True) st.bar_chart(df_probs) else: st.json(probs) with st.expander("System Metadata"): st.json({ "Loaded Checkpoint": result.get("checkpoint"), "Audio Duration": f"{result.get('audio_seconds', 0):.2f} seconds", "Model": result.get("model_name") }) else: st.error(f"Server Error: {res_json.get('message')}") else: st.error(f"HTTP Error {res.status_code}: {res.text}") except Exception as e: st.error(f"System connection error: {str(e)}") if choice == "Server Control": # st.title("Server Control") render_server_controls() elif choice == "Server Control Manual": render_server_controls_manual() elif choice == "Server Validation/Test Upload": # st.title("Server Validation/Test Upload") # st.info("Module to upload and manage the .pkl files for validation and testing.") render_server_validation_test_upload() elif choice == "Dataset / Client Summary": # st.title("Dataset / Client Summary") # st.info("Module to view registered clients, data distribution, and profiles.") # Content will be implemented in the next step render_dataset_client_summary() # elif choice == "Client Panels": # # st.title("Client Panels") # # st.info("Module to monitor individual client states and submit models manually if needed.") # # Content will be implemented in the next step # render_client_panels() elif choice == "Global / Server Panel": # st.title("Global / Server Panel") # st.info("Module to view current global model status, active run state, and aggregation details.") # Content will be implemented in the next step render_global_server_panel() elif choice == "Charts": render_charts() # elif choice == "Inference Demo": # render_inference_demo()