Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import json | |
| import numpy as np | |
| import easyocr | |
| from PIL import Image | |
| from huggingface_hub import InferenceClient | |
| import re | |
| from collections import Counter | |
| import math | |
| # Try to import pypdf for PDF support | |
| try: | |
| from pypdf import PdfReader | |
| PDF_SUPPORT = True | |
| except ImportError: | |
| PDF_SUPPORT = False | |
| # --- 1. PAGE CONFIGURATION & THEME --- | |
| st.set_page_config(page_title="MediChat AI", layout="wide", page_icon="π₯") | |
| # Custom CSS for Professional Medical Dashboard | |
| st.markdown(""" | |
| <style> | |
| .stApp { background-color: #f8f9fa; color: #2c3e50; } | |
| h1, h2, h3 { color: #1b5e20 !important; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; font-weight: 600; } | |
| label[data-testid="stWidgetLabel"] { color: #2c3e50 !important; font-weight: 600; } | |
| [data-testid="stFileUploader"] { color: #2c3e50 !important; } | |
| [data-testid="stFileUploader"] * { color: #2c3e50 !important; } | |
| [data-testid="stFileUploader"] small { color: #555555 !important; } | |
| .medical-card { background-color: white; padding: 20px; border-radius: 12px; box-shadow: 0 4px 6px rgba(0,0,0,0.05); margin-bottom: 20px; border: 1px solid #e0e0e0; } | |
| .diagnosis { border-left: 6px solid #1976d2; } | |
| .advice { border-left: 6px solid #2e7d32; } | |
| .warning { border-left: 6px solid #d32f2f; } | |
| .card-label { color: #7f8c8d; font-size: 0.85rem; text-transform: uppercase; letter-spacing: 1px; margin-bottom: 8px; font-weight: bold; } | |
| .card-content { color: #2c3e50; font-size: 1.1rem; line-height: 1.6; } | |
| div[data-testid="stFileUploader"] { background-color: white; border: 2px dashed #a5d6a7; border-radius: 12px; padding: 30px; } | |
| div.stButton > button:first-child { background-color: #2e7d32; color: white; border-radius: 8px; border: none; padding: 0.6rem 1.2rem; font-size: 1rem; font-weight: 600; box-shadow: 0 2px 4px rgba(0,0,0,0.1); transition: all 0.2s ease; } | |
| div.stButton > button:first-child:hover { background-color: #1b5e20; box-shadow: 0 4px 8px rgba(0,0,0,0.2); transform: translateY(-1px); } | |
| div[data-testid="stChatMessage"] { background-color: white; border: 1px solid #e0e0e0; border-radius: 12px; padding: 15px; margin-bottom: 10px; box-shadow: 0 2px 4px rgba(0,0,0,0.02); } | |
| div[data-testid="stChatMessage"] p, div[data-testid="stChatMessage"] div { color: #2c3e50 !important; } | |
| div[data-testid="stChatMessage"][data-testid*="user"] { background-color: #f1f8e9; } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # --- 2. SECURE BACKGROUND SETUP --- | |
| # Fetch the token securely from Hugging Face Settings -> Secrets. No UI needed! | |
| api_key = os.environ.get("HUGGINGFACEHUB_API_TOKEN", "") | |
| # --- 3. LOGIC FUNCTIONS --- | |
| def load_ocr_reader(): | |
| return easyocr.Reader(['en'], gpu=False) | |
| def extract_text(uploaded_file): | |
| text = "" | |
| try: | |
| if uploaded_file.type in ["image/jpeg", "image/png", "image/jpg"]: | |
| image = Image.open(uploaded_file) | |
| image_np = np.array(image) | |
| reader = load_ocr_reader() | |
| result = reader.readtext(image_np, detail=0) | |
| text = " ".join(result) | |
| elif uploaded_file.type == "application/pdf": | |
| if PDF_SUPPORT: | |
| reader = PdfReader(uploaded_file) | |
| for page in reader.pages: | |
| text += page.extract_text() or "" | |
| else: | |
| return "Error: PDF support requires 'pypdf'." | |
| except Exception as e: | |
| return f"Error processing file: {e}" | |
| return text | |
| def analyze_report(text, client): | |
| if not text: return {"error": "No text extracted"} | |
| system_prompt = """ | |
| Analyze this medical report. Return valid JSON only. | |
| Keys: "severity" (CRITICAL, MODERATE, NORMAL), "diagnosis_hypothesis", "abnormal_values" (list of strings), "medical_advice". | |
| IMPORTANT INSTRUCTION: | |
| For "abnormal_values", list EVERY abnormal finding found in the document. | |
| FORMAT: "Condition/Test Name: Value (Reference Range if available)" or "Specific Finding". | |
| EXAMPLE: ["Hemoglobin: 7.2 g/dL (Low)", "ECG: Sinus Tachycardia", "Potassium: 5.8 mmol/L (High)"]. | |
| Do not summarize vaguely. Be exact and exhaustive. | |
| """ | |
| try: | |
| response = client.chat_completion( | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": f"REPORT: {text[:3000]}"} | |
| ], | |
| model="Qwen/Qwen2.5-7B-Instruct", # Swap to "aaditya/Llama3-OpenBioLLM-8B" here if you want to test it! | |
| max_tokens=1500, | |
| temperature=0.1 | |
| ) | |
| raw = response.choices[0].message.content | |
| clean = raw.replace("```json", "").replace("```", "").strip() | |
| try: | |
| return json.loads(clean) | |
| except json.JSONDecodeError: | |
| start = clean.find('{') | |
| end = clean.rfind('}') + 1 | |
| if start != -1 and end != 0: | |
| try: return json.loads(clean[start:end]) | |
| except: pass | |
| fallback = { | |
| "severity": "UNKNOWN", | |
| "diagnosis_hypothesis": "Not found", | |
| "abnormal_values": [], | |
| "medical_advice": "Not found" | |
| } | |
| sev_match = re.search(r'"severity":\s*"([^"]+)"', clean, re.IGNORECASE) | |
| diag_match = re.search(r'"diagnosis_hypothesis":\s*"([^"]+)"', clean, re.IGNORECASE) | |
| advice_match = re.search(r'"medical_advice":\s*"([^"]+)"', clean, re.IGNORECASE) | |
| if sev_match: fallback["severity"] = sev_match.group(1) | |
| if diag_match: fallback["diagnosis_hypothesis"] = diag_match.group(1) | |
| if advice_match: fallback["medical_advice"] = advice_match.group(1) | |
| abnormal_match = re.search(r'"abnormal_values":\s*\[(.*?)\]', clean, re.DOTALL) | |
| if abnormal_match: | |
| values = re.findall(r'"([^"]+)"', abnormal_match.group(1)) | |
| fallback["abnormal_values"] = values | |
| if fallback["diagnosis_hypothesis"] != "Not found": | |
| return fallback | |
| return { | |
| "severity": "UNKNOWN", | |
| "diagnosis_hypothesis": "AI Analysis Error: JSON formatting failed.", | |
| "abnormal_values": ["Raw output could not be parsed"], | |
| "medical_advice": f"Raw Output Preview: {clean[:100]}..." | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| def simple_text_splitter(text, chunk_size=500, overlap=50): | |
| chunks = [] | |
| start = 0 | |
| while start < len(text): | |
| end = min(start + chunk_size, len(text)) | |
| chunks.append(text[start:end]) | |
| start += (chunk_size - overlap) | |
| return chunks | |
| def text_to_vector(text): | |
| words = re.compile(r'\w+').findall(text.lower()) | |
| return Counter(words) | |
| def get_cosine_similarity(vec1, vec2): | |
| intersection = set(vec1.keys()) & set(vec2.keys()) | |
| numerator = sum([vec1[x] * vec2[x] for x in intersection]) | |
| sum1 = sum([vec1[x]**2 for x in vec1.keys()]) | |
| sum2 = sum([vec2[x]**2 for x in vec2.keys()]) | |
| denominator = math.sqrt(sum1) * math.sqrt(sum2) | |
| return float(numerator) / denominator if denominator else 0.0 | |
| def ask_chatbot(question, chunks, client): | |
| question_vec = text_to_vector(question) | |
| scores = [] | |
| for i, chunk in enumerate(chunks): | |
| chunk_vec = text_to_vector(chunk) | |
| score = get_cosine_similarity(question_vec, chunk_vec) | |
| scores.append((score, i)) | |
| scores.sort(key=lambda x: x[0], reverse=True) | |
| top_indices = [idx for score, idx in scores[:3]] | |
| context = "\n...\n".join([chunks[i] for i in top_indices]) | |
| sys_p = """ | |
| You are a Medical Assistant. Your job is to answer the user's question using EITHER the report data OR general medical knowledge. | |
| INSTRUCTIONS: | |
| 1. Check the "REPORT CONTEXT". Does it have the answer? | |
| - YES -> Answer strictly based on the report. Start with "Based on your report..." | |
| - NO -> Do NOT say "I don't know." Instead, switch to GENERAL ADVICE MODE. Start with "Your report doesn't specify this, but generally..." and provide ONE concise sentence of standard medical advice for the conditions found in the text. | |
| CONSTRAINT: Keep your answer to ONE or TWO sentences max. | |
| """ | |
| user_p = f"REPORT CONTEXT:\n{context}\n\nUSER QUESTION:\n{question}" | |
| try: | |
| resp = client.chat_completion( | |
| messages=[{"role": "system", "content": sys_p}, {"role": "user", "content": user_p}], | |
| model="Qwen/Qwen2.5-7B-Instruct", # Swap to "aaditya/Llama3-OpenBioLLM-8B" here if you want to test it! | |
| max_tokens=600, | |
| temperature=0.7 | |
| ) | |
| return resp.choices[0].message.content | |
| except Exception as e: | |
| return f"Error: {e}" | |
| # --- 5. MAIN UI LAYOUT --- | |
| st.title("π₯ MediChat AI: Analysis Dashboard") | |
| # Put a tiny, clean status indicator in the sidebar instead of a giant drop-down menu | |
| with st.sidebar: | |
| if api_key: | |
| st.success("β Secure AI Connection Ready") | |
| st.caption("Engine: EasyOCR + Qwen 2.5 AI") | |
| else: | |
| st.error("β Setup Required: Please add your Hugging Face Token to the Space Settings -> Secrets tab.") | |
| st.markdown("### Upload Report for Instant AI Diagnosis") | |
| uploaded_file = st.file_uploader("Upload Medical Record (PDF/Image)", type=["jpg", "png", "jpeg", "pdf"]) | |
| if uploaded_file and api_key: | |
| client = InferenceClient(token=api_key) | |
| if "analysis" not in st.session_state: st.session_state.analysis = None | |
| if "chunks" not in st.session_state: st.session_state.chunks = [] | |
| if st.button("Run Analysis β‘"): | |
| st.session_state.analysis = None | |
| st.session_state.chunks = [] | |
| with st.spinner("Processing medical data..."): | |
| raw_text = extract_text(uploaded_file) | |
| if len(raw_text) < 50: | |
| st.warning("β οΈ Low text quality detected. If this is a scanned PDF, please convert it to an Image (JPG/PNG) for accurate results.") | |
| if len(raw_text) < 5: | |
| st.error("Could not read text. Please try a clearer file or convert PDF to Image.") | |
| else: | |
| st.session_state.analysis = analyze_report(raw_text, client) | |
| st.session_state.chunks = simple_text_splitter(raw_text) | |
| # --- RESULTS DASHBOARD --- | |
| if st.session_state.analysis: | |
| res = st.session_state.analysis | |
| if "error" in res: | |
| st.error(f"π¨ AI API Error: {res['error']}") | |
| else: | |
| sev = res.get("severity", "UNKNOWN").upper() | |
| if "CRITICAL" in sev: | |
| status_color = "#d32f2f" | |
| status_icon = "π¨" | |
| elif "MODERATE" in sev: | |
| status_color = "#f57c00" | |
| status_icon = "β οΈ" | |
| else: | |
| status_color = "#2e7d32" | |
| status_icon = "β " | |
| st.markdown("---") | |
| st.markdown(f"<h3 style='color:{status_color}!important'>{status_icon} Analysis Result: {sev}</h3>", unsafe_allow_html=True) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.markdown(f""" | |
| <div class="medical-card diagnosis"> | |
| <div class="card-label">Primary Diagnosis</div> | |
| <div class="card-content">{res.get('diagnosis_hypothesis', 'Analysis pending...')}</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| with col2: | |
| st.markdown(f""" | |
| <div class="medical-card advice"> | |
| <div class="card-label">Recommended Action</div> | |
| <div class="card-content">{res.get('medical_advice', 'Consult a doctor.')}</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| if "abnormal_values" in res and isinstance(res["abnormal_values"], list) and len(res["abnormal_values"]) > 0: | |
| findings_html = "".join([f"<li style='margin-bottom:5px;'>{val}</li>" for val in res["abnormal_values"]]) | |
| st.markdown(f""" | |
| <div class="medical-card warning"> | |
| <div class="card-label" style="color:#d32f2f;">β οΈ Critical Findings</div> | |
| <ul style="margin-bottom:0; padding-left:20px;"> | |
| {findings_html} | |
| </ul> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| else: | |
| st.success("No specific abnormal values detected in the extraction.") | |
| # --- CHAT SECTION --- | |
| if st.session_state.chunks: | |
| st.markdown("---") | |
| st.subheader("π¬ Doctor's Companion Chat") | |
| if "messages" not in st.session_state: st.session_state.messages = [] | |
| for msg in st.session_state.messages: | |
| with st.chat_message(msg["role"]): st.write(msg["content"]) | |
| if prompt := st.chat_input("Ask specific questions about this patient..."): | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): st.write(prompt) | |
| with st.chat_message("assistant"): | |
| with st.spinner("Reviewing case notes..."): | |
| ans = ask_chatbot(prompt, st.session_state.chunks, client) | |
| st.write(ans) | |
| st.session_state.messages.append({"role": "assistant", "content": ans}) |