# -*- coding: utf-8 -*- """app Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/1_pOuPpNNnrEQCdGrxl0JnaQJXAIb4F-j """ # ========================================= # TabeebAI — Unified Streamlit App # AI Medical Triage (Hugging Face Spaces) # ========================================= import os import json import time import base64 import tempfile import traceback import numpy as np import pandas as pd import streamlit as st from io import BytesIO from groq import Groq from sentence_transformers import SentenceTransformer # ========================================= # GROQ CLIENT # ========================================= api_key = os.environ.get("GROQ_API_KEY") client = Groq(api_key=api_key) if api_key else None # ========================================= # DISEASE DATASET # ========================================= _DATA_PATH = os.path.join(os.path.dirname(__file__), "data", "diseases_symptoms.csv") disease_df = None _SYMPTOM_COLS = [] try: disease_df = pd.read_csv(_DATA_PATH) disease_df.columns = ( disease_df.columns.str.lower().str.strip().str.replace(" ", "_") ) _SYMPTOM_COLS = [c for c in disease_df.columns if c != "disease"] print(f"[TabeebAI] Disease dataset: {len(disease_df)} diseases, {len(_SYMPTOM_COLS)} symptoms") except Exception as e: print(f"[TabeebAI] Disease dataset not found: {e}") def _match_symptom_to_column(symptom_name: str): query_words = set( symptom_name.lower().replace("_", " ").replace("-", " ").split() ) stop_words = {"severe", "mild", "moderate", "acute", "chronic", "sudden", "persistent", "intense", "sharp", "dull"} query_words -= stop_words best_col, best_score = None, 0 for col in _SYMPTOM_COLS: col_words = set(col.replace("_", " ").split()) overlap = len(query_words & col_words) if overlap > best_score: best_score, best_col = overlap, col return best_col if best_score > 0 else None def lookup_diseases(symptoms_list: list, top_n: int = 5) -> list: if disease_df is None or not symptoms_list: return [] matched_cols = [] for s in symptoms_list: col = _match_symptom_to_column(s.get("name", "")) if col and col not in matched_cols: matched_cols.append(col) if not matched_cols: return [] scores = disease_df[matched_cols].sum(axis=1) total_cols = len(matched_cols) results = ( disease_df[["disease"]] .assign(matched=scores, total=total_cols) .query("matched > 0") .sort_values("matched", ascending=False) .head(top_n) ) return results.to_dict(orient="records") # ========================================= # RAG SETUP # ========================================= _rag_model = None _rag_embeddings = None _rag_docs = [] def setup_rag(): global _rag_model, _rag_embeddings, _rag_docs knowledge_path = os.path.join(os.path.dirname(__file__), "data", "medical_knowledge.json") try: with open(knowledge_path, "r", encoding="utf-8") as f: _rag_docs = json.load(f) print("[TabeebAI] Loading embedding model...") _rag_model = SentenceTransformer("all-MiniLM-L6-v2") texts = [f"{d['title']}. {d['text']}" for d in _rag_docs] _rag_embeddings = _rag_model.encode( texts, normalize_embeddings=True, show_progress_bar=False ) print(f"[TabeebAI] RAG ready: {len(_rag_docs)} documents embedded") except Exception as e: print(f"[TabeebAI] RAG setup failed: {e}") def retrieve_knowledge(query: str, n: int = 3) -> list: if _rag_model is None or _rag_embeddings is None or not _rag_docs: return [] query_emb = _rag_model.encode([query], normalize_embeddings=True, show_progress_bar=False) scores = np.dot(_rag_embeddings, query_emb.T).flatten() top_idx = scores.argsort()[-n:][::-1] return [ {"title": _rag_docs[i]["title"], "text": _rag_docs[i]["text"], "score": float(scores[i])} for i in top_idx if scores[i] > 0.2 ] setup_rag() # ========================================= # LANGUAGE DETECTION # ========================================= def detect_language(text: str) -> str: if not text or not text.strip(): return "unknown" urdu_chars = sum(1 for c in text if '\u0600' <= c <= '\u06ff') devanagari_chars = sum(1 for c in text if '\u0900' <= c <= '\u097f') latin_chars = sum(1 for c in text if c.isascii() and c.isalpha()) total_alpha = sum(1 for c in text if c.isalpha()) if total_alpha == 0: return "unknown" if urdu_chars / total_alpha > 0.3: return "ur" if latin_chars / total_alpha > 0.5: return "en" return "other" # ========================================= # TRANSCRIPTION # ========================================= def _call_whisper(audio_path: str, language=None): with open(audio_path, "rb") as f: kwargs = dict(file=f, model="whisper-large-v3-turbo") if language: kwargs["language"] = language response = client.audio.transcriptions.create(**kwargs) text = response.text.strip() return text, detect_language(text) def transcribe_audio_file(audio_path: str): text, lang = _call_whisper(audio_path) if lang == "other": text, lang = _call_whisper(audio_path, language="ur") if lang == "other": return "Unsupported language detected. Please speak in Urdu or English only.", "unknown" return text, lang # ========================================= # TRANSLATION # ========================================= def translate_if_needed(text: str, lang: str) -> str: if not text or lang == "en": return text try: response = client.chat.completions.create( model="llama-3.3-70b-versatile", messages=[{ "role": "user", "content": ( "Translate the following Urdu medical text into clear English. " "Return only the translation, no explanations:\n\n" + text ) }], temperature=0.1 ) return response.choices[0].message.content.strip() except Exception as e: return f"Translation Error: {str(e)}" # ========================================= # CLASSIFICATION # ========================================= def classify_query(text: str) -> str: prompt = f"""You are a medical query classifier for a clinical triage system. Classify the following patient statement into exactly ONE category: MEDICAL — describes any symptom, pain, illness, injury, medication, or health concern CRISIS — mentions self-harm, suicide, wanting to die, or harming others NON_MEDICAL — anything unrelated to health (greetings, general questions, jokes, etc.) Statement: "{text}" Reply with ONLY one word: MEDICAL, CRISIS, or NON_MEDICAL""" try: response = client.chat.completions.create( model="llama-3.1-8b-instant", messages=[{"role": "user", "content": prompt}], max_tokens=5, temperature=0 ) result = response.choices[0].message.content.strip().upper() if "CRISIS" in result: return "crisis" if "NON_MEDICAL" in result or "NON" in result: return "non_medical" return "medical" except Exception: return "medical" # ========================================= # SYMPTOM EXTRACTION # ========================================= def extract_symptoms(text: str) -> dict: prompt = f"""You are a clinical AI assistant. STRICT RULES: - Output MUST be ONLY in English - Return ONLY valid JSON, no extra text, no markdown Extract structured medical symptoms from this text: {text} FORMAT: {{ "chief_complaint": "", "symptoms": [ {{ "name": "", "severity": "mild/moderate/severe", "duration": "" }} ], "possible_conditions": [], "urgency": "low/medium/high", "language": "English" }}""" output = "" try: response = client.chat.completions.create( model="llama-3.3-70b-versatile", messages=[{"role": "user", "content": prompt}], temperature=0 ) output = response.choices[0].message.content output = output.replace("```json", "").replace("```", "").strip() return json.loads(output) except Exception as e: return {"error": "Parsing failed", "details": str(e), "raw_output": output} # ========================================= # RISK SCORING # ========================================= def calculate_risk_score(analysis: dict) -> int: if not analysis or "error" in analysis: return 0 score = 0 urgency_scores = {"low": 10, "medium": 40, "high": 75} score += urgency_scores.get(analysis.get("urgency", "low").lower(), 10) severity_bonus = {"mild": 3, "moderate": 8, "severe": 18} for symptom in analysis.get("symptoms", []): score += severity_bonus.get(symptom.get("severity", "mild").lower(), 3) emergency_keywords = [ "chest pain", "difficulty breathing", "shortness of breath", "unconscious", "unresponsive", "severe bleeding", "heart attack", "stroke", "choking", "not breathing", "no pulse", "سینے میں درد", "سانس لینے میں دشواری", "بے ہوش" ] full_text = analysis.get("chief_complaint", "").lower() full_text += " " + " ".join([s.get("name", "") for s in analysis.get("symptoms", [])]) for keyword in emergency_keywords: if keyword in full_text: score += 30 break score += min(len(analysis.get("symptoms", [])) * 3, 15) return min(score, 100) def get_risk_level(score: int) -> str: if score >= 71: return "RED" elif score >= 31: return "YELLOW" else: return "GREEN" # ========================================= # SOAP REPORT # ========================================= def generate_soap_report(analysis: dict, english_text: str, risk_score: int, retrieved_context: str = "") -> str: if not analysis or "error" in analysis: return "Cannot generate report — symptom extraction failed." symptoms_text = "\n".join([ f" - {s.get('name','?')} | severity: {s.get('severity','?')} | duration: {s.get('duration','?')}" for s in analysis.get("symptoms", []) ]) conditions = ", ".join(analysis.get("possible_conditions", [])) or "None identified" context_section = f"\nRETRIEVED MEDICAL KNOWLEDGE:\n{retrieved_context}\n" if retrieved_context else "" prompt = f"""You are a clinical documentation assistant. Generate a concise SOAP format medical report based on the data below. Return ONLY the report — no extra commentary, no markdown headings with #. PATIENT DATA: Chief Complaint : {analysis.get('chief_complaint', 'N/A')} Symptoms : {symptoms_text} Possible Conditions: {conditions} Urgency : {analysis.get('urgency', 'N/A')} Risk Score : {risk_score}/100 Original Statement: {english_text} {context_section} FORMAT TO USE: SUBJECTIVE: [what the patient reports] OBJECTIVE: [observable findings from the speech/statement] ASSESSMENT: [clinical interpretation, possible diagnoses] PLAN: [recommended next steps for the treating doctor] NOTE: This report is AI-generated and must be reviewed by a qualified health professional before any clinical decision is made.""" try: response = client.chat.completions.create( model="llama-3.3-70b-versatile", messages=[{"role": "user", "content": prompt}], temperature=0.2 ) return response.choices[0].message.content.strip() except Exception as e: return f"Report generation error: {str(e)}" # ========================================= # FULL PIPELINE # ========================================= def run_full_pipeline(text: str, lang: str) -> dict: timings = {} t0 = time.time() # Translation t = time.time() english_text = translate_if_needed(text, lang) timings["translation_ms"] = round((time.time() - t) * 1000) # Classification t = time.time() category = classify_query(english_text) timings["classification_ms"] = round((time.time() - t) * 1000) lang_label = "Urdu" if lang == "ur" else ("English" if lang == "en" else "Unknown") if category == "crisis": return { "status": "crisis", "english_text": english_text, "lang": lang, "lang_label": lang_label, "timings": timings, "total_ms": round((time.time() - t0) * 1000) } if category == "non_medical": return { "status": "non_medical", "english_text": english_text, "lang": lang, "lang_label": lang_label, "timings": timings, "total_ms": round((time.time() - t0) * 1000) } # Symptom extraction t = time.time() analysis = extract_symptoms(english_text) timings["extraction_ms"] = round((time.time() - t) * 1000) # Risk scoring t = time.time() risk_score = calculate_risk_score(analysis) risk_level = get_risk_level(risk_score) timings["risk_scoring_ms"] = round((time.time() - t) * 1000) # Disease lookup t = time.time() symptoms_list = analysis.get("symptoms", []) dataset_matches = lookup_diseases(symptoms_list, top_n=5) timings["disease_lookup_ms"] = round((time.time() - t) * 1000) if dataset_matches: analysis["possible_conditions"] = [d["disease"] for d in dataset_matches] analysis["dataset_match_detail"] = [ f"{d['disease']} ({d['matched']}/{d['total']} symptoms matched)" for d in dataset_matches ] else: analysis["dataset_match_detail"] = ["Dataset lookup returned no matches"] # RAG retrieval t = time.time() rag_query = english_text + " " + " ".join(s.get("name", "") for s in symptoms_list) retrieved = retrieve_knowledge(rag_query, n=3) timings["rag_retrieval_ms"] = round((time.time() - t) * 1000) retrieved_context = "\n\n".join( f"[{r['title']}]\n{r['text']}" for r in retrieved ) if retrieved else "" # SOAP report t = time.time() soap_report = generate_soap_report(analysis, english_text, risk_score, retrieved_context) timings["soap_generation_ms"] = round((time.time() - t) * 1000) timings["total_ms"] = round((time.time() - t0) * 1000) return { "status": "medical", "lang": lang, "lang_label": lang_label, "english_text": english_text, "analysis": analysis, "risk_score": risk_score, "risk_level": risk_level, "dataset_matches": dataset_matches, "retrieved_chunks": retrieved, "soap_report": soap_report, "models_used": { "transcription": "whisper-large-v3-turbo", "classification": "llama-3.1-8b-instant", "extraction": "llama-3.3-70b-versatile", "soap": "llama-3.3-70b-versatile", "embeddings": "all-MiniLM-L6-v2" }, "timings": timings } # ========================================= # DIRECT PIPELINE ENTRYPOINTS # (replaces HTTP API calls) # ========================================= def call_text_pipeline(text: str) -> dict: if not client: return {"status": "error", "message": "GROQ_API_KEY not configured"} if not text or not text.strip(): return {"status": "error", "message": "Text cannot be empty"} lang = detect_language(text) if lang == "other": return {"status": "error", "message": "Unsupported language. Only Urdu and English are supported."} try: return run_full_pipeline(text, lang) except Exception as e: return {"status": "error", "message": str(e)} def call_audio_pipeline(audio_bytes: bytes) -> dict: if not client: return {"status": "error", "message": "GROQ_API_KEY not configured"} try: with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: tmp.write(audio_bytes) tmp_path = tmp.name t0 = time.time() text, lang = transcribe_audio_file(tmp_path) transcription_ms = round((time.time() - t0) * 1000) os.unlink(tmp_path) if lang == "unknown": return {"status": "error", "message": text} result = run_full_pipeline(text, lang) result["original_transcript"] = text result["timings"]["transcription_ms"] = transcription_ms return result except Exception as e: return {"status": "error", "message": str(e)} def get_system_health() -> dict: return { "status": "healthy", "groq_configured": bool(api_key), "rag_ready": _rag_model is not None, "disease_db_ready": disease_df is not None, "disease_count": len(disease_df) if disease_df is not None else 0, "rag_doc_count": len(_rag_docs) } # ========================================= # STREAMLIT PAGE CONFIG # ========================================= st.set_page_config( page_title="TabeebAI — Clinical Triage", page_icon="assets/favicon.ico" if os.path.exists("assets/favicon.ico") else "🏥", layout="wide", initial_sidebar_state="expanded" ) _logo_b64 = "" _logo_path = os.path.join(os.path.dirname(__file__), "assets", "logo.png") if os.path.exists(_logo_path): with open(_logo_path, "rb") as _f: _logo_b64 = base64.b64encode(_f.read()).decode() _logo_img = f'' if _logo_b64 else "+" _agahi_b64 = "" _agahi_path = os.path.join(os.path.dirname(__file__), "assets", "agahi_logo.png") if os.path.exists(_agahi_path): with open(_agahi_path, "rb") as _f: _agahi_b64 = base64.b64encode(_f.read()).decode() _agahi_img = f'' if _agahi_b64 else "Aagahi Labs" # ========================================= # GLOBAL CSS # ========================================= st.markdown(""" """, unsafe_allow_html=True) # Ctrl+Enter in the symptom textarea → click Analyze st.markdown(""" """, unsafe_allow_html=True) # ========================================= # SESSION STATE # ========================================= if "result" not in st.session_state: st.session_state.result = None if "processing" not in st.session_state: st.session_state.processing = False if "input_mode" not in st.session_state: st.session_state.input_mode = "text" # HITL review state if "soap_source" not in st.session_state: st.session_state.soap_source = None # fingerprint to detect new reports if "soap_edited" not in st.session_state: st.session_state.soap_edited = None # doctor-edited SOAP text if "soap_confirmed" not in st.session_state: st.session_state.soap_confirmed = False if "soap_reviewer_name" not in st.session_state: st.session_state.soap_reviewer_name = "" if "soap_reviewer_role" not in st.session_state: st.session_state.soap_reviewer_role = "" if "soap_confirmed_at" not in st.session_state: st.session_state.soap_confirmed_at = None # Accessibility FONT_SCALES = [0.85, 1.0, 1.15, 1.3, 1.5] FONT_LABELS = ["Small", "Normal", "Large", "X-Large", "XX-Large"] if "font_idx" not in st.session_state: st.session_state.font_idx = 2 # default: Large # ========================================= # COMPONENT HELPERS # ========================================= def render_risk_bar(score: int, risk_level: str): if risk_level == "RED": color, label = "#c0392b", "Emergency" elif risk_level == "YELLOW": color, label = "#c97f0a", "Moderate" else: color, label = "#0f7a5a", "Low Risk" pct = score st.markdown(f"""
Risk Level {label}
{score}/100
""", unsafe_allow_html=True) def render_alert(status: str, result: dict): score = result.get("risk_score", 0) level = result.get("risk_level", "GREEN") analysis = result.get("analysis", {}) complaint = analysis.get("chief_complaint", "N/A") if analysis else "N/A" conditions = ", ".join(analysis.get("possible_conditions", [])) if analysis else "N/A" if status == "crisis": st.markdown("""
You Are Not Alone
It sounds like you or someone around you may be in emotional distress. TabeebAI cannot provide mental health support, but trained counsellors are available right now.
یہ ایپ ذہنی صحت کی مدد کے لیے نہیں ہے۔ براہ کرم اوپر دیے گئے نمبروں پر کال کریں۔
""", unsafe_allow_html=True) return if status == "non_medical": st.markdown("""
Non-Medical Input Detected
TabeebAI is designed to assist with patient symptom triage only. Please describe the patient's medical symptoms or health concerns.
براہ کرم مریض کی علامات یا صحت کے مسائل بیان کریں۔
""", unsafe_allow_html=True) return if level == "RED": st.markdown(f"""
EMERGENCY ALERT — Immediate Action Required
Chief Complaint: {complaint}
Risk Score: {score}/100    Possible Conditions: {conditions}
""", unsafe_allow_html=True) elif level == "YELLOW": st.markdown(f"""
CAUTION — Medical Attention Recommended
Chief Complaint: {complaint}
Risk Score: {score}/100    Possible Conditions: {conditions}
""", unsafe_allow_html=True) else: st.markdown(f"""
LOW RISK — Home Care Advised
Chief Complaint: {complaint}
Risk Score: {score}/100    Possible Conditions: {conditions}
""", unsafe_allow_html=True) def render_symptoms_chips(symptoms: list): _chip_styles = { "severe": {"color": "#c0392b", "bg": "rgba(192,57,43,0.08)", "border": "rgba(192,57,43,0.30)"}, "moderate": {"color": "#c97f0a", "bg": "rgba(201,127,10,0.10)", "border": "rgba(201,127,10,0.30)"}, } _chip_default = {"color": "#0d7494", "bg": "rgba(13,116,148,0.10)", "border": "rgba(13,116,148,0.25)"} chips_html = "" for s in symptoms: name = s.get("name", "") sev = s.get("severity", "mild").lower() dur = s.get("duration", "") cs = _chip_styles.get(sev, _chip_default) chips_html += f"""{name}|{sev}""" st.markdown(f'
{chips_html}
', unsafe_allow_html=True) def render_soap_report(soap: str): sections = {"SUBJECTIVE:": [], "OBJECTIVE:": [], "ASSESSMENT:": [], "PLAN:": []} current = None for line in soap.splitlines(): stripped = line.strip() if stripped.upper().startswith("SUBJECTIVE"): current = "SUBJECTIVE:" elif stripped.upper().startswith("OBJECTIVE"): current = "OBJECTIVE:" elif stripped.upper().startswith("ASSESSMENT"): current = "ASSESSMENT:" elif stripped.upper().startswith("PLAN"): current = "PLAN:" elif current: sections[current].append(line) icons = {"SUBJECTIVE:": "S", "OBJECTIVE:": "O", "ASSESSMENT:": "A", "PLAN:": "P"} colors = {"SUBJECTIVE:": "#0d7494", "OBJECTIVE:": "#0f7a5a", "ASSESSMENT:": "#c97f0a", "PLAN:": "#7c3aed"} color_bg = {"SUBJECTIVE:": "rgba(13,116,148,0.10)", "OBJECTIVE:": "rgba(15,122,90,0.08)", "ASSESSMENT:": "rgba(201,127,10,0.10)", "PLAN:": "rgba(124,58,237,0.08)"} color_border = {"SUBJECTIVE:": "rgba(13,116,148,0.25)", "OBJECTIVE:": "rgba(15,122,90,0.20)", "ASSESSMENT:": "rgba(201,127,10,0.25)", "PLAN:": "rgba(124,58,237,0.20)"} for key, lines in sections.items(): content = "\n".join(l for l in lines if l.strip()) if not content: continue c = colors[key] bg = color_bg[key] bd = color_border[key] st.markdown(f"""
{icons[key]}
{key.rstrip(':')}
{content}
""", unsafe_allow_html=True) # ========================================= # SIDEBAR # ========================================= with st.sidebar: st.markdown(f"""
{_agahi_img}
Aagahi Labs
Presents
""", unsafe_allow_html=True) health = get_system_health() if health: st.markdown(f"""
System Online
{health.get('disease_count', 0)} diseases · {health.get('rag_doc_count', 0)} RAG docs
""", unsafe_allow_html=True) st.markdown('', unsafe_allow_html=True) dashboard = st.radio( "Select View", ["Patient Dashboard", "Doctor Dashboard", "Developer Dashboard"], label_visibility="collapsed" ) st.markdown('', unsafe_allow_html=True) st.markdown('', unsafe_allow_html=True) input_mode = st.radio( "Input", ["Text Input", "Audio Input"], label_visibility="collapsed", key="input_mode_radio" ) st.markdown('', unsafe_allow_html=True) with st.expander("Accessibility", expanded=False): _fi = st.session_state.font_idx _fa_col, _fl_col, _fp_col = st.columns([1, 2, 1]) with _fa_col: if st.button("A−", key="font_dec", disabled=(_fi == 0)): st.session_state.font_idx -= 1 st.rerun() with _fl_col: st.markdown( f'
{FONT_LABELS[_fi]}
', unsafe_allow_html=True, ) with _fp_col: if st.button("A+", key="font_inc", disabled=(_fi == len(FONT_SCALES) - 1)): st.session_state.font_idx += 1 st.rerun() st.markdown('', unsafe_allow_html=True) st.markdown("""
Disclaimer
TabeebAI is a clinical decision support tool. Not a substitute for professional medical judgement. All outputs must be reviewed by a licensed clinician.
""", unsafe_allow_html=True) if st.session_state.result: st.markdown('', unsafe_allow_html=True) if st.button("Clear Results", use_container_width=True): st.session_state.result = None st.rerun() # ========================================= # MAIN CONTENT AREA # ========================================= # Dynamic font scale — injected fresh every render so sidebar controls take effect immediately _font_px = round(16 * FONT_SCALES[st.session_state.font_idx], 2) st.markdown( f"", unsafe_allow_html=True, ) # Header st.markdown(f"""
TabeebAI
AI-Powered Clinical Triage — Urdu / English
System Active
""", unsafe_allow_html=True) # Disclaimer st.markdown("""
Important: TabeebAI is a clinical decision support tool intended to assist qualified healthcare professionals. It is not a diagnostic tool and must not replace professional medical judgement. All AI-generated reports must be reviewed by a licensed medical professional before any action is taken.
یہ ایپ صرف ڈاکٹروں کی مدد کے لیے ہے۔ کوئی بھی طبی فیصلہ کرنے سے پہلے ڈاکٹر سے مشورہ کریں۔
""", unsafe_allow_html=True) # ========================================= # INPUT SECTION # ========================================= def render_input_section(): st.markdown('
Patient Input
', unsafe_allow_html=True) if input_mode == "Text Input": col_in, col_btn = st.columns([5, 1]) with col_in: user_text = st.text_area( "Enter symptoms in Urdu or English", placeholder=( "Example: I have severe chest pain and difficulty breathing since morning...\n" "مثال: مجھے سینے میں درد ہے اور سانس لینے میں دشواری ہو رہی ہے..." ), height=120, label_visibility="collapsed", key="text_input" ) with col_btn: st.markdown("""
Ctrl+Enter
""", unsafe_allow_html=True) analyze_clicked = st.button("Analyze", use_container_width=True, type="primary") if analyze_clicked: if not user_text or not user_text.strip(): st.warning("Please enter some text before analyzing.") return with st.spinner("Running clinical pipeline..."): result = call_text_pipeline(user_text) result["original_input"] = user_text st.session_state.result = result st.rerun() # ========================= # AUDIO INPUT MODE # ========================= else: st.markdown("""
Microphone not working? Browsers block mic access inside embedded frames. Open the app in a new tab to enable live recording, or use Upload File below.
""", unsafe_allow_html=True) col_rec, col_up = st.columns(2) audio_bytes = None source_label = None # ── COLUMN 1: Live microphone recording ── with col_rec: st.markdown( '
Live Recording
', unsafe_allow_html=True ) audio_value = st.audio_input( "Record patient voice", label_visibility="collapsed", key="live_recorder" ) if audio_value is not None: audio_bytes = audio_value.read() source_label = "live recording" # ── COLUMN 2: File upload fallback ── with col_up: st.markdown( '
Upload File
', unsafe_allow_html=True ) uploaded = st.file_uploader( "Upload audio file", type=["wav", "mp3", "m4a", "ogg", "flac"], label_visibility="collapsed", key="audio_upload" ) if uploaded is not None and audio_bytes is None: audio_bytes = uploaded.read() source_label = f"uploaded — {uploaded.name}" # ── Playback + Analyze button ── if audio_bytes: st.markdown("
", unsafe_allow_html=True) st.audio(audio_bytes, format="audio/wav") st.markdown( f'
' f'Source: {source_label}
', unsafe_allow_html=True ) else: st.markdown("""
Record using the microphone above, or upload an audio file
""", unsafe_allow_html=True) # ── Submit button always visible ── st.markdown("
", unsafe_allow_html=True) if st.button( "Transcribe & Analyze", use_container_width=True, type="primary", key="audio_analyze_btn", disabled=(audio_bytes is None) ): if audio_bytes is None: st.warning("Please record or upload audio first.") else: with st.spinner("Transcribing audio and running clinical pipeline..."): result = call_audio_pipeline(audio_bytes) st.session_state.result = result st.rerun() render_input_section() # ========================================= # NO RESULT STATE # ========================================= if st.session_state.result is None: st.markdown("""
+
No Analysis Yet
Enter patient symptoms above to begin clinical triage
""", unsafe_allow_html=True) st.stop() # ========================================= # RESULTS # ========================================= result = st.session_state.result status = result.get("status", "error") analysis = result.get("analysis", {}) or {} risk_score = result.get("risk_score", 0) risk_level = result.get("risk_level", "GREEN") soap = result.get("soap_report", "") retrieved = result.get("retrieved_chunks", []) timings = result.get("timings", {}) lang_label = result.get("lang_label", "Unknown") english_text = result.get("english_text", "") original_input = result.get("original_input", result.get("original_transcript", "")) dataset_matches = result.get("dataset_matches", []) st.markdown("
", unsafe_allow_html=True) # ── Alert banner always shown ────────────────────────────── render_alert(status, result) if status in ("crisis", "non_medical", "error"): if status == "error": st.error(f"Pipeline Error: {result.get('message', 'Unknown error')}") st.stop() # ========================================= # DASHBOARD TABS # ========================================= if dashboard == "Patient Dashboard": st.markdown('
Patient Overview
', unsafe_allow_html=True) # Key metrics row m1, m2, m3, m4 = st.columns(4) with m1: st.metric("Risk Score", f"{risk_score}/100") with m2: lvl_map = {"RED": "Emergency", "YELLOW": "Moderate", "GREEN": "Low Risk"} st.metric("Urgency", lvl_map.get(risk_level, risk_level)) with m3: st.metric("Language", lang_label) with m4: n_sym = len(analysis.get("symptoms", [])) st.metric("Symptoms Found", str(n_sym)) st.markdown("
", unsafe_allow_html=True) # Urgency meter render_risk_bar(risk_score, risk_level) st.markdown("
", unsafe_allow_html=True) # Two column layout col_left, col_right = st.columns([1, 1], gap="large") with col_left: st.markdown('
Reported Symptoms
', unsafe_allow_html=True) symptoms = analysis.get("symptoms", []) if symptoms: render_symptoms_chips(symptoms) else: st.markdown('
No specific symptoms extracted.
', unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) st.markdown('
Chief Complaint
', unsafe_allow_html=True) complaint = analysis.get("chief_complaint", "Not specified") st.markdown(f"""
{complaint}
""", unsafe_allow_html=True) with col_right: st.markdown('
What You Should Do Next
', unsafe_allow_html=True) if risk_level == "RED": steps = [ ("Call 1122 / 115 immediately", "red"), ("Do not leave the patient alone", "red"), ("Keep patient calm and still", "yellow"), ("Prepare for hospital transfer", "yellow"), ] elif risk_level == "YELLOW": steps = [ ("See a doctor within 24 hours", "yellow"), ("Monitor symptoms for worsening", "yellow"), ("Seek urgent care if new symptoms develop", "cyan"), ("Keep track of symptom duration", "cyan"), ] else: steps = [ ("Rest and home care is appropriate", "green"), ("Stay hydrated and monitor temperature", "green"), ("Visit a clinic if symptoms persist > 3 days", "cyan"), ("Avoid strenuous activity until recovered", "cyan"), ] for step, color in steps: dot_class = f"dot-{color}" st.markdown(f"""
{step}
""", unsafe_allow_html=True) if original_input and lang_label == "Urdu": st.markdown("
", unsafe_allow_html=True) st.markdown('
Translation
', unsafe_allow_html=True) col_ur, col_en = st.columns(2) with col_ur: st.markdown(f"""
Original (Urdu)
{original_input}
""", unsafe_allow_html=True) with col_en: st.markdown(f"""
English
{english_text}
""", unsafe_allow_html=True) elif dashboard == "Doctor Dashboard": st.markdown('
Clinical Overview
', unsafe_allow_html=True) # Clinical metrics cm1, cm2, cm3, cm4 = st.columns(4) with cm1: st.metric("Risk Score", f"{risk_score}/100") with cm2: urgency = analysis.get("urgency", "N/A").upper() st.metric("Triage Urgency", urgency) with cm3: st.metric("Input Language", lang_label) with cm4: st.metric("Symptoms Extracted", len(analysis.get("symptoms", []))) st.markdown("
", unsafe_allow_html=True) render_risk_bar(risk_score, risk_level) st.markdown("
", unsafe_allow_html=True) # Main clinical area tab_soap, tab_symptoms, tab_diseases, tab_rag = st.tabs([ "SOAP Report", "Symptom Analysis", "Disease Matching", "Retrieved Knowledge" ]) with tab_soap: if soap: # Reset HITL state whenever a brand-new report arrives if st.session_state.soap_source != soap: st.session_state.soap_source = soap st.session_state.soap_edited = soap st.session_state.soap_confirmed = False st.session_state.soap_reviewer_name = "" st.session_state.soap_reviewer_role = "" st.session_state.soap_confirmed_at = None # ── Confirmation banner ──────────────────────────────────── if st.session_state.soap_confirmed: st.markdown(f"""
Report Confirmed
Reviewed by {st.session_state.soap_reviewer_name}  ·  {st.session_state.soap_reviewer_role}  ·  {st.session_state.soap_confirmed_at}
""", unsafe_allow_html=True) # ── Attractive rendered SOAP view (always shown) ────────── section_label = ( "SOAP Report — confirmed" if st.session_state.soap_confirmed else "SOAP Report" ) st.markdown(f'
{section_label}
', unsafe_allow_html=True) render_soap_report(st.session_state.soap_edited or soap) # ── Editable textarea in expander (only when not confirmed) ── if not st.session_state.soap_confirmed: with st.expander("✏️ Edit Report", expanded=False): edited_soap = st.text_area( "SOAP Report", value=st.session_state.soap_edited, height=300, label_visibility="collapsed", key="soap_edit_area", ) st.session_state.soap_edited = edited_soap # ── Doctor Review row ────────────────────────────────────── st.markdown("
", unsafe_allow_html=True) st.markdown("""
Doctor Review
""", unsafe_allow_html=True) rev_c1, rev_c2, rev_c3 = st.columns([2, 2, 1]) if st.session_state.soap_confirmed: # Show confirmed values as readable text, not disabled inputs with rev_c1: st.markdown(f"""
Doctor Name
{st.session_state.soap_reviewer_name}
""", unsafe_allow_html=True) with rev_c2: st.markdown(f"""
Role / Specialisation
{st.session_state.soap_reviewer_role}
""", unsafe_allow_html=True) reviewer_name = st.session_state.soap_reviewer_name reviewer_role = st.session_state.soap_reviewer_role else: with rev_c1: reviewer_name = st.text_input( "Doctor Name", placeholder="Dr. Ahmed Khan", key="reviewer_name_input", ) with rev_c2: reviewer_role = st.text_input( "Role / Specialisation", placeholder="e.g. General Practitioner", key="reviewer_role_input", ) with rev_c3: st.markdown("
", unsafe_allow_html=True) if not st.session_state.soap_confirmed: if st.button("✔ Confirm Report", type="primary", use_container_width=True, key="confirm_btn"): if not reviewer_name.strip(): st.warning("Please enter the reviewing doctor's name.") else: st.session_state.soap_confirmed = True st.session_state.soap_reviewer_name = reviewer_name.strip() st.session_state.soap_reviewer_role = ( reviewer_role.strip() or "Physician" ) st.session_state.soap_confirmed_at = time.strftime( "%Y-%m-%d %H:%M:%S" ) st.rerun() else: if st.button("↩ Reopen for Edit", use_container_width=True, key="reopen_btn"): st.session_state.soap_confirmed = False st.rerun() # ── Download buttons (use doctor-edited text) ────────────── final_soap = st.session_state.soap_edited or soap reviewer_line = ( f"\n\nReviewed by: {st.session_state.soap_reviewer_name}" f" ({st.session_state.soap_reviewer_role})" f" — {st.session_state.soap_confirmed_at}" if st.session_state.soap_confirmed else "\n\n[Pending doctor review]" ) st.markdown("
", unsafe_allow_html=True) col_dl1, col_dl2, col_space = st.columns([1, 1, 3]) with col_dl1: st.download_button( "Download Report (.txt)", data=final_soap + reviewer_line, file_name="tabeebai_soap_report.txt", mime="text/plain", use_container_width=True, ) with col_dl2: md_report = f"""# TabeebAI Clinical Report **Date:** {time.strftime('%Y-%m-%d %H:%M')} **Risk Score:** {risk_score}/100 **Risk Level:** {risk_level} **Language:** {lang_label} --- ## Chief Complaint {analysis.get("chief_complaint", "N/A")} --- {final_soap} --- {reviewer_line} *AI-generated report — must be reviewed by a licensed medical professional.* """ st.download_button( "Download Report (.md)", data=md_report, file_name="tabeebai_report.md", mime="text/markdown", use_container_width=True, ) else: st.info("No SOAP report generated.") with tab_symptoms: symptoms = analysis.get("symptoms", []) if symptoms: st.markdown('
Extracted Symptoms
', unsafe_allow_html=True) render_symptoms_chips(symptoms) st.markdown("
", unsafe_allow_html=True) for s in symptoms: sev = s.get("severity", "mild").lower() _sev_map = { "severe": {"color": "#c0392b", "bg": "rgba(192,57,43,0.08)", "border": "rgba(192,57,43,0.30)"}, "moderate": {"color": "#c97f0a", "bg": "rgba(201,127,10,0.10)", "border": "rgba(201,127,10,0.30)"}, "mild": {"color": "#0f7a5a", "bg": "rgba(15,122,90,0.08)", "border": "rgba(15,122,90,0.25)"}, } sc = _sev_map.get(sev, {"color": "#0d7494", "bg": "rgba(13,116,148,0.10)", "border": "rgba(13,116,148,0.25)"}) st.markdown(f"""
{s.get("name", "—")}
Duration: {s.get("duration", "N/A")} {sev}
""", unsafe_allow_html=True) else: st.info("No symptoms extracted.") st.markdown("
", unsafe_allow_html=True) st.markdown('
Transcript
', unsafe_allow_html=True) t1, t2 = st.columns(2) with t1: st.markdown(f"""
Original Input
{original_input or "—"}
""", unsafe_allow_html=True) with t2: st.markdown(f"""
English Translation
{english_text or "—"}
""", unsafe_allow_html=True) with tab_diseases: st.markdown('
Dataset-Matched Diseases
', unsafe_allow_html=True) if dataset_matches: max_matched = max(d.get("matched", 0) for d in dataset_matches) or 1 for d in dataset_matches: matched = d.get("matched", 0) total = d.get("total", 1) pct = int((matched / max_matched) * 100) bar_color = "#c0392b" if pct > 70 else "#c97f0a" if pct > 40 else "#0d7494" st.markdown(f"""
{d.get("disease","—")} {matched}/{total} symptoms
""", unsafe_allow_html=True) else: st.info("No disease matches found. Ensure diseases_symptoms.csv is in the data/ folder.") detail = analysis.get("dataset_match_detail", []) if detail: with st.expander("Match Detail"): for item in detail: st.markdown(f'
{item}
', unsafe_allow_html=True) with tab_rag: st.markdown('
Retrieved Medical Knowledge
', unsafe_allow_html=True) if retrieved: for chunk in retrieved: score_pct = int(chunk.get("score", 0) * 100) score_color = "#0f7a5a" if score_pct > 70 else "#c97f0a" if score_pct > 50 else "#6b7a91" with st.expander(f"{chunk.get('title', 'Document')} — Relevance: {score_pct}%"): st.markdown(f"""
{chunk.get("text", "")}
Cosine Similarity {chunk.get('score',0):.4f}
""", unsafe_allow_html=True) else: st.info("No RAG chunks retrieved. Ensure medical_knowledge.json is in the data/ folder.") elif dashboard == "Developer Dashboard": st.markdown('
Pipeline Observability
', unsafe_allow_html=True) dev_tab1, dev_tab2, dev_tab3, dev_tab4, dev_tab5 = st.tabs([ "Pipeline Status", "Timings", "Raw JSON", "RAG Debug", "Models" ]) with dev_tab1: st.markdown('
Execution Pipeline
', unsafe_allow_html=True) pipeline_steps = [ ("Language Detection", lang_label, "green"), ("Query Classification", "medical", "green"), ("Translation", "Urdu → English" if lang_label == "Urdu" else "Skipped (English)", "cyan" if lang_label == "Urdu" else "yellow"), ("Symptom Extraction", f"{len(analysis.get('symptoms',[]))} symptoms extracted", "green"), ("Risk Scoring", f"{risk_score}/100 — {risk_level}", "green" if risk_level == "GREEN" else ("yellow" if risk_level == "YELLOW" else "red")), ("Disease Lookup", f"{len(dataset_matches)} diseases matched", "green" if dataset_matches else "yellow"), ("RAG Retrieval", f"{len(retrieved)} chunks retrieved", "green" if retrieved else "yellow"), ("SOAP Generation", "Complete", "green"), ] for step_name, step_detail, step_color in pipeline_steps: st.markdown(f"""
{step_name} — {step_detail}
""", unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) st.markdown('
Input Summary
', unsafe_allow_html=True) col_a, col_b = st.columns(2) with col_a: st.markdown(f"""
Original Input
{original_input or "—"}
""", unsafe_allow_html=True) with col_b: st.markdown(f"""
English (post-translation)
{english_text or "—"}
""", unsafe_allow_html=True) with dev_tab2: st.markdown('
Pipeline Timing Breakdown
', unsafe_allow_html=True) timing_labels = { "transcription_ms": "Transcription", "translation_ms": "Translation", "classification_ms": "Classification", "extraction_ms": "Extraction", "risk_scoring_ms": "Risk Scoring", "disease_lookup_ms": "Disease Lookup", "rag_retrieval_ms": "RAG Retrieval", "soap_generation_ms": "SOAP Report", "total_ms": "TOTAL" } if timings: cols = st.columns(4) i = 0 for key, label in timing_labels.items(): if key in timings: val = timings[key] with cols[i % 4]: border_color = "var(--cyan)" if key == "total_ms" else "var(--border-subtle)" st.markdown(f"""
{val}ms {label}
""", unsafe_allow_html=True) i += 1 total = timings.get("total_ms", 0) st.markdown("
", unsafe_allow_html=True) if total > 0: st.markdown('
Time Distribution
', unsafe_allow_html=True) render_keys = [k for k in timing_labels if k != "total_ms" and k in timings] for key in render_keys: label = timing_labels[key] val = timings[key] pct = min(int((val / total) * 100), 100) st.markdown(f"""
{label} {val}ms ({pct}%)
""", unsafe_allow_html=True) else: st.info("No timing data available.") with dev_tab3: st.markdown('
Raw Analysis JSON
', unsafe_allow_html=True) st.code(json.dumps(analysis, indent=2, ensure_ascii=False), language="json") st.markdown('
Full Pipeline Response
', unsafe_allow_html=True) display_result = {k: v for k, v in result.items() if k != "soap_report"} st.code(json.dumps(display_result, indent=2, ensure_ascii=False), language="json") st.download_button( "Download Full JSON", data=json.dumps(result, indent=2, ensure_ascii=False), file_name="tabeebai_debug.json", mime="application/json" ) with dev_tab4: st.markdown('
RAG Retrieval Debug
', unsafe_allow_html=True) if retrieved: for i, chunk in enumerate(retrieved): st.markdown(f"""
Chunk {i+1}: {chunk.get("title", "Untitled")} score: {chunk.get("score", 0):.4f}
{chunk.get("text", "")}
""", unsafe_allow_html=True) else: st.info("No RAG chunks retrieved.") st.markdown("
", unsafe_allow_html=True) st.markdown('
Disease Lookup Debug
', unsafe_allow_html=True) if dataset_matches: st.code(json.dumps(dataset_matches, indent=2), language="json") else: st.info("No disease matches.") detail = analysis.get("dataset_match_detail", []) if detail: st.markdown('
Match Detail Strings
', unsafe_allow_html=True) for item in detail: st.markdown(f'
{item}
', unsafe_allow_html=True) with dev_tab5: st.markdown('
Models Used
', unsafe_allow_html=True) models = result.get("models_used", {}) if models: for role, model_name in models.items(): st.markdown(f"""
{role.replace("_"," ")} {model_name}
""", unsafe_allow_html=True) else: st.info("Model info not available.") st.markdown("
", unsafe_allow_html=True) st.markdown('
System Health
', unsafe_allow_html=True) st.code(json.dumps(get_system_health(), indent=2), language="json")