Spaces:
Running
Running
| import os | |
| import time | |
| import tempfile | |
| import traceback | |
| import numpy as np | |
| import soundfile as sf | |
| import gradio as gr | |
| # --- Pipeline & Model Imports --- | |
| # (Kept intact to connect with your backend framework logic) | |
| from interface import ( | |
| load_patient_registry, save_patient_registry, separate_audio, monitoring_table_rows, _copy_audio_to_storage, | |
| run_end_to_end, search_history_records, search_reasoning_records, _initialize_models_for_live_processing, RESULTS_DIR, | |
| process_audio_chunk_for_separation, infer_on_separated_chunk, _live_sr, | |
| LIVE_PROCESSING_WINDOW_SECONDS, LIVE_OVERLAP_SECONDS, resolve_patient_names, save_audio_file, predict_sources | |
| ) | |
| from gnn import EnhancedPatientStateManager, ClinicalAlertSystem | |
| # ========================================== | |
| # 1. STYLING CONFIGURATION | |
| # ========================================== | |
| def load_css(): | |
| css_path = os.path.join(os.path.dirname(__file__), "style.css") | |
| if os.path.exists(css_path): | |
| with open(css_path, "r") as f: | |
| return f.read() | |
| return "" | |
| css_styles = load_css() | |
| # ========================================== | |
| # 2. CORE BUSINESS & UTILITY LOGIC | |
| # ========================================== | |
| def register_patients(ref_audio_1, ref_audio_2, ref_audio_3, patient_name_1, patient_name_2, patient_name_3): | |
| raw_names = [patient_name_1, patient_name_2, patient_name_3] | |
| audio_paths = [ref_audio_1, ref_audio_2, ref_audio_3] | |
| patient_entries = [] | |
| for idx, (name, audio_path) in enumerate(zip(raw_names, audio_paths), start=1): | |
| entry = { | |
| "patient_id": f"patient_{idx}", | |
| "name": name.strip() if name else f"Patient {idx}", | |
| "reference_audio": str(audio_path) if audio_path else None, | |
| } | |
| patient_entries.append(entry) | |
| save_patient_registry(patient_entries) | |
| registered = [f"Patient {idx}: {entry['name']}" for idx, entry in enumerate(patient_entries, start=1)] | |
| message = ( | |
| "Patients registered successfully. Reference audio files are saved for each patient.\n" | |
| + "\n".join(registered) | |
| + "\nSaved to pipeline_results/patient_registry.json" | |
| ) | |
| choices = get_registered_patient_choices() | |
| selected_values = [choices[i + 1] if i + 1 < len(choices) else "Unassigned" for i in range(6)] | |
| return ( | |
| message, | |
| gr.update(choices=choices, value=selected_values[0]), | |
| gr.update(choices=choices, value=selected_values[1]), | |
| gr.update(choices=choices, value=selected_values[2]), | |
| gr.update(choices=choices, value=selected_values[0]), | |
| gr.update(choices=choices, value=selected_values[1]), | |
| gr.update(choices=choices, value=selected_values[2]), | |
| ) | |
| def get_registered_patient_choices(default_count=3): | |
| registry = load_patient_registry() | |
| names = [entry.get("name") or f"Patient {idx + 1}" for idx, entry in enumerate(registry)] | |
| names = list(dict.fromkeys(names)) | |
| if not names: | |
| names = [f"Patient {i}" for i in range(1, default_count + 1)] | |
| return ["Unassigned"] + names | |
| def normalize_live_patient_names(selected_names): | |
| normalized = [] | |
| for name in selected_names: | |
| if name and name != "Unassigned": | |
| normalized.append(name) | |
| else: | |
| normalized.append(None) | |
| return normalized | |
| def predict(mix_audio, p1, p2, p3): # p1, p2, p3 are patient names | |
| selected_names = [n if n != "Unassigned" else None for n in [p1, p2, p3]] | |
| registry = load_patient_registry() | |
| reference_audio_paths = [None, None, None] | |
| for i, name in enumerate(selected_names): | |
| if name: | |
| for entry in registry: | |
| if entry.get("name") == name: | |
| reference_audio_paths[i] = entry.get("local_reference_audio") or entry.get("reference_audio") | |
| break | |
| try: | |
| outputs, reasoning_summaries, history_record = run_end_to_end( | |
| mix_audio, | |
| patient_names=selected_names, | |
| reference_audio_paths=reference_audio_paths | |
| ) | |
| except Exception as exc: | |
| error_message = f"Pipeline error: {exc}" | |
| error_trace = traceback.format_exc() | |
| empty = [None, None, None, None, None, None] | |
| return empty + [error_message, [[]], f"Error at {time.strftime('%H:%M:%S')}", error_trace] | |
| monitor_rows = monitoring_table_rows() | |
| history_status = f"Completed at {history_record.get('timestamp', 'unknown')}. {len(reasoning_summaries)} patient(s) processed." | |
| message = f"Full pipeline complete. {len(reasoning_summaries)} reasoning summaries available." | |
| return outputs + [message] + [monitor_rows] + [history_status] | |
| def refresh_monitoring(): | |
| rows = monitoring_table_rows() | |
| if not rows: | |
| return [], "No reasoning summary available. Run the pipeline and make sure pipeline_results/reasoning_summary.json exists." | |
| return rows, f"Loaded {len(rows)} patient states from reasoning summary." | |
| def search_history(query): | |
| if not query: | |
| return [], "Enter a patient name or ID to search history." | |
| rows = search_reasoning_records(query) | |
| if not rows: | |
| return [], f"No clinical findings found matching '{query}'." | |
| return rows, f"Found {len(rows)} matching clinical record(s)." | |
| # ========================================== | |
| # 3. LIVE STREAMING CONTROLLERS | |
| # ========================================== | |
| def start_live_monitoring_session(selected_patient_1, selected_patient_2, selected_patient_3): | |
| global _live_processor, _live_wav2vec_model, _live_gnn_model, _live_device | |
| _live_processor, _live_wav2vec_model, _live_gnn_model, _live_device = _initialize_models_for_live_processing() | |
| selected_names = normalize_live_patient_names([selected_patient_1, selected_patient_2, selected_patient_3]) | |
| managers = [ | |
| EnhancedPatientStateManager(), | |
| EnhancedPatientStateManager(), | |
| EnhancedPatientStateManager(), | |
| ] | |
| empty_audio = np.array([], dtype=np.float32) | |
| empty_buffers = [np.array([], dtype=np.float32) for _ in range(3)] | |
| empty_table = [] | |
| status_names = [name for name in selected_names if name is not None] | |
| if status_names: | |
| status = f"Live monitoring initialized for: {', '.join(status_names)}. Click microphone to start." | |
| else: | |
| status = "No patients selected. Select at least one patient to monitor." | |
| return ( | |
| empty_audio, | |
| empty_buffers, | |
| managers, | |
| [0.0, 0.0, 0.0], | |
| empty_table, | |
| selected_names, | |
| status, | |
| gr.update(value=None, interactive=True), | |
| gr.update(interactive=False), | |
| ) | |
| def process_live_audio_stream(audio_chunk, live_audio_buffer, live_separated_buffers, live_patient_managers, live_current_timestamps, live_patient_names): | |
| if audio_chunk is None: | |
| return live_audio_buffer, live_separated_buffers, live_patient_managers, live_current_timestamps, [], "No audio received." | |
| import librosa | |
| import torch | |
| sr, np_audio = audio_chunk | |
| if sr is None or np_audio is None: | |
| return live_audio_buffer, live_separated_buffers, live_patient_managers, live_current_timestamps, [], "Invalid audio chunk received." | |
| mono = np.mean(np_audio, axis=-1) if np_audio.ndim > 1 else np_audio | |
| target_sr = _live_sr | |
| if sr != target_sr: | |
| mono = librosa.resample(mono.astype(np.float32), orig_sr=sr, target_sr=target_sr) | |
| mono = mono.astype(np.float32) | |
| current_audio_base = live_audio_buffer if live_audio_buffer is not None else np.array([], dtype=np.float32) | |
| current_audio = np.concatenate([current_audio_base, mono]) if current_audio_base.size > 0 else mono | |
| max_buffer = int(LIVE_PROCESSING_WINDOW_SECONDS * target_sr) | |
| if current_audio.size > max_buffer: | |
| current_audio = current_audio[-max_buffer:] | |
| window_samples = int(LIVE_PROCESSING_WINDOW_SECONDS * target_sr) | |
| overlap_samples = int(LIVE_OVERLAP_SECONDS * target_sr) | |
| step = window_samples - overlap_samples | |
| new_buffers = [buf.copy() for buf in live_separated_buffers] if live_separated_buffers else [np.array([], dtype=np.float32) for _ in range(3)] | |
| new_timestamps = list(live_current_timestamps) | |
| audios_out = [None, None, None] | |
| active_slots = [i for i, name in enumerate(live_patient_names) if name is not None] if live_patient_names else [] | |
| if current_audio.size >= window_samples and active_slots: | |
| chunk = current_audio[-window_samples:] | |
| prediction, _ = predict_sources(torch.from_numpy(chunk).unsqueeze(0), _live_sr) | |
| separated = prediction.cpu().numpy() | |
| for i in range(min(3, separated.shape[0])): | |
| selected_name = live_patient_names[i] if i < len(live_patient_names) else None | |
| if selected_name is None: | |
| new_buffers[i] = np.array([], dtype=np.float32) | |
| audios_out[i] = None | |
| continue | |
| separated_i = separated[i].astype(np.float32) | |
| hop = step if step > 0 else window_samples | |
| if separated_i.size > hop: | |
| new_buffers[i] = separated_i[-hop:] | |
| else: | |
| new_buffers[i] = separated_i | |
| new_timestamps[i] = new_timestamps[i] + hop / target_sr if new_timestamps[i] > 0 else hop / target_sr | |
| # Performance Note: Disk I/O (sf.write/_copy_audio_to_storage) removed to reduce live latency | |
| try: | |
| audios_out[i] = (target_sr, separated_i) | |
| except Exception: | |
| audios_out[i] = None | |
| rows = [] | |
| for i, manager in enumerate(live_patient_managers): | |
| selected_name = live_patient_names[i] if i < len(live_patient_names) else None | |
| if selected_name is None or manager is None: | |
| continue | |
| separated_chunk = new_buffers[i] | |
| timestamp = new_timestamps[i] if new_timestamps[i] > 0 else 0.0 | |
| if separated_chunk.size >= target_sr: | |
| state = infer_on_separated_chunk( | |
| separated_chunk, | |
| _live_gnn_model, | |
| _live_processor, | |
| _live_wav2vec_model, | |
| _live_device, | |
| manager, | |
| f"live_patient_{i+1}", | |
| timestamp, | |
| ) | |
| rows.append([ | |
| selected_name, | |
| manager.patient_data.get(f"live_patient_{i+1}", {}).get("wheeze_ema"), | |
| manager.patient_data.get(f"live_patient_{i+1}", {}).get("crackle_ema"), | |
| state.get("breathing_rate_mean"), | |
| state.get("comment", ""), | |
| ]) | |
| return current_audio, new_buffers, live_patient_managers, new_timestamps, rows, "Processing live audio...", audios_out[0], audios_out[1], audios_out[2], live_patient_names | |
| def stop_live_monitoring_session(): | |
| empty_audio = np.array([], dtype=np.float32) | |
| empty_buffers = [np.array([], dtype=np.float32) for _ in range(3)] | |
| cleared_managers = [None, None, None] | |
| cleared_timestamps = [0.0, 0.0, 0.0] | |
| return empty_audio, empty_buffers, cleared_managers, cleared_timestamps, [], "Live monitoring stopped.", gr.update(value=None, interactive=False), gr.update(interactive=True), None, None, None | |
| # ========================================== | |
| # 4. INTERFACE BUILDING METHOD | |
| # ========================================== | |
| def create_ui(): | |
| # Pass structural embedded CSS string variable safely inside Blocks | |
| with gr.Blocks() as demo: | |
| gr.HTML("<div class='header-box'><h1>Patient Monitoring System</h1></div>") | |
| with gr.Row(): | |
| # Sidebar Menu | |
| with gr.Column(scale=1, variant="panel"): | |
| gr.Markdown("### Navigation") | |
| btn_register = gr.Button("Register Patients", variant="secondary", elem_classes="sidebar-btn") | |
| btn_separation = gr.Button("Audio Separation", variant="secondary", elem_classes="sidebar-btn") | |
| btn_live_mon = gr.Button("Live Monitoring", variant="secondary", elem_classes="sidebar-btn") | |
| btn_history = gr.Button("View History", variant="secondary", elem_classes="sidebar-btn") | |
| # Content Area | |
| with gr.Column(scale=4): | |
| live_audio_buffer_state = gr.State(value=None) | |
| live_separated_buffers_state = gr.State(value=[]) | |
| live_patient_managers_state = gr.State(value=[None, None, None]) | |
| live_current_timestamps_state = gr.State(value=[0.0, 0.0, 0.0]) | |
| live_patient_names_state = gr.State(value=[None, None, None]) | |
| # Registration Page | |
| with gr.Column(visible=True) as reg_page: | |
| gr.Markdown("### Patient Registration") | |
| with gr.Row(): | |
| with gr.Column(variant="panel"): | |
| gr.Markdown("#### Patient 1") | |
| patient_name_1 = gr.Textbox(label="Name", placeholder="Enter name", elem_classes="vibrant-status") | |
| ref_audio_1 = gr.Audio(label="Ref Audio", type="filepath") | |
| with gr.Column(variant="panel"): | |
| gr.Markdown("#### Patient 2") | |
| patient_name_2 = gr.Textbox(label="Name", placeholder="Enter name", elem_classes="vibrant-status") | |
| ref_audio_2 = gr.Audio(label="Ref Audio", type="filepath") | |
| with gr.Column(variant="panel"): | |
| gr.Markdown("#### Patient 3") | |
| patient_name_3 = gr.Textbox(label="Name", placeholder="Enter name", elem_classes="vibrant-status") | |
| ref_audio_3 = gr.Audio(label="Ref Audio", type="filepath") | |
| register_btn = gr.Button("Submit Registration", variant="primary", size="lg") | |
| register_status = gr.Textbox(label="Status", interactive=False, elem_classes="vibrant-status") | |
| # Separation Page | |
| with gr.Column(visible=False) as sep_page: | |
| gr.Markdown("### Source Separation & Inference") | |
| with gr.Row(): | |
| with gr.Column(scale=2, variant="panel"): | |
| mix_audio = gr.Audio(label="Upload Mixture (Multiple Patients)", type="filepath") | |
| gr.Markdown("#### Patient Assignment") | |
| with gr.Row(): | |
| sep_p1 = gr.Dropdown(label="Source 1", choices=get_registered_patient_choices(), value="Unassigned") | |
| sep_p2 = gr.Dropdown(label="Source 2", choices=get_registered_patient_choices(), value="Unassigned") | |
| sep_p3 = gr.Dropdown(label="Source 3", choices=get_registered_patient_choices(), value="Unassigned") | |
| submit_btn = gr.Button("Run Separation Pipeline", variant="primary") | |
| with gr.Column(scale=1): | |
| status_text = gr.Textbox(label="Process Status", interactive=False, elem_classes="vibrant-status") | |
| history_status_text = gr.Textbox(label="History Logging", interactive=False, elem_classes="vibrant-status") | |
| gr.Markdown("#### Separated Patient Data") | |
| with gr.Row(): | |
| with gr.Column(variant="panel"): | |
| out_audio_1 = gr.Audio(label="Patient 1 Audio", type="filepath") | |
| out_wave_1 = gr.Image(label="Waveform 1", type="filepath") | |
| with gr.Column(variant="panel"): | |
| out_audio_2 = gr.Audio(label="Patient 2 Audio", type="filepath") | |
| out_wave_2 = gr.Image(label="Waveform 2", type="filepath") | |
| with gr.Column(variant="panel"): | |
| out_audio_3 = gr.Audio(label="Patient 3 Audio", type="filepath") | |
| out_wave_3 = gr.Image(label="Waveform 3", type="filepath") | |
| gr.Markdown("#### Immediate Findings") | |
| monitor_table_small = gr.Dataframe( | |
| headers=["Patient Name", "overall_state", "mean_wheeze_prob", "mean_crackle_prob", "breathing_rate_mean", "comment"], | |
| datatype=["str", "str", "number", "number", "number", "str"], | |
| interactive=False, | |
| ) | |
| # History Page | |
| with gr.Column(visible=False) as history_page: | |
| gr.Markdown('<div class="page-title">Patient History Search</div>', elem_classes="page-container") | |
| with gr.Row(): | |
| search_query = gr.Textbox(label="Search by Patient Name or Audio ID", placeholder="Enter name...", elem_classes="vibrant-status") | |
| search_button = gr.Button("Search History", variant="primary") | |
| history_results = gr.Dataframe( | |
| headers=["Timestamp", "Audio File", "Patient Names", "Reasoning Count"], | |
| datatype=["str", "str", "str", "number"], | |
| interactive=False, | |
| ) | |
| history_msg = gr.Textbox(label="Search Results", interactive=False, elem_classes="vibrant-status") | |
| # Live Monitoring Page | |
| with gr.Column(visible=False) as live_mon_page: | |
| gr.Markdown("## Live Audio Monitoring") | |
| # Start/Stop buttons at the very top alone | |
| with gr.Row(): | |
| start_live_btn = gr.Button("Start Live Monitoring", variant="primary") | |
| stop_live_btn = gr.Button("Stop Live Monitoring", variant="stop") | |
| # Monitor Slots row - Stretching along one line | |
| with gr.Row(): | |
| select_patient_1 = gr.Dropdown( | |
| label="Monitor Slot 1", | |
| choices=get_registered_patient_choices(), | |
| value="Unassigned", | |
| interactive=True, | |
| ) | |
| select_patient_2 = gr.Dropdown( | |
| label="Monitor Slot 2", | |
| choices=get_registered_patient_choices(), | |
| value="Unassigned", | |
| interactive=True, | |
| ) | |
| select_patient_3 = gr.Dropdown( | |
| label="Monitor Slot 3", | |
| choices=get_registered_patient_choices(), | |
| value="Unassigned", | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| live_mic_input = gr.Audio( | |
| sources=["microphone"], | |
| streaming=True, | |
| label="Live Microphone Input", | |
| type="numpy", | |
| interactive=True, | |
| ) | |
| with gr.Column(scale=2): | |
| gr.Markdown("#### Live Separated Sources") | |
| with gr.Row(): | |
| live_out_audio_1 = gr.Audio(label="Live Patient 1", interactive=False, type="numpy") | |
| live_out_audio_2 = gr.Audio(label="Live Patient 2", interactive=False, type="numpy") | |
| live_out_audio_3 = gr.Audio(label="Live Patient 3", interactive=False, type="numpy") | |
| gr.Markdown("#### Live Findings") | |
| live_monitor_table = gr.Dataframe( | |
| headers=["Patient Name", "overall_state", "mean_wheeze_prob", "mean_crackle_prob", "breathing_rate_mean", "comment"], | |
| datatype=["str", "str", "number", "number", "number", "str"], | |
| interactive=False, | |
| ) | |
| live_monitor_status = gr.Textbox(label="Live Status", interactive=False, elem_classes="vibrant-status") | |
| # --- Tab Routing Mechanics --- | |
| def nav_reg(): | |
| return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(variant="primary"), gr.update(variant="secondary"), gr.update(variant="secondary"), gr.update(variant="secondary") | |
| def nav_sep(): | |
| choices = get_registered_patient_choices() | |
| return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(variant="secondary"), gr.update(variant="primary"), gr.update(variant="secondary"), gr.update(variant="secondary"), gr.update(choices=choices), gr.update(choices=choices), gr.update(choices=choices) | |
| def nav_his(): | |
| return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(variant="secondary"), gr.update(variant="secondary"), gr.update(variant="secondary"), gr.update(variant="primary") | |
| def nav_live(): | |
| return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(variant="secondary"), gr.update(variant="secondary"), gr.update(variant="primary"), gr.update(variant="secondary") | |
| btn_register.click( | |
| nav_reg, | |
| outputs=[reg_page, sep_page, history_page, live_mon_page, btn_register, btn_separation, btn_history, btn_live_mon], | |
| queue=False, | |
| ) | |
| btn_separation.click( | |
| nav_sep, | |
| outputs=[reg_page, sep_page, history_page, live_mon_page, btn_register, btn_separation, btn_history, btn_live_mon, sep_p1, sep_p2, sep_p3], | |
| queue=False, | |
| ) | |
| btn_history.click( | |
| nav_his, | |
| outputs=[reg_page, sep_page, history_page, live_mon_page, btn_register, btn_separation, btn_history, btn_live_mon], | |
| queue=False, | |
| ) | |
| btn_live_mon.click( | |
| nav_live, | |
| outputs=[reg_page, sep_page, history_page, live_mon_page, btn_register, btn_separation, btn_history, btn_live_mon], | |
| queue=False, | |
| ) | |
| # --- Interactive Trigger Bindings --- | |
| register_btn.click( | |
| fn=register_patients, | |
| inputs=[ref_audio_1, ref_audio_2, ref_audio_3, patient_name_1, patient_name_2, patient_name_3], | |
| outputs=[register_status, select_patient_1, select_patient_2, select_patient_3, sep_p1, sep_p2, sep_p3], | |
| ) | |
| submit_btn.click( | |
| fn=predict, | |
| inputs=[mix_audio, sep_p1, sep_p2, sep_p3], | |
| outputs=[out_audio_1, out_audio_2, out_audio_3, out_wave_1, out_wave_2, out_wave_3, status_text, monitor_table_small, history_status_text], | |
| ) | |
| search_button.click( | |
| fn=search_history, | |
| inputs=[search_query], | |
| outputs=[history_results, history_msg], | |
| ) | |
| start_live_btn.click( | |
| fn=start_live_monitoring_session, | |
| inputs=[select_patient_1, select_patient_2, select_patient_3], | |
| outputs=[ | |
| live_audio_buffer_state, | |
| live_separated_buffers_state, | |
| live_patient_managers_state, | |
| live_current_timestamps_state, | |
| live_monitor_table, | |
| live_patient_names_state, | |
| live_monitor_status, | |
| live_mic_input, | |
| start_live_btn, | |
| ], | |
| queue=False, | |
| ) | |
| live_mic_input.stream( | |
| fn=process_live_audio_stream, | |
| inputs=[ | |
| live_mic_input, | |
| live_audio_buffer_state, | |
| live_separated_buffers_state, | |
| live_patient_managers_state, | |
| live_current_timestamps_state, | |
| live_patient_names_state, | |
| ], | |
| outputs=[live_audio_buffer_state, live_separated_buffers_state, live_patient_managers_state, live_current_timestamps_state, live_monitor_table, live_monitor_status, live_out_audio_1, live_out_audio_2, live_out_audio_3, live_patient_names_state], | |
| concurrency_limit=5, | |
| ) | |
| stop_live_btn.click( | |
| fn=stop_live_monitoring_session, | |
| inputs=[], | |
| outputs=[live_audio_buffer_state, live_separated_buffers_state, live_patient_managers_state, live_current_timestamps_state, live_monitor_table, live_monitor_status, live_mic_input, start_live_btn, live_out_audio_1, live_out_audio_2, live_out_audio_3], | |
| queue=False, | |
| ) | |
| return demo | |
| app = create_ui() | |
| if __name__ == "__main__": | |
| host = "0.0.0.0" | |
| app.launch( | |
| share=False, | |
| server_name=host, | |
| server_port=int(os.environ.get("PORT", 7860)), | |
| theme=gr.themes.Soft(), | |
| css=css_styles, | |
| allowed_paths=[str(RESULTS_DIR.resolve())], | |
| ) |