MediChatAI / src /streamlit_app.py
Eslam Waleed
Initial Release: MediChat AI Dashboard
a8ee124
import streamlit as st
import os
import json
import numpy as np
import easyocr
from PIL import Image
from huggingface_hub import InferenceClient
import re
from collections import Counter
import math
# Try to import pypdf for PDF support
try:
from pypdf import PdfReader
PDF_SUPPORT = True
except ImportError:
PDF_SUPPORT = False
# --- 1. PAGE CONFIGURATION & THEME ---
st.set_page_config(page_title="MediChat AI", layout="wide", page_icon="πŸ₯")
# Custom CSS for Professional Medical Dashboard
st.markdown("""
<style>
.stApp { background-color: #f8f9fa; color: #2c3e50; }
h1, h2, h3 { color: #1b5e20 !important; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; font-weight: 600; }
label[data-testid="stWidgetLabel"] { color: #2c3e50 !important; font-weight: 600; }
[data-testid="stFileUploader"] { color: #2c3e50 !important; }
[data-testid="stFileUploader"] * { color: #2c3e50 !important; }
[data-testid="stFileUploader"] small { color: #555555 !important; }
.medical-card { background-color: white; padding: 20px; border-radius: 12px; box-shadow: 0 4px 6px rgba(0,0,0,0.05); margin-bottom: 20px; border: 1px solid #e0e0e0; }
.diagnosis { border-left: 6px solid #1976d2; }
.advice { border-left: 6px solid #2e7d32; }
.warning { border-left: 6px solid #d32f2f; }
.card-label { color: #7f8c8d; font-size: 0.85rem; text-transform: uppercase; letter-spacing: 1px; margin-bottom: 8px; font-weight: bold; }
.card-content { color: #2c3e50; font-size: 1.1rem; line-height: 1.6; }
div[data-testid="stFileUploader"] { background-color: white; border: 2px dashed #a5d6a7; border-radius: 12px; padding: 30px; }
div.stButton > button:first-child { background-color: #2e7d32; color: white; border-radius: 8px; border: none; padding: 0.6rem 1.2rem; font-size: 1rem; font-weight: 600; box-shadow: 0 2px 4px rgba(0,0,0,0.1); transition: all 0.2s ease; }
div.stButton > button:first-child:hover { background-color: #1b5e20; box-shadow: 0 4px 8px rgba(0,0,0,0.2); transform: translateY(-1px); }
div[data-testid="stChatMessage"] { background-color: white; border: 1px solid #e0e0e0; border-radius: 12px; padding: 15px; margin-bottom: 10px; box-shadow: 0 2px 4px rgba(0,0,0,0.02); }
div[data-testid="stChatMessage"] p, div[data-testid="stChatMessage"] div { color: #2c3e50 !important; }
div[data-testid="stChatMessage"][data-testid*="user"] { background-color: #f1f8e9; }
</style>
""", unsafe_allow_html=True)
# --- 2. SECURE BACKGROUND SETUP ---
# Fetch the token securely from Hugging Face Settings -> Secrets. No UI needed!
api_key = os.environ.get("HUGGINGFACEHUB_API_TOKEN", "")
# --- 3. LOGIC FUNCTIONS ---
@st.cache_resource
def load_ocr_reader():
return easyocr.Reader(['en'], gpu=False)
def extract_text(uploaded_file):
text = ""
try:
if uploaded_file.type in ["image/jpeg", "image/png", "image/jpg"]:
image = Image.open(uploaded_file)
image_np = np.array(image)
reader = load_ocr_reader()
result = reader.readtext(image_np, detail=0)
text = " ".join(result)
elif uploaded_file.type == "application/pdf":
if PDF_SUPPORT:
reader = PdfReader(uploaded_file)
for page in reader.pages:
text += page.extract_text() or ""
else:
return "Error: PDF support requires 'pypdf'."
except Exception as e:
return f"Error processing file: {e}"
return text
def analyze_report(text, client):
if not text: return {"error": "No text extracted"}
system_prompt = """
Analyze this medical report. Return valid JSON only.
Keys: "severity" (CRITICAL, MODERATE, NORMAL), "diagnosis_hypothesis", "abnormal_values" (list of strings), "medical_advice".
IMPORTANT INSTRUCTION:
For "abnormal_values", list EVERY abnormal finding found in the document.
FORMAT: "Condition/Test Name: Value (Reference Range if available)" or "Specific Finding".
EXAMPLE: ["Hemoglobin: 7.2 g/dL (Low)", "ECG: Sinus Tachycardia", "Potassium: 5.8 mmol/L (High)"].
Do not summarize vaguely. Be exact and exhaustive.
"""
try:
response = client.chat_completion(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"REPORT: {text[:3000]}"}
],
model="Qwen/Qwen2.5-7B-Instruct", # Swap to "aaditya/Llama3-OpenBioLLM-8B" here if you want to test it!
max_tokens=1500,
temperature=0.1
)
raw = response.choices[0].message.content
clean = raw.replace("```json", "").replace("```", "").strip()
try:
return json.loads(clean)
except json.JSONDecodeError:
start = clean.find('{')
end = clean.rfind('}') + 1
if start != -1 and end != 0:
try: return json.loads(clean[start:end])
except: pass
fallback = {
"severity": "UNKNOWN",
"diagnosis_hypothesis": "Not found",
"abnormal_values": [],
"medical_advice": "Not found"
}
sev_match = re.search(r'"severity":\s*"([^"]+)"', clean, re.IGNORECASE)
diag_match = re.search(r'"diagnosis_hypothesis":\s*"([^"]+)"', clean, re.IGNORECASE)
advice_match = re.search(r'"medical_advice":\s*"([^"]+)"', clean, re.IGNORECASE)
if sev_match: fallback["severity"] = sev_match.group(1)
if diag_match: fallback["diagnosis_hypothesis"] = diag_match.group(1)
if advice_match: fallback["medical_advice"] = advice_match.group(1)
abnormal_match = re.search(r'"abnormal_values":\s*\[(.*?)\]', clean, re.DOTALL)
if abnormal_match:
values = re.findall(r'"([^"]+)"', abnormal_match.group(1))
fallback["abnormal_values"] = values
if fallback["diagnosis_hypothesis"] != "Not found":
return fallback
return {
"severity": "UNKNOWN",
"diagnosis_hypothesis": "AI Analysis Error: JSON formatting failed.",
"abnormal_values": ["Raw output could not be parsed"],
"medical_advice": f"Raw Output Preview: {clean[:100]}..."
}
except Exception as e:
return {"error": str(e)}
def simple_text_splitter(text, chunk_size=500, overlap=50):
chunks = []
start = 0
while start < len(text):
end = min(start + chunk_size, len(text))
chunks.append(text[start:end])
start += (chunk_size - overlap)
return chunks
def text_to_vector(text):
words = re.compile(r'\w+').findall(text.lower())
return Counter(words)
def get_cosine_similarity(vec1, vec2):
intersection = set(vec1.keys()) & set(vec2.keys())
numerator = sum([vec1[x] * vec2[x] for x in intersection])
sum1 = sum([vec1[x]**2 for x in vec1.keys()])
sum2 = sum([vec2[x]**2 for x in vec2.keys()])
denominator = math.sqrt(sum1) * math.sqrt(sum2)
return float(numerator) / denominator if denominator else 0.0
def ask_chatbot(question, chunks, client):
question_vec = text_to_vector(question)
scores = []
for i, chunk in enumerate(chunks):
chunk_vec = text_to_vector(chunk)
score = get_cosine_similarity(question_vec, chunk_vec)
scores.append((score, i))
scores.sort(key=lambda x: x[0], reverse=True)
top_indices = [idx for score, idx in scores[:3]]
context = "\n...\n".join([chunks[i] for i in top_indices])
sys_p = """
You are a Medical Assistant. Your job is to answer the user's question using EITHER the report data OR general medical knowledge.
INSTRUCTIONS:
1. Check the "REPORT CONTEXT". Does it have the answer?
- YES -> Answer strictly based on the report. Start with "Based on your report..."
- NO -> Do NOT say "I don't know." Instead, switch to GENERAL ADVICE MODE. Start with "Your report doesn't specify this, but generally..." and provide ONE concise sentence of standard medical advice for the conditions found in the text.
CONSTRAINT: Keep your answer to ONE or TWO sentences max.
"""
user_p = f"REPORT CONTEXT:\n{context}\n\nUSER QUESTION:\n{question}"
try:
resp = client.chat_completion(
messages=[{"role": "system", "content": sys_p}, {"role": "user", "content": user_p}],
model="Qwen/Qwen2.5-7B-Instruct", # Swap to "aaditya/Llama3-OpenBioLLM-8B" here if you want to test it!
max_tokens=600,
temperature=0.7
)
return resp.choices[0].message.content
except Exception as e:
return f"Error: {e}"
# --- 5. MAIN UI LAYOUT ---
st.title("πŸ₯ MediChat AI: Analysis Dashboard")
# Put a tiny, clean status indicator in the sidebar instead of a giant drop-down menu
with st.sidebar:
if api_key:
st.success("βœ… Secure AI Connection Ready")
st.caption("Engine: EasyOCR + Qwen 2.5 AI")
else:
st.error("❌ Setup Required: Please add your Hugging Face Token to the Space Settings -> Secrets tab.")
st.markdown("### Upload Report for Instant AI Diagnosis")
uploaded_file = st.file_uploader("Upload Medical Record (PDF/Image)", type=["jpg", "png", "jpeg", "pdf"])
if uploaded_file and api_key:
client = InferenceClient(token=api_key)
if "analysis" not in st.session_state: st.session_state.analysis = None
if "chunks" not in st.session_state: st.session_state.chunks = []
if st.button("Run Analysis ⚑"):
st.session_state.analysis = None
st.session_state.chunks = []
with st.spinner("Processing medical data..."):
raw_text = extract_text(uploaded_file)
if len(raw_text) < 50:
st.warning("⚠️ Low text quality detected. If this is a scanned PDF, please convert it to an Image (JPG/PNG) for accurate results.")
if len(raw_text) < 5:
st.error("Could not read text. Please try a clearer file or convert PDF to Image.")
else:
st.session_state.analysis = analyze_report(raw_text, client)
st.session_state.chunks = simple_text_splitter(raw_text)
# --- RESULTS DASHBOARD ---
if st.session_state.analysis:
res = st.session_state.analysis
if "error" in res:
st.error(f"🚨 AI API Error: {res['error']}")
else:
sev = res.get("severity", "UNKNOWN").upper()
if "CRITICAL" in sev:
status_color = "#d32f2f"
status_icon = "🚨"
elif "MODERATE" in sev:
status_color = "#f57c00"
status_icon = "⚠️"
else:
status_color = "#2e7d32"
status_icon = "βœ…"
st.markdown("---")
st.markdown(f"<h3 style='color:{status_color}!important'>{status_icon} Analysis Result: {sev}</h3>", unsafe_allow_html=True)
col1, col2 = st.columns(2)
with col1:
st.markdown(f"""
<div class="medical-card diagnosis">
<div class="card-label">Primary Diagnosis</div>
<div class="card-content">{res.get('diagnosis_hypothesis', 'Analysis pending...')}</div>
</div>
""", unsafe_allow_html=True)
with col2:
st.markdown(f"""
<div class="medical-card advice">
<div class="card-label">Recommended Action</div>
<div class="card-content">{res.get('medical_advice', 'Consult a doctor.')}</div>
</div>
""", unsafe_allow_html=True)
if "abnormal_values" in res and isinstance(res["abnormal_values"], list) and len(res["abnormal_values"]) > 0:
findings_html = "".join([f"<li style='margin-bottom:5px;'>{val}</li>" for val in res["abnormal_values"]])
st.markdown(f"""
<div class="medical-card warning">
<div class="card-label" style="color:#d32f2f;">⚠️ Critical Findings</div>
<ul style="margin-bottom:0; padding-left:20px;">
{findings_html}
</ul>
</div>
""", unsafe_allow_html=True)
else:
st.success("No specific abnormal values detected in the extraction.")
# --- CHAT SECTION ---
if st.session_state.chunks:
st.markdown("---")
st.subheader("πŸ’¬ Doctor's Companion Chat")
if "messages" not in st.session_state: st.session_state.messages = []
for msg in st.session_state.messages:
with st.chat_message(msg["role"]): st.write(msg["content"])
if prompt := st.chat_input("Ask specific questions about this patient..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"): st.write(prompt)
with st.chat_message("assistant"):
with st.spinner("Reviewing case notes..."):
ans = ask_chatbot(prompt, st.session_state.chunks, client)
st.write(ans)
st.session_state.messages.append({"role": "assistant", "content": ans})