Spaces:
Running
Running
| import json | |
| import os | |
| import pandas as pd | |
| import requests | |
| import streamlit as st | |
| API_URL = os.getenv("API_URL", "http://localhost:8000/api") | |
| # ========================================== | |
| # 1. PAGE SETTINGS | |
| # ========================================== | |
| st.set_page_config( | |
| page_title="Federated Learning Admin", | |
| layout="wide", | |
| page_icon="settings" | |
| ) | |
| # ========================================== | |
| # 2. SIDEBAR NAVIGATION | |
| # ========================================== | |
| st.sidebar.title("FL Orchestrator") | |
| st.sidebar.markdown("---") | |
| # Define the exact menu options you requested | |
| menu_options = [ | |
| "Server Control", | |
| "Server Control Manual", | |
| "Server Validation/Test Upload", | |
| "Dataset / Client Summary", | |
| # "Client Panels", | |
| "Global / Server Panel", | |
| "Charts", | |
| # "Inference Demo", | |
| ] | |
| # Create the radio button menu in the sidebar | |
| choice = st.sidebar.radio("Navigation", menu_options) | |
| st.sidebar.markdown("---") | |
| if st.sidebar.button("Refresh Application"): | |
| st.rerun() | |
| # ========================================== | |
| # 3. MAIN CONTENT ROUTING | |
| # ========================================== | |
| def render_server_controls(): | |
| """ | |
| Renders the Server Control panel. | |
| Handles training execution and global server configuration updates. | |
| """ | |
| st.title("Server Control") | |
| st.info("Manage training sessions and update global server configuration.") | |
| # 1. Fetch current system state from backend | |
| try: | |
| res = requests.get(f"{API_URL}/config") | |
| res_data = res.json() | |
| system_state = res_data.get("data", {}) if res_data.get("status") == "success" else {} | |
| except Exception as e: | |
| st.error(f"Cannot connect to the backend server. Is it running? Details: {str(e)}") | |
| return | |
| active_run = system_state.get("active_run") | |
| current_setup = system_state.get("current_setup", {}) | |
| # ========================================== | |
| # EXECUTION CONTROLS | |
| # ========================================== | |
| st.subheader("Execution Controls") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| if st.button("Start Training", type="primary", use_container_width=True): | |
| with st.spinner("Initializing training session..."): | |
| try: | |
| train_res = requests.post(f"{API_URL}/train") | |
| data = train_res.json() | |
| if data.get("status") == "success": | |
| st.success(data.get("message", "Training started successfully.")) | |
| st.rerun() | |
| else: | |
| st.error(f"Failed to start: {data.get('message')}") | |
| except Exception as e: | |
| st.error(f"Request error: {str(e)}") | |
| with col2: | |
| if st.button("Stop Training", type="secondary", use_container_width=True): | |
| with st.spinner("Stopping current session..."): | |
| try: | |
| stop_res = requests.delete(f"{API_URL}/train") | |
| data = stop_res.json() | |
| if data.get("status") == "success": | |
| st.warning(data.get("message", "Training stopped.")) | |
| st.rerun() | |
| else: | |
| st.error(f"Failed to stop: {data.get('message')}") | |
| except Exception as e: | |
| st.error(f"Request error: {str(e)}") | |
| # Display active run status if it exists | |
| if active_run and active_run.get("status") not in ["completed", "failed"]: | |
| st.success(f"**Active Session:** {active_run.get('run_id')} **Status:** {active_run.get('status').upper()}") | |
| else: | |
| st.info("No active training session currently running.") | |
| st.markdown("---") | |
| # ========================================== | |
| # GLOBAL CONFIGURATION FORM | |
| # ========================================== | |
| st.subheader("Global Configuration") | |
| with st.form("server_config_form"): | |
| col_cfg1, col_cfg2 = st.columns(2) | |
| with col_cfg1: | |
| rounds_pretrain = st.number_input("Rounds Pretrain", min_value=1, | |
| value=int(current_setup.get("rounds_pretrain", 1))) | |
| rounds_ssl = st.number_input("Rounds SSL", min_value=0, value=int(current_setup.get("rounds_ssl", 1))) | |
| tau = st.number_input("Tau (Threshold)", min_value=0.0, max_value=1.0, | |
| value=float(current_setup.get("tau", 0.5)), step=0.05) | |
| ssl_weight = st.number_input("SSL Weight", min_value=0.0, value=float(current_setup.get("ssl_weight", 1.0)), | |
| step=0.1) | |
| with col_cfg2: | |
| seed = st.number_input("Random Seed", value=int(current_setup.get("seed", 42))) | |
| # Setup Aggregator selectbox | |
| agg_options = ["fedavg", "fedprox"] | |
| current_agg = current_setup.get("aggregator", "fedavg") | |
| agg_index = agg_options.index(current_agg) if current_agg in agg_options else 0 | |
| aggregator = st.selectbox("Aggregator", agg_options, index=agg_index) | |
| fedprox_mu = st.number_input("FedProx Mu", min_value=0.0, | |
| value=float(current_setup.get("fedprox_mu", 0.01)), step=0.01) | |
| # Setup Weighting method selectbox | |
| weight_options = ["num_samples", "equal"] | |
| current_weight = current_setup.get("weight_by", "num_samples") | |
| weight_index = weight_options.index(current_weight) if current_weight in weight_options else 0 | |
| weight_by = st.selectbox("Weight By", weight_options, index=weight_index) | |
| submit_cfg = st.form_submit_button("Save Configuration", use_container_width=True) | |
| if submit_cfg: | |
| payload = { | |
| "rounds_pretrain": rounds_pretrain, | |
| "rounds_ssl": rounds_ssl, | |
| "tau": tau, | |
| "ssl_weight": ssl_weight, | |
| "seed": seed, | |
| "aggregator": aggregator, | |
| "fedprox_mu": fedprox_mu, | |
| "weight_by": weight_by | |
| } | |
| with st.spinner("Saving configuration..."): | |
| try: | |
| cfg_res = requests.post(f"{API_URL}/config", json=payload) | |
| cfg_data = cfg_res.json() | |
| if cfg_data.get("status") == "success": | |
| st.success("Configuration updated successfully.") | |
| st.rerun() | |
| else: | |
| st.error(f"Update failed: {cfg_data.get('message')}") | |
| except Exception as e: | |
| st.error(f"Request failed: {str(e)}") | |
| def render_server_controls_manual(): | |
| """ | |
| Renders the Manual Server Control panel. | |
| Allows for step-by-step orchestration of the FL process. | |
| """ | |
| st.title("Server Control (Manual Mode)") | |
| st.info("Human-in-the-loop: start rounds, monitor progress, and aggregate models manually.") | |
| # 1. Fetch current system state | |
| try: | |
| res = requests.get(f"{API_URL}/config") | |
| system_state = res.json().get("data", {}) | |
| active_run = system_state.get("active_run") | |
| current_setup = system_state.get("current_setup", {}) | |
| except Exception as e: | |
| st.error(f"Backend connection error: {str(e)}") | |
| return | |
| # ========================================== | |
| # 1. EXECUTION CONTROLS (Manual Start/Stop) | |
| # ========================================== | |
| st.subheader("Orchestration") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| # Start manual mode and broadcast NEW_ROUND_MANUAL | |
| if st.button("Start / Broadcast Next Round", type="primary", use_container_width=True): | |
| with st.spinner("Broadcasting command..."): | |
| try: | |
| res = requests.post(f"{API_URL}/train_manual") | |
| data = res.json() | |
| if data.get("status") == "success": | |
| st.success(data.get("message")) | |
| st.rerun() | |
| else: | |
| st.error(f"Error: {data.get('message')}") | |
| except Exception as e: | |
| st.error(str(e)) | |
| with col2: | |
| # Stop manual mode | |
| if st.button("Stop Manual Session", type="secondary", use_container_width=True): | |
| with st.spinner("Stopping..."): | |
| try: | |
| res = requests.delete(f"{API_URL}/train_manual") | |
| if res.json().get("status") == "success": | |
| st.warning("Manual session stopped.") | |
| st.rerun() | |
| except Exception as e: | |
| st.error(str(e)) | |
| st.markdown("---") | |
| # ========================================== | |
| # 2. SUBMISSION MONITORING | |
| # ========================================== | |
| st.subheader("Submission Monitor") | |
| try: | |
| sub_res = requests.get(f"{API_URL}/submission_status") | |
| res_json = sub_res.json() | |
| if res_json.get("status") == "error": | |
| # If no run is active, backend returns an error state. | |
| st.warning(f"Warning: {res_json.get('message')}") | |
| st.info("Click 'Start Training' to begin a new session.") | |
| else: | |
| # Data exists only when status == "success". | |
| sub_status = res_json.get("data", {}) | |
| curr = sub_status.get("current_count", 0) | |
| total = sub_status.get("total_required", 0) | |
| progress_str = sub_status.get("progress", "0/0") | |
| is_ready = sub_status.get("is_ready", False) | |
| all_ids = sub_status.get("all_client_ids", []) | |
| submitted_ids = sub_status.get("submitted_clients", []) | |
| # --- Progress dashboard --- | |
| m_col1, m_col2 = st.columns([1, 3]) | |
| with m_col1: | |
| st.metric(label="Clients Ready", value=progress_str) | |
| with m_col2: | |
| st.write(f"**Pending round:** {sub_status.get('pending_round')}") | |
| st.progress(curr / total if total > 0 else 0) | |
| st.write("**Detailed Status:**") | |
| if all_ids: | |
| cols = st.columns(len(all_ids)) | |
| for i, cid in enumerate(all_ids): | |
| with cols[i]: | |
| if cid in submitted_ids: | |
| st.success(f"Submitted: {cid}") | |
| else: | |
| st.error(f"Missing: {cid}") | |
| st.markdown("---") | |
| # ========================================== | |
| # 4. Aggregation action area | |
| # ========================================== | |
| with st.container(): | |
| if is_ready: | |
| st.success(f"Ready: received all {progress_str} submissions.") | |
| if st.button("Aggregate and Start Next Round", type="primary", use_container_width=True): | |
| with st.spinner("Running aggregation..."): | |
| try: | |
| agg_res = requests.post(f"{API_URL}/aggregate_manual") | |
| agg_data = agg_res.json() | |
| if agg_data.get("status") == "success": | |
| st.success(f"{agg_data.get('message')}") | |
| import time | |
| time.sleep(1.5) | |
| st.rerun() | |
| else: | |
| st.error(f"Error: {agg_data.get('message')}") | |
| except Exception as e: | |
| st.error(str(e)) | |
| else: | |
| st.warning(f"Waiting for {total - curr} more client submissions...") | |
| st.button("Aggregate Submitted Models (Waiting...)", | |
| type="secondary", use_container_width=True, disabled=True) | |
| except Exception as e: | |
| st.error(f"Submission monitor connection error: {e}") | |
| st.markdown("---") | |
| # ========================================== | |
| # 3. GLOBAL CONFIGURATION (reuse existing form) | |
| # ========================================== | |
| st.subheader("Global Configuration") | |
| with st.form("server_config_form"): | |
| col_cfg1, col_cfg2 = st.columns(2) | |
| with col_cfg1: | |
| rounds_pretrain = st.number_input("Rounds Pretrain", min_value=1, | |
| value=int(current_setup.get("rounds_pretrain", 1))) | |
| rounds_ssl = st.number_input("Rounds SSL", min_value=0, value=int(current_setup.get("rounds_ssl", 1))) | |
| tau = st.number_input("Tau (Threshold)", min_value=0.0, max_value=1.0, | |
| value=float(current_setup.get("tau", 0.5)), step=0.05) | |
| ssl_weight = st.number_input("SSL Weight", min_value=0.0, value=float(current_setup.get("ssl_weight", 1.0)), | |
| step=0.1) | |
| with col_cfg2: | |
| seed = st.number_input("Random Seed", value=int(current_setup.get("seed", 42))) | |
| # Setup Aggregator selectbox | |
| agg_options = ["fedavg", "fedprox"] | |
| current_agg = current_setup.get("aggregator", "fedavg") | |
| agg_index = agg_options.index(current_agg) if current_agg in agg_options else 0 | |
| aggregator = st.selectbox("Aggregator", agg_options, index=agg_index) | |
| fedprox_mu = st.number_input("FedProx Mu", min_value=0.0, | |
| value=float(current_setup.get("fedprox_mu", 0.01)), step=0.01) | |
| # Setup Weighting method selectbox | |
| weight_options = ["num_samples", "equal"] | |
| current_weight = current_setup.get("weight_by", "num_samples") | |
| weight_index = weight_options.index(current_weight) if current_weight in weight_options else 0 | |
| weight_by = st.selectbox("Weight By", weight_options, index=weight_index) | |
| submit_cfg = st.form_submit_button("Save Configuration", use_container_width=True) | |
| if submit_cfg: | |
| payload = { | |
| "rounds_pretrain": rounds_pretrain, | |
| "rounds_ssl": rounds_ssl, | |
| "tau": tau, | |
| "ssl_weight": ssl_weight, | |
| "seed": seed, | |
| "aggregator": aggregator, | |
| "fedprox_mu": fedprox_mu, | |
| "weight_by": weight_by | |
| } | |
| with st.spinner("Saving configuration..."): | |
| try: | |
| cfg_res = requests.post(f"{API_URL}/config", json=payload) | |
| cfg_data = cfg_res.json() | |
| if cfg_data.get("status") == "success": | |
| st.success("Configuration updated successfully.") | |
| st.rerun() | |
| else: | |
| st.error(f"Update failed: {cfg_data.get('message')}") | |
| except Exception as e: | |
| st.error(f"Request failed: {str(e)}") | |
| def render_server_validation_test_upload(): | |
| """ | |
| Renders the Server Validation/Test Upload panel. | |
| Handles uploading and deleting server-side evaluation datasets (.pkl). | |
| """ | |
| st.title("Server Validation/Test Upload") | |
| st.info("Upload `.pkl` datasets for server-side global model evaluation.") | |
| # 1. Fetch current system state to get upload_status | |
| try: | |
| res = requests.get(f"{API_URL}/config") | |
| res_data = res.json() | |
| system_state = res_data.get("data", {}) if res_data.get("status") == "success" else {} | |
| except Exception as e: | |
| st.error(f"Cannot connect to the backend server. Details: {str(e)}") | |
| return | |
| upload_status = system_state.get("upload_status", {}) | |
| valid_exists = upload_status.get("valid", {}).get("uploaded", False) | |
| test_exists = upload_status.get("test", {}).get("uploaded", False) | |
| # ========================================== | |
| # DATASET STATUS CARDS | |
| # ========================================== | |
| st.subheader("Current Data Status") | |
| col_stat1, col_stat2 = st.columns(2) | |
| with col_stat1: | |
| if valid_exists: | |
| st.success("**Validation Dataset:** Available") | |
| # st.caption(f"Details: {upload_status.get('valid')}") | |
| else: | |
| st.error("**Validation Dataset:** Missing") | |
| with col_stat2: | |
| if test_exists: | |
| st.success("**Test Dataset:** Available") | |
| # st.caption(f"Details: {upload_status.get('test')}") | |
| else: | |
| st.error("**Test Dataset:** Missing") | |
| st.markdown("---") | |
| # ========================================== | |
| # UPLOAD FORM | |
| # ========================================== | |
| st.subheader("Upload Dataset") | |
| split_type = st.selectbox("Select Split Type", ["valid", "test"], | |
| help="Choose whether this data is for validation or testing.") | |
| uploaded_file = st.file_uploader(f"Choose a .pkl file for the '{split_type}' split", type=["pkl"]) | |
| if st.button("Upload File", type="primary"): | |
| if uploaded_file is not None: | |
| with st.spinner(f"Uploading {uploaded_file.name}..."): | |
| try: | |
| # Prepare the file payload for FastAPI (multipart/form-data) | |
| files = {"file": (uploaded_file.name, uploaded_file.getvalue(), "application/octet-stream")} | |
| upload_res = requests.post(f"{API_URL}/data/{split_type}", files=files) | |
| data = upload_res.json() | |
| if data.get("status") == "success": | |
| st.success(data.get("message", "File uploaded successfully.")) | |
| st.rerun() | |
| else: | |
| st.error(f"Upload failed: {data.get('message')}") | |
| except Exception as e: | |
| st.error(f"Request error: {str(e)}") | |
| else: | |
| st.warning("Please select a file to upload first.") | |
| st.markdown("---") | |
| # ========================================== | |
| # DANGER ZONE (DELETION) | |
| # ========================================== | |
| st.subheader("Danger Zone") | |
| del_col1, del_col2, del_col3 = st.columns(3) | |
| with del_col1: | |
| if st.button("Delete Valid Data", disabled=not valid_exists, use_container_width=True): | |
| try: | |
| del_res = requests.delete(f"{API_URL}/data/valid") | |
| if del_res.json().get("status") == "success": | |
| st.toast("Validation data deleted.") | |
| st.rerun() | |
| else: | |
| st.error(del_res.json().get("message")) | |
| except Exception as e: | |
| st.error(str(e)) | |
| with del_col2: | |
| if st.button("Delete Test Data", disabled=not test_exists, use_container_width=True): | |
| try: | |
| del_res = requests.delete(f"{API_URL}/data/test") | |
| if del_res.json().get("status") == "success": | |
| st.toast("Test data deleted.") | |
| st.rerun() | |
| else: | |
| st.error(del_res.json().get("message")) | |
| except Exception as e: | |
| st.error(str(e)) | |
| with del_col3: | |
| if st.button("Clear ALL Data", type="primary", disabled=not (valid_exists or test_exists), | |
| use_container_width=True): | |
| try: | |
| del_all_res = requests.delete(f"{API_URL}/data") | |
| if del_all_res.json().get("status") == "success": | |
| st.toast("All server data cleared.") | |
| st.rerun() | |
| else: | |
| st.error(del_all_res.json().get("message")) | |
| except Exception as e: | |
| st.error(str(e)) | |
| def render_dataset_client_summary(): | |
| """ | |
| Renders the Dataset and Client Summary panel. | |
| Displays registered clients, data volumes, and class distributions. | |
| """ | |
| st.title("Dataset / Client Summary") | |
| st.info("Overview of registered clients, their data volume, and label distribution.") | |
| # 1. Fetch data from backend | |
| try: | |
| # Fetch profiles from API (returns List[ClientProfileResponse]) | |
| res_profiles = requests.get(f"{API_URL}/profiles") | |
| # Keep config fetch if other global metadata is needed | |
| res_config = requests.get(f"{API_URL}/config") | |
| profiles_data = res_profiles.json().get("data", []) if res_profiles.status_code == 200 else [] | |
| except Exception as e: | |
| st.error(f"Cannot connect to the backend server. Details: {str(e)}") | |
| return | |
| # Check whether any client has registered | |
| if not profiles_data: | |
| st.warning("No clients have registered their profiles yet. Start your client scripts to register them.") | |
| return | |
| # 2. Process data for visualization (profile data is a list of objects) | |
| total_clients = len(profiles_data) | |
| total_labeled = 0 | |
| total_unlabeled = 0 | |
| profile_list = [] | |
| class_hist_list = [] | |
| # Iterate through List[ClientProfileResponse] | |
| for info in profiles_data: | |
| cid = info.get("client_id") | |
| labeled = info.get("labeled_count", 0) | |
| unlabeled = info.get("unlabeled_count", 0) | |
| total_labeled += labeled | |
| total_unlabeled += unlabeled | |
| # Build table row using fields from ClientProfileResponse | |
| profile_list.append({ | |
| "Client ID": cid, | |
| "Labeled Data": labeled, | |
| "Unlabeled Data": unlabeled, | |
| "Feature Dim": info.get("feature_dim", "N/A"), | |
| "Model Name": info.get("model_name", "N/A"), | |
| # Keep N/A when a field is unavailable in response | |
| "Batch Size": info.get("batch_size", "N/A") | |
| }) | |
| # Build class distribution chart data | |
| hist = info.get("class_hist", {}) | |
| if hist: | |
| for label, count in hist.items(): | |
| class_hist_list.append({ | |
| "Client": cid, | |
| "Class": str(label), | |
| "Count": count | |
| }) | |
| # ========================================== | |
| # OVERALL METRICS | |
| # ========================================== | |
| st.subheader("Global Metrics") | |
| col_m1, col_m2, col_m3 = st.columns(3) | |
| col_m1.metric("Registered Clients", total_clients) | |
| col_m2.metric("Total Labeled Samples", total_labeled) | |
| col_m3.metric("Total Unlabeled Samples", total_unlabeled) | |
| st.markdown("---") | |
| # ========================================== | |
| # CLIENT PROFILE TABLE | |
| # ========================================== | |
| st.subheader("Client Data Overview") | |
| df_profiles = pd.DataFrame(profile_list) | |
| st.dataframe(df_profiles, use_container_width=True, hide_index=True) | |
| st.markdown("---") | |
| # ========================================== | |
| # DATA DISTRIBUTION CHARTS | |
| # ========================================== | |
| st.subheader("Data Distribution Analysis") | |
| col_chart1, col_chart2 = st.columns(2) | |
| with col_chart1: | |
| st.markdown("**Data Volume per Client**") | |
| if not df_profiles.empty: | |
| # Bar chart showing Labeled vs Unlabeled side by side | |
| df_volume = df_profiles.set_index("Client ID")[["Labeled Data", "Unlabeled Data"]] | |
| st.bar_chart(df_volume) | |
| with col_chart2: | |
| st.markdown("**Class Label Distribution**") | |
| if class_hist_list: | |
| df_hist = pd.DataFrame(class_hist_list) | |
| # Convert Class to string so chart X-axis is categorical | |
| df_hist['Class'] = df_hist['Class'].astype(str) | |
| # Pivot by class and client for grouped class count visualization | |
| df_pivot = df_hist.pivot(index='Class', columns='Client', values='Count').fillna(0) | |
| st.bar_chart(df_pivot) | |
| else: | |
| st.info("No class history (class_hist) provided by clients.") | |
| st.markdown("---") | |
| if st.button("Clear All Profiles", type="secondary"): | |
| # Optional: add a confirmation dialog before destructive action | |
| try: | |
| res = requests.delete(f"{API_URL}/profiles") | |
| if res.status_code == 200: | |
| st.success("All client profiles have been cleared.") | |
| st.rerun() | |
| else: | |
| st.error("Cannot clear profiles.") | |
| except Exception as e: | |
| st.error(f"Error: {e}") | |
| def render_client_panels(): | |
| """ | |
| Renders the Client Panels section. | |
| Allows monitoring of individual clients and manual submission of model updates. | |
| """ | |
| st.title("Client Panels") | |
| st.info("Monitor specific clients and manually submit model weights (.pt) on their behalf.") | |
| # 1. Fetch system state to get registered clients | |
| try: | |
| res = requests.get(f"{API_URL}/config") | |
| res_data = res.json() | |
| system_state = res_data.get("data", {}) if res_data.get("status") == "success" else {} | |
| except Exception as e: | |
| st.error(f"Cannot connect to the backend server. Details: {str(e)}") | |
| return | |
| profiles = system_state.get("profiles", {}) | |
| active_run = system_state.get("active_run") | |
| if not profiles: | |
| st.warning("No clients are currently registered. Please register clients first.") | |
| return | |
| # 2. Select a Client | |
| client_ids = list(profiles.keys()) | |
| selected_client = st.selectbox("Select a Client to View/Manage", client_ids) | |
| st.markdown("---") | |
| client_info = profiles[selected_client] | |
| # ========================================== | |
| # CLIENT INFORMATION | |
| # ========================================== | |
| st.subheader(f"Profile: {selected_client}") | |
| col_info1, col_info2, col_info3 = st.columns(3) | |
| col_info1.metric("Labeled Data", client_info.get("labeled_count", 0)) | |
| col_info2.metric("Unlabeled Data", client_info.get("unlabeled_count", 0)) | |
| col_info3.metric("Batch Size", client_info.get("batch_size", "N/A")) | |
| with st.expander("View Raw Profile Data"): | |
| st.json(client_info) | |
| st.markdown("---") | |
| # ========================================== | |
| # MANUAL MODEL SUBMISSION FORM | |
| # ========================================== | |
| st.subheader("Manual Model Submission") | |
| if active_run and active_run.get("status") == "waiting_responses": | |
| st.success( | |
| f"Server is currently waiting for responses for Run: **{active_run.get('run_id')}** (Round {active_run.get('pending_round')})") | |
| with st.form(f"submit_model_form_{selected_client}"): | |
| st.write("Submit weights directly to the server for this client.") | |
| # Form Inputs | |
| num_samples = st.number_input( | |
| "Number of Samples (Weight for FedAvg)", | |
| min_value=1, | |
| value=int(client_info.get("labeled_count", 100)), | |
| help="The number of data samples used to train this model (used for weighting in FedAvg)." | |
| ) | |
| metrics_input = st.text_area( | |
| "Training Metrics (JSON format)", | |
| value='{\n "loss": 0.45,\n "accuracy": 0.82\n}', | |
| help="Provide metrics like loss and accuracy as a JSON string." | |
| ) | |
| uploaded_model = st.file_uploader("Upload Model Weights (.pt)", type=["pt", "bin", "pth"]) | |
| submit_btn = st.form_submit_button("Submit Model Update", type="primary", use_container_width=True) | |
| if submit_btn: | |
| # 1. Validate JSON metrics | |
| try: | |
| parsed_metrics = json.loads(metrics_input) | |
| valid_json_str = json.dumps(parsed_metrics) | |
| except json.JSONDecodeError: | |
| st.error("Invalid JSON format in Training Metrics.") | |
| st.stop() | |
| # 2. Check File | |
| if uploaded_model is None: | |
| st.error("Please upload a model weights file (.pt).") | |
| else: | |
| with st.spinner("Submitting model to server..."): | |
| try: | |
| # Prepare form data | |
| data = { | |
| "client_id": selected_client, | |
| "num_samples": num_samples, | |
| "metrics": valid_json_str | |
| } | |
| # Prepare file | |
| files = { | |
| "file": (uploaded_model.name, uploaded_model.getvalue(), "application/octet-stream") | |
| } | |
| # Send request | |
| submit_res = requests.post(f"{API_URL}/model", data=data, files=files) | |
| res_json = submit_res.json() | |
| if res_json.get("status") == "success": | |
| st.success(res_json.get("message")) | |
| else: | |
| st.error(f"Submission failed: {res_json.get('message')}") | |
| except Exception as e: | |
| st.error(f"Request error: {str(e)}") | |
| else: | |
| st.info( | |
| "Manual submission is disabled. The server is not currently waiting for client responses (no active pending round).") | |
| def render_global_server_panel(): | |
| """ | |
| Renders the Global / Server Panel. | |
| Displays active run details, global model status, and aggregation history. | |
| """ | |
| st.title("Global / Server Panel") | |
| st.info("Monitor the central orchestrator state, active run details, and aggregation history.") | |
| # 1. Fetch system state from backend | |
| try: | |
| res = requests.get(f"{API_URL}/config") | |
| res_data = res.json() | |
| system_state = res_data.get("data", {}) if res_data.get("status") == "success" else {} | |
| except Exception as e: | |
| st.error(f"Cannot connect to the backend server. Details: {str(e)}") | |
| return | |
| active_run = system_state.get("active_run") | |
| history_data = system_state.get("history", {}) | |
| # ========================================== | |
| # ACTIVE RUN DETAILS | |
| # ========================================== | |
| st.subheader("Current Active Run") | |
| if active_run: | |
| status_color = "" if active_run.get("status") not in ["failed", "stopped"] else "" | |
| st.markdown(f"**Run ID:** `{active_run.get('run_id', 'N/A')}`") | |
| col1, col2, col3, col4 = st.columns(4) | |
| col1.metric("Status", f"{status_color} {active_run.get('status', 'N/A').upper()}") | |
| col2.metric("Current Stage", active_run.get("current_stage", active_run.get("stage", "N/A")).upper()) | |
| col3.metric("Current Round", active_run.get("current_round", 0)) | |
| col4.metric("Pending Round", active_run.get("pending_round", "N/A")) | |
| with st.expander("View Raw Active Run State"): | |
| st.json(active_run) | |
| else: | |
| st.info("No active run is currently in progress. Start training in the 'Server Control' panel.") | |
| st.markdown("---") | |
| # ========================================== | |
| # GLOBAL MODEL DOWNLOAD | |
| # ========================================== | |
| st.subheader("Global Model Status") | |
| st.write("Fetch and download the latest aggregated global model weights (`.pt` file).") | |
| if st.button("Fetch Latest Global Model"): | |
| with st.spinner("Connecting to server to fetch model weights..."): | |
| try: | |
| model_res = requests.get(f"{API_URL}/model") | |
| content_type = model_res.headers.get('Content-Type', '') | |
| if 'application/json' in content_type: | |
| err_data = model_res.json() | |
| st.error(f"Cannot fetch model: {err_data.get('message', 'Unknown error')}") | |
| elif model_res.status_code == 200: | |
| st.success("Model weights fetched successfully! Click below to save to disk.") | |
| st.download_button( | |
| label="Save global_model.pt", | |
| data=model_res.content, | |
| file_name="global_model.pt", | |
| mime="application/octet-stream", | |
| type="primary" | |
| ) | |
| else: | |
| st.error(f"Unexpected error: HTTP {model_res.status_code}") | |
| except Exception as e: | |
| st.error(f"Request error: {str(e)}") | |
| st.markdown("---") | |
| # ========================================== | |
| # AGGREGATION HISTORY TABLE | |
| # ========================================== | |
| st.subheader("Aggregation History") | |
| # Read round records from history_data | |
| rounds_list = history_data.get("rounds", []) | |
| if rounds_list: | |
| try: | |
| df_history = pd.DataFrame(rounds_list) | |
| # Hide client_metrics because long JSON blobs reduce table readability. | |
| if "client_metrics" in df_history.columns: | |
| df_history = df_history.drop(columns=["client_metrics"]) | |
| # Keep 'round' as the first column. | |
| cols = ['round'] + [c for c in df_history.columns if c != 'round'] | |
| df_history = df_history[cols] | |
| st.dataframe(df_history, use_container_width=True, hide_index=True) | |
| except Exception as e: | |
| st.error(f"Could not parse history table: {e}") | |
| else: | |
| st.info("No aggregation history available yet. The table will populate once rounds are completed.") | |
| def render_charts(): | |
| st.title("Training Performance Analytics") | |
| try: | |
| # Fetch chart data from API | |
| res = requests.get(f"{API_URL}/charts") | |
| history_data = res.json().get("data", {}) | |
| except Exception as e: | |
| st.error(f"Cannot connect to backend: {e}") | |
| return | |
| # Read list of rounds from 'rounds' key | |
| rounds_list = history_data.get("rounds", []) | |
| if not rounds_list: | |
| st.warning("No training rounds recorded yet. Start training to see results.") | |
| return | |
| # 1. Convert list to DataFrame | |
| df = pd.DataFrame(rounds_list) | |
| # Use round as chart X-axis index | |
| df.set_index("round", inplace=True) | |
| # 2. Tabs | |
| tab1, tab2, tab3 = st.tabs(["Global Accuracy", "Local Loss (Clients)", "Global F1 Score"]) | |
| # ========================================== | |
| # TAB 1: GLOBAL ACCURACY | |
| # ========================================== | |
| with tab1: | |
| st.subheader("Global Accuracy (Validation vs Test)") | |
| # Check columns before plotting to avoid missing-key errors | |
| cols_to_plot = [c for c in ["val_accuracy", "test_accuracy"] if c in df.columns] | |
| if cols_to_plot: | |
| st.line_chart(df[cols_to_plot]) | |
| else: | |
| st.info("Accuracy data is not available yet.") | |
| # ========================================== | |
| # TAB 2: LOCAL LOSS (multi-client) | |
| # ========================================== | |
| with tab2: | |
| st.subheader("Local Training Loss") | |
| all_clients = set() | |
| for r in rounds_list: | |
| c_metrics = r.get("client_metrics", {}) | |
| if isinstance(c_metrics, dict): | |
| all_clients.update(c_metrics.keys()) | |
| all_clients = sorted(list(all_clients)) | |
| if not all_clients: | |
| st.info("No client metrics found in history.") | |
| else: | |
| # Select client to inspect | |
| selected_client = st.selectbox("Select Client to inspect:", all_clients) | |
| # Extract selected client loss per round | |
| loss_data = {} | |
| for r_idx, row in df.iterrows(): | |
| c_metrics = row.get("client_metrics", {}).get(selected_client, {}) | |
| if c_metrics: | |
| p_loss = c_metrics.get("pretrain_loss", 0) | |
| s_loss = c_metrics.get("ssl_labeled_loss", 0) | |
| # Convert 0 to None so missing phase points are not connected | |
| loss_data[r_idx] = { | |
| "Pretrain Loss": p_loss if p_loss > 0 else None, | |
| "SSL Loss": s_loss if s_loss > 0 else None | |
| } | |
| if loss_data: | |
| df_loss = pd.DataFrame.from_dict(loss_data, orient='index') | |
| # Render as two separate charts | |
| st.markdown("#### Phase 1: Pre-training Loss") | |
| st.caption("Phase where the model learns from local labeled data.") | |
| # Plot only Pretrain Loss column | |
| st.line_chart(df_loss[["Pretrain Loss"]]) | |
| st.markdown("#### Phase 2: SSL Labeled Loss") | |
| st.caption("Semi-supervised learning phase.") | |
| # Plot only SSL Loss column | |
| st.line_chart(df_loss[["SSL Loss"]]) | |
| else: | |
| st.warning(f"No loss data available for {selected_client} in the selected rounds.") | |
| # ========================================== | |
| # TAB 3: GLOBAL F1 SCORE | |
| # ========================================== | |
| with tab3: | |
| st.subheader("Global Weighted F1 Score") | |
| cols_to_plot_f1 = [c for c in ["val_WF1", "test_WF1"] if c in df.columns] | |
| if cols_to_plot_f1: | |
| st.line_chart(df[cols_to_plot_f1]) | |
| else: | |
| st.info("F1 Score data is not available yet.") | |
| # ========================================== | |
| # 3. Show detailed raw metrics table | |
| # ========================================== | |
| st.markdown("---") | |
| with st.expander("View Full Raw Metrics Table"): | |
| # Copy DataFrame and hide verbose client_metrics column for readability | |
| df_display = df.copy() | |
| if "client_metrics" in df_display.columns: | |
| df_display.drop(columns=["client_metrics"], inplace=True) | |
| st.dataframe(df_display, use_container_width=True) | |
| def render_inference_demo(): | |
| st.header("Emotion Recognition Inference") | |
| st.markdown("Upload an audio file to predict emotion using the trained global AI model.") | |
| col_config, col_main = st.columns([1, 2], gap="large") | |
| with col_config: | |
| st.subheader("Model Configuration") | |
| with st.container(border=True): | |
| fl_method = st.selectbox("FL Method", ["FedAvg", "FedProx"]) | |
| # Available label-ratio tokens used by checkpoints. | |
| label_ratio = st.selectbox("Label Ratio", ["0p3", "0p5", "0p7", "1"], index=3) | |
| num_clients = st.number_input("Num Clients", min_value=1, max_value=100, value=5) | |
| model_name = st.selectbox("Model Architecture", ["fedalmer", "memocmt", "fleser", "threemser", "cemobam"]) | |
| device_pref = st.radio("Processing Device", ["cpu", "cuda"], horizontal=True) | |
| with col_main: | |
| st.subheader("Audio Upload") | |
| uploaded_file = st.file_uploader("Drag and drop an audio file here (.wav, .mp3, .flac)", | |
| type=["wav", "mp3", "flac"]) | |
| # Preview player appears only after upload. | |
| if uploaded_file is not None: | |
| st.audio(uploaded_file, format="audio/wav") | |
| st.markdown("<br>", unsafe_allow_html=True) | |
| if st.button("Analyze and Predict Emotion", type="primary", use_container_width=True): | |
| if uploaded_file is None: | |
| st.warning("Please upload an audio file before running prediction.") | |
| else: | |
| with st.spinner("Analyzing audio... (first model load may take longer)"): | |
| try: | |
| # 1. Build form payload. | |
| data_payload = { | |
| "fl_method": fl_method, | |
| "label_ratio_token": label_ratio, | |
| "num_clients": num_clients, | |
| "model_name": model_name, | |
| "device_pref": device_pref | |
| } | |
| # 2. Build file payload. | |
| files_payload = { | |
| "file": (uploaded_file.name, uploaded_file.getvalue(), "audio/wav") | |
| } | |
| # 3. Send request to FastAPI. | |
| res = requests.post(f"{API_URL}/predict", data=data_payload, files=files_payload, timeout=30.00) | |
| # 4. Handle response. | |
| if res.status_code == 200: | |
| res_json = res.json() | |
| if res_json.get("status") == "success": | |
| result = res_json.get("data", {}) | |
| st.success("Prediction completed.") | |
| st.divider() | |
| # Show result metrics. | |
| m1, m2 = st.columns(2) | |
| emotion_str = str(result.get('emotion', 'Unknown')).title() | |
| confidence = result.get('confidence', 0) | |
| m1.metric(label="Predicted Emotion", value=emotion_str) | |
| m2.metric(label="Confidence", value=f"{confidence:.2%}") | |
| with st.expander("ASR Transcript", expanded=True): | |
| st.write(f"*{result.get('transcript', 'Speech not recognized')}*") | |
| with st.expander("Probability Distribution"): | |
| probs = result.get("probabilities", {}) | |
| if probs: | |
| # Render simple probability bar chart. | |
| import pandas as pd | |
| df_probs = pd.DataFrame(list(probs.items()), columns=["Emotion", "Probability"]) | |
| df_probs.set_index("Emotion", inplace=True) | |
| st.bar_chart(df_probs) | |
| else: | |
| st.json(probs) | |
| with st.expander("System Metadata"): | |
| st.json({ | |
| "Loaded Checkpoint": result.get("checkpoint"), | |
| "Audio Duration": f"{result.get('audio_seconds', 0):.2f} seconds", | |
| "Model": result.get("model_name") | |
| }) | |
| else: | |
| st.error(f"Server Error: {res_json.get('message')}") | |
| else: | |
| st.error(f"HTTP Error {res.status_code}: {res.text}") | |
| except Exception as e: | |
| st.error(f"System connection error: {str(e)}") | |
| if choice == "Server Control": | |
| # st.title("Server Control") | |
| render_server_controls() | |
| elif choice == "Server Control Manual": | |
| render_server_controls_manual() | |
| elif choice == "Server Validation/Test Upload": | |
| # st.title("Server Validation/Test Upload") | |
| # st.info("Module to upload and manage the .pkl files for validation and testing.") | |
| render_server_validation_test_upload() | |
| elif choice == "Dataset / Client Summary": | |
| # st.title("Dataset / Client Summary") | |
| # st.info("Module to view registered clients, data distribution, and profiles.") | |
| # Content will be implemented in the next step | |
| render_dataset_client_summary() | |
| # elif choice == "Client Panels": | |
| # # st.title("Client Panels") | |
| # # st.info("Module to monitor individual client states and submit models manually if needed.") | |
| # # Content will be implemented in the next step | |
| # render_client_panels() | |
| elif choice == "Global / Server Panel": | |
| # st.title("Global / Server Panel") | |
| # st.info("Module to view current global model status, active run state, and aggregation details.") | |
| # Content will be implemented in the next step | |
| render_global_server_panel() | |
| elif choice == "Charts": | |
| render_charts() | |
| # elif choice == "Inference Demo": | |
| # render_inference_demo() |