# -*- coding: utf-8 -*-
"""app
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1_pOuPpNNnrEQCdGrxl0JnaQJXAIb4F-j
"""
# =========================================
# TabeebAI — Unified Streamlit App
# AI Medical Triage (Hugging Face Spaces)
# =========================================
import os
import json
import time
import base64
import tempfile
import traceback
import numpy as np
import pandas as pd
import streamlit as st
from io import BytesIO
from groq import Groq
from sentence_transformers import SentenceTransformer
# =========================================
# GROQ CLIENT
# =========================================
api_key = os.environ.get("GROQ_API_KEY")
client = Groq(api_key=api_key) if api_key else None
# =========================================
# DISEASE DATASET
# =========================================
_DATA_PATH = os.path.join(os.path.dirname(__file__), "data", "diseases_symptoms.csv")
disease_df = None
_SYMPTOM_COLS = []
try:
disease_df = pd.read_csv(_DATA_PATH)
disease_df.columns = (
disease_df.columns.str.lower().str.strip().str.replace(" ", "_")
)
_SYMPTOM_COLS = [c for c in disease_df.columns if c != "disease"]
print(f"[TabeebAI] Disease dataset: {len(disease_df)} diseases, {len(_SYMPTOM_COLS)} symptoms")
except Exception as e:
print(f"[TabeebAI] Disease dataset not found: {e}")
def _match_symptom_to_column(symptom_name: str):
query_words = set(
symptom_name.lower().replace("_", " ").replace("-", " ").split()
)
stop_words = {"severe", "mild", "moderate", "acute", "chronic",
"sudden", "persistent", "intense", "sharp", "dull"}
query_words -= stop_words
best_col, best_score = None, 0
for col in _SYMPTOM_COLS:
col_words = set(col.replace("_", " ").split())
overlap = len(query_words & col_words)
if overlap > best_score:
best_score, best_col = overlap, col
return best_col if best_score > 0 else None
def lookup_diseases(symptoms_list: list, top_n: int = 5) -> list:
if disease_df is None or not symptoms_list:
return []
matched_cols = []
for s in symptoms_list:
col = _match_symptom_to_column(s.get("name", ""))
if col and col not in matched_cols:
matched_cols.append(col)
if not matched_cols:
return []
scores = disease_df[matched_cols].sum(axis=1)
total_cols = len(matched_cols)
results = (
disease_df[["disease"]]
.assign(matched=scores, total=total_cols)
.query("matched > 0")
.sort_values("matched", ascending=False)
.head(top_n)
)
return results.to_dict(orient="records")
# =========================================
# RAG SETUP
# =========================================
_rag_model = None
_rag_embeddings = None
_rag_docs = []
def setup_rag():
global _rag_model, _rag_embeddings, _rag_docs
knowledge_path = os.path.join(os.path.dirname(__file__), "data", "medical_knowledge.json")
try:
with open(knowledge_path, "r", encoding="utf-8") as f:
_rag_docs = json.load(f)
print("[TabeebAI] Loading embedding model...")
_rag_model = SentenceTransformer("all-MiniLM-L6-v2")
texts = [f"{d['title']}. {d['text']}" for d in _rag_docs]
_rag_embeddings = _rag_model.encode(
texts, normalize_embeddings=True, show_progress_bar=False
)
print(f"[TabeebAI] RAG ready: {len(_rag_docs)} documents embedded")
except Exception as e:
print(f"[TabeebAI] RAG setup failed: {e}")
def retrieve_knowledge(query: str, n: int = 3) -> list:
if _rag_model is None or _rag_embeddings is None or not _rag_docs:
return []
query_emb = _rag_model.encode([query], normalize_embeddings=True, show_progress_bar=False)
scores = np.dot(_rag_embeddings, query_emb.T).flatten()
top_idx = scores.argsort()[-n:][::-1]
return [
{"title": _rag_docs[i]["title"], "text": _rag_docs[i]["text"], "score": float(scores[i])}
for i in top_idx if scores[i] > 0.2
]
setup_rag()
# =========================================
# LANGUAGE DETECTION
# =========================================
def detect_language(text: str) -> str:
if not text or not text.strip():
return "unknown"
urdu_chars = sum(1 for c in text if '\u0600' <= c <= '\u06ff')
devanagari_chars = sum(1 for c in text if '\u0900' <= c <= '\u097f')
latin_chars = sum(1 for c in text if c.isascii() and c.isalpha())
total_alpha = sum(1 for c in text if c.isalpha())
if total_alpha == 0:
return "unknown"
if urdu_chars / total_alpha > 0.3:
return "ur"
if latin_chars / total_alpha > 0.5:
return "en"
return "other"
# =========================================
# TRANSCRIPTION
# =========================================
def _call_whisper(audio_path: str, language=None):
with open(audio_path, "rb") as f:
kwargs = dict(file=f, model="whisper-large-v3-turbo")
if language:
kwargs["language"] = language
response = client.audio.transcriptions.create(**kwargs)
text = response.text.strip()
return text, detect_language(text)
def transcribe_audio_file(audio_path: str):
text, lang = _call_whisper(audio_path)
if lang == "other":
text, lang = _call_whisper(audio_path, language="ur")
if lang == "other":
return "Unsupported language detected. Please speak in Urdu or English only.", "unknown"
return text, lang
# =========================================
# TRANSLATION
# =========================================
def translate_if_needed(text: str, lang: str) -> str:
if not text or lang == "en":
return text
try:
response = client.chat.completions.create(
model="llama-3.3-70b-versatile",
messages=[{
"role": "user",
"content": (
"Translate the following Urdu medical text into clear English. "
"Return only the translation, no explanations:\n\n" + text
)
}],
temperature=0.1
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"Translation Error: {str(e)}"
# =========================================
# CLASSIFICATION
# =========================================
def classify_query(text: str) -> str:
prompt = f"""You are a medical query classifier for a clinical triage system.
Classify the following patient statement into exactly ONE category:
MEDICAL — describes any symptom, pain, illness, injury, medication, or health concern
CRISIS — mentions self-harm, suicide, wanting to die, or harming others
NON_MEDICAL — anything unrelated to health (greetings, general questions, jokes, etc.)
Statement: "{text}"
Reply with ONLY one word: MEDICAL, CRISIS, or NON_MEDICAL"""
try:
response = client.chat.completions.create(
model="llama-3.1-8b-instant",
messages=[{"role": "user", "content": prompt}],
max_tokens=5,
temperature=0
)
result = response.choices[0].message.content.strip().upper()
if "CRISIS" in result:
return "crisis"
if "NON_MEDICAL" in result or "NON" in result:
return "non_medical"
return "medical"
except Exception:
return "medical"
# =========================================
# SYMPTOM EXTRACTION
# =========================================
def extract_symptoms(text: str) -> dict:
prompt = f"""You are a clinical AI assistant.
STRICT RULES:
- Output MUST be ONLY in English
- Return ONLY valid JSON, no extra text, no markdown
Extract structured medical symptoms from this text:
{text}
FORMAT:
{{
"chief_complaint": "",
"symptoms": [
{{
"name": "",
"severity": "mild/moderate/severe",
"duration": ""
}}
],
"possible_conditions": [],
"urgency": "low/medium/high",
"language": "English"
}}"""
output = ""
try:
response = client.chat.completions.create(
model="llama-3.3-70b-versatile",
messages=[{"role": "user", "content": prompt}],
temperature=0
)
output = response.choices[0].message.content
output = output.replace("```json", "").replace("```", "").strip()
return json.loads(output)
except Exception as e:
return {"error": "Parsing failed", "details": str(e), "raw_output": output}
# =========================================
# RISK SCORING
# =========================================
def calculate_risk_score(analysis: dict) -> int:
if not analysis or "error" in analysis:
return 0
score = 0
urgency_scores = {"low": 10, "medium": 40, "high": 75}
score += urgency_scores.get(analysis.get("urgency", "low").lower(), 10)
severity_bonus = {"mild": 3, "moderate": 8, "severe": 18}
for symptom in analysis.get("symptoms", []):
score += severity_bonus.get(symptom.get("severity", "mild").lower(), 3)
emergency_keywords = [
"chest pain", "difficulty breathing", "shortness of breath",
"unconscious", "unresponsive", "severe bleeding", "heart attack",
"stroke", "choking", "not breathing", "no pulse",
"سینے میں درد", "سانس لینے میں دشواری", "بے ہوش"
]
full_text = analysis.get("chief_complaint", "").lower()
full_text += " " + " ".join([s.get("name", "") for s in analysis.get("symptoms", [])])
for keyword in emergency_keywords:
if keyword in full_text:
score += 30
break
score += min(len(analysis.get("symptoms", [])) * 3, 15)
return min(score, 100)
def get_risk_level(score: int) -> str:
if score >= 71:
return "RED"
elif score >= 31:
return "YELLOW"
else:
return "GREEN"
# =========================================
# SOAP REPORT
# =========================================
def generate_soap_report(analysis: dict, english_text: str, risk_score: int, retrieved_context: str = "") -> str:
if not analysis or "error" in analysis:
return "Cannot generate report — symptom extraction failed."
symptoms_text = "\n".join([
f" - {s.get('name','?')} | severity: {s.get('severity','?')} | duration: {s.get('duration','?')}"
for s in analysis.get("symptoms", [])
])
conditions = ", ".join(analysis.get("possible_conditions", [])) or "None identified"
context_section = f"\nRETRIEVED MEDICAL KNOWLEDGE:\n{retrieved_context}\n" if retrieved_context else ""
prompt = f"""You are a clinical documentation assistant.
Generate a concise SOAP format medical report based on the data below.
Return ONLY the report — no extra commentary, no markdown headings with #.
PATIENT DATA:
Chief Complaint : {analysis.get('chief_complaint', 'N/A')}
Symptoms :
{symptoms_text}
Possible Conditions: {conditions}
Urgency : {analysis.get('urgency', 'N/A')}
Risk Score : {risk_score}/100
Original Statement: {english_text}
{context_section}
FORMAT TO USE:
SUBJECTIVE:
[what the patient reports]
OBJECTIVE:
[observable findings from the speech/statement]
ASSESSMENT:
[clinical interpretation, possible diagnoses]
PLAN:
[recommended next steps for the treating doctor]
NOTE: This report is AI-generated and must be reviewed by a qualified health professional before any clinical decision is made."""
try:
response = client.chat.completions.create(
model="llama-3.3-70b-versatile",
messages=[{"role": "user", "content": prompt}],
temperature=0.2
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"Report generation error: {str(e)}"
# =========================================
# FULL PIPELINE
# =========================================
def run_full_pipeline(text: str, lang: str) -> dict:
timings = {}
t0 = time.time()
# Translation
t = time.time()
english_text = translate_if_needed(text, lang)
timings["translation_ms"] = round((time.time() - t) * 1000)
# Classification
t = time.time()
category = classify_query(english_text)
timings["classification_ms"] = round((time.time() - t) * 1000)
lang_label = "Urdu" if lang == "ur" else ("English" if lang == "en" else "Unknown")
if category == "crisis":
return {
"status": "crisis",
"english_text": english_text,
"lang": lang,
"lang_label": lang_label,
"timings": timings,
"total_ms": round((time.time() - t0) * 1000)
}
if category == "non_medical":
return {
"status": "non_medical",
"english_text": english_text,
"lang": lang,
"lang_label": lang_label,
"timings": timings,
"total_ms": round((time.time() - t0) * 1000)
}
# Symptom extraction
t = time.time()
analysis = extract_symptoms(english_text)
timings["extraction_ms"] = round((time.time() - t) * 1000)
# Risk scoring
t = time.time()
risk_score = calculate_risk_score(analysis)
risk_level = get_risk_level(risk_score)
timings["risk_scoring_ms"] = round((time.time() - t) * 1000)
# Disease lookup
t = time.time()
symptoms_list = analysis.get("symptoms", [])
dataset_matches = lookup_diseases(symptoms_list, top_n=5)
timings["disease_lookup_ms"] = round((time.time() - t) * 1000)
if dataset_matches:
analysis["possible_conditions"] = [d["disease"] for d in dataset_matches]
analysis["dataset_match_detail"] = [
f"{d['disease']} ({d['matched']}/{d['total']} symptoms matched)"
for d in dataset_matches
]
else:
analysis["dataset_match_detail"] = ["Dataset lookup returned no matches"]
# RAG retrieval
t = time.time()
rag_query = english_text + " " + " ".join(s.get("name", "") for s in symptoms_list)
retrieved = retrieve_knowledge(rag_query, n=3)
timings["rag_retrieval_ms"] = round((time.time() - t) * 1000)
retrieved_context = "\n\n".join(
f"[{r['title']}]\n{r['text']}" for r in retrieved
) if retrieved else ""
# SOAP report
t = time.time()
soap_report = generate_soap_report(analysis, english_text, risk_score, retrieved_context)
timings["soap_generation_ms"] = round((time.time() - t) * 1000)
timings["total_ms"] = round((time.time() - t0) * 1000)
return {
"status": "medical",
"lang": lang,
"lang_label": lang_label,
"english_text": english_text,
"analysis": analysis,
"risk_score": risk_score,
"risk_level": risk_level,
"dataset_matches": dataset_matches,
"retrieved_chunks": retrieved,
"soap_report": soap_report,
"models_used": {
"transcription": "whisper-large-v3-turbo",
"classification": "llama-3.1-8b-instant",
"extraction": "llama-3.3-70b-versatile",
"soap": "llama-3.3-70b-versatile",
"embeddings": "all-MiniLM-L6-v2"
},
"timings": timings
}
# =========================================
# DIRECT PIPELINE ENTRYPOINTS
# (replaces HTTP API calls)
# =========================================
def call_text_pipeline(text: str) -> dict:
if not client:
return {"status": "error", "message": "GROQ_API_KEY not configured"}
if not text or not text.strip():
return {"status": "error", "message": "Text cannot be empty"}
lang = detect_language(text)
if lang == "other":
return {"status": "error", "message": "Unsupported language. Only Urdu and English are supported."}
try:
return run_full_pipeline(text, lang)
except Exception as e:
return {"status": "error", "message": str(e)}
def call_audio_pipeline(audio_bytes: bytes) -> dict:
if not client:
return {"status": "error", "message": "GROQ_API_KEY not configured"}
try:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp.write(audio_bytes)
tmp_path = tmp.name
t0 = time.time()
text, lang = transcribe_audio_file(tmp_path)
transcription_ms = round((time.time() - t0) * 1000)
os.unlink(tmp_path)
if lang == "unknown":
return {"status": "error", "message": text}
result = run_full_pipeline(text, lang)
result["original_transcript"] = text
result["timings"]["transcription_ms"] = transcription_ms
return result
except Exception as e:
return {"status": "error", "message": str(e)}
def get_system_health() -> dict:
return {
"status": "healthy",
"groq_configured": bool(api_key),
"rag_ready": _rag_model is not None,
"disease_db_ready": disease_df is not None,
"disease_count": len(disease_df) if disease_df is not None else 0,
"rag_doc_count": len(_rag_docs)
}
# =========================================
# STREAMLIT PAGE CONFIG
# =========================================
st.set_page_config(
page_title="TabeebAI — Clinical Triage",
page_icon="assets/favicon.ico" if os.path.exists("assets/favicon.ico") else "🏥",
layout="wide",
initial_sidebar_state="expanded"
)
_logo_b64 = ""
_logo_path = os.path.join(os.path.dirname(__file__), "assets", "logo.png")
if os.path.exists(_logo_path):
with open(_logo_path, "rb") as _f:
_logo_b64 = base64.b64encode(_f.read()).decode()
_logo_img = f'
' if _logo_b64 else "+"
_agahi_b64 = ""
_agahi_path = os.path.join(os.path.dirname(__file__), "assets", "agahi_logo.png")
if os.path.exists(_agahi_path):
with open(_agahi_path, "rb") as _f:
_agahi_b64 = base64.b64encode(_f.read()).decode()
_agahi_img = f'
' if _agahi_b64 else "Aagahi Labs"
# =========================================
# GLOBAL CSS
# =========================================
st.markdown("""
""", unsafe_allow_html=True)
# Ctrl+Enter in the symptom textarea → click Analyze
st.markdown("""
""", unsafe_allow_html=True)
# =========================================
# SESSION STATE
# =========================================
if "result" not in st.session_state:
st.session_state.result = None
if "processing" not in st.session_state:
st.session_state.processing = False
if "input_mode" not in st.session_state:
st.session_state.input_mode = "text"
# HITL review state
if "soap_source" not in st.session_state:
st.session_state.soap_source = None # fingerprint to detect new reports
if "soap_edited" not in st.session_state:
st.session_state.soap_edited = None # doctor-edited SOAP text
if "soap_confirmed" not in st.session_state:
st.session_state.soap_confirmed = False
if "soap_reviewer_name" not in st.session_state:
st.session_state.soap_reviewer_name = ""
if "soap_reviewer_role" not in st.session_state:
st.session_state.soap_reviewer_role = ""
if "soap_confirmed_at" not in st.session_state:
st.session_state.soap_confirmed_at = None
# Accessibility
FONT_SCALES = [0.85, 1.0, 1.15, 1.3, 1.5]
FONT_LABELS = ["Small", "Normal", "Large", "X-Large", "XX-Large"]
if "font_idx" not in st.session_state:
st.session_state.font_idx = 2 # default: Large
# =========================================
# COMPONENT HELPERS
# =========================================
def render_risk_bar(score: int, risk_level: str):
if risk_level == "RED":
color, label = "#c0392b", "Emergency"
elif risk_level == "YELLOW":
color, label = "#c97f0a", "Moderate"
else:
color, label = "#0f7a5a", "Low Risk"
pct = score
st.markdown(f"""
Risk Level
{label}
{score}/100
""", unsafe_allow_html=True)
def render_alert(status: str, result: dict):
score = result.get("risk_score", 0)
level = result.get("risk_level", "GREEN")
analysis = result.get("analysis", {})
complaint = analysis.get("chief_complaint", "N/A") if analysis else "N/A"
conditions = ", ".join(analysis.get("possible_conditions", [])) if analysis else "N/A"
if status == "crisis":
st.markdown("""
You Are Not Alone
It sounds like you or someone around you may be in emotional distress.
TabeebAI cannot provide mental health support, but trained counsellors are available right now.
- Umang — 0317-4288665 (24/7)
- Rozan Counselling — 051-2890505
- Rescue — 1122
یہ ایپ ذہنی صحت کی مدد کے لیے نہیں ہے۔ براہ کرم اوپر دیے گئے نمبروں پر کال کریں۔
""", unsafe_allow_html=True)
return
if status == "non_medical":
st.markdown("""
Non-Medical Input Detected
TabeebAI is designed to assist with patient symptom triage only.
Please describe the patient's medical symptoms or health concerns.
براہ کرم مریض کی علامات یا صحت کے مسائل بیان کریں۔
""", unsafe_allow_html=True)
return
if level == "RED":
st.markdown(f"""
EMERGENCY ALERT — Immediate Action Required
Chief Complaint: {complaint}
Risk Score: {score}/100
Possible Conditions: {conditions}
- Call emergency services immediately — Rescue 1122 / Edhi 115
- Do not leave the patient alone
- Keep patient calm and still
- Prepare for immediate hospital transfer
""", unsafe_allow_html=True)
elif level == "YELLOW":
st.markdown(f"""
CAUTION — Medical Attention Recommended
Chief Complaint: {complaint}
Risk Score: {score}/100
Possible Conditions: {conditions}
- Schedule a doctor's appointment within 24 hours
- Monitor symptoms closely for worsening
- Seek urgent care if new symptoms develop
""", unsafe_allow_html=True)
else:
st.markdown(f"""
LOW RISK — Home Care Advised
Chief Complaint: {complaint}
Risk Score: {score}/100
Possible Conditions: {conditions}
- Rest and home care
- Stay hydrated and monitor temperature
- Visit a clinic if symptoms persist beyond 3 days
""", unsafe_allow_html=True)
def render_symptoms_chips(symptoms: list):
_chip_styles = {
"severe": {"color": "#c0392b", "bg": "rgba(192,57,43,0.08)", "border": "rgba(192,57,43,0.30)"},
"moderate": {"color": "#c97f0a", "bg": "rgba(201,127,10,0.10)", "border": "rgba(201,127,10,0.30)"},
}
_chip_default = {"color": "#0d7494", "bg": "rgba(13,116,148,0.10)", "border": "rgba(13,116,148,0.25)"}
chips_html = ""
for s in symptoms:
name = s.get("name", "")
sev = s.get("severity", "mild").lower()
dur = s.get("duration", "")
cs = _chip_styles.get(sev, _chip_default)
chips_html += f"""{name}|{sev}"""
st.markdown(f'{chips_html}
', unsafe_allow_html=True)
def render_soap_report(soap: str):
sections = {"SUBJECTIVE:": [], "OBJECTIVE:": [], "ASSESSMENT:": [], "PLAN:": []}
current = None
for line in soap.splitlines():
stripped = line.strip()
if stripped.upper().startswith("SUBJECTIVE"):
current = "SUBJECTIVE:"
elif stripped.upper().startswith("OBJECTIVE"):
current = "OBJECTIVE:"
elif stripped.upper().startswith("ASSESSMENT"):
current = "ASSESSMENT:"
elif stripped.upper().startswith("PLAN"):
current = "PLAN:"
elif current:
sections[current].append(line)
icons = {"SUBJECTIVE:": "S", "OBJECTIVE:": "O", "ASSESSMENT:": "A", "PLAN:": "P"}
colors = {"SUBJECTIVE:": "#0d7494", "OBJECTIVE:": "#0f7a5a", "ASSESSMENT:": "#c97f0a", "PLAN:": "#7c3aed"}
color_bg = {"SUBJECTIVE:": "rgba(13,116,148,0.10)", "OBJECTIVE:": "rgba(15,122,90,0.08)", "ASSESSMENT:": "rgba(201,127,10,0.10)", "PLAN:": "rgba(124,58,237,0.08)"}
color_border = {"SUBJECTIVE:": "rgba(13,116,148,0.25)", "OBJECTIVE:": "rgba(15,122,90,0.20)", "ASSESSMENT:": "rgba(201,127,10,0.25)", "PLAN:": "rgba(124,58,237,0.20)"}
for key, lines in sections.items():
content = "\n".join(l for l in lines if l.strip())
if not content:
continue
c = colors[key]
bg = color_bg[key]
bd = color_border[key]
st.markdown(f"""
{icons[key]}
{key.rstrip(':')}
{content}
""", unsafe_allow_html=True)
# =========================================
# SIDEBAR
# =========================================
with st.sidebar:
st.markdown(f"""
""", unsafe_allow_html=True)
health = get_system_health()
if health:
st.markdown(f"""
{health.get('disease_count', 0)} diseases · {health.get('rag_doc_count', 0)} RAG docs
""", unsafe_allow_html=True)
st.markdown('', unsafe_allow_html=True)
dashboard = st.radio(
"Select View",
["Patient Dashboard", "Doctor Dashboard", "Developer Dashboard"],
label_visibility="collapsed"
)
st.markdown('', unsafe_allow_html=True)
st.markdown('', unsafe_allow_html=True)
input_mode = st.radio(
"Input",
["Text Input", "Audio Input"],
label_visibility="collapsed",
key="input_mode_radio"
)
st.markdown('', unsafe_allow_html=True)
with st.expander("Accessibility", expanded=False):
_fi = st.session_state.font_idx
_fa_col, _fl_col, _fp_col = st.columns([1, 2, 1])
with _fa_col:
if st.button("A−", key="font_dec", disabled=(_fi == 0)):
st.session_state.font_idx -= 1
st.rerun()
with _fl_col:
st.markdown(
f'{FONT_LABELS[_fi]}
',
unsafe_allow_html=True,
)
with _fp_col:
if st.button("A+", key="font_inc", disabled=(_fi == len(FONT_SCALES) - 1)):
st.session_state.font_idx += 1
st.rerun()
st.markdown('', unsafe_allow_html=True)
st.markdown("""
Disclaimer
TabeebAI is a clinical decision support tool. Not a substitute for professional medical judgement.
All outputs must be reviewed by a licensed clinician.
""", unsafe_allow_html=True)
if st.session_state.result:
st.markdown('', unsafe_allow_html=True)
if st.button("Clear Results", use_container_width=True):
st.session_state.result = None
st.rerun()
# =========================================
# MAIN CONTENT AREA
# =========================================
# Dynamic font scale — injected fresh every render so sidebar controls take effect immediately
_font_px = round(16 * FONT_SCALES[st.session_state.font_idx], 2)
st.markdown(
f"",
unsafe_allow_html=True,
)
# Header
st.markdown(f"""
""", unsafe_allow_html=True)
# Disclaimer
st.markdown("""
Important: TabeebAI is a
clinical decision support tool intended to assist qualified healthcare professionals.
It is
not a diagnostic tool and must not replace professional medical judgement. All AI-generated reports must be reviewed by a licensed medical professional before any action is taken.
یہ ایپ صرف ڈاکٹروں کی مدد کے لیے ہے۔ کوئی بھی طبی فیصلہ کرنے سے پہلے ڈاکٹر سے مشورہ کریں۔
""", unsafe_allow_html=True)
# =========================================
# INPUT SECTION
# =========================================
def render_input_section():
st.markdown('Patient Input
', unsafe_allow_html=True)
if input_mode == "Text Input":
col_in, col_btn = st.columns([5, 1])
with col_in:
user_text = st.text_area(
"Enter symptoms in Urdu or English",
placeholder=(
"Example: I have severe chest pain and difficulty breathing since morning...\n"
"مثال: مجھے سینے میں درد ہے اور سانس لینے میں دشواری ہو رہی ہے..."
),
height=120,
label_visibility="collapsed",
key="text_input"
)
with col_btn:
st.markdown("""
Ctrl+Enter
""", unsafe_allow_html=True)
analyze_clicked = st.button("Analyze", use_container_width=True, type="primary")
if analyze_clicked:
if not user_text or not user_text.strip():
st.warning("Please enter some text before analyzing.")
return
with st.spinner("Running clinical pipeline..."):
result = call_text_pipeline(user_text)
result["original_input"] = user_text
st.session_state.result = result
st.rerun()
# =========================
# AUDIO INPUT MODE
# =========================
else:
st.markdown("""
Microphone not working?
Browsers block mic access inside embedded frames.
Open the app in a new tab
to enable live recording, or use
Upload File below.
""", unsafe_allow_html=True)
col_rec, col_up = st.columns(2)
audio_bytes = None
source_label = None
# ── COLUMN 1: Live microphone recording ──
with col_rec:
st.markdown(
'Live Recording
',
unsafe_allow_html=True
)
audio_value = st.audio_input(
"Record patient voice",
label_visibility="collapsed",
key="live_recorder"
)
if audio_value is not None:
audio_bytes = audio_value.read()
source_label = "live recording"
# ── COLUMN 2: File upload fallback ──
with col_up:
st.markdown(
'Upload File
',
unsafe_allow_html=True
)
uploaded = st.file_uploader(
"Upload audio file",
type=["wav", "mp3", "m4a", "ogg", "flac"],
label_visibility="collapsed",
key="audio_upload"
)
if uploaded is not None and audio_bytes is None:
audio_bytes = uploaded.read()
source_label = f"uploaded — {uploaded.name}"
# ── Playback + Analyze button ──
if audio_bytes:
st.markdown("
", unsafe_allow_html=True)
st.audio(audio_bytes, format="audio/wav")
st.markdown(
f''
f'Source: {source_label}
',
unsafe_allow_html=True
)
else:
st.markdown("""
Record using the microphone above, or upload an audio file
""", unsafe_allow_html=True)
# ── Submit button always visible ──
st.markdown("
", unsafe_allow_html=True)
if st.button(
"Transcribe & Analyze",
use_container_width=True,
type="primary",
key="audio_analyze_btn",
disabled=(audio_bytes is None)
):
if audio_bytes is None:
st.warning("Please record or upload audio first.")
else:
with st.spinner("Transcribing audio and running clinical pipeline..."):
result = call_audio_pipeline(audio_bytes)
st.session_state.result = result
st.rerun()
render_input_section()
# =========================================
# NO RESULT STATE
# =========================================
if st.session_state.result is None:
st.markdown("""
+
No Analysis Yet
Enter patient symptoms above to begin clinical triage
""", unsafe_allow_html=True)
st.stop()
# =========================================
# RESULTS
# =========================================
result = st.session_state.result
status = result.get("status", "error")
analysis = result.get("analysis", {}) or {}
risk_score = result.get("risk_score", 0)
risk_level = result.get("risk_level", "GREEN")
soap = result.get("soap_report", "")
retrieved = result.get("retrieved_chunks", [])
timings = result.get("timings", {})
lang_label = result.get("lang_label", "Unknown")
english_text = result.get("english_text", "")
original_input = result.get("original_input", result.get("original_transcript", ""))
dataset_matches = result.get("dataset_matches", [])
st.markdown("
", unsafe_allow_html=True)
# ── Alert banner always shown ──────────────────────────────
render_alert(status, result)
if status in ("crisis", "non_medical", "error"):
if status == "error":
st.error(f"Pipeline Error: {result.get('message', 'Unknown error')}")
st.stop()
# =========================================
# DASHBOARD TABS
# =========================================
if dashboard == "Patient Dashboard":
st.markdown('Patient Overview
', unsafe_allow_html=True)
# Key metrics row
m1, m2, m3, m4 = st.columns(4)
with m1:
st.metric("Risk Score", f"{risk_score}/100")
with m2:
lvl_map = {"RED": "Emergency", "YELLOW": "Moderate", "GREEN": "Low Risk"}
st.metric("Urgency", lvl_map.get(risk_level, risk_level))
with m3:
st.metric("Language", lang_label)
with m4:
n_sym = len(analysis.get("symptoms", []))
st.metric("Symptoms Found", str(n_sym))
st.markdown("
", unsafe_allow_html=True)
# Urgency meter
render_risk_bar(risk_score, risk_level)
st.markdown("
", unsafe_allow_html=True)
# Two column layout
col_left, col_right = st.columns([1, 1], gap="large")
with col_left:
st.markdown('Reported Symptoms
', unsafe_allow_html=True)
symptoms = analysis.get("symptoms", [])
if symptoms:
render_symptoms_chips(symptoms)
else:
st.markdown('No specific symptoms extracted.
', unsafe_allow_html=True)
st.markdown("
", unsafe_allow_html=True)
st.markdown('Chief Complaint
', unsafe_allow_html=True)
complaint = analysis.get("chief_complaint", "Not specified")
st.markdown(f"""
{complaint}
""", unsafe_allow_html=True)
with col_right:
st.markdown('What You Should Do Next
', unsafe_allow_html=True)
if risk_level == "RED":
steps = [
("Call 1122 / 115 immediately", "red"),
("Do not leave the patient alone", "red"),
("Keep patient calm and still", "yellow"),
("Prepare for hospital transfer", "yellow"),
]
elif risk_level == "YELLOW":
steps = [
("See a doctor within 24 hours", "yellow"),
("Monitor symptoms for worsening", "yellow"),
("Seek urgent care if new symptoms develop", "cyan"),
("Keep track of symptom duration", "cyan"),
]
else:
steps = [
("Rest and home care is appropriate", "green"),
("Stay hydrated and monitor temperature", "green"),
("Visit a clinic if symptoms persist > 3 days", "cyan"),
("Avoid strenuous activity until recovered", "cyan"),
]
for step, color in steps:
dot_class = f"dot-{color}"
st.markdown(f"""
""", unsafe_allow_html=True)
if original_input and lang_label == "Urdu":
st.markdown("
", unsafe_allow_html=True)
st.markdown('Translation
', unsafe_allow_html=True)
col_ur, col_en = st.columns(2)
with col_ur:
st.markdown(f"""
Original (Urdu)
{original_input}
""", unsafe_allow_html=True)
with col_en:
st.markdown(f"""
""", unsafe_allow_html=True)
elif dashboard == "Doctor Dashboard":
st.markdown('Clinical Overview
', unsafe_allow_html=True)
# Clinical metrics
cm1, cm2, cm3, cm4 = st.columns(4)
with cm1:
st.metric("Risk Score", f"{risk_score}/100")
with cm2:
urgency = analysis.get("urgency", "N/A").upper()
st.metric("Triage Urgency", urgency)
with cm3:
st.metric("Input Language", lang_label)
with cm4:
st.metric("Symptoms Extracted", len(analysis.get("symptoms", [])))
st.markdown("
", unsafe_allow_html=True)
render_risk_bar(risk_score, risk_level)
st.markdown("
", unsafe_allow_html=True)
# Main clinical area
tab_soap, tab_symptoms, tab_diseases, tab_rag = st.tabs([
"SOAP Report", "Symptom Analysis", "Disease Matching", "Retrieved Knowledge"
])
with tab_soap:
if soap:
# Reset HITL state whenever a brand-new report arrives
if st.session_state.soap_source != soap:
st.session_state.soap_source = soap
st.session_state.soap_edited = soap
st.session_state.soap_confirmed = False
st.session_state.soap_reviewer_name = ""
st.session_state.soap_reviewer_role = ""
st.session_state.soap_confirmed_at = None
# ── Confirmation banner ────────────────────────────────────
if st.session_state.soap_confirmed:
st.markdown(f"""
✅
Report Confirmed
Reviewed by
{st.session_state.soap_reviewer_name}
·
{st.session_state.soap_reviewer_role}
·
{st.session_state.soap_confirmed_at}
""", unsafe_allow_html=True)
# ── Attractive rendered SOAP view (always shown) ──────────
section_label = (
"SOAP Report — confirmed" if st.session_state.soap_confirmed
else "SOAP Report"
)
st.markdown(f'{section_label}
',
unsafe_allow_html=True)
render_soap_report(st.session_state.soap_edited or soap)
# ── Editable textarea in expander (only when not confirmed) ──
if not st.session_state.soap_confirmed:
with st.expander("✏️ Edit Report", expanded=False):
edited_soap = st.text_area(
"SOAP Report",
value=st.session_state.soap_edited,
height=300,
label_visibility="collapsed",
key="soap_edit_area",
)
st.session_state.soap_edited = edited_soap
# ── Doctor Review row ──────────────────────────────────────
st.markdown("
", unsafe_allow_html=True)
st.markdown("""
""", unsafe_allow_html=True)
rev_c1, rev_c2, rev_c3 = st.columns([2, 2, 1])
if st.session_state.soap_confirmed:
# Show confirmed values as readable text, not disabled inputs
with rev_c1:
st.markdown(f"""
Doctor Name
{st.session_state.soap_reviewer_name}
""", unsafe_allow_html=True)
with rev_c2:
st.markdown(f"""
Role / Specialisation
{st.session_state.soap_reviewer_role}
""", unsafe_allow_html=True)
reviewer_name = st.session_state.soap_reviewer_name
reviewer_role = st.session_state.soap_reviewer_role
else:
with rev_c1:
reviewer_name = st.text_input(
"Doctor Name",
placeholder="Dr. Ahmed Khan",
key="reviewer_name_input",
)
with rev_c2:
reviewer_role = st.text_input(
"Role / Specialisation",
placeholder="e.g. General Practitioner",
key="reviewer_role_input",
)
with rev_c3:
st.markdown("
", unsafe_allow_html=True)
if not st.session_state.soap_confirmed:
if st.button("✔ Confirm Report", type="primary",
use_container_width=True, key="confirm_btn"):
if not reviewer_name.strip():
st.warning("Please enter the reviewing doctor's name.")
else:
st.session_state.soap_confirmed = True
st.session_state.soap_reviewer_name = reviewer_name.strip()
st.session_state.soap_reviewer_role = (
reviewer_role.strip() or "Physician"
)
st.session_state.soap_confirmed_at = time.strftime(
"%Y-%m-%d %H:%M:%S"
)
st.rerun()
else:
if st.button("↩ Reopen for Edit", use_container_width=True,
key="reopen_btn"):
st.session_state.soap_confirmed = False
st.rerun()
# ── Download buttons (use doctor-edited text) ──────────────
final_soap = st.session_state.soap_edited or soap
reviewer_line = (
f"\n\nReviewed by: {st.session_state.soap_reviewer_name}"
f" ({st.session_state.soap_reviewer_role})"
f" — {st.session_state.soap_confirmed_at}"
if st.session_state.soap_confirmed else
"\n\n[Pending doctor review]"
)
st.markdown("
", unsafe_allow_html=True)
col_dl1, col_dl2, col_space = st.columns([1, 1, 3])
with col_dl1:
st.download_button(
"Download Report (.txt)",
data=final_soap + reviewer_line,
file_name="tabeebai_soap_report.txt",
mime="text/plain",
use_container_width=True,
)
with col_dl2:
md_report = f"""# TabeebAI Clinical Report
**Date:** {time.strftime('%Y-%m-%d %H:%M')}
**Risk Score:** {risk_score}/100
**Risk Level:** {risk_level}
**Language:** {lang_label}
---
## Chief Complaint
{analysis.get("chief_complaint", "N/A")}
---
{final_soap}
---
{reviewer_line}
*AI-generated report — must be reviewed by a licensed medical professional.*
"""
st.download_button(
"Download Report (.md)",
data=md_report,
file_name="tabeebai_report.md",
mime="text/markdown",
use_container_width=True,
)
else:
st.info("No SOAP report generated.")
with tab_symptoms:
symptoms = analysis.get("symptoms", [])
if symptoms:
st.markdown('Extracted Symptoms
', unsafe_allow_html=True)
render_symptoms_chips(symptoms)
st.markdown("
", unsafe_allow_html=True)
for s in symptoms:
sev = s.get("severity", "mild").lower()
_sev_map = {
"severe": {"color": "#c0392b", "bg": "rgba(192,57,43,0.08)", "border": "rgba(192,57,43,0.30)"},
"moderate": {"color": "#c97f0a", "bg": "rgba(201,127,10,0.10)", "border": "rgba(201,127,10,0.30)"},
"mild": {"color": "#0f7a5a", "bg": "rgba(15,122,90,0.08)", "border": "rgba(15,122,90,0.25)"},
}
sc = _sev_map.get(sev, {"color": "#0d7494", "bg": "rgba(13,116,148,0.10)", "border": "rgba(13,116,148,0.25)"})
st.markdown(f"""
{s.get("name", "—")}
Duration: {s.get("duration", "N/A")}
{sev}
""", unsafe_allow_html=True)
else:
st.info("No symptoms extracted.")
st.markdown("
", unsafe_allow_html=True)
st.markdown('Transcript
', unsafe_allow_html=True)
t1, t2 = st.columns(2)
with t1:
st.markdown(f"""
Original Input
{original_input or "—"}
""", unsafe_allow_html=True)
with t2:
st.markdown(f"""
English Translation
{english_text or "—"}
""", unsafe_allow_html=True)
with tab_diseases:
st.markdown('Dataset-Matched Diseases
', unsafe_allow_html=True)
if dataset_matches:
max_matched = max(d.get("matched", 0) for d in dataset_matches) or 1
for d in dataset_matches:
matched = d.get("matched", 0)
total = d.get("total", 1)
pct = int((matched / max_matched) * 100)
bar_color = "#c0392b" if pct > 70 else "#c97f0a" if pct > 40 else "#0d7494"
st.markdown(f"""
{d.get("disease","—")}
{matched}/{total} symptoms
""", unsafe_allow_html=True)
else:
st.info("No disease matches found. Ensure diseases_symptoms.csv is in the data/ folder.")
detail = analysis.get("dataset_match_detail", [])
if detail:
with st.expander("Match Detail"):
for item in detail:
st.markdown(f'{item}
', unsafe_allow_html=True)
with tab_rag:
st.markdown('Retrieved Medical Knowledge
', unsafe_allow_html=True)
if retrieved:
for chunk in retrieved:
score_pct = int(chunk.get("score", 0) * 100)
score_color = "#0f7a5a" if score_pct > 70 else "#c97f0a" if score_pct > 50 else "#6b7a91"
with st.expander(f"{chunk.get('title', 'Document')} — Relevance: {score_pct}%"):
st.markdown(f"""
{chunk.get("text", "")}
Cosine Similarity
{chunk.get('score',0):.4f}
""", unsafe_allow_html=True)
else:
st.info("No RAG chunks retrieved. Ensure medical_knowledge.json is in the data/ folder.")
elif dashboard == "Developer Dashboard":
st.markdown('Pipeline Observability
', unsafe_allow_html=True)
dev_tab1, dev_tab2, dev_tab3, dev_tab4, dev_tab5 = st.tabs([
"Pipeline Status", "Timings", "Raw JSON", "RAG Debug", "Models"
])
with dev_tab1:
st.markdown('Execution Pipeline
', unsafe_allow_html=True)
pipeline_steps = [
("Language Detection", lang_label, "green"),
("Query Classification", "medical", "green"),
("Translation", "Urdu → English" if lang_label == "Urdu" else "Skipped (English)", "cyan" if lang_label == "Urdu" else "yellow"),
("Symptom Extraction", f"{len(analysis.get('symptoms',[]))} symptoms extracted", "green"),
("Risk Scoring", f"{risk_score}/100 — {risk_level}", "green" if risk_level == "GREEN" else ("yellow" if risk_level == "YELLOW" else "red")),
("Disease Lookup", f"{len(dataset_matches)} diseases matched", "green" if dataset_matches else "yellow"),
("RAG Retrieval", f"{len(retrieved)} chunks retrieved", "green" if retrieved else "yellow"),
("SOAP Generation", "Complete", "green"),
]
for step_name, step_detail, step_color in pipeline_steps:
st.markdown(f"""
{step_name}
— {step_detail}
""", unsafe_allow_html=True)
st.markdown("
", unsafe_allow_html=True)
st.markdown('Input Summary
', unsafe_allow_html=True)
col_a, col_b = st.columns(2)
with col_a:
st.markdown(f"""
Original Input
{original_input or "—"}
""", unsafe_allow_html=True)
with col_b:
st.markdown(f"""
English (post-translation)
{english_text or "—"}
""", unsafe_allow_html=True)
with dev_tab2:
st.markdown('Pipeline Timing Breakdown
', unsafe_allow_html=True)
timing_labels = {
"transcription_ms": "Transcription",
"translation_ms": "Translation",
"classification_ms": "Classification",
"extraction_ms": "Extraction",
"risk_scoring_ms": "Risk Scoring",
"disease_lookup_ms": "Disease Lookup",
"rag_retrieval_ms": "RAG Retrieval",
"soap_generation_ms": "SOAP Report",
"total_ms": "TOTAL"
}
if timings:
cols = st.columns(4)
i = 0
for key, label in timing_labels.items():
if key in timings:
val = timings[key]
with cols[i % 4]:
border_color = "var(--cyan)" if key == "total_ms" else "var(--border-subtle)"
st.markdown(f"""
{val}ms
{label}
""", unsafe_allow_html=True)
i += 1
total = timings.get("total_ms", 0)
st.markdown("
", unsafe_allow_html=True)
if total > 0:
st.markdown('Time Distribution
', unsafe_allow_html=True)
render_keys = [k for k in timing_labels if k != "total_ms" and k in timings]
for key in render_keys:
label = timing_labels[key]
val = timings[key]
pct = min(int((val / total) * 100), 100)
st.markdown(f"""
""", unsafe_allow_html=True)
else:
st.info("No timing data available.")
with dev_tab3:
st.markdown('Raw Analysis JSON
', unsafe_allow_html=True)
st.code(json.dumps(analysis, indent=2, ensure_ascii=False), language="json")
st.markdown('Full Pipeline Response
', unsafe_allow_html=True)
display_result = {k: v for k, v in result.items() if k != "soap_report"}
st.code(json.dumps(display_result, indent=2, ensure_ascii=False), language="json")
st.download_button(
"Download Full JSON",
data=json.dumps(result, indent=2, ensure_ascii=False),
file_name="tabeebai_debug.json",
mime="application/json"
)
with dev_tab4:
st.markdown('RAG Retrieval Debug
', unsafe_allow_html=True)
if retrieved:
for i, chunk in enumerate(retrieved):
st.markdown(f"""
Chunk {i+1}: {chunk.get("title", "Untitled")}
score: {chunk.get("score", 0):.4f}
{chunk.get("text", "")}
""", unsafe_allow_html=True)
else:
st.info("No RAG chunks retrieved.")
st.markdown("
", unsafe_allow_html=True)
st.markdown('Disease Lookup Debug
', unsafe_allow_html=True)
if dataset_matches:
st.code(json.dumps(dataset_matches, indent=2), language="json")
else:
st.info("No disease matches.")
detail = analysis.get("dataset_match_detail", [])
if detail:
st.markdown('Match Detail Strings
', unsafe_allow_html=True)
for item in detail:
st.markdown(f'{item}
', unsafe_allow_html=True)
with dev_tab5:
st.markdown('Models Used
', unsafe_allow_html=True)
models = result.get("models_used", {})
if models:
for role, model_name in models.items():
st.markdown(f"""
{role.replace("_"," ")}
{model_name}
""", unsafe_allow_html=True)
else:
st.info("Model info not available.")
st.markdown("
", unsafe_allow_html=True)
st.markdown('System Health
', unsafe_allow_html=True)
st.code(json.dumps(get_system_health(), indent=2), language="json")