import json import os import pandas as pd import requests import streamlit as st API_URL = os.getenv("API_URL", "http://127.0.0.1:8001/api") LABEL_DISPLAY_MAP = { "ang": "anger", "hap": "happiness/excitement", "sad": "sadness", "neu": "neutral", "exc": "happiness/excitement", 0: "anger", 1: "happiness/excitement", 2: "sadness", 3: "neutral", "0": "anger", "1": "happiness/excitement", "2": "sadness", "3": "neutral", } def _map_label_display(label): if label in LABEL_DISPLAY_MAP: return LABEL_DISPLAY_MAP[label] key = str(label) return LABEL_DISPLAY_MAP.get(key, key) # ========================================== # 1. PAGE SETTINGS # ========================================== st.set_page_config( page_title="Federated Learning Admin", layout="wide", page_icon="settings" ) # ========================================== # 2. SIDEBAR (GLOBAL ACTIONS) # ========================================== st.sidebar.title("Server") st.sidebar.markdown("---") if st.sidebar.button("Refresh Application"): st.rerun() # ========================================== # 3. MAIN CONTENT ROUTING # ========================================== def simple_block_header(text): st.markdown( f"""

{text}

""", unsafe_allow_html=True ) st.markdown( """ """, unsafe_allow_html=True ) def _parse_int_field(name: str, raw: str, min_value: int = None) -> int: try: value = int(str(raw).strip()) except Exception: raise ValueError(f"{name} must be an integer.") if min_value is not None and value < min_value: raise ValueError(f"{name} must be >= {min_value}.") return value def _parse_float_field(name: str, raw: str, min_value: float = None, max_value: float = None) -> float: try: value = float(str(raw).strip()) except Exception: raise ValueError(f"{name} must be a number.") if min_value is not None and value < min_value: raise ValueError(f"{name} must be >= {min_value}.") if max_value is not None and value > max_value: raise ValueError(f"{name} must be <= {max_value}.") return value def render_global_configuration(): st.subheader("Global Configuration") current_setup = {} try: res = requests.get(f"{API_URL}/config") res_data = res.json() if res_data.get("status") == "success": current_setup = (res_data.get("data") or {}).get("current_setup", {}) except Exception as e: st.warning(f"Cannot fetch global configuration: {e}") with st.form("server_config_form_shared"): col_cfg1, col_cfg2 = st.columns(2) with col_cfg1: rounds_pretrain_raw = st.text_input("Rounds Pretrain", value=str(int(current_setup.get("rounds_pretrain", 1)))) rounds_ssl_raw = st.text_input("Rounds SSL", value=str(int(current_setup.get("rounds_ssl", 1)))) tau_raw = st.text_input("Tau (Threshold)", value=str(float(current_setup.get("tau", 0.5)))) ssl_weight_raw = st.text_input("SSL Weight", value=str(float(current_setup.get("ssl_weight", 1.0)))) with col_cfg2: seed_raw = st.text_input("Random Seed", value=str(int(current_setup.get("seed", 42)))) agg_options = ["fedavg", "fedprox"] current_agg = current_setup.get("aggregator", "fedavg") agg_index = agg_options.index(current_agg) if current_agg in agg_options else 0 aggregator = st.selectbox("Aggregator", agg_options, index=agg_index) fedprox_mu_raw = st.text_input("FedProx Mu", value=str(float(current_setup.get("fedprox_mu", 0.01)))) weight_options = ["num_samples", "equal"] current_weight = current_setup.get("weight_by", "num_samples") weight_index = weight_options.index(current_weight) if current_weight in weight_options else 0 weight_by = st.selectbox("Weight By", weight_options, index=weight_index) submit_cfg = st.form_submit_button("Save Configuration", use_container_width=True) if submit_cfg: try: rounds_pretrain = _parse_int_field("Rounds Pretrain", rounds_pretrain_raw, min_value=1) rounds_ssl = _parse_int_field("Rounds SSL", rounds_ssl_raw, min_value=0) tau = _parse_float_field("Tau (Threshold)", tau_raw, min_value=0.0, max_value=1.0) ssl_weight = _parse_float_field("SSL Weight", ssl_weight_raw, min_value=0.0) seed = _parse_int_field("Random Seed", seed_raw) fedprox_mu = _parse_float_field("FedProx Mu", fedprox_mu_raw, min_value=0.0) except ValueError as ve: st.error(str(ve)) return payload = { "rounds_pretrain": rounds_pretrain, "rounds_ssl": rounds_ssl, "tau": tau, "ssl_weight": ssl_weight, "seed": seed, "aggregator": aggregator, "fedprox_mu": fedprox_mu, "weight_by": weight_by, } with st.spinner("Saving configuration..."): try: cfg_res = requests.post(f"{API_URL}/config", json=payload) cfg_data = cfg_res.json() if cfg_data.get("status") == "success": st.success("Configuration updated successfully.") st.rerun() else: st.error(f"Update failed: {cfg_data.get('message')}") except Exception as e: st.error(f"Request failed: {str(e)}") def render_server_controls(): """ Renders the Server Control panel. Handles training execution and global server configuration updates. """ # st.subheader("Server Control") st.info("Manage training sessions.") # 1. Fetch current system state from backend try: res = requests.get(f"{API_URL}/config") res_data = res.json() system_state = res_data.get("data", {}) if res_data.get("status") == "success" else {} except Exception as e: st.error(f"Cannot connect to the backend server. Is it running? Details: {str(e)}") return profiles_data = [] profiles_error = None try: res_profiles = requests.get(f"{API_URL}/profiles") profiles_json = res_profiles.json() if res_profiles.status_code == 200 else {} if profiles_json.get("status") == "success": profiles_data = profiles_json.get("data", []) else: profiles_error = profiles_json.get("message", "Cannot fetch client profiles.") except Exception as e: profiles_error = str(e) active_run = system_state.get("active_run") # ========================================== # EXECUTION CONTROLS # ========================================== st.subheader("Execution Controls") col1, col2 = st.columns(2) with col1: if st.button("Start Training", type="primary", use_container_width=True): with st.spinner("Initializing training session..."): try: train_res = requests.post(f"{API_URL}/train") data = train_res.json() if data.get("status") == "success": st.success(data.get("message", "Training started successfully.")) st.rerun() else: st.error(f"Failed to start: {data.get('message')}") except Exception as e: st.error(f"Request error: {str(e)}") with col2: if st.button("Stop Training", type="secondary", use_container_width=True): with st.spinner("Stopping current session..."): try: stop_res = requests.delete(f"{API_URL}/train") data = stop_res.json() if data.get("status") == "success": st.warning(data.get("message", "Training stopped.")) st.rerun() else: st.error(f"Failed to stop: {data.get('message')}") except Exception as e: st.error(f"Request error: {str(e)}") # Display active run status if it exists if active_run and active_run.get("status") not in ["completed", "failed"]: st.success(f"**Active Session:** {active_run.get('run_id')} **Status:** {active_run.get('status').upper()}") else: st.info("No active training session currently running.") st.subheader("Registered Client Profiles") if profiles_error: st.warning(f"Profiles API is not available right now: {profiles_error}") elif not profiles_data: st.info("No client profiles found. Start client scripts to register profile metadata.") else: client_ids = [str(item.get("client_id", "N/A")) for item in profiles_data] st.write("**Detailed Status:**") cols = st.columns(len(client_ids)) for i, cid in enumerate(client_ids): with cols[i]: st.success(f"Registered: {cid}") # st.markdown("---") def render_server_controls_manual(): """ Renders the Manual Server Control panel. Allows for step-by-step orchestration of the FL process. """ # st.subheader("Server Control (Manual Mode)") st.info("Human-in-the-loop: start rounds, monitor progress, and aggregate models manually.") # 1. Fetch current system state try: res = requests.get(f"{API_URL}/config") system_state = res.json().get("data", {}) active_run = system_state.get("active_run") except Exception as e: st.error(f"Backend connection error: {str(e)}") return profiles_data = [] profiles_error = None try: res_profiles = requests.get(f"{API_URL}/profiles") profiles_json = res_profiles.json() if res_profiles.status_code == 200 else {} if profiles_json.get("status") == "success": profiles_data = profiles_json.get("data", []) else: profiles_error = profiles_json.get("message", "Cannot fetch client profiles.") except Exception as e: profiles_error = str(e) # ========================================== # 1. EXECUTION CONTROLS (Manual Start/Stop) # ========================================== st.subheader("Orchestration") col1, col2 = st.columns(2) with col1: # Start manual mode and broadcast NEW_ROUND_MANUAL if st.button("Start / Broadcast Next Round", type="primary", use_container_width=True): with st.spinner("Broadcasting command..."): try: res = requests.post(f"{API_URL}/train_manual") data = res.json() if data.get("status") == "success": st.success(data.get("message")) st.rerun() else: st.error(f"Error: {data.get('message')}") except Exception as e: st.error(str(e)) with col2: # Stop manual mode if st.button("Stop Manual Session", type="secondary", use_container_width=True): with st.spinner("Stopping..."): try: res = requests.delete(f"{API_URL}/train_manual") if res.json().get("status") == "success": st.warning("Manual session stopped.") st.rerun() except Exception as e: st.error(str(e)) st.markdown("---") st.subheader("Registered Client Profiles") if profiles_error: st.warning(f"Profiles API is not available right now: {profiles_error}") elif not profiles_data: st.info("No client profiles found. Start client scripts to register profile metadata.") else: client_ids = [str(item.get("client_id", "N/A")) for item in profiles_data] st.write("**Detailed Status:**") cols = st.columns(len(client_ids)) for i, cid in enumerate(client_ids): with cols[i]: st.success(f"Registered: {cid}") st.markdown("---") # ========================================== # 2. SUBMISSION MONITORING # ========================================== st.subheader("Submission Monitor") try: sub_res = requests.get(f"{API_URL}/submission_status") res_json = sub_res.json() if res_json.get("status") == "error": # If no run is active, backend returns an error state. st.warning(f"Warning: {res_json.get('message')}") st.info("Click 'Start Training' to begin a new session.") else: # Data exists only when status == "success". sub_status = res_json.get("data", {}) curr = sub_status.get("current_count", 0) total = sub_status.get("total_required", 0) progress_str = sub_status.get("progress", "0/0") is_ready = sub_status.get("is_ready", False) all_ids = sub_status.get("all_client_ids", []) submitted_ids = sub_status.get("submitted_clients", []) # --- Progress dashboard --- m_col1, m_col2 = st.columns([1, 3]) with m_col1: st.metric(label="Clients Ready", value=progress_str) with m_col2: st.write(f"**Pending round:** {sub_status.get('pending_round')}") st.progress(curr / total if total > 0 else 0) st.write("**Detailed Status:**") if all_ids: cols = st.columns(len(all_ids)) for i, cid in enumerate(all_ids): with cols[i]: if cid in submitted_ids: st.success(f"Submitted: {cid}") else: st.error(f"Missing: {cid}") st.markdown("---") # ========================================== # 4. Aggregation action area # ========================================== with st.container(): if is_ready: st.success(f"Ready: received all {progress_str} submissions.") if st.button("Aggregate and Start Next Round", type="primary", use_container_width=True): with st.spinner("Running aggregation..."): try: agg_res = requests.post(f"{API_URL}/aggregate_manual") agg_data = agg_res.json() if agg_data.get("status") == "success": st.success(f"{agg_data.get('message')}") import time time.sleep(1.5) st.rerun() else: st.error(f"Error: {agg_data.get('message')}") except Exception as e: st.error(str(e)) else: st.warning(f"Waiting for {total - curr} more client submissions...") st.button("Aggregate Submitted Models (Waiting...)", type="secondary", use_container_width=True, disabled=True) except Exception as e: st.error(f"Submission monitor connection error: {e}") st.markdown("---") def render_server_validation_test_upload(): """ Renders the Server Validation/Test Upload panel. Handles uploading and deleting server-side evaluation datasets (.pkl). """ simple_block_header("Server Validation/Test Upload") st.info("Upload `.pkl` datasets for server-side global model evaluation.") # 1. Fetch current system state to get upload_status try: res = requests.get(f"{API_URL}/config") res_data = res.json() system_state = res_data.get("data", {}) if res_data.get("status") == "success" else {} except Exception as e: st.error(f"Cannot connect to the backend server. Details: {str(e)}") return upload_status = system_state.get("upload_status", {}) valid_exists = upload_status.get("valid", {}).get("uploaded", False) test_exists = upload_status.get("test", {}).get("uploaded", False) # ========================================== # DATASET STATUS CARDS # ========================================== st.subheader("Current Data Status") col_stat1, col_stat2 = st.columns(2) with col_stat1: if valid_exists: st.success("**Validation Dataset:** Available") # st.caption(f"Details: {upload_status.get('valid')}") else: st.error("**Validation Dataset:** Missing") with col_stat2: if test_exists: st.success("**Test Dataset:** Available") # st.caption(f"Details: {upload_status.get('test')}") else: st.error("**Test Dataset:** Missing") st.markdown("---") # ========================================== # UPLOAD FORM # ========================================== st.subheader("Upload Dataset") split_type = st.selectbox("Select Split Type", ["valid", "test"], help="Choose whether this data is for validation or testing.") uploaded_file = st.file_uploader(f"Choose a .pkl file for the '{split_type}' split", type=["pkl"]) if st.button("Upload File", type="primary"): if uploaded_file is not None: with st.spinner(f"Uploading {uploaded_file.name}..."): try: # Prepare the file payload for FastAPI (multipart/form-data) files = {"file": (uploaded_file.name, uploaded_file.getvalue(), "application/octet-stream")} upload_res = requests.post(f"{API_URL}/data/{split_type}", files=files) data = upload_res.json() if data.get("status") == "success": st.success(data.get("message", "File uploaded successfully.")) st.rerun() else: st.error(f"Upload failed: {data.get('message')}") except Exception as e: st.error(f"Request error: {str(e)}") else: st.warning("Please select a file to upload first.") st.markdown("---") # ========================================== # DANGER ZONE (DELETION) # ========================================== st.subheader("Danger Zone") del_col1, del_col2, del_col3 = st.columns(3) with del_col1: if st.button("Delete Valid Data", disabled=not valid_exists, use_container_width=True): try: del_res = requests.delete(f"{API_URL}/data/valid") if del_res.json().get("status") == "success": st.toast("Validation data deleted.") st.rerun() else: st.error(del_res.json().get("message")) except Exception as e: st.error(str(e)) with del_col2: if st.button("Delete Test Data", disabled=not test_exists, use_container_width=True): try: del_res = requests.delete(f"{API_URL}/data/test") if del_res.json().get("status") == "success": st.toast("Test data deleted.") st.rerun() else: st.error(del_res.json().get("message")) except Exception as e: st.error(str(e)) with del_col3: if st.button("Clear ALL Data", type="primary", disabled=not (valid_exists or test_exists), use_container_width=True): try: del_all_res = requests.delete(f"{API_URL}/data") if del_all_res.json().get("status") == "success": st.toast("All server data cleared.") st.rerun() else: st.error(del_all_res.json().get("message")) except Exception as e: st.error(str(e)) def render_dataset_client_summary(): """ Renders the Dataset and Client Summary panel. Displays registered clients, data volumes, and class distributions. """ simple_block_header("Dataset / Client Summary") st.info("Overview of registered clients, their data volume, and label distribution.") # 1. Fetch data from backend try: # Fetch profiles from API (returns List[ClientProfileResponse]) res_profiles = requests.get(f"{API_URL}/profiles") # Keep config fetch if other global metadata is needed res_config = requests.get(f"{API_URL}/config") profiles_data = res_profiles.json().get("data", []) if res_profiles.status_code == 200 else [] except Exception as e: st.error(f"Cannot connect to the backend server. Details: {str(e)}") return # Check whether any client has registered if not profiles_data: st.warning("No clients have registered their profiles yet. Start your client scripts to register them.") return # 2. Process data for visualization (profile data is a list of objects) total_clients = len(profiles_data) total_labeled = 0 total_unlabeled = 0 profile_list = [] class_hist_list = [] # Iterate through List[ClientProfileResponse] for info in profiles_data: cid = info.get("client_id") labeled = info.get("labeled_count", 0) unlabeled = info.get("unlabeled_count", 0) total_labeled += labeled total_unlabeled += unlabeled # Build table row using fields from ClientProfileResponse profile_list.append({ "Client ID": cid, "Labeled Data": labeled, "Unlabeled Data": unlabeled, "Feature Dim": info.get("feature_dim", "N/A"), "Model Name": info.get("model_name", "N/A"), # Keep N/A when a field is unavailable in response "Batch Size": info.get("batch_size", "N/A") }) # Build class distribution chart data hist = info.get("class_hist", {}) if hist: for label, count in hist.items(): class_hist_list.append({ "Client": cid, "Class": _map_label_display(label), "Count": count }) # ========================================== # OVERALL METRICS # ========================================== st.subheader("Global Metrics") col_m1, col_m2, col_m3 = st.columns(3) col_m1.metric("Registered Clients", total_clients) col_m2.metric("Total Labeled Samples", total_labeled) col_m3.metric("Total Unlabeled Samples", total_unlabeled) st.markdown("---") # ========================================== # CLIENT PROFILE TABLE # ========================================== st.subheader("Client Data Overview") df_profiles = pd.DataFrame(profile_list) st.dataframe(df_profiles, use_container_width=True, hide_index=True) st.markdown("---") # ========================================== # DATA DISTRIBUTION CHARTS # ========================================== st.subheader("Data Distribution Analysis") col_chart1, col_chart2 = st.columns(2) with col_chart1: st.markdown("**Data Volume per Client**") if not df_profiles.empty: # Bar chart showing Labeled vs Unlabeled side by side df_volume = df_profiles.set_index("Client ID")[["Labeled Data", "Unlabeled Data"]] st.bar_chart(df_volume) with col_chart2: st.markdown("**Class Label Distribution**") if class_hist_list: df_hist = pd.DataFrame(class_hist_list) # Convert Class to string so chart X-axis is categorical df_hist['Class'] = df_hist['Class'].astype(str) # Pivot by class and client for grouped class count visualization df_pivot = df_hist.pivot(index='Class', columns='Client', values='Count').fillna(0) st.bar_chart(df_pivot) else: st.info("No class history (class_hist) provided by clients.") st.markdown("---") if st.button("Clear All Profiles", type="secondary"): # Optional: add a confirmation dialog before destructive action try: res = requests.delete(f"{API_URL}/profiles") if res.status_code == 200: st.success("All client profiles have been cleared.") st.rerun() else: st.error("Cannot clear profiles.") except Exception as e: st.error(f"Error: {e}") def render_client_panels(): """ Renders the Client Panels section. Allows monitoring of individual clients and manual submission of model updates. """ simple_block_header("Client Panels") st.info("Monitor specific clients and manually submit model weights (.pt) on their behalf.") # 1. Fetch system state to get registered clients try: res = requests.get(f"{API_URL}/config") res_data = res.json() system_state = res_data.get("data", {}) if res_data.get("status") == "success" else {} except Exception as e: st.error(f"Cannot connect to the backend server. Details: {str(e)}") return profiles = system_state.get("profiles", {}) active_run = system_state.get("active_run") if not profiles: st.warning("No clients are currently registered. Please register clients first.") return # 2. Select a Client client_ids = list(profiles.keys()) selected_client = st.selectbox("Select a Client to View/Manage", client_ids) st.markdown("---") client_info = profiles[selected_client] # ========================================== # CLIENT INFORMATION # ========================================== st.subheader(f"Profile: {selected_client}") col_info1, col_info2, col_info3 = st.columns(3) col_info1.metric("Labeled Data", client_info.get("labeled_count", 0)) col_info2.metric("Unlabeled Data", client_info.get("unlabeled_count", 0)) col_info3.metric("Batch Size", client_info.get("batch_size", "N/A")) with st.expander("View Raw Profile Data"): st.json(client_info) st.markdown("---") # ========================================== # MANUAL MODEL SUBMISSION FORM # ========================================== st.subheader("Manual Model Submission") if active_run and active_run.get("status") == "waiting_responses": st.success( f"Server is currently waiting for responses for Run: **{active_run.get('run_id')}** (Round {active_run.get('pending_round')})") with st.form(f"submit_model_form_{selected_client}"): st.write("Submit weights directly to the server for this client.") # Form Inputs num_samples = st.number_input( "Number of Samples (Weight for FedAvg)", min_value=1, value=int(client_info.get("labeled_count", 100)), help="The number of data samples used to train this model (used for weighting in FedAvg)." ) metrics_input = st.text_area( "Training Metrics (JSON format)", value='{\n "loss": 0.45,\n "accuracy": 0.82\n}', help="Provide metrics like loss and accuracy as a JSON string." ) uploaded_model = st.file_uploader("Upload Model Weights (.pt)", type=["pt", "bin", "pth"]) submit_btn = st.form_submit_button("Submit Model Update", type="primary", use_container_width=True) if submit_btn: # 1. Validate JSON metrics try: parsed_metrics = json.loads(metrics_input) valid_json_str = json.dumps(parsed_metrics) except json.JSONDecodeError: st.error("Invalid JSON format in Training Metrics.") st.stop() # 2. Check File if uploaded_model is None: st.error("Please upload a model weights file (.pt).") else: with st.spinner("Submitting model to server..."): try: # Prepare form data data = { "client_id": selected_client, "num_samples": num_samples, "metrics": valid_json_str } # Prepare file files = { "file": (uploaded_model.name, uploaded_model.getvalue(), "application/octet-stream") } # Send request submit_res = requests.post(f"{API_URL}/model", data=data, files=files) res_json = submit_res.json() if res_json.get("status") == "success": st.success(res_json.get("message")) else: st.error(f"Submission failed: {res_json.get('message')}") except Exception as e: st.error(f"Request error: {str(e)}") else: st.info( "Manual submission is disabled. The server is not currently waiting for client responses (no active pending round).") def render_global_server_panel(): """ Renders the Global / Server Panel. Displays active run details, global model status, and aggregation history. """ simple_block_header("Global / Server Panel") st.info("Monitor the central orchestrator state, active run details, and aggregation history.") # 1. Fetch system state from backend try: res = requests.get(f"{API_URL}/config") res_data = res.json() system_state = res_data.get("data", {}) if res_data.get("status") == "success" else {} except Exception as e: st.error(f"Cannot connect to the backend server. Details: {str(e)}") return active_run = system_state.get("active_run") history_data = system_state.get("history", {}) # ========================================== # ACTIVE RUN DETAILS # ========================================== st.subheader("Current Active Run") if active_run: status_color = "" if active_run.get("status") not in ["failed", "stopped"] else "" st.markdown(f"**Run ID:** `{active_run.get('run_id', 'N/A')}`") col1, col2, col3, col4 = st.columns(4) col1.metric("Status", f"{status_color} {active_run.get('status', 'N/A').upper()}") col2.metric("Current Stage", active_run.get("current_stage", active_run.get("stage", "N/A")).upper()) col3.metric("Current Round", active_run.get("current_round", 0)) col4.metric("Pending Round", active_run.get("pending_round", "N/A")) with st.expander("View Raw Active Run State"): st.json(active_run) else: st.info("No active run is currently in progress. Start training in the 'Server Control' panel.") st.markdown("---") # ========================================== # GLOBAL MODEL DOWNLOAD # ========================================== st.subheader("Global Model Status") st.write("Fetch and download the latest aggregated global model weights (`.pt` file).") if st.button("Fetch Latest Global Model"): with st.spinner("Connecting to server to fetch model weights..."): try: model_res = requests.get(f"{API_URL}/model") content_type = model_res.headers.get('Content-Type', '') if 'application/json' in content_type: err_data = model_res.json() st.error(f"Cannot fetch model: {err_data.get('message', 'Unknown error')}") elif model_res.status_code == 200: st.success("Model weights fetched successfully! Click below to save to disk.") st.download_button( label="Save global_model.pt", data=model_res.content, file_name="global_model.pt", mime="application/octet-stream", type="primary" ) else: st.error(f"Unexpected error: HTTP {model_res.status_code}") except Exception as e: st.error(f"Request error: {str(e)}") st.markdown("---") # ========================================== # AGGREGATION HISTORY TABLE # ========================================== st.subheader("Aggregation History") # Read round records from history_data rounds_list = history_data.get("rounds", []) if rounds_list: try: df_history = pd.DataFrame(rounds_list) # Hide client_metrics because long JSON blobs reduce table readability. if "client_metrics" in df_history.columns: df_history = df_history.drop(columns=["client_metrics"]) # Keep 'round' as the first column. cols = ['round'] + [c for c in df_history.columns if c != 'round'] df_history = df_history[cols] st.dataframe(df_history, use_container_width=True, hide_index=True) except Exception as e: st.error(f"Could not parse history table: {e}") else: st.info("No aggregation history available yet. The table will populate once rounds are completed.") def render_charts(): simple_block_header("Training Performance Analytics") try: # Fetch chart data from API res = requests.get(f"{API_URL}/charts") history_data = res.json().get("data", {}) except Exception as e: st.error(f"Cannot connect to backend: {e}") return # Read list of rounds from 'rounds' key rounds_list = history_data.get("rounds", []) if not rounds_list: st.warning("No training rounds recorded yet. Start training to see results.") return # 1. Convert list to DataFrame df = pd.DataFrame(rounds_list) # Use round as chart X-axis index df.set_index("round", inplace=True) # 2. Tabs tab1, tab2, tab3 = st.tabs(["Global Accuracy", "Local Loss (Clients)", "Global F1 Score"]) # ========================================== # TAB 1: GLOBAL ACCURACY # ========================================== with tab1: st.subheader("Global Accuracy (Validation vs Test)") # Check columns before plotting to avoid missing-key errors cols_to_plot = [c for c in ["val_accuracy", "test_accuracy"] if c in df.columns] if cols_to_plot: st.line_chart(df[cols_to_plot]) else: st.info("Accuracy data is not available yet.") # ========================================== # TAB 2: LOCAL LOSS (multi-client) # ========================================== with tab2: st.subheader("Local Training Loss") all_clients = set() for r in rounds_list: c_metrics = r.get("client_metrics", {}) if isinstance(c_metrics, dict): all_clients.update(c_metrics.keys()) all_clients = sorted(list(all_clients)) if not all_clients: st.info("No client metrics found in history.") else: # Select client to inspect selected_client = st.selectbox("Select Client to inspect:", all_clients) # Extract selected client loss per round loss_data = {} for r_idx, row in df.iterrows(): c_metrics = row.get("client_metrics", {}).get(selected_client, {}) if c_metrics: p_loss = c_metrics.get("pretrain_loss", 0) s_loss = c_metrics.get("ssl_labeled_loss", 0) # Convert 0 to None so missing phase points are not connected loss_data[r_idx] = { "Pretrain Loss": p_loss if p_loss > 0 else None, "SSL Loss": s_loss if s_loss > 0 else None } if loss_data: df_loss = pd.DataFrame.from_dict(loss_data, orient='index') # Render as two separate charts st.markdown("#### Phase 1: Pre-training Loss") st.caption("Phase where the model learns from local labeled data.") # Plot only Pretrain Loss column st.line_chart(df_loss[["Pretrain Loss"]]) st.markdown("#### Phase 2: SSL Labeled Loss") st.caption("Semi-supervised learning phase.") # Plot only SSL Loss column st.line_chart(df_loss[["SSL Loss"]]) else: st.warning(f"No loss data available for {selected_client} in the selected rounds.") # ========================================== # TAB 3: GLOBAL F1 SCORE # ========================================== with tab3: st.subheader("Global Weighted F1 Score") cols_to_plot_f1 = [c for c in ["val_WF1", "test_WF1"] if c in df.columns] if cols_to_plot_f1: st.line_chart(df[cols_to_plot_f1]) else: st.info("F1 Score data is not available yet.") # ========================================== # 3. Show detailed raw metrics table # ========================================== st.markdown("---") with st.expander("View Full Raw Metrics Table"): # Copy DataFrame and hide verbose client_metrics column for readability df_display = df.copy() if "client_metrics" in df_display.columns: df_display.drop(columns=["client_metrics"], inplace=True) st.dataframe(df_display, use_container_width=True) def render_server_controls_combined(): simple_block_header("Server Control") st.info("Manage automatic and manual orchestration in one section.") render_global_configuration() st.divider() tab_auto, tab_manual = st.tabs(["Server Control", "Server Control Manual"]) with tab_auto: render_server_controls() with tab_manual: render_server_controls_manual() def render_single_page_layout(): # simple_block_header("Federated Learning Admin") # Thay thế st.title bằng dòng này st.markdown( """

Federated Learning Server

""", unsafe_allow_html=True ) render_server_validation_test_upload() st.divider() render_dataset_client_summary() st.divider() render_server_controls_combined() # st.divider() render_global_server_panel() st.divider() render_charts() def render_inference_demo(): simple_block_header("Emotion Recognition Inference") st.markdown("Upload an audio file to predict emotion using the trained global AI model.") col_config, col_main = st.columns([1, 2], gap="large") with col_config: st.subheader("Model Configuration") with st.container(border=True): fl_method = st.selectbox("FL Method", ["FedAvg", "FedProx"]) # Available label-ratio tokens used by checkpoints. label_ratio = st.selectbox("Label Ratio", ["0p3", "0p5", "0p7", "1"], index=3) num_clients = st.number_input("Num Clients", min_value=1, max_value=100, value=5) model_name = st.selectbox("Model Architecture", ["fedalmer", "memocmt", "fleser", "threemser", "cemobam"]) device_pref = st.radio("Processing Device", ["cpu", "cuda"], horizontal=True) with col_main: st.subheader("Audio Upload") uploaded_file = st.file_uploader("Drag and drop an audio file here (.wav, .mp3, .flac)", type=["wav", "mp3", "flac"]) # Preview player appears only after upload. if uploaded_file is not None: st.audio(uploaded_file, format="audio/wav") st.markdown("
", unsafe_allow_html=True) if st.button("Analyze and Predict Emotion", type="primary", use_container_width=True): if uploaded_file is None: st.warning("Please upload an audio file before running prediction.") else: with st.spinner("Analyzing audio... (first model load may take longer)"): try: # 1. Build form payload. data_payload = { "fl_method": fl_method, "label_ratio_token": label_ratio, "num_clients": num_clients, "model_name": model_name, "device_pref": device_pref } # 2. Build file payload. files_payload = { "file": (uploaded_file.name, uploaded_file.getvalue(), "audio/wav") } # 3. Send request to FastAPI. res = requests.post(f"{API_URL}/predict", data=data_payload, files=files_payload, timeout=30.00) # 4. Handle response. if res.status_code == 200: res_json = res.json() if res_json.get("status") == "success": result = res_json.get("data", {}) st.success("Prediction completed.") st.divider() # Show result metrics. m1, m2 = st.columns(2) emotion_str = str(result.get('emotion', 'Unknown')).title() confidence = result.get('confidence', 0) m1.metric(label="Predicted Emotion", value=emotion_str) m2.metric(label="Confidence", value=f"{confidence:.2%}") with st.expander("ASR Transcript", expanded=True): st.write(f"*{result.get('transcript', 'Speech not recognized')}*") with st.expander("Probability Distribution"): probs = result.get("probabilities", {}) if probs: # Render simple probability bar chart. import pandas as pd df_probs = pd.DataFrame(list(probs.items()), columns=["Emotion", "Probability"]) df_probs.set_index("Emotion", inplace=True) st.bar_chart(df_probs) else: st.json(probs) with st.expander("System Metadata"): st.json({ "Loaded Checkpoint": result.get("checkpoint"), "Audio Duration": f"{result.get('audio_seconds', 0):.2f} seconds", "Model": result.get("model_name") }) else: st.error(f"Server Error: {res_json.get('message')}") else: st.error(f"HTTP Error {res.status_code}: {res.text}") except Exception as e: st.error(f"System connection error: {str(e)}") render_single_page_layout()