client_001_gui / src /streamlit_app.py
TrungNt14's picture
Update src/streamlit_app.py
7033fe3 verified
import json
import os
import pandas as pd
import requests
import streamlit as st
API_URL = os.getenv("API_URL", "http://127.0.0.1:8002/api")
LABEL_DISPLAY_MAP = {
"ang": "anger",
"hap": "happiness/excitement",
"sad": "sadness",
"neu": "neutral",
"exc": "happiness/excitement",
0: "anger",
1: "happiness/excitement",
2: "sadness",
3: "neutral",
"0": "anger",
"1": "happiness/excitement",
"2": "sadness",
"3": "neutral",
}
def _map_label_display(label):
if label in LABEL_DISPLAY_MAP:
return LABEL_DISPLAY_MAP[label]
key = str(label)
return LABEL_DISPLAY_MAP.get(key, key)
# ==========================================
# 1. PAGE SETTINGS
# ==========================================
st.set_page_config(
page_title="Federated Learning Client",
layout="wide"
)
# ==========================================
# 2. SIDEBAR (GLOBAL ACTIONS)
# ==========================================
client_id_display = "Unknown"
try:
res_config = requests.get(f"{API_URL}/info").json()
if res_config.get("status") == "success":
client_id_display = res_config.get("data", {}).get("client_id", "Unknown")
except Exception:
pass
st.sidebar.title("Client")
st.sidebar.success(f"ID: **{client_id_display}**")
st.sidebar.markdown("---")
if st.sidebar.button("Refresh Application", use_container_width=True):
st.rerun()
def simple_block_header(text):
st.markdown(
f"""
<div style="
background: linear-gradient(90deg, #1E293B 0%, #0F172A 100%);
padding: 12px 20px;
border-left: 6px solid #6366F1; /* Màu Indigo của Tab */
border-radius: 4px 15px 15px 4px;
margin-top: 20px;
margin-bottom: 25px;
box-shadow: 5px 5px 15px rgba(0,0,0,0.5);
">
<h3 style="
color: #FFFFFF;
margin: 0;
font-weight: 600;
font-size: 28px;
letter-spacing: 0.5px;
">
{text}
</h3>
</div>
""",
unsafe_allow_html=True
)
st.markdown(
"""
<style>
/* 1. Định dạng chung cho tất cả các nút Tab */
button[data-baseweb="tab"] {
padding: 12px 25px !important;
border-radius: 8px 8px 0px 0px !important;
margin-right: 5px !important;
border: none !important;
transition: background-color 0.3s ease !important;
}
/* 2. Khi Tab ĐƯỢC CHỌN (Màu Indigo rực rỡ) */
button[data-baseweb="tab"][aria-selected="true"] {
background-color: #475569 !important; /* Màu Slate (Xanh đá) rất dịu */
border-bottom: 2px solid #94A3B8 !important; /* Thêm vạch nhạt phía dưới để tạo điểm nhấn */
}
/* Chữ của Tab được chọn */
button[data-baseweb="tab"][aria-selected="true"] p {
font-size: 22px !important;
font-weight: 700 !important;
color: #FFFFFF !important;
}
/* 3. Khi Tab CHƯA chọn (Màu trùng với nền Simple Block Header) */
button[data-baseweb="tab"][aria-selected="false"] {
background-color: #1E1E1E !important; /* Cùng màu với cái header của bạn */
color: #666666 !important;
}
/* Chữ của Tab chưa chọn (cho mờ đi để nổi bật tab chính) */
button[data-baseweb="tab"][aria-selected="false"] p {
font-size: 20px !important;
font-weight: 500 !important;
color: #777777 !important;
}
/* Hiệu ứng khi di chuột qua */
button[data-baseweb="tab"][aria-selected="false"]:hover {
background-color: #2D2E3A !important;
color: #FFFFFF !important;
}
/* 4. Dọn dẹp các thành phần thừa của Streamlit */
div[data-baseweb="tab-highlight"] {
background-color: transparent !important;
}
/* Đường line mờ chạy ngang dưới toàn bộ thanh tab */
div[role="tablist"] {
border-bottom: 2px solid #1E1E1E !important;
}
</style>
""",
unsafe_allow_html=True
)
def render_status_timeline(key_prefix: str = "status"):
"""Shared helper to render the status timeline box with Clear functionality."""
st.subheader("Execution Timeline")
col_title, col_refresh, col_clear = st.columns([2, 1, 1])
with col_refresh:
if st.button("Refresh", use_container_width=True, key=f"{key_prefix}_refresh_timeline"):
st.rerun()
with col_clear:
if st.button(
"Clear",
use_container_width=True,
help="Clear the entire execution timeline",
key=f"{key_prefix}_clear_timeline",
):
try:
res = requests.delete(f"{API_URL}/history_status").json()
if res.get("status") == "success":
st.toast("Execution history cleared successfully.")
st.rerun()
else:
st.error("Could not clear history.")
except Exception as e:
st.error(f"Clear action failed: {e}")
try:
# Load and display status timeline data.
res = requests.get(f"{API_URL}/history_status").json()
if res.get("status") == "success":
history = res.get("data", [])
if not history:
st.info("No events yet.")
else:
with st.container(height=250):
for item in history:
status_tag = "[TRAINING]" if item["is_training"] else "[DONE]"
if "error" in item.get("message", "").lower():
status_tag = "[ERROR]"
round_idx = item.get("round_idx")
round_prefix = f"**Round {round_idx}**: " if round_idx is not None else ""
st.markdown(
f"`[{item['timestamp']}]` {status_tag} {round_prefix}{item['message']}")
except Exception as e:
st.error(f"Connection error while fetching timeline: {e}")
def trigger_disconnect(button_key: str = "disconnect_btn"):
"""Call local API to disconnect from the global server."""
if st.button("Disconnect from Global Server", use_container_width=True, key=button_key):
with st.spinner("Disconnecting..."):
try:
res = requests.post(f"{API_URL}/disconnect")
data = res.json()
if data.get("status") == "success":
st.success(data.get("message", "Disconnected successfully."))
st.rerun()
else:
st.error(data.get("message", "Disconnect request failed."))
except Exception as e:
st.error(f"Failed to disconnect. Error: {e}")
def _parse_int_field(name: str, raw: str, min_value: int = None, max_value: int = None) -> int:
try:
value = int(str(raw).strip())
except Exception:
raise ValueError(f"{name} must be an integer.")
if min_value is not None and value < min_value:
raise ValueError(f"{name} must be >= {min_value}.")
if max_value is not None and value > max_value:
raise ValueError(f"{name} must be <= {max_value}.")
return value
def _parse_float_field(name: str, raw: str, min_value: float = None, max_value: float = None) -> float:
try:
value = float(str(raw).strip())
except Exception:
raise ValueError(f"{name} must be a number.")
if min_value is not None and value < min_value:
raise ValueError(f"{name} must be >= {min_value}.")
if max_value is not None and value > max_value:
raise ValueError(f"{name} must be <= {max_value}.")
return value
def render_local_training_configuration():
st.subheader("Local Training Configuration")
current_config = {}
try:
config_res = requests.get(f"{API_URL}/config").json()
if config_res.get("status") == "success":
current_config = config_res.get("data", {})
except Exception:
st.warning("Failed to fetch the current configuration from the Local Client.")
if not current_config:
return
with st.form("config_form_shared"):
c1, c2 = st.columns(2)
with c1:
sample_limit_raw = st.text_input("Sample Limit", value=str(current_config.get("sample_limit", 150)), key="ui_sample_limit")
labeled_ratio_raw = st.text_input("Labeled Ratio", value=str(float(current_config.get("labeled_ratio", 0.5))), key="ui_labeled_ratio")
with c2:
batch_size_raw = st.text_input("Batch Size", value=str(current_config.get("batch_size", 16)), key="ui_batch_size")
seed_raw = st.text_input("Seed", value=str(current_config.get("seed", 42)), key="ui_seed")
# Additional epoch settings shown and editable on the same form
e1, e2 = st.columns(2)
with e1:
pretrain_epochs_raw = st.text_input(
"Local Pretrain Epochs",
value=str(int(current_config.get("local_pretrain_epochs", 1))),
key="ui_local_pretrain_epochs",
)
with e2:
ssl_epochs_raw = st.text_input(
"Local SSL Epochs",
value=str(int(current_config.get("local_ssl_epochs", 1))),
key="ui_local_ssl_epochs",
)
submitted = st.form_submit_button("Save Configuration", type="primary", use_container_width=True)
if submitted:
try:
sample_limit = _parse_int_field("Sample Limit", sample_limit_raw, min_value=20, max_value=200000)
labeled_ratio = _parse_float_field("Labeled Ratio", labeled_ratio_raw, min_value=0.05, max_value=0.95)
batch_size = _parse_int_field("Batch Size", batch_size_raw, min_value=2, max_value=1024)
seed = _parse_int_field("Seed", seed_raw, min_value=0, max_value=100000)
# Allow direct numeric text input for epochs and validate
local_pretrain_epochs = _parse_int_field("Local Pretrain Epochs", pretrain_epochs_raw, min_value=1, max_value=100)
local_ssl_epochs = _parse_int_field("Local SSL Epochs", ssl_epochs_raw, min_value=1, max_value=100)
except ValueError as ve:
st.error(str(ve))
return
payload = {
"sample_limit": sample_limit,
"labeled_ratio": labeled_ratio,
"batch_size": batch_size,
"seed": seed,
"model_name": current_config.get("model_name", "fedalmer"),
"lr": float(current_config.get("lr", 1e-3)),
"local_pretrain_epochs": int(local_pretrain_epochs),
"local_ssl_epochs": int(local_ssl_epochs),
}
try:
put_res = requests.put(f"{API_URL}/config", json=payload).json()
if put_res.get("status") == "success":
st.success("Configuration updated successfully!")
st.rerun()
else:
st.error(f"Update error: {put_res.get('message')}")
except Exception as e:
st.error(f"API call error: {e}")
def render_federated_ssl_client():
st.subheader("Federated Client (Auto Mode)")
st.markdown("Manage training configuration and connection to the Global Server.")
# --- 1. CONNECT TO GLOBAL SERVER ---
st.subheader("Connectivity")
# Show guidance text above the action button.
st.info(
"Click the button to activate the Background Task. The client will continuously listen for 'NEW_ROUND' commands from the Global Server."
)
# Keep the connect/disconnect actions close to each other.
col_connect, col_disconnect = st.columns(2)
with col_connect:
if st.button("Connect to Server", use_container_width=True, type="primary", key="connect_auto_btn"):
with st.spinner("Connecting..."):
try:
res = requests.get(f"{API_URL}/connect_to_global")
data = res.json()
if data.get("status") == "success":
st.success(data.get("message"))
else:
st.error(data.get("message"))
except Exception as e:
st.error(f"Failed to connect to API. Error: {e}")
with col_disconnect:
trigger_disconnect("disconnect_auto_btn")
render_status_timeline("auto")
def render_federated_ssl_client_manual():
st.subheader("Federated Client (Manual Mode)")
st.markdown("Manual training mode: receive commands from Admin and submit model updates manually.")
# --- 1. CONNECT TO GLOBAL SERVER (MANUAL) ---
st.subheader("Connectivity")
st.info("Enable manual listening mode. The client will wait for 'NEW_ROUND_MANUAL' commands from Admin Dashboard.")
col_connect_manual, col_disconnect_manual = st.columns(2)
with col_connect_manual:
if st.button("Connect to Server", use_container_width=True, type="primary", key="connect_manual_btn"):
with st.spinner("Connecting to manual stream..."):
try:
res = requests.get(f"{API_URL}/connect_to_global_manual")
data = res.json()
if data.get("status") == "success":
st.success(data.get("message"))
else:
st.error(data.get("message"))
except Exception as e:
st.error(f"Connection error: {e}")
with col_disconnect_manual:
trigger_disconnect("disconnect_manual_btn")
st.divider()
# --- SHOW TIMELINE ---
render_status_timeline("manual")
st.divider()
# --- 2. MANUAL SUBMISSION ---
st.subheader("Model Submission")
# 2.1 Check current status to decide whether submission is allowed.
can_submit = False
current_round = 0
try:
status_res = requests.get(f"{API_URL}/last_status").json()
if status_res.get("status") == "success":
status_data = status_res.get("data")
if status_data:
# Check service status text to detect when manual submit is expected.
if "Waiting for Manual Submit" in status_data.get("message", ""):
can_submit = True
current_round = status_data.get("round_idx") or 0
except:
pass
# 2.2 Render state-specific UI.
if can_submit:
st.success(f"Round {current_round} training is complete. Client is ready to submit to Global Server.")
if st.button(f"Confirm Submission for Round {current_round}", type="primary", use_container_width=True):
with st.spinner("Compressing and uploading model to Global Server..."):
try:
res = requests.post(f"{API_URL}/manual_submit")
data = res.json()
if data.get("status") == "success":
st.success(data.get("message"))
st.rerun() # Refresh so submit button hides automatically.
else:
st.error(data.get("message"))
except Exception as e:
st.error(f"Submission error: {e}")
else:
st.warning(
"No new trained model is ready for submission yet. Please wait for the next server training command.")
# Show disabled submit button to communicate available action.
st.button("Submit Trained Model to Global", disabled=True, use_container_width=True)
def render_client_data_upload():
simple_block_header("Client Data Upload")
# st.header("Client Data Upload")
st.markdown("Manage and extend local training data for this client.")
dataset_locked = False
try:
lock_res = requests.get(f"{API_URL}/last_status").json()
if lock_res.get("status") == "success":
dataset_locked = bool((lock_res.get("data") or {}).get("is_training", False))
except Exception:
dataset_locked = False
if dataset_locked:
st.warning("Dataset is locked while local training is running. Upload/Add/Delete is temporarily disabled.")
# Khởi tạo state để quản lý dữ liệu của NHIỀU file cùng lúc
if "samples_state" not in st.session_state:
st.session_state.samples_state = {}
# Use tabs to organize upload workflows.
tab_bulk, tab_single = st.tabs([
"Upload Batch (.pkl)",
"Add Samples",
])
# ==========================================
# TAB 1: BATCH UPLOAD (Giữ nguyên)
# ==========================================
with tab_bulk:
st.subheader("Upload Batch Data (.pkl)")
split_name = st.radio(
"Select split:", ["labeled", "unlabeled"], horizontal=True, key="bulk_split"
)
uploaded_file = st.file_uploader(f"Upload {split_name} file (.pkl)", type=["pkl"])
if st.button("Upload .pkl", type="primary", disabled=dataset_locked):
if uploaded_file:
with st.spinner("Processing..."):
files = {"file": (uploaded_file.name, uploaded_file.getvalue(), "application/octet-stream")}
res = requests.put(f"{API_URL}/data/{split_name}", files=files)
if res.status_code == 200 and res.json().get("status") == "success":
st.success(res.json().get("message"))
else:
st.error(res.json().get("message", "Upload failed."))
else:
st.warning("Please select a file.")
# ==========================================
# TAB 2: ADD SAMPLES (Multi-file Upload)
# ==========================================
with tab_single:
st.subheader("Add Samples (WAV / Record)")
audio_source = st.radio("Choose audio source:", ["Upload .wav (Multiple)", "Record Audio"], horizontal=True)
# Hỗ trợ lấy 1 hoặc nhiều file đưa vào chung 1 list raw_audios
raw_audios = []
if audio_source == "Upload .wav (Multiple)":
raw_audios = st.file_uploader("Upload .wav files", type=["wav"], accept_multiple_files=True)
else:
if hasattr(st, "audio_input"):
recorded = st.audio_input("Record your voice")
if recorded is not None:
raw_audios = [recorded]
else:
st.warning("Your Streamlit version does not support audio recording.")
if raw_audios:
st.divider()
st.markdown(f"**Processing {len(raw_audios)} sample(s):**")
for i, audio_file in enumerate(raw_audios):
# Lấy tên file hoặc tạo tên mặc định nếu ghi âm
filename = getattr(audio_file, "name", f"recorded_audio_{i}.wav")
file_key = f"{i}_{filename}" # Tạo key duy nhất cho mỗi file
# Khởi tạo state riêng cho file này nếu chưa có
if file_key not in st.session_state.samples_state:
st.session_state.samples_state[file_key] = {
"transcript": "",
"saved": False
}
state = st.session_state.samples_state[file_key]
# Trạng thái icon: Nếu đã lưu thì hiện tick xanh
status_icon = "(Saved)" if state["saved"] else "(Pending)"
# Bọc mỗi file trong 1 expander (đóng lại nếu đã save xong để gọn màn hình)
with st.expander(f"Audio: {filename} - {status_icon}", expanded=not state["saved"]):
if state["saved"]:
st.success(f"Sample '{filename}' has been successfully added to the dataset!")
continue # Bỏ qua không render nút bấm nữa nếu đã lưu
# --- UI cho file chưa lưu ---
st.audio(audio_file.getvalue(), format="audio/wav")
# 1. Nút Transcribe
if st.button("Transcribe Audio", key=f"transcribe_{file_key}"):
with st.spinner("Running speech recognition..."):
files = {"audio_file": (filename, audio_file.getvalue(), "audio/wav")}
res = requests.post(f"{API_URL}/inference/transcribe", files=files)
if res.status_code == 200 and res.json().get("status") == "success":
data = res.json().get("data", {})
# LẤY TEXT TỪ API VÀ ÉP VÀO WIDGET KEY ĐỂ HIỂN THỊ LUÔN
transcript_result = data.get("transcript", "")
st.session_state.samples_state[file_key]["transcript"] = transcript_result
st.session_state[f"text_{file_key}"] = transcript_result # <--- FIX Ở ĐÂY
st.rerun() # Refresh giao diện để hiện chữ
else:
st.error("Transcription API call failed.")
# 2. Text Area (Editable)
edited_text = st.text_area(
"Transcript (editable):",
value=state["transcript"],
key=f"text_{file_key}"
)
# 3. Labeling
col1, col2 = st.columns(2)
with col1:
target_split = st.selectbox(
"Add to Split:", ["labeled", "unlabeled"],
key=f"split_{file_key}"
)
with col2:
emotion_label = st.selectbox(
"Emotion Label:",
["ang", "hap", "sad", "neu"],
disabled=(target_split == "unlabeled"),
key=f"label_{file_key}"
)
# 4. Save Button
if st.button(
"Save Sample",
type="primary",
key=f"save_{file_key}",
use_container_width=True,
disabled=dataset_locked,
):
if not edited_text.strip():
st.warning("Transcript cannot be empty.")
else:
with st.spinner(f"Extracting features & saving {filename}..."):
files = {"audio_file": (filename, audio_file.getvalue(), "audio/wav")}
data_payload = {
"text": edited_text,
"label": emotion_label if target_split == "labeled" else ""
}
res = requests.post(f"{API_URL}/data/{target_split}/append", files=files,
data=data_payload)
if res.status_code == 200 and res.json().get("status") == "success":
# Đánh dấu đã lưu thành công
st.session_state.samples_state[file_key]["saved"] = True
st.session_state.samples_state[file_key]["transcript"] = edited_text
st.rerun()
else:
st.error(f"Save failed: {res.json().get('message')}")
else:
st.info("Please select or record audio files to begin.")
st.divider()
# st.subheader("Manage Existing Data")
col1, col2 = st.columns(2)
with col1:
if st.button("Delete Labeled", use_container_width=True, disabled=dataset_locked):
res = requests.delete(f"{API_URL}/data/labeled").json()
if res.get("status") == "success":
st.success("Deleted labeled data.")
else:
st.error(res.get("message", "Delete labeled data failed."))
with col2:
if st.button("Delete Unlabeled", use_container_width=True, disabled=dataset_locked):
res = requests.delete(f"{API_URL}/data/unlabeled").json()
if res.get("status") == "success":
st.success("Deleted unlabeled data.")
else:
st.error(res.get("message", "Delete unlabeled data failed."))
def render_private_dataset_summary():
simple_block_header("Private Dataset Summary")
st.markdown("Overview of your local dataset, including data distribution and sample counts.")
try:
with st.spinner("Fetching dataset summary from Local Server..."):
res = requests.get(f"{API_URL}/profile")
if res.status_code == 200:
res_json = res.json()
data = res_json.get("data")
# If data exists, the profile has been built and we can show the charts
if data:
st.subheader("Sample Overview")
col1, col2, col3 = st.columns(3)
labeled_count = data.get("labeled_count", 0)
unlabeled_count = data.get("unlabeled_count", 0)
total_count = labeled_count + unlabeled_count
col1.metric("Labeled Samples", f"{labeled_count:,}")
col2.metric("Unlabeled Samples", f"{unlabeled_count:,}")
col3.metric("Total Samples", f"{total_count:,}")
st.divider()
# Display Bar Chart for Class Distribution
st.subheader("Labeled Data Distribution")
st.markdown("Histogram of classes available in the labeled dataset.")
class_hist = data.get("class_hist", {})
if class_hist:
mapped_hist = { _map_label_display(k): v for k, v in class_hist.items() }
# Convert dictionary to Pandas DataFrame for Streamlit bar chart
df = pd.DataFrame.from_dict(mapped_hist, orient='index', columns=['Sample Count'])
st.bar_chart(df)
else:
st.info("No class distribution data available. The dataset might be empty.")
# Expandable section for raw JSON data
with st.expander("View Full Profile Metadata"):
st.json(data)
else:
# Profile not built yet
st.warning("Dataset profile has not been built yet.")
st.info(
"Please go to the 'Publish Client Profile To Server' tab to generate and publish your data profile first.")
else:
st.error(f"API Error (HTTP {res.status_code}): {res.text}")
except Exception as e:
st.error(f"Connection Error to Local Server: {e}")
def render_publish_client_profile():
simple_block_header("Publish Client Profile To Server")
st.markdown(
"Calculate local dataset characteristics and publish the profile. The Global Server requires this information to assign appropriate training tasks and aggregate weights properly.")
# ---------------------------------------------------------
# PART 1: CURRENT STATUS
# ---------------------------------------------------------
st.subheader("1. Current Profile Status")
is_published = False
try:
res = requests.get(f"{API_URL}/profile")
if res.status_code == 200:
res_json = res.json()
data = res_json.get("data")
if data:
is_published = True
st.success("A profile is currently published and visible to the Global Server.")
st.write(f"**Last Published At:** {data.get('published_at', 'Unknown time')}")
else:
st.info(
"No profile is currently published. The Global Server cannot include this client in training rounds.")
else:
st.error(f"Failed to fetch profile status (HTTP {res.status_code}): {res.text}")
except Exception as e:
st.error(f"Connection Error to Local Server: {e}")
st.divider()
# ---------------------------------------------------------
# PART 2: ACTIONS
# ---------------------------------------------------------
st.subheader("2. Profile Actions")
col1, col2 = st.columns(2)
with col1:
# Button to compute and publish the profile
if st.button("Generate and Publish Profile", type="primary", use_container_width=True):
with st.spinner("Analyzing local dataset and building profile..."):
try:
post_res = requests.post(f"{API_URL}/profile")
if post_res.status_code == 200:
post_data = post_res.json()
if post_data.get("status") == "success":
st.success(post_data.get("message"))
st.rerun() # Refresh the page to update the status section
else:
st.error(f"System Error: {post_data.get('message')}")
else:
st.error(f"API Error (HTTP {post_res.status_code}): {post_res.text}")
except Exception as e:
st.error(f"Connection Error: {e}")
with col2:
# Button to delete the profile (only enable if it is currently published)
if st.button("Unpublish Profile", use_container_width=True, disabled=not is_published):
try:
del_res = requests.delete(f"{API_URL}/profile")
if del_res.status_code == 200:
del_data = del_res.json()
if del_data.get("status") == "success":
st.success("Profile has been successfully unpublished.")
st.rerun() # Refresh the page to update the status section
else:
st.error(f"System Error: {del_data.get('message')}")
else:
st.error(f"API Error (HTTP {del_res.status_code}): {del_res.text}")
except Exception as e:
st.error(f"Connection Error: {e}")
def render_local_round_history():
simple_block_header("Local Round History")
st.markdown("Review the performance metrics and configuration details of completed training rounds.")
# ---------------------------------------------------------
# PART 1: CLEAR HISTORY ACTION
# ---------------------------------------------------------
col1, col2 = st.columns([4, 1])
with col2:
if st.button("Clear History", type="secondary", use_container_width=True):
try:
res = requests.delete(f"{API_URL}/history").json()
if res.get("status") == "success":
st.success("Training history cleared successfully.")
st.rerun()
else:
st.error(f"System Error: {res.get('message')}")
except Exception as e:
st.error(f"Connection Error: {e}")
# ---------------------------------------------------------
# PART 2: FETCH AND DISPLAY HISTORY
# ---------------------------------------------------------
try:
with st.spinner("Fetching training history from Local Server..."):
res = requests.get(f"{API_URL}/history")
if res.status_code == 200:
res_json = res.json()
data = res_json.get("data", {})
rounds = data.get("rounds", [])
if rounds:
# Convert the list of dictionary records into a Pandas DataFrame
df = pd.DataFrame(rounds)
# 2A. VISUALIZATIONS (Chart only)
st.subheader("1. Training Trends")
if "round" not in df.columns:
st.warning("Missing 'round' field in history data, cannot render charts.")
else:
metric_specs = [
("local_accuracy", "Local Accuracy Progression"),
("pretrain_loss", "Pretrain Loss"),
("ssl_labeled_loss", "SSL Labeled Loss"),
]
available_series = []
for metric_name, metric_title in metric_specs:
if metric_name not in df.columns:
continue
# Keep only rounds where this metric has a numeric value.
series_df = df[["round", metric_name]].copy()
series_df[metric_name] = pd.to_numeric(series_df[metric_name], errors="coerce")
series_df = series_df.dropna(subset=["round", metric_name])
# For loss metrics, only show rounds that really contain loss values.
if metric_name in ("pretrain_loss", "ssl_labeled_loss"):
series_df = series_df[series_df[metric_name] > 0]
if series_df.empty:
continue
available_series.append((metric_title, series_df.set_index("round")[[metric_name]]))
if not available_series:
st.info("No valid metric values available to plot yet.")
else:
cols = st.columns(2)
for idx, (metric_title, metric_df) in enumerate(available_series):
with cols[idx % 2]:
st.markdown(f"**{metric_title}**")
st.line_chart(metric_df)
st.divider()
else:
st.info("No training history available. The client has not completed any training rounds yet.")
else:
st.error(f"API Error (HTTP {res.status_code}): {res.text}")
except Exception as e:
st.error(f"Connection Error to Local Server: {e}")
def render_client_local_logs():
simple_block_header("Client Local Logs")
st.markdown("Monitor system events, background training progress, and Global Server connection status.")
# ---------------------------------------------------------
# PART 1: LOG ACTIONS (REFRESH & CLEAR)
# ---------------------------------------------------------
col1, col2, _ = st.columns([1, 1, 3])
with col1:
# A simple rerun to fetch the latest data from the server
if st.button("Refresh Logs", type="primary", use_container_width=True):
st.rerun()
with col2:
# Action to delete the log history
if st.button("Clear Logs", use_container_width=True):
try:
res = requests.delete(f"{API_URL}/logs").json()
if res.get("status") == "success":
st.success("Logs cleared successfully.")
st.rerun()
else:
st.error(f"System Error: {res.get('message')}")
except Exception as e:
st.error(f"Connection Error: {e}")
st.divider()
# ---------------------------------------------------------
# PART 2: FETCH AND DISPLAY LOGS
# ---------------------------------------------------------
try:
with st.spinner("Fetching latest logs from Local Server..."):
res = requests.get(f"{API_URL}/logs")
if res.status_code == 200:
res_json = res.json()
data = res_json.get("data", {})
# Extract the list of log strings
events = data.get("events", [])
if events:
# Combine the list into a single multiline string
log_text = "\n".join(events)
# Use a text area with disabled=True to create a read-only terminal view
st.text_area(
label="Terminal Output",
value=log_text,
height=500,
disabled=True,
label_visibility="collapsed"
)
else:
st.info("No logs available. The system is currently idle or logs have been cleared.")
else:
st.error(f"API Error (HTTP {res.status_code}): {res.text}")
except Exception as e:
st.error(f"Connection Error to Local Server: {e}")
def render_inference_speech_emotion():
simple_block_header("FedalMER Inference: Speech Emotion")
st.caption("Input: Audio | Output: Emotion label with confidence (Powered by Federated Server APIs)")
tab_registry, tab_last_session = st.tabs(["Registry Inference", "Last Session Only"])
# ==============================================================
# TAB 1: REGISTRY INFERENCE
# ==============================================================
with tab_registry:
try:
meta_res = requests.get(f"{API_URL}/inference/metadata").json()
if meta_res.get("status") != "success":
st.error(f"Failed to load metadata: {meta_res.get('message')}")
return
meta_data = meta_res.get("data", {})
methods = meta_data.get("methods", ["FedAvg", "FedProx"])
ratios = meta_data.get("label_ratios", ["1"])
ratios = ["0p5", "1"]
clients_opts = meta_data.get("num_clients_options", [1, 5, 10])
models = meta_data.get("available_models", ["fedalmer"])
except Exception as e:
st.error(f"Cannot connect to API to fetch metadata: {e}")
return
if not models:
st.warning("No models available.")
st.stop()
ratio_display_labels = {"0p5": "0.5", "1": "1.0"}
with st.expander("Inference Configuration", expanded=True):
col1, col2, col3 = st.columns(3)
with col1:
selected_method = st.selectbox("FL method", methods, key="reg_method")
with col2:
selected_ratio = st.selectbox("Label ratio", ratios, key="reg_ratio",
format_func=lambda x: ratio_display_labels.get(x, x),)
with col3:
selected_clients = st.selectbox("Number of clients", clients_opts, key="reg_clients")
st.divider()
col4, col5 = st.columns(2)
with col4:
target_model = st.selectbox("Target Model", models, index=0, key="reg_target_model")
with col5:
baseline_candidates = [m for m in models if m != target_model]
selected_baselines = []
if baseline_candidates:
compare_enabled = st.checkbox("Compare baseline models", value=False, key="reg_compare_enabled")
if compare_enabled:
selected_baselines = st.multiselect(
"Baseline models",
baseline_candidates,
default=baseline_candidates[:min(2, len(baseline_candidates))],
key="reg_baselines",
)
# st.subheader("Audio Input")
reg_audio_source = st.radio("Audio source", ["Upload file", "Record microphone"], horizontal=True,
key="reg_audio_source")
reg_input_audio = None
if reg_audio_source == "Upload file":
reg_input_audio = st.file_uploader("Upload audio", type=["wav", "flac", "ogg", "mp3", "m4a"],
label_visibility="collapsed", key="reg_uploader")
else:
if hasattr(st, "audio_input"):
reg_input_audio = st.audio_input("Record audio", key="reg_recorder")
else:
st.warning("Your Streamlit version does not support audio recording. Please upload a file instead.")
if reg_input_audio is not None:
st.audio(reg_input_audio.getvalue())
st.session_state["reg_audio_bytes"] = reg_input_audio.getvalue()
st.session_state["reg_audio_name"] = getattr(reg_input_audio, "name", "recorded_audio.wav")
c1, c2 = st.columns(2)
with c1:
reg_transcribe_btn = st.button("1) Transcribe Audio", use_container_width=True, type="primary",
key="reg_transcribe_btn")
with c2:
reg_predict_btn = st.button("2) Predict Emotion", use_container_width=True, type="primary",
key="reg_predict_btn")
# --- XỬ LÝ TRANSCRIBE TAB 1 ---
if reg_transcribe_btn:
if "reg_audio_bytes" not in st.session_state:
st.error("Please provide an audio file first.")
else:
with st.spinner("Running ASR transcription via API..."):
try:
files = {"audio_file": (st.session_state["reg_audio_name"], st.session_state["reg_audio_bytes"],
"audio/wav")}
res = requests.post(f"{API_URL}/inference/transcribe", files=files).json()
if res.get("status") == "success":
transcribed_text = res.get("data", {}).get("transcript", "")
st.session_state["reg_asr_transcript"] = transcribed_text
# 👉 FIX: ÉP TEXT VÀO KEY CỦA TEXT_AREA
st.session_state["reg_transcript_box"] = transcribed_text
st.rerun() # Bắt buộc rerun
else:
st.error(f"Transcription failed: {res.get('message')}")
except Exception as e:
st.error(f"API Connection Error: {e}")
# 👉 FIX: KHÔNG DÙNG THAM SỐ `value=` Ở ĐÂY NỮA
reg_transcript = st.text_area(
"Edit transcript before prediction",
height=100,
key="reg_transcript_box"
)
# --- XỬ LÝ PREDICT TAB 1 ---
if reg_predict_btn:
if "reg_audio_bytes" not in st.session_state:
st.error("Please provide an audio file first.")
elif not reg_transcript.strip():
st.error("Transcript is empty. Please transcribe or type manually first.")
else:
with st.spinner("Running emotion prediction via API..."):
try:
files = {"audio_file": (st.session_state["reg_audio_name"], st.session_state["reg_audio_bytes"],
"audio/wav")}
if selected_baselines:
payload = {
"transcript": reg_transcript,
"method": selected_method,
"ratio": str(selected_ratio),
"clients": selected_clients,
"target_model": target_model,
"baselines": ",".join(selected_baselines),
}
res = requests.post(f"{API_URL}/inference/compare", files=files, data=payload).json()
st.session_state["reg_is_comparison"] = True
else:
payload = {
"transcript": reg_transcript,
"method": selected_method,
"ratio": str(selected_ratio),
"clients": selected_clients,
"model_name": target_model,
}
res = requests.post(f"{API_URL}/inference/predict", files=files, data=payload).json()
st.session_state["reg_is_comparison"] = False
if res.get("status") == "success":
st.session_state["reg_prediction_results"] = res.get("data")
else:
st.error(f"Prediction failed: {res.get('message')}")
except Exception as e:
st.error(f"API Connection Error: {e}")
if "reg_prediction_results" in st.session_state:
st.divider()
st.subheader("Emotion Output")
reg_results = st.session_state["reg_prediction_results"]
if not st.session_state.get("reg_is_comparison", False):
st.metric(
label=f"Predicted Emotion ({reg_results.get('model_name', 'unknown')})",
value=reg_results.get("emotion", "Unknown"),
delta=f"Conf: {reg_results.get('confidence', 0):.4f}",
delta_color="normal",
)
st.caption(
f"FL Method: {reg_results.get('fl_method')} | Audio Length: {reg_results.get('audio_seconds', 0):.2f}s")
else:
table_data = []
for item in reg_results:
table_data.append({
"Model": item.get("model", "Unknown"),
"Emotion": item.get("emotion", "ERROR"),
"Confidence": f"{item.get('confidence', 0):.4f}" if "confidence" in item else "",
"Status": item.get("status", "ok"),
})
st.dataframe(table_data, use_container_width=True)
# ==============================================================
# TAB 2: LAST SESSION ONLY
# ==============================================================
with tab_last_session:
st.info("This tab only calls /inference/predict_last_session and does not support baseline comparison.")
last_audio_source = st.radio("Audio source", ["Upload file", "Record microphone"], horizontal=True,
key="last_audio_source")
last_input_audio = None
if last_audio_source == "Upload file":
last_input_audio = st.file_uploader("Upload audio", type=["wav", "flac", "ogg", "mp3", "m4a"],
label_visibility="collapsed", key="last_uploader")
else:
if hasattr(st, "audio_input"):
last_input_audio = st.audio_input("Record audio", key="last_recorder")
else:
st.warning("Your Streamlit version does not support audio recording. Please upload a file instead.")
if last_input_audio is not None:
st.audio(last_input_audio.getvalue())
st.session_state["last_audio_bytes"] = last_input_audio.getvalue()
st.session_state["last_audio_name"] = getattr(last_input_audio, "name", "recorded_audio.wav")
c3, c4 = st.columns(2)
with c3:
last_transcribe_btn = st.button("1) Transcribe Audio", use_container_width=True, type="primary",
key="last_transcribe_btn")
with c4:
last_predict_btn = st.button("2) Predict Last Session", use_container_width=True, type="primary",
key="last_predict_btn")
# --- XỬ LÝ TRANSCRIBE TAB 2 ---
if last_transcribe_btn:
if "last_audio_bytes" not in st.session_state:
st.error("Please provide an audio file first.")
else:
with st.spinner("Running ASR transcription via API..."):
try:
files = {
"audio_file": (st.session_state["last_audio_name"], st.session_state["last_audio_bytes"],
"audio/wav")}
res = requests.post(f"{API_URL}/inference/transcribe", files=files).json()
if res.get("status") == "success":
transcribed_text = res.get("data", {}).get("transcript", "")
st.session_state["last_asr_transcript"] = transcribed_text
# 👉 FIX: ÉP TEXT VÀO KEY CỦA TEXT_AREA
st.session_state["last_transcript_box"] = transcribed_text
st.rerun() # Bắt buộc rerun
else:
st.error(f"Transcription failed: {res.get('message')}")
except Exception as e:
st.error(f"API Connection Error: {e}")
# 👉 FIX: KHÔNG DÙNG THAM SỐ `value=` Ở ĐÂY NỮA
last_transcript = st.text_area(
"Edit transcript before prediction",
height=100,
key="last_transcript_box"
)
# --- XỬ LÝ PREDICT TAB 2 ---
if last_predict_btn:
if "last_audio_bytes" not in st.session_state:
st.error("Please provide an audio file first.")
elif not last_transcript.strip():
st.error("Transcript is empty. Please transcribe or type manually first.")
else:
with st.spinner("Running emotion prediction via API..."):
try:
files = {
"audio_file": (st.session_state["last_audio_name"], st.session_state["last_audio_bytes"],
"audio/wav")}
payload = {"transcript": last_transcript}
res = requests.post(f"{API_URL}/inference/predict_last_session", files=files,
data=payload).json()
if res.get("status") == "success":
st.session_state["last_prediction_results"] = res.get("data")
else:
st.error(f"Prediction failed: {res.get('message')}")
except Exception as e:
st.error(f"API Connection Error: {e}")
if "last_prediction_results" in st.session_state:
st.divider()
st.subheader("Last Session Emotion Output")
last_results = st.session_state["last_prediction_results"]
st.metric(
label=f"Predicted Emotion ({last_results.get('model_name', 'last_session')})",
value=last_results.get("emotion", "Unknown"),
delta=f"Conf: {last_results.get('confidence', 0):.4f}",
delta_color="normal",
)
st.caption(
f"Source: {last_results.get('checkpoint_source', 'final_model/final_global.pt')} | "
f"Audio Length: {last_results.get('audio_seconds', 0):.2f}s"
)
def render_client_section():
simple_block_header("Client")
st.markdown("Client identity and quick connection control.")
col1, col2, col3 = st.columns([2, 2, 1])
with col1:
st.text_input("Client ID", value=client_id_display, disabled=True)
with col2:
st.text_input("Local API", value=API_URL, disabled=True)
with col3:
trigger_disconnect("disconnect_client_section_btn")
def render_fl_combined_section():
simple_block_header("Federated Learning")
render_local_training_configuration()
st.divider()
# st.subheader("Combined Client")
tab_auto, tab_manual = st.tabs(["FL Auto", "FL Manual"])
with tab_auto:
render_federated_ssl_client()
with tab_manual:
render_federated_ssl_client_manual()
def render_single_page_layout():
# simple_block_header("Federated Learning Client")
# render_client_section()
# Thay thế st.title bằng dòng này
st.markdown(
"""
<h1 style='text-align: center; font-size: 80px; font-weight: 800; color: white; margin-bottom: 30px;'>
Federated Learning Client
</h1>
""",
unsafe_allow_html=True
)
st.divider()
render_client_data_upload()
st.divider()
render_private_dataset_summary()
st.divider()
render_fl_combined_section()
st.divider()
render_local_round_history()
st.divider()
render_inference_speech_emotion()
render_single_page_layout()