Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import pandas as pd | |
| import requests | |
| import streamlit as st | |
| API_URL = os.getenv("API_URL", "http://127.0.0.1:8002/api") | |
| LABEL_DISPLAY_MAP = { | |
| "ang": "anger", | |
| "hap": "happiness/excitement", | |
| "sad": "sadness", | |
| "neu": "neutral", | |
| "exc": "happiness/excitement", | |
| 0: "anger", | |
| 1: "happiness/excitement", | |
| 2: "sadness", | |
| 3: "neutral", | |
| "0": "anger", | |
| "1": "happiness/excitement", | |
| "2": "sadness", | |
| "3": "neutral", | |
| } | |
| def _map_label_display(label): | |
| if label in LABEL_DISPLAY_MAP: | |
| return LABEL_DISPLAY_MAP[label] | |
| key = str(label) | |
| return LABEL_DISPLAY_MAP.get(key, key) | |
| # ========================================== | |
| # 1. PAGE SETTINGS | |
| # ========================================== | |
| st.set_page_config( | |
| page_title="Federated Learning Client", | |
| layout="wide" | |
| ) | |
| # ========================================== | |
| # 2. SIDEBAR (GLOBAL ACTIONS) | |
| # ========================================== | |
| client_id_display = "Unknown" | |
| try: | |
| res_config = requests.get(f"{API_URL}/info").json() | |
| if res_config.get("status") == "success": | |
| client_id_display = res_config.get("data", {}).get("client_id", "Unknown") | |
| except Exception: | |
| pass | |
| st.sidebar.title("Client") | |
| st.sidebar.success(f"ID: **{client_id_display}**") | |
| st.sidebar.markdown("---") | |
| if st.sidebar.button("Refresh Application", use_container_width=True): | |
| st.rerun() | |
| def simple_block_header(text): | |
| st.markdown( | |
| f""" | |
| <div style=" | |
| background: linear-gradient(90deg, #1E293B 0%, #0F172A 100%); | |
| padding: 12px 20px; | |
| border-left: 6px solid #6366F1; /* Màu Indigo của Tab */ | |
| border-radius: 4px 15px 15px 4px; | |
| margin-top: 20px; | |
| margin-bottom: 25px; | |
| box-shadow: 5px 5px 15px rgba(0,0,0,0.5); | |
| "> | |
| <h3 style=" | |
| color: #FFFFFF; | |
| margin: 0; | |
| font-weight: 600; | |
| font-size: 28px; | |
| letter-spacing: 0.5px; | |
| "> | |
| {text} | |
| </h3> | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| st.markdown( | |
| """ | |
| <style> | |
| /* 1. Định dạng chung cho tất cả các nút Tab */ | |
| button[data-baseweb="tab"] { | |
| padding: 12px 25px !important; | |
| border-radius: 8px 8px 0px 0px !important; | |
| margin-right: 5px !important; | |
| border: none !important; | |
| transition: background-color 0.3s ease !important; | |
| } | |
| /* 2. Khi Tab ĐƯỢC CHỌN (Màu Indigo rực rỡ) */ | |
| button[data-baseweb="tab"][aria-selected="true"] { | |
| background-color: #475569 !important; /* Màu Slate (Xanh đá) rất dịu */ | |
| border-bottom: 2px solid #94A3B8 !important; /* Thêm vạch nhạt phía dưới để tạo điểm nhấn */ | |
| } | |
| /* Chữ của Tab được chọn */ | |
| button[data-baseweb="tab"][aria-selected="true"] p { | |
| font-size: 22px !important; | |
| font-weight: 700 !important; | |
| color: #FFFFFF !important; | |
| } | |
| /* 3. Khi Tab CHƯA chọn (Màu trùng với nền Simple Block Header) */ | |
| button[data-baseweb="tab"][aria-selected="false"] { | |
| background-color: #1E1E1E !important; /* Cùng màu với cái header của bạn */ | |
| color: #666666 !important; | |
| } | |
| /* Chữ của Tab chưa chọn (cho mờ đi để nổi bật tab chính) */ | |
| button[data-baseweb="tab"][aria-selected="false"] p { | |
| font-size: 20px !important; | |
| font-weight: 500 !important; | |
| color: #777777 !important; | |
| } | |
| /* Hiệu ứng khi di chuột qua */ | |
| button[data-baseweb="tab"][aria-selected="false"]:hover { | |
| background-color: #2D2E3A !important; | |
| color: #FFFFFF !important; | |
| } | |
| /* 4. Dọn dẹp các thành phần thừa của Streamlit */ | |
| div[data-baseweb="tab-highlight"] { | |
| background-color: transparent !important; | |
| } | |
| /* Đường line mờ chạy ngang dưới toàn bộ thanh tab */ | |
| div[role="tablist"] { | |
| border-bottom: 2px solid #1E1E1E !important; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| def render_status_timeline(key_prefix: str = "status"): | |
| """Shared helper to render the status timeline box with Clear functionality.""" | |
| st.subheader("Execution Timeline") | |
| col_title, col_refresh, col_clear = st.columns([2, 1, 1]) | |
| with col_refresh: | |
| if st.button("Refresh", use_container_width=True, key=f"{key_prefix}_refresh_timeline"): | |
| st.rerun() | |
| with col_clear: | |
| if st.button( | |
| "Clear", | |
| use_container_width=True, | |
| help="Clear the entire execution timeline", | |
| key=f"{key_prefix}_clear_timeline", | |
| ): | |
| try: | |
| res = requests.delete(f"{API_URL}/history_status").json() | |
| if res.get("status") == "success": | |
| st.toast("Execution history cleared successfully.") | |
| st.rerun() | |
| else: | |
| st.error("Could not clear history.") | |
| except Exception as e: | |
| st.error(f"Clear action failed: {e}") | |
| try: | |
| # Load and display status timeline data. | |
| res = requests.get(f"{API_URL}/history_status").json() | |
| if res.get("status") == "success": | |
| history = res.get("data", []) | |
| if not history: | |
| st.info("No events yet.") | |
| else: | |
| with st.container(height=250): | |
| for item in history: | |
| status_tag = "[TRAINING]" if item["is_training"] else "[DONE]" | |
| if "error" in item.get("message", "").lower(): | |
| status_tag = "[ERROR]" | |
| round_idx = item.get("round_idx") | |
| round_prefix = f"**Round {round_idx}**: " if round_idx is not None else "" | |
| st.markdown( | |
| f"`[{item['timestamp']}]` {status_tag} {round_prefix}{item['message']}") | |
| except Exception as e: | |
| st.error(f"Connection error while fetching timeline: {e}") | |
| def trigger_disconnect(button_key: str = "disconnect_btn"): | |
| """Call local API to disconnect from the global server.""" | |
| if st.button("Disconnect from Global Server", use_container_width=True, key=button_key): | |
| with st.spinner("Disconnecting..."): | |
| try: | |
| res = requests.post(f"{API_URL}/disconnect") | |
| data = res.json() | |
| if data.get("status") == "success": | |
| st.success(data.get("message", "Disconnected successfully.")) | |
| st.rerun() | |
| else: | |
| st.error(data.get("message", "Disconnect request failed.")) | |
| except Exception as e: | |
| st.error(f"Failed to disconnect. Error: {e}") | |
| def _parse_int_field(name: str, raw: str, min_value: int = None, max_value: int = None) -> int: | |
| try: | |
| value = int(str(raw).strip()) | |
| except Exception: | |
| raise ValueError(f"{name} must be an integer.") | |
| if min_value is not None and value < min_value: | |
| raise ValueError(f"{name} must be >= {min_value}.") | |
| if max_value is not None and value > max_value: | |
| raise ValueError(f"{name} must be <= {max_value}.") | |
| return value | |
| def _parse_float_field(name: str, raw: str, min_value: float = None, max_value: float = None) -> float: | |
| try: | |
| value = float(str(raw).strip()) | |
| except Exception: | |
| raise ValueError(f"{name} must be a number.") | |
| if min_value is not None and value < min_value: | |
| raise ValueError(f"{name} must be >= {min_value}.") | |
| if max_value is not None and value > max_value: | |
| raise ValueError(f"{name} must be <= {max_value}.") | |
| return value | |
| def render_local_training_configuration(): | |
| st.subheader("Local Training Configuration") | |
| current_config = {} | |
| try: | |
| config_res = requests.get(f"{API_URL}/config").json() | |
| if config_res.get("status") == "success": | |
| current_config = config_res.get("data", {}) | |
| except Exception: | |
| st.warning("Failed to fetch the current configuration from the Local Client.") | |
| if not current_config: | |
| return | |
| with st.form("config_form_shared"): | |
| c1, c2 = st.columns(2) | |
| with c1: | |
| sample_limit_raw = st.text_input("Sample Limit", value=str(current_config.get("sample_limit", 150)), key="ui_sample_limit") | |
| labeled_ratio_raw = st.text_input("Labeled Ratio", value=str(float(current_config.get("labeled_ratio", 0.5))), key="ui_labeled_ratio") | |
| with c2: | |
| batch_size_raw = st.text_input("Batch Size", value=str(current_config.get("batch_size", 16)), key="ui_batch_size") | |
| seed_raw = st.text_input("Seed", value=str(current_config.get("seed", 42)), key="ui_seed") | |
| # Additional epoch settings shown and editable on the same form | |
| e1, e2 = st.columns(2) | |
| with e1: | |
| pretrain_epochs_raw = st.text_input( | |
| "Local Pretrain Epochs", | |
| value=str(int(current_config.get("local_pretrain_epochs", 1))), | |
| key="ui_local_pretrain_epochs", | |
| ) | |
| with e2: | |
| ssl_epochs_raw = st.text_input( | |
| "Local SSL Epochs", | |
| value=str(int(current_config.get("local_ssl_epochs", 1))), | |
| key="ui_local_ssl_epochs", | |
| ) | |
| submitted = st.form_submit_button("Save Configuration", type="primary", use_container_width=True) | |
| if submitted: | |
| try: | |
| sample_limit = _parse_int_field("Sample Limit", sample_limit_raw, min_value=20, max_value=200000) | |
| labeled_ratio = _parse_float_field("Labeled Ratio", labeled_ratio_raw, min_value=0.05, max_value=0.95) | |
| batch_size = _parse_int_field("Batch Size", batch_size_raw, min_value=2, max_value=1024) | |
| seed = _parse_int_field("Seed", seed_raw, min_value=0, max_value=100000) | |
| # Allow direct numeric text input for epochs and validate | |
| local_pretrain_epochs = _parse_int_field("Local Pretrain Epochs", pretrain_epochs_raw, min_value=1, max_value=100) | |
| local_ssl_epochs = _parse_int_field("Local SSL Epochs", ssl_epochs_raw, min_value=1, max_value=100) | |
| except ValueError as ve: | |
| st.error(str(ve)) | |
| return | |
| payload = { | |
| "sample_limit": sample_limit, | |
| "labeled_ratio": labeled_ratio, | |
| "batch_size": batch_size, | |
| "seed": seed, | |
| "model_name": current_config.get("model_name", "fedalmer"), | |
| "lr": float(current_config.get("lr", 1e-3)), | |
| "local_pretrain_epochs": int(local_pretrain_epochs), | |
| "local_ssl_epochs": int(local_ssl_epochs), | |
| } | |
| try: | |
| put_res = requests.put(f"{API_URL}/config", json=payload).json() | |
| if put_res.get("status") == "success": | |
| st.success("Configuration updated successfully!") | |
| st.rerun() | |
| else: | |
| st.error(f"Update error: {put_res.get('message')}") | |
| except Exception as e: | |
| st.error(f"API call error: {e}") | |
| def render_federated_ssl_client(): | |
| st.subheader("Federated Client (Auto Mode)") | |
| st.markdown("Manage training configuration and connection to the Global Server.") | |
| # --- 1. CONNECT TO GLOBAL SERVER --- | |
| st.subheader("Connectivity") | |
| # Show guidance text above the action button. | |
| st.info( | |
| "Click the button to activate the Background Task. The client will continuously listen for 'NEW_ROUND' commands from the Global Server." | |
| ) | |
| # Keep the connect/disconnect actions close to each other. | |
| col_connect, col_disconnect = st.columns(2) | |
| with col_connect: | |
| if st.button("Connect to Server", use_container_width=True, type="primary", key="connect_auto_btn"): | |
| with st.spinner("Connecting..."): | |
| try: | |
| res = requests.get(f"{API_URL}/connect_to_global") | |
| data = res.json() | |
| if data.get("status") == "success": | |
| st.success(data.get("message")) | |
| else: | |
| st.error(data.get("message")) | |
| except Exception as e: | |
| st.error(f"Failed to connect to API. Error: {e}") | |
| with col_disconnect: | |
| trigger_disconnect("disconnect_auto_btn") | |
| render_status_timeline("auto") | |
| def render_federated_ssl_client_manual(): | |
| st.subheader("Federated Client (Manual Mode)") | |
| st.markdown("Manual training mode: receive commands from Admin and submit model updates manually.") | |
| # --- 1. CONNECT TO GLOBAL SERVER (MANUAL) --- | |
| st.subheader("Connectivity") | |
| st.info("Enable manual listening mode. The client will wait for 'NEW_ROUND_MANUAL' commands from Admin Dashboard.") | |
| col_connect_manual, col_disconnect_manual = st.columns(2) | |
| with col_connect_manual: | |
| if st.button("Connect to Server", use_container_width=True, type="primary", key="connect_manual_btn"): | |
| with st.spinner("Connecting to manual stream..."): | |
| try: | |
| res = requests.get(f"{API_URL}/connect_to_global_manual") | |
| data = res.json() | |
| if data.get("status") == "success": | |
| st.success(data.get("message")) | |
| else: | |
| st.error(data.get("message")) | |
| except Exception as e: | |
| st.error(f"Connection error: {e}") | |
| with col_disconnect_manual: | |
| trigger_disconnect("disconnect_manual_btn") | |
| st.divider() | |
| # --- SHOW TIMELINE --- | |
| render_status_timeline("manual") | |
| st.divider() | |
| # --- 2. MANUAL SUBMISSION --- | |
| st.subheader("Model Submission") | |
| # 2.1 Check current status to decide whether submission is allowed. | |
| can_submit = False | |
| current_round = 0 | |
| try: | |
| status_res = requests.get(f"{API_URL}/last_status").json() | |
| if status_res.get("status") == "success": | |
| status_data = status_res.get("data") | |
| if status_data: | |
| # Check service status text to detect when manual submit is expected. | |
| if "Waiting for Manual Submit" in status_data.get("message", ""): | |
| can_submit = True | |
| current_round = status_data.get("round_idx") or 0 | |
| except: | |
| pass | |
| # 2.2 Render state-specific UI. | |
| if can_submit: | |
| st.success(f"Round {current_round} training is complete. Client is ready to submit to Global Server.") | |
| if st.button(f"Confirm Submission for Round {current_round}", type="primary", use_container_width=True): | |
| with st.spinner("Compressing and uploading model to Global Server..."): | |
| try: | |
| res = requests.post(f"{API_URL}/manual_submit") | |
| data = res.json() | |
| if data.get("status") == "success": | |
| st.success(data.get("message")) | |
| st.rerun() # Refresh so submit button hides automatically. | |
| else: | |
| st.error(data.get("message")) | |
| except Exception as e: | |
| st.error(f"Submission error: {e}") | |
| else: | |
| st.warning( | |
| "No new trained model is ready for submission yet. Please wait for the next server training command.") | |
| # Show disabled submit button to communicate available action. | |
| st.button("Submit Trained Model to Global", disabled=True, use_container_width=True) | |
| def render_client_data_upload(): | |
| simple_block_header("Client Data Upload") | |
| # st.header("Client Data Upload") | |
| st.markdown("Manage and extend local training data for this client.") | |
| dataset_locked = False | |
| try: | |
| lock_res = requests.get(f"{API_URL}/last_status").json() | |
| if lock_res.get("status") == "success": | |
| dataset_locked = bool((lock_res.get("data") or {}).get("is_training", False)) | |
| except Exception: | |
| dataset_locked = False | |
| if dataset_locked: | |
| st.warning("Dataset is locked while local training is running. Upload/Add/Delete is temporarily disabled.") | |
| # Khởi tạo state để quản lý dữ liệu của NHIỀU file cùng lúc | |
| if "samples_state" not in st.session_state: | |
| st.session_state.samples_state = {} | |
| # Use tabs to organize upload workflows. | |
| tab_bulk, tab_single = st.tabs([ | |
| "Upload Batch (.pkl)", | |
| "Add Samples", | |
| ]) | |
| # ========================================== | |
| # TAB 1: BATCH UPLOAD (Giữ nguyên) | |
| # ========================================== | |
| with tab_bulk: | |
| st.subheader("Upload Batch Data (.pkl)") | |
| split_name = st.radio( | |
| "Select split:", ["labeled", "unlabeled"], horizontal=True, key="bulk_split" | |
| ) | |
| uploaded_file = st.file_uploader(f"Upload {split_name} file (.pkl)", type=["pkl"]) | |
| if st.button("Upload .pkl", type="primary", disabled=dataset_locked): | |
| if uploaded_file: | |
| with st.spinner("Processing..."): | |
| files = {"file": (uploaded_file.name, uploaded_file.getvalue(), "application/octet-stream")} | |
| res = requests.put(f"{API_URL}/data/{split_name}", files=files) | |
| if res.status_code == 200 and res.json().get("status") == "success": | |
| st.success(res.json().get("message")) | |
| else: | |
| st.error(res.json().get("message", "Upload failed.")) | |
| else: | |
| st.warning("Please select a file.") | |
| # ========================================== | |
| # TAB 2: ADD SAMPLES (Multi-file Upload) | |
| # ========================================== | |
| with tab_single: | |
| st.subheader("Add Samples (WAV / Record)") | |
| audio_source = st.radio("Choose audio source:", ["Upload .wav (Multiple)", "Record Audio"], horizontal=True) | |
| # Hỗ trợ lấy 1 hoặc nhiều file đưa vào chung 1 list raw_audios | |
| raw_audios = [] | |
| if audio_source == "Upload .wav (Multiple)": | |
| raw_audios = st.file_uploader("Upload .wav files", type=["wav"], accept_multiple_files=True) | |
| else: | |
| if hasattr(st, "audio_input"): | |
| recorded = st.audio_input("Record your voice") | |
| if recorded is not None: | |
| raw_audios = [recorded] | |
| else: | |
| st.warning("Your Streamlit version does not support audio recording.") | |
| if raw_audios: | |
| st.divider() | |
| st.markdown(f"**Processing {len(raw_audios)} sample(s):**") | |
| for i, audio_file in enumerate(raw_audios): | |
| # Lấy tên file hoặc tạo tên mặc định nếu ghi âm | |
| filename = getattr(audio_file, "name", f"recorded_audio_{i}.wav") | |
| file_key = f"{i}_{filename}" # Tạo key duy nhất cho mỗi file | |
| # Khởi tạo state riêng cho file này nếu chưa có | |
| if file_key not in st.session_state.samples_state: | |
| st.session_state.samples_state[file_key] = { | |
| "transcript": "", | |
| "saved": False | |
| } | |
| state = st.session_state.samples_state[file_key] | |
| # Trạng thái icon: Nếu đã lưu thì hiện tick xanh | |
| status_icon = "(Saved)" if state["saved"] else "(Pending)" | |
| # Bọc mỗi file trong 1 expander (đóng lại nếu đã save xong để gọn màn hình) | |
| with st.expander(f"Audio: {filename} - {status_icon}", expanded=not state["saved"]): | |
| if state["saved"]: | |
| st.success(f"Sample '{filename}' has been successfully added to the dataset!") | |
| continue # Bỏ qua không render nút bấm nữa nếu đã lưu | |
| # --- UI cho file chưa lưu --- | |
| st.audio(audio_file.getvalue(), format="audio/wav") | |
| # 1. Nút Transcribe | |
| if st.button("Transcribe Audio", key=f"transcribe_{file_key}"): | |
| with st.spinner("Running speech recognition..."): | |
| files = {"audio_file": (filename, audio_file.getvalue(), "audio/wav")} | |
| res = requests.post(f"{API_URL}/inference/transcribe", files=files) | |
| if res.status_code == 200 and res.json().get("status") == "success": | |
| data = res.json().get("data", {}) | |
| # LẤY TEXT TỪ API VÀ ÉP VÀO WIDGET KEY ĐỂ HIỂN THỊ LUÔN | |
| transcript_result = data.get("transcript", "") | |
| st.session_state.samples_state[file_key]["transcript"] = transcript_result | |
| st.session_state[f"text_{file_key}"] = transcript_result # <--- FIX Ở ĐÂY | |
| st.rerun() # Refresh giao diện để hiện chữ | |
| else: | |
| st.error("Transcription API call failed.") | |
| # 2. Text Area (Editable) | |
| edited_text = st.text_area( | |
| "Transcript (editable):", | |
| value=state["transcript"], | |
| key=f"text_{file_key}" | |
| ) | |
| # 3. Labeling | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| target_split = st.selectbox( | |
| "Add to Split:", ["labeled", "unlabeled"], | |
| key=f"split_{file_key}" | |
| ) | |
| with col2: | |
| emotion_label = st.selectbox( | |
| "Emotion Label:", | |
| ["ang", "hap", "sad", "neu"], | |
| disabled=(target_split == "unlabeled"), | |
| key=f"label_{file_key}" | |
| ) | |
| # 4. Save Button | |
| if st.button( | |
| "Save Sample", | |
| type="primary", | |
| key=f"save_{file_key}", | |
| use_container_width=True, | |
| disabled=dataset_locked, | |
| ): | |
| if not edited_text.strip(): | |
| st.warning("Transcript cannot be empty.") | |
| else: | |
| with st.spinner(f"Extracting features & saving {filename}..."): | |
| files = {"audio_file": (filename, audio_file.getvalue(), "audio/wav")} | |
| data_payload = { | |
| "text": edited_text, | |
| "label": emotion_label if target_split == "labeled" else "" | |
| } | |
| res = requests.post(f"{API_URL}/data/{target_split}/append", files=files, | |
| data=data_payload) | |
| if res.status_code == 200 and res.json().get("status") == "success": | |
| # Đánh dấu đã lưu thành công | |
| st.session_state.samples_state[file_key]["saved"] = True | |
| st.session_state.samples_state[file_key]["transcript"] = edited_text | |
| st.rerun() | |
| else: | |
| st.error(f"Save failed: {res.json().get('message')}") | |
| else: | |
| st.info("Please select or record audio files to begin.") | |
| st.divider() | |
| # st.subheader("Manage Existing Data") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| if st.button("Delete Labeled", use_container_width=True, disabled=dataset_locked): | |
| res = requests.delete(f"{API_URL}/data/labeled").json() | |
| if res.get("status") == "success": | |
| st.success("Deleted labeled data.") | |
| else: | |
| st.error(res.get("message", "Delete labeled data failed.")) | |
| with col2: | |
| if st.button("Delete Unlabeled", use_container_width=True, disabled=dataset_locked): | |
| res = requests.delete(f"{API_URL}/data/unlabeled").json() | |
| if res.get("status") == "success": | |
| st.success("Deleted unlabeled data.") | |
| else: | |
| st.error(res.get("message", "Delete unlabeled data failed.")) | |
| def render_private_dataset_summary(): | |
| simple_block_header("Private Dataset Summary") | |
| st.markdown("Overview of your local dataset, including data distribution and sample counts.") | |
| try: | |
| with st.spinner("Fetching dataset summary from Local Server..."): | |
| res = requests.get(f"{API_URL}/profile") | |
| if res.status_code == 200: | |
| res_json = res.json() | |
| data = res_json.get("data") | |
| # If data exists, the profile has been built and we can show the charts | |
| if data: | |
| st.subheader("Sample Overview") | |
| col1, col2, col3 = st.columns(3) | |
| labeled_count = data.get("labeled_count", 0) | |
| unlabeled_count = data.get("unlabeled_count", 0) | |
| total_count = labeled_count + unlabeled_count | |
| col1.metric("Labeled Samples", f"{labeled_count:,}") | |
| col2.metric("Unlabeled Samples", f"{unlabeled_count:,}") | |
| col3.metric("Total Samples", f"{total_count:,}") | |
| st.divider() | |
| # Display Bar Chart for Class Distribution | |
| st.subheader("Labeled Data Distribution") | |
| st.markdown("Histogram of classes available in the labeled dataset.") | |
| class_hist = data.get("class_hist", {}) | |
| if class_hist: | |
| mapped_hist = { _map_label_display(k): v for k, v in class_hist.items() } | |
| # Convert dictionary to Pandas DataFrame for Streamlit bar chart | |
| df = pd.DataFrame.from_dict(mapped_hist, orient='index', columns=['Sample Count']) | |
| st.bar_chart(df) | |
| else: | |
| st.info("No class distribution data available. The dataset might be empty.") | |
| # Expandable section for raw JSON data | |
| with st.expander("View Full Profile Metadata"): | |
| st.json(data) | |
| else: | |
| # Profile not built yet | |
| st.warning("Dataset profile has not been built yet.") | |
| st.info( | |
| "Please go to the 'Publish Client Profile To Server' tab to generate and publish your data profile first.") | |
| else: | |
| st.error(f"API Error (HTTP {res.status_code}): {res.text}") | |
| except Exception as e: | |
| st.error(f"Connection Error to Local Server: {e}") | |
| def render_publish_client_profile(): | |
| simple_block_header("Publish Client Profile To Server") | |
| st.markdown( | |
| "Calculate local dataset characteristics and publish the profile. The Global Server requires this information to assign appropriate training tasks and aggregate weights properly.") | |
| # --------------------------------------------------------- | |
| # PART 1: CURRENT STATUS | |
| # --------------------------------------------------------- | |
| st.subheader("1. Current Profile Status") | |
| is_published = False | |
| try: | |
| res = requests.get(f"{API_URL}/profile") | |
| if res.status_code == 200: | |
| res_json = res.json() | |
| data = res_json.get("data") | |
| if data: | |
| is_published = True | |
| st.success("A profile is currently published and visible to the Global Server.") | |
| st.write(f"**Last Published At:** {data.get('published_at', 'Unknown time')}") | |
| else: | |
| st.info( | |
| "No profile is currently published. The Global Server cannot include this client in training rounds.") | |
| else: | |
| st.error(f"Failed to fetch profile status (HTTP {res.status_code}): {res.text}") | |
| except Exception as e: | |
| st.error(f"Connection Error to Local Server: {e}") | |
| st.divider() | |
| # --------------------------------------------------------- | |
| # PART 2: ACTIONS | |
| # --------------------------------------------------------- | |
| st.subheader("2. Profile Actions") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| # Button to compute and publish the profile | |
| if st.button("Generate and Publish Profile", type="primary", use_container_width=True): | |
| with st.spinner("Analyzing local dataset and building profile..."): | |
| try: | |
| post_res = requests.post(f"{API_URL}/profile") | |
| if post_res.status_code == 200: | |
| post_data = post_res.json() | |
| if post_data.get("status") == "success": | |
| st.success(post_data.get("message")) | |
| st.rerun() # Refresh the page to update the status section | |
| else: | |
| st.error(f"System Error: {post_data.get('message')}") | |
| else: | |
| st.error(f"API Error (HTTP {post_res.status_code}): {post_res.text}") | |
| except Exception as e: | |
| st.error(f"Connection Error: {e}") | |
| with col2: | |
| # Button to delete the profile (only enable if it is currently published) | |
| if st.button("Unpublish Profile", use_container_width=True, disabled=not is_published): | |
| try: | |
| del_res = requests.delete(f"{API_URL}/profile") | |
| if del_res.status_code == 200: | |
| del_data = del_res.json() | |
| if del_data.get("status") == "success": | |
| st.success("Profile has been successfully unpublished.") | |
| st.rerun() # Refresh the page to update the status section | |
| else: | |
| st.error(f"System Error: {del_data.get('message')}") | |
| else: | |
| st.error(f"API Error (HTTP {del_res.status_code}): {del_res.text}") | |
| except Exception as e: | |
| st.error(f"Connection Error: {e}") | |
| def render_local_round_history(): | |
| simple_block_header("Local Round History") | |
| st.markdown("Review the performance metrics and configuration details of completed training rounds.") | |
| # --------------------------------------------------------- | |
| # PART 1: CLEAR HISTORY ACTION | |
| # --------------------------------------------------------- | |
| col1, col2 = st.columns([4, 1]) | |
| with col2: | |
| if st.button("Clear History", type="secondary", use_container_width=True): | |
| try: | |
| res = requests.delete(f"{API_URL}/history").json() | |
| if res.get("status") == "success": | |
| st.success("Training history cleared successfully.") | |
| st.rerun() | |
| else: | |
| st.error(f"System Error: {res.get('message')}") | |
| except Exception as e: | |
| st.error(f"Connection Error: {e}") | |
| # --------------------------------------------------------- | |
| # PART 2: FETCH AND DISPLAY HISTORY | |
| # --------------------------------------------------------- | |
| try: | |
| with st.spinner("Fetching training history from Local Server..."): | |
| res = requests.get(f"{API_URL}/history") | |
| if res.status_code == 200: | |
| res_json = res.json() | |
| data = res_json.get("data", {}) | |
| rounds = data.get("rounds", []) | |
| if rounds: | |
| # Convert the list of dictionary records into a Pandas DataFrame | |
| df = pd.DataFrame(rounds) | |
| # 2A. VISUALIZATIONS (Chart only) | |
| st.subheader("1. Training Trends") | |
| if "round" not in df.columns: | |
| st.warning("Missing 'round' field in history data, cannot render charts.") | |
| else: | |
| metric_specs = [ | |
| ("local_accuracy", "Local Accuracy Progression"), | |
| ("pretrain_loss", "Pretrain Loss"), | |
| ("ssl_labeled_loss", "SSL Labeled Loss"), | |
| ] | |
| available_series = [] | |
| for metric_name, metric_title in metric_specs: | |
| if metric_name not in df.columns: | |
| continue | |
| # Keep only rounds where this metric has a numeric value. | |
| series_df = df[["round", metric_name]].copy() | |
| series_df[metric_name] = pd.to_numeric(series_df[metric_name], errors="coerce") | |
| series_df = series_df.dropna(subset=["round", metric_name]) | |
| # For loss metrics, only show rounds that really contain loss values. | |
| if metric_name in ("pretrain_loss", "ssl_labeled_loss"): | |
| series_df = series_df[series_df[metric_name] > 0] | |
| if series_df.empty: | |
| continue | |
| available_series.append((metric_title, series_df.set_index("round")[[metric_name]])) | |
| if not available_series: | |
| st.info("No valid metric values available to plot yet.") | |
| else: | |
| cols = st.columns(2) | |
| for idx, (metric_title, metric_df) in enumerate(available_series): | |
| with cols[idx % 2]: | |
| st.markdown(f"**{metric_title}**") | |
| st.line_chart(metric_df) | |
| st.divider() | |
| else: | |
| st.info("No training history available. The client has not completed any training rounds yet.") | |
| else: | |
| st.error(f"API Error (HTTP {res.status_code}): {res.text}") | |
| except Exception as e: | |
| st.error(f"Connection Error to Local Server: {e}") | |
| def render_client_local_logs(): | |
| simple_block_header("Client Local Logs") | |
| st.markdown("Monitor system events, background training progress, and Global Server connection status.") | |
| # --------------------------------------------------------- | |
| # PART 1: LOG ACTIONS (REFRESH & CLEAR) | |
| # --------------------------------------------------------- | |
| col1, col2, _ = st.columns([1, 1, 3]) | |
| with col1: | |
| # A simple rerun to fetch the latest data from the server | |
| if st.button("Refresh Logs", type="primary", use_container_width=True): | |
| st.rerun() | |
| with col2: | |
| # Action to delete the log history | |
| if st.button("Clear Logs", use_container_width=True): | |
| try: | |
| res = requests.delete(f"{API_URL}/logs").json() | |
| if res.get("status") == "success": | |
| st.success("Logs cleared successfully.") | |
| st.rerun() | |
| else: | |
| st.error(f"System Error: {res.get('message')}") | |
| except Exception as e: | |
| st.error(f"Connection Error: {e}") | |
| st.divider() | |
| # --------------------------------------------------------- | |
| # PART 2: FETCH AND DISPLAY LOGS | |
| # --------------------------------------------------------- | |
| try: | |
| with st.spinner("Fetching latest logs from Local Server..."): | |
| res = requests.get(f"{API_URL}/logs") | |
| if res.status_code == 200: | |
| res_json = res.json() | |
| data = res_json.get("data", {}) | |
| # Extract the list of log strings | |
| events = data.get("events", []) | |
| if events: | |
| # Combine the list into a single multiline string | |
| log_text = "\n".join(events) | |
| # Use a text area with disabled=True to create a read-only terminal view | |
| st.text_area( | |
| label="Terminal Output", | |
| value=log_text, | |
| height=500, | |
| disabled=True, | |
| label_visibility="collapsed" | |
| ) | |
| else: | |
| st.info("No logs available. The system is currently idle or logs have been cleared.") | |
| else: | |
| st.error(f"API Error (HTTP {res.status_code}): {res.text}") | |
| except Exception as e: | |
| st.error(f"Connection Error to Local Server: {e}") | |
| def render_inference_speech_emotion(): | |
| simple_block_header("FedalMER Inference: Speech Emotion") | |
| st.caption("Input: Audio | Output: Emotion label with confidence (Powered by Federated Server APIs)") | |
| tab_registry, tab_last_session = st.tabs(["Registry Inference", "Last Session Only"]) | |
| # ============================================================== | |
| # TAB 1: REGISTRY INFERENCE | |
| # ============================================================== | |
| with tab_registry: | |
| try: | |
| meta_res = requests.get(f"{API_URL}/inference/metadata").json() | |
| if meta_res.get("status") != "success": | |
| st.error(f"Failed to load metadata: {meta_res.get('message')}") | |
| return | |
| meta_data = meta_res.get("data", {}) | |
| methods = meta_data.get("methods", ["FedAvg", "FedProx"]) | |
| ratios = meta_data.get("label_ratios", ["1"]) | |
| ratios = ["0p5", "1"] | |
| clients_opts = meta_data.get("num_clients_options", [1, 5, 10]) | |
| models = meta_data.get("available_models", ["fedalmer"]) | |
| except Exception as e: | |
| st.error(f"Cannot connect to API to fetch metadata: {e}") | |
| return | |
| if not models: | |
| st.warning("No models available.") | |
| st.stop() | |
| ratio_display_labels = {"0p5": "0.5", "1": "1.0"} | |
| with st.expander("Inference Configuration", expanded=True): | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| selected_method = st.selectbox("FL method", methods, key="reg_method") | |
| with col2: | |
| selected_ratio = st.selectbox("Label ratio", ratios, key="reg_ratio", | |
| format_func=lambda x: ratio_display_labels.get(x, x),) | |
| with col3: | |
| selected_clients = st.selectbox("Number of clients", clients_opts, key="reg_clients") | |
| st.divider() | |
| col4, col5 = st.columns(2) | |
| with col4: | |
| target_model = st.selectbox("Target Model", models, index=0, key="reg_target_model") | |
| with col5: | |
| baseline_candidates = [m for m in models if m != target_model] | |
| selected_baselines = [] | |
| if baseline_candidates: | |
| compare_enabled = st.checkbox("Compare baseline models", value=False, key="reg_compare_enabled") | |
| if compare_enabled: | |
| selected_baselines = st.multiselect( | |
| "Baseline models", | |
| baseline_candidates, | |
| default=baseline_candidates[:min(2, len(baseline_candidates))], | |
| key="reg_baselines", | |
| ) | |
| # st.subheader("Audio Input") | |
| reg_audio_source = st.radio("Audio source", ["Upload file", "Record microphone"], horizontal=True, | |
| key="reg_audio_source") | |
| reg_input_audio = None | |
| if reg_audio_source == "Upload file": | |
| reg_input_audio = st.file_uploader("Upload audio", type=["wav", "flac", "ogg", "mp3", "m4a"], | |
| label_visibility="collapsed", key="reg_uploader") | |
| else: | |
| if hasattr(st, "audio_input"): | |
| reg_input_audio = st.audio_input("Record audio", key="reg_recorder") | |
| else: | |
| st.warning("Your Streamlit version does not support audio recording. Please upload a file instead.") | |
| if reg_input_audio is not None: | |
| st.audio(reg_input_audio.getvalue()) | |
| st.session_state["reg_audio_bytes"] = reg_input_audio.getvalue() | |
| st.session_state["reg_audio_name"] = getattr(reg_input_audio, "name", "recorded_audio.wav") | |
| c1, c2 = st.columns(2) | |
| with c1: | |
| reg_transcribe_btn = st.button("1) Transcribe Audio", use_container_width=True, type="primary", | |
| key="reg_transcribe_btn") | |
| with c2: | |
| reg_predict_btn = st.button("2) Predict Emotion", use_container_width=True, type="primary", | |
| key="reg_predict_btn") | |
| # --- XỬ LÝ TRANSCRIBE TAB 1 --- | |
| if reg_transcribe_btn: | |
| if "reg_audio_bytes" not in st.session_state: | |
| st.error("Please provide an audio file first.") | |
| else: | |
| with st.spinner("Running ASR transcription via API..."): | |
| try: | |
| files = {"audio_file": (st.session_state["reg_audio_name"], st.session_state["reg_audio_bytes"], | |
| "audio/wav")} | |
| res = requests.post(f"{API_URL}/inference/transcribe", files=files).json() | |
| if res.get("status") == "success": | |
| transcribed_text = res.get("data", {}).get("transcript", "") | |
| st.session_state["reg_asr_transcript"] = transcribed_text | |
| # 👉 FIX: ÉP TEXT VÀO KEY CỦA TEXT_AREA | |
| st.session_state["reg_transcript_box"] = transcribed_text | |
| st.rerun() # Bắt buộc rerun | |
| else: | |
| st.error(f"Transcription failed: {res.get('message')}") | |
| except Exception as e: | |
| st.error(f"API Connection Error: {e}") | |
| # 👉 FIX: KHÔNG DÙNG THAM SỐ `value=` Ở ĐÂY NỮA | |
| reg_transcript = st.text_area( | |
| "Edit transcript before prediction", | |
| height=100, | |
| key="reg_transcript_box" | |
| ) | |
| # --- XỬ LÝ PREDICT TAB 1 --- | |
| if reg_predict_btn: | |
| if "reg_audio_bytes" not in st.session_state: | |
| st.error("Please provide an audio file first.") | |
| elif not reg_transcript.strip(): | |
| st.error("Transcript is empty. Please transcribe or type manually first.") | |
| else: | |
| with st.spinner("Running emotion prediction via API..."): | |
| try: | |
| files = {"audio_file": (st.session_state["reg_audio_name"], st.session_state["reg_audio_bytes"], | |
| "audio/wav")} | |
| if selected_baselines: | |
| payload = { | |
| "transcript": reg_transcript, | |
| "method": selected_method, | |
| "ratio": str(selected_ratio), | |
| "clients": selected_clients, | |
| "target_model": target_model, | |
| "baselines": ",".join(selected_baselines), | |
| } | |
| res = requests.post(f"{API_URL}/inference/compare", files=files, data=payload).json() | |
| st.session_state["reg_is_comparison"] = True | |
| else: | |
| payload = { | |
| "transcript": reg_transcript, | |
| "method": selected_method, | |
| "ratio": str(selected_ratio), | |
| "clients": selected_clients, | |
| "model_name": target_model, | |
| } | |
| res = requests.post(f"{API_URL}/inference/predict", files=files, data=payload).json() | |
| st.session_state["reg_is_comparison"] = False | |
| if res.get("status") == "success": | |
| st.session_state["reg_prediction_results"] = res.get("data") | |
| else: | |
| st.error(f"Prediction failed: {res.get('message')}") | |
| except Exception as e: | |
| st.error(f"API Connection Error: {e}") | |
| if "reg_prediction_results" in st.session_state: | |
| st.divider() | |
| st.subheader("Emotion Output") | |
| reg_results = st.session_state["reg_prediction_results"] | |
| if not st.session_state.get("reg_is_comparison", False): | |
| st.metric( | |
| label=f"Predicted Emotion ({reg_results.get('model_name', 'unknown')})", | |
| value=reg_results.get("emotion", "Unknown"), | |
| delta=f"Conf: {reg_results.get('confidence', 0):.4f}", | |
| delta_color="normal", | |
| ) | |
| st.caption( | |
| f"FL Method: {reg_results.get('fl_method')} | Audio Length: {reg_results.get('audio_seconds', 0):.2f}s") | |
| else: | |
| table_data = [] | |
| for item in reg_results: | |
| table_data.append({ | |
| "Model": item.get("model", "Unknown"), | |
| "Emotion": item.get("emotion", "ERROR"), | |
| "Confidence": f"{item.get('confidence', 0):.4f}" if "confidence" in item else "", | |
| "Status": item.get("status", "ok"), | |
| }) | |
| st.dataframe(table_data, use_container_width=True) | |
| # ============================================================== | |
| # TAB 2: LAST SESSION ONLY | |
| # ============================================================== | |
| with tab_last_session: | |
| st.info("This tab only calls /inference/predict_last_session and does not support baseline comparison.") | |
| last_audio_source = st.radio("Audio source", ["Upload file", "Record microphone"], horizontal=True, | |
| key="last_audio_source") | |
| last_input_audio = None | |
| if last_audio_source == "Upload file": | |
| last_input_audio = st.file_uploader("Upload audio", type=["wav", "flac", "ogg", "mp3", "m4a"], | |
| label_visibility="collapsed", key="last_uploader") | |
| else: | |
| if hasattr(st, "audio_input"): | |
| last_input_audio = st.audio_input("Record audio", key="last_recorder") | |
| else: | |
| st.warning("Your Streamlit version does not support audio recording. Please upload a file instead.") | |
| if last_input_audio is not None: | |
| st.audio(last_input_audio.getvalue()) | |
| st.session_state["last_audio_bytes"] = last_input_audio.getvalue() | |
| st.session_state["last_audio_name"] = getattr(last_input_audio, "name", "recorded_audio.wav") | |
| c3, c4 = st.columns(2) | |
| with c3: | |
| last_transcribe_btn = st.button("1) Transcribe Audio", use_container_width=True, type="primary", | |
| key="last_transcribe_btn") | |
| with c4: | |
| last_predict_btn = st.button("2) Predict Last Session", use_container_width=True, type="primary", | |
| key="last_predict_btn") | |
| # --- XỬ LÝ TRANSCRIBE TAB 2 --- | |
| if last_transcribe_btn: | |
| if "last_audio_bytes" not in st.session_state: | |
| st.error("Please provide an audio file first.") | |
| else: | |
| with st.spinner("Running ASR transcription via API..."): | |
| try: | |
| files = { | |
| "audio_file": (st.session_state["last_audio_name"], st.session_state["last_audio_bytes"], | |
| "audio/wav")} | |
| res = requests.post(f"{API_URL}/inference/transcribe", files=files).json() | |
| if res.get("status") == "success": | |
| transcribed_text = res.get("data", {}).get("transcript", "") | |
| st.session_state["last_asr_transcript"] = transcribed_text | |
| # 👉 FIX: ÉP TEXT VÀO KEY CỦA TEXT_AREA | |
| st.session_state["last_transcript_box"] = transcribed_text | |
| st.rerun() # Bắt buộc rerun | |
| else: | |
| st.error(f"Transcription failed: {res.get('message')}") | |
| except Exception as e: | |
| st.error(f"API Connection Error: {e}") | |
| # 👉 FIX: KHÔNG DÙNG THAM SỐ `value=` Ở ĐÂY NỮA | |
| last_transcript = st.text_area( | |
| "Edit transcript before prediction", | |
| height=100, | |
| key="last_transcript_box" | |
| ) | |
| # --- XỬ LÝ PREDICT TAB 2 --- | |
| if last_predict_btn: | |
| if "last_audio_bytes" not in st.session_state: | |
| st.error("Please provide an audio file first.") | |
| elif not last_transcript.strip(): | |
| st.error("Transcript is empty. Please transcribe or type manually first.") | |
| else: | |
| with st.spinner("Running emotion prediction via API..."): | |
| try: | |
| files = { | |
| "audio_file": (st.session_state["last_audio_name"], st.session_state["last_audio_bytes"], | |
| "audio/wav")} | |
| payload = {"transcript": last_transcript} | |
| res = requests.post(f"{API_URL}/inference/predict_last_session", files=files, | |
| data=payload).json() | |
| if res.get("status") == "success": | |
| st.session_state["last_prediction_results"] = res.get("data") | |
| else: | |
| st.error(f"Prediction failed: {res.get('message')}") | |
| except Exception as e: | |
| st.error(f"API Connection Error: {e}") | |
| if "last_prediction_results" in st.session_state: | |
| st.divider() | |
| st.subheader("Last Session Emotion Output") | |
| last_results = st.session_state["last_prediction_results"] | |
| st.metric( | |
| label=f"Predicted Emotion ({last_results.get('model_name', 'last_session')})", | |
| value=last_results.get("emotion", "Unknown"), | |
| delta=f"Conf: {last_results.get('confidence', 0):.4f}", | |
| delta_color="normal", | |
| ) | |
| st.caption( | |
| f"Source: {last_results.get('checkpoint_source', 'final_model/final_global.pt')} | " | |
| f"Audio Length: {last_results.get('audio_seconds', 0):.2f}s" | |
| ) | |
| def render_client_section(): | |
| simple_block_header("Client") | |
| st.markdown("Client identity and quick connection control.") | |
| col1, col2, col3 = st.columns([2, 2, 1]) | |
| with col1: | |
| st.text_input("Client ID", value=client_id_display, disabled=True) | |
| with col2: | |
| st.text_input("Local API", value=API_URL, disabled=True) | |
| with col3: | |
| trigger_disconnect("disconnect_client_section_btn") | |
| def render_fl_combined_section(): | |
| simple_block_header("Federated Learning") | |
| render_local_training_configuration() | |
| st.divider() | |
| # st.subheader("Combined Client") | |
| tab_auto, tab_manual = st.tabs(["FL Auto", "FL Manual"]) | |
| with tab_auto: | |
| render_federated_ssl_client() | |
| with tab_manual: | |
| render_federated_ssl_client_manual() | |
| def render_single_page_layout(): | |
| # simple_block_header("Federated Learning Client") | |
| # render_client_section() | |
| # Thay thế st.title bằng dòng này | |
| st.markdown( | |
| """ | |
| <h1 style='text-align: center; font-size: 80px; font-weight: 800; color: white; margin-bottom: 30px;'> | |
| Federated Learning Client | |
| </h1> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| st.divider() | |
| render_client_data_upload() | |
| st.divider() | |
| render_private_dataset_summary() | |
| st.divider() | |
| render_fl_combined_section() | |
| st.divider() | |
| render_local_round_history() | |
| st.divider() | |
| render_inference_speech_emotion() | |
| render_single_page_layout() | |