import os import time import tempfile import traceback import numpy as np import soundfile as sf import gradio as gr # --- Pipeline & Model Imports --- # (Kept intact to connect with your backend framework logic) from interface import ( load_patient_registry, save_patient_registry, separate_audio, monitoring_table_rows, _copy_audio_to_storage, run_end_to_end, search_history_records, search_reasoning_records, _initialize_models_for_live_processing, RESULTS_DIR, process_audio_chunk_for_separation, infer_on_separated_chunk, _live_sr, LIVE_PROCESSING_WINDOW_SECONDS, LIVE_OVERLAP_SECONDS, resolve_patient_names, save_audio_file, predict_sources ) from gnn import EnhancedPatientStateManager, ClinicalAlertSystem # ========================================== # 1. STYLING CONFIGURATION # ========================================== def load_css(): css_path = os.path.join(os.path.dirname(__file__), "style.css") if os.path.exists(css_path): with open(css_path, "r") as f: return f.read() return "" css_styles = load_css() # ========================================== # 2. CORE BUSINESS & UTILITY LOGIC # ========================================== def register_patients(ref_audio_1, ref_audio_2, ref_audio_3, patient_name_1, patient_name_2, patient_name_3): raw_names = [patient_name_1, patient_name_2, patient_name_3] audio_paths = [ref_audio_1, ref_audio_2, ref_audio_3] patient_entries = [] for idx, (name, audio_path) in enumerate(zip(raw_names, audio_paths), start=1): entry = { "patient_id": f"patient_{idx}", "name": name.strip() if name else f"Patient {idx}", "reference_audio": str(audio_path) if audio_path else None, } patient_entries.append(entry) save_patient_registry(patient_entries) registered = [f"Patient {idx}: {entry['name']}" for idx, entry in enumerate(patient_entries, start=1)] message = ( "Patients registered successfully. Reference audio files are saved for each patient.\n" + "\n".join(registered) + "\nSaved to pipeline_results/patient_registry.json" ) choices = get_registered_patient_choices() selected_values = [choices[i + 1] if i + 1 < len(choices) else "Unassigned" for i in range(6)] return ( message, gr.update(choices=choices, value=selected_values[0]), gr.update(choices=choices, value=selected_values[1]), gr.update(choices=choices, value=selected_values[2]), gr.update(choices=choices, value=selected_values[0]), gr.update(choices=choices, value=selected_values[1]), gr.update(choices=choices, value=selected_values[2]), ) def get_registered_patient_choices(default_count=3): registry = load_patient_registry() names = [entry.get("name") or f"Patient {idx + 1}" for idx, entry in enumerate(registry)] names = list(dict.fromkeys(names)) if not names: names = [f"Patient {i}" for i in range(1, default_count + 1)] return ["Unassigned"] + names def normalize_live_patient_names(selected_names): normalized = [] for name in selected_names: if name and name != "Unassigned": normalized.append(name) else: normalized.append(None) return normalized def predict(mix_audio, p1, p2, p3): # p1, p2, p3 are patient names selected_names = [n if n != "Unassigned" else None for n in [p1, p2, p3]] registry = load_patient_registry() reference_audio_paths = [None, None, None] for i, name in enumerate(selected_names): if name: for entry in registry: if entry.get("name") == name: reference_audio_paths[i] = entry.get("local_reference_audio") or entry.get("reference_audio") break try: outputs, reasoning_summaries, history_record = run_end_to_end( mix_audio, patient_names=selected_names, reference_audio_paths=reference_audio_paths ) except Exception as exc: error_message = f"Pipeline error: {exc}" error_trace = traceback.format_exc() empty = [None, None, None, None, None, None] return empty + [error_message, [[]], f"Error at {time.strftime('%H:%M:%S')}", error_trace] monitor_rows = monitoring_table_rows() history_status = f"Completed at {history_record.get('timestamp', 'unknown')}. {len(reasoning_summaries)} patient(s) processed." message = f"Full pipeline complete. {len(reasoning_summaries)} reasoning summaries available." return outputs + [message] + [monitor_rows] + [history_status] def refresh_monitoring(): rows = monitoring_table_rows() if not rows: return [], "No reasoning summary available. Run the pipeline and make sure pipeline_results/reasoning_summary.json exists." return rows, f"Loaded {len(rows)} patient states from reasoning summary." def search_history(query): if not query: return [], "Enter a patient name or ID to search history." rows = search_reasoning_records(query) if not rows: return [], f"No clinical findings found matching '{query}'." return rows, f"Found {len(rows)} matching clinical record(s)." # ========================================== # 3. LIVE STREAMING CONTROLLERS # ========================================== def start_live_monitoring_session(selected_patient_1, selected_patient_2, selected_patient_3): global _live_processor, _live_wav2vec_model, _live_gnn_model, _live_device _live_processor, _live_wav2vec_model, _live_gnn_model, _live_device = _initialize_models_for_live_processing() selected_names = normalize_live_patient_names([selected_patient_1, selected_patient_2, selected_patient_3]) managers = [ EnhancedPatientStateManager(), EnhancedPatientStateManager(), EnhancedPatientStateManager(), ] empty_audio = np.array([], dtype=np.float32) empty_buffers = [np.array([], dtype=np.float32) for _ in range(3)] empty_table = [] status_names = [name for name in selected_names if name is not None] if status_names: status = f"Live monitoring initialized for: {', '.join(status_names)}. Click microphone to start." else: status = "No patients selected. Select at least one patient to monitor." return ( empty_audio, empty_buffers, managers, [0.0, 0.0, 0.0], empty_table, selected_names, status, gr.update(value=None, interactive=True), gr.update(interactive=False), ) def process_live_audio_stream(audio_chunk, live_audio_buffer, live_separated_buffers, live_patient_managers, live_current_timestamps, live_patient_names): if audio_chunk is None: return live_audio_buffer, live_separated_buffers, live_patient_managers, live_current_timestamps, [], "No audio received." import librosa import torch sr, np_audio = audio_chunk if sr is None or np_audio is None: return live_audio_buffer, live_separated_buffers, live_patient_managers, live_current_timestamps, [], "Invalid audio chunk received." mono = np.mean(np_audio, axis=-1) if np_audio.ndim > 1 else np_audio target_sr = _live_sr if sr != target_sr: mono = librosa.resample(mono.astype(np.float32), orig_sr=sr, target_sr=target_sr) mono = mono.astype(np.float32) current_audio_base = live_audio_buffer if live_audio_buffer is not None else np.array([], dtype=np.float32) current_audio = np.concatenate([current_audio_base, mono]) if current_audio_base.size > 0 else mono max_buffer = int(LIVE_PROCESSING_WINDOW_SECONDS * target_sr) if current_audio.size > max_buffer: current_audio = current_audio[-max_buffer:] window_samples = int(LIVE_PROCESSING_WINDOW_SECONDS * target_sr) overlap_samples = int(LIVE_OVERLAP_SECONDS * target_sr) step = window_samples - overlap_samples new_buffers = [buf.copy() for buf in live_separated_buffers] if live_separated_buffers else [np.array([], dtype=np.float32) for _ in range(3)] new_timestamps = list(live_current_timestamps) audios_out = [None, None, None] active_slots = [i for i, name in enumerate(live_patient_names) if name is not None] if live_patient_names else [] if current_audio.size >= window_samples and active_slots: chunk = current_audio[-window_samples:] prediction, _ = predict_sources(torch.from_numpy(chunk).unsqueeze(0), _live_sr) separated = prediction.cpu().numpy() for i in range(min(3, separated.shape[0])): selected_name = live_patient_names[i] if i < len(live_patient_names) else None if selected_name is None: new_buffers[i] = np.array([], dtype=np.float32) audios_out[i] = None continue separated_i = separated[i].astype(np.float32) hop = step if step > 0 else window_samples if separated_i.size > hop: new_buffers[i] = separated_i[-hop:] else: new_buffers[i] = separated_i new_timestamps[i] = new_timestamps[i] + hop / target_sr if new_timestamps[i] > 0 else hop / target_sr # Performance Note: Disk I/O (sf.write/_copy_audio_to_storage) removed to reduce live latency try: audios_out[i] = (target_sr, separated_i) except Exception: audios_out[i] = None rows = [] for i, manager in enumerate(live_patient_managers): selected_name = live_patient_names[i] if i < len(live_patient_names) else None if selected_name is None or manager is None: continue separated_chunk = new_buffers[i] timestamp = new_timestamps[i] if new_timestamps[i] > 0 else 0.0 if separated_chunk.size >= target_sr: state = infer_on_separated_chunk( separated_chunk, _live_gnn_model, _live_processor, _live_wav2vec_model, _live_device, manager, f"live_patient_{i+1}", timestamp, ) rows.append([ selected_name, manager.patient_data.get(f"live_patient_{i+1}", {}).get("wheeze_ema"), manager.patient_data.get(f"live_patient_{i+1}", {}).get("crackle_ema"), state.get("breathing_rate_mean"), state.get("comment", ""), ]) return current_audio, new_buffers, live_patient_managers, new_timestamps, rows, "Processing live audio...", audios_out[0], audios_out[1], audios_out[2], live_patient_names def stop_live_monitoring_session(): empty_audio = np.array([], dtype=np.float32) empty_buffers = [np.array([], dtype=np.float32) for _ in range(3)] cleared_managers = [None, None, None] cleared_timestamps = [0.0, 0.0, 0.0] return empty_audio, empty_buffers, cleared_managers, cleared_timestamps, [], "Live monitoring stopped.", gr.update(value=None, interactive=False), gr.update(interactive=True), None, None, None # ========================================== # 4. INTERFACE BUILDING METHOD # ========================================== def create_ui(): # Pass structural embedded CSS string variable safely inside Blocks with gr.Blocks() as demo: gr.HTML("