File size: 13,706 Bytes
a8ee124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
import streamlit as st
import os
import json
import numpy as np
import easyocr
from PIL import Image
from huggingface_hub import InferenceClient
import re
from collections import Counter
import math

# Try to import pypdf for PDF support
try:
    from pypdf import PdfReader
    PDF_SUPPORT = True
except ImportError:
    PDF_SUPPORT = False

# --- 1. PAGE CONFIGURATION & THEME ---
st.set_page_config(page_title="MediChat AI", layout="wide", page_icon="πŸ₯")

# Custom CSS for Professional Medical Dashboard
st.markdown("""
    <style>
    .stApp { background-color: #f8f9fa; color: #2c3e50; }
    h1, h2, h3 { color: #1b5e20 !important; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; font-weight: 600; }
    label[data-testid="stWidgetLabel"] { color: #2c3e50 !important; font-weight: 600; }
    [data-testid="stFileUploader"] { color: #2c3e50 !important; }
    [data-testid="stFileUploader"] * { color: #2c3e50 !important; }
    [data-testid="stFileUploader"] small { color: #555555 !important; }
    .medical-card { background-color: white; padding: 20px; border-radius: 12px; box-shadow: 0 4px 6px rgba(0,0,0,0.05); margin-bottom: 20px; border: 1px solid #e0e0e0; }
    .diagnosis { border-left: 6px solid #1976d2; }
    .advice { border-left: 6px solid #2e7d32; }
    .warning { border-left: 6px solid #d32f2f; }
    .card-label { color: #7f8c8d; font-size: 0.85rem; text-transform: uppercase; letter-spacing: 1px; margin-bottom: 8px; font-weight: bold; }
    .card-content { color: #2c3e50; font-size: 1.1rem; line-height: 1.6; }
    div[data-testid="stFileUploader"] { background-color: white; border: 2px dashed #a5d6a7; border-radius: 12px; padding: 30px; }
    div.stButton > button:first-child { background-color: #2e7d32; color: white; border-radius: 8px; border: none; padding: 0.6rem 1.2rem; font-size: 1rem; font-weight: 600; box-shadow: 0 2px 4px rgba(0,0,0,0.1); transition: all 0.2s ease; }
    div.stButton > button:first-child:hover { background-color: #1b5e20; box-shadow: 0 4px 8px rgba(0,0,0,0.2); transform: translateY(-1px); }
    div[data-testid="stChatMessage"] { background-color: white; border: 1px solid #e0e0e0; border-radius: 12px; padding: 15px; margin-bottom: 10px; box-shadow: 0 2px 4px rgba(0,0,0,0.02); }
    div[data-testid="stChatMessage"] p, div[data-testid="stChatMessage"] div { color: #2c3e50 !important; }
    div[data-testid="stChatMessage"][data-testid*="user"] { background-color: #f1f8e9; }
    </style>
""", unsafe_allow_html=True)

# --- 2. SECURE BACKGROUND SETUP ---
# Fetch the token securely from Hugging Face Settings -> Secrets. No UI needed!
api_key = os.environ.get("HUGGINGFACEHUB_API_TOKEN", "")

# --- 3. LOGIC FUNCTIONS ---
@st.cache_resource
def load_ocr_reader():
    return easyocr.Reader(['en'], gpu=False)

def extract_text(uploaded_file):
    text = ""
    try:
        if uploaded_file.type in ["image/jpeg", "image/png", "image/jpg"]:
            image = Image.open(uploaded_file)
            image_np = np.array(image)
            reader = load_ocr_reader()
            result = reader.readtext(image_np, detail=0)
            text = " ".join(result)
        elif uploaded_file.type == "application/pdf":
            if PDF_SUPPORT:
                reader = PdfReader(uploaded_file)
                for page in reader.pages:
                    text += page.extract_text() or ""
            else:
                return "Error: PDF support requires 'pypdf'."
    except Exception as e:
        return f"Error processing file: {e}"
    return text

def analyze_report(text, client):
    if not text: return {"error": "No text extracted"}
    
    system_prompt = """
    Analyze this medical report. Return valid JSON only.
    Keys: "severity" (CRITICAL, MODERATE, NORMAL), "diagnosis_hypothesis", "abnormal_values" (list of strings), "medical_advice".
    
    IMPORTANT INSTRUCTION:
    For "abnormal_values", list EVERY abnormal finding found in the document. 
    FORMAT: "Condition/Test Name: Value (Reference Range if available)" or "Specific Finding".
    EXAMPLE: ["Hemoglobin: 7.2 g/dL (Low)", "ECG: Sinus Tachycardia", "Potassium: 5.8 mmol/L (High)"].
    Do not summarize vaguely. Be exact and exhaustive.
    """
    
    try:
        response = client.chat_completion(
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": f"REPORT: {text[:3000]}"} 
            ],
            model="Qwen/Qwen2.5-7B-Instruct", # Swap to "aaditya/Llama3-OpenBioLLM-8B" here if you want to test it!
            max_tokens=1500,
            temperature=0.1
        )
        raw = response.choices[0].message.content
        clean = raw.replace("```json", "").replace("```", "").strip()
        
        try:
            return json.loads(clean)
        except json.JSONDecodeError:
            start = clean.find('{')
            end = clean.rfind('}') + 1
            if start != -1 and end != 0:
                try: return json.loads(clean[start:end])
                except: pass
            
            fallback = {
                "severity": "UNKNOWN",
                "diagnosis_hypothesis": "Not found",
                "abnormal_values": [],
                "medical_advice": "Not found"
            }
            
            sev_match = re.search(r'"severity":\s*"([^"]+)"', clean, re.IGNORECASE)
            diag_match = re.search(r'"diagnosis_hypothesis":\s*"([^"]+)"', clean, re.IGNORECASE)
            advice_match = re.search(r'"medical_advice":\s*"([^"]+)"', clean, re.IGNORECASE)
            
            if sev_match: fallback["severity"] = sev_match.group(1)
            if diag_match: fallback["diagnosis_hypothesis"] = diag_match.group(1)
            if advice_match: fallback["medical_advice"] = advice_match.group(1)
            
            abnormal_match = re.search(r'"abnormal_values":\s*\[(.*?)\]', clean, re.DOTALL)
            if abnormal_match:
                values = re.findall(r'"([^"]+)"', abnormal_match.group(1))
                fallback["abnormal_values"] = values

            if fallback["diagnosis_hypothesis"] != "Not found":
                return fallback
            
            return {
                "severity": "UNKNOWN",
                "diagnosis_hypothesis": "AI Analysis Error: JSON formatting failed.",
                "abnormal_values": ["Raw output could not be parsed"],
                "medical_advice": f"Raw Output Preview: {clean[:100]}..."
            }
            
    except Exception as e:
        return {"error": str(e)}

def simple_text_splitter(text, chunk_size=500, overlap=50):
    chunks = []
    start = 0
    while start < len(text):
        end = min(start + chunk_size, len(text))
        chunks.append(text[start:end])
        start += (chunk_size - overlap)
    return chunks

def text_to_vector(text):
    words = re.compile(r'\w+').findall(text.lower())
    return Counter(words)

def get_cosine_similarity(vec1, vec2):
    intersection = set(vec1.keys()) & set(vec2.keys())
    numerator = sum([vec1[x] * vec2[x] for x in intersection])
    sum1 = sum([vec1[x]**2 for x in vec1.keys()])
    sum2 = sum([vec2[x]**2 for x in vec2.keys()])
    denominator = math.sqrt(sum1) * math.sqrt(sum2)
    return float(numerator) / denominator if denominator else 0.0

def ask_chatbot(question, chunks, client):
    question_vec = text_to_vector(question)
    scores = []
    for i, chunk in enumerate(chunks):
        chunk_vec = text_to_vector(chunk)
        score = get_cosine_similarity(question_vec, chunk_vec)
        scores.append((score, i))
    
    scores.sort(key=lambda x: x[0], reverse=True)
    top_indices = [idx for score, idx in scores[:3]]
    context = "\n...\n".join([chunks[i] for i in top_indices])
    
    sys_p = """
    You are a Medical Assistant. Your job is to answer the user's question using EITHER the report data OR general medical knowledge.
    
    INSTRUCTIONS:
    1. Check the "REPORT CONTEXT". Does it have the answer?
       - YES -> Answer strictly based on the report. Start with "Based on your report..."
       - NO -> Do NOT say "I don't know." Instead, switch to GENERAL ADVICE MODE. Start with "Your report doesn't specify this, but generally..." and provide ONE concise sentence of standard medical advice for the conditions found in the text.
    
    CONSTRAINT: Keep your answer to ONE or TWO sentences max.
    """
    
    user_p = f"REPORT CONTEXT:\n{context}\n\nUSER QUESTION:\n{question}"
    
    try:
        resp = client.chat_completion(
            messages=[{"role": "system", "content": sys_p}, {"role": "user", "content": user_p}],
            model="Qwen/Qwen2.5-7B-Instruct", # Swap to "aaditya/Llama3-OpenBioLLM-8B" here if you want to test it!
            max_tokens=600,
            temperature=0.7 
        )
        return resp.choices[0].message.content
    except Exception as e:
        return f"Error: {e}"

# --- 5. MAIN UI LAYOUT ---

st.title("πŸ₯ MediChat AI: Analysis Dashboard")

# Put a tiny, clean status indicator in the sidebar instead of a giant drop-down menu
with st.sidebar:
    if api_key:
        st.success("βœ… Secure AI Connection Ready")
        st.caption("Engine: EasyOCR + Qwen 2.5 AI")
    else:
        st.error("❌ Setup Required: Please add your Hugging Face Token to the Space Settings -> Secrets tab.")

st.markdown("### Upload Report for Instant AI Diagnosis")

uploaded_file = st.file_uploader("Upload Medical Record (PDF/Image)", type=["jpg", "png", "jpeg", "pdf"])

if uploaded_file and api_key:
    client = InferenceClient(token=api_key)

    if "analysis" not in st.session_state: st.session_state.analysis = None
    if "chunks" not in st.session_state: st.session_state.chunks = []

    if st.button("Run Analysis ⚑"):
        st.session_state.analysis = None
        st.session_state.chunks = []
        
        with st.spinner("Processing medical data..."):
            raw_text = extract_text(uploaded_file)
            
            if len(raw_text) < 50:
                st.warning("⚠️ Low text quality detected. If this is a scanned PDF, please convert it to an Image (JPG/PNG) for accurate results.")
            
            if len(raw_text) < 5:
                st.error("Could not read text. Please try a clearer file or convert PDF to Image.")
            else:
                st.session_state.analysis = analyze_report(raw_text, client)
                st.session_state.chunks = simple_text_splitter(raw_text)

    # --- RESULTS DASHBOARD ---
    if st.session_state.analysis:
        res = st.session_state.analysis
        
        if "error" in res:
            st.error(f"🚨 AI API Error: {res['error']}")
        
        else:
            sev = res.get("severity", "UNKNOWN").upper()
            
            if "CRITICAL" in sev: 
                status_color = "#d32f2f"
                status_icon = "🚨"
            elif "MODERATE" in sev: 
                status_color = "#f57c00"
                status_icon = "⚠️"
            else: 
                status_color = "#2e7d32"
                status_icon = "βœ…"

            st.markdown("---")
            st.markdown(f"<h3 style='color:{status_color}!important'>{status_icon} Analysis Result: {sev}</h3>", unsafe_allow_html=True)
            
            col1, col2 = st.columns(2)
            
            with col1:
                st.markdown(f"""
                <div class="medical-card diagnosis">
                    <div class="card-label">Primary Diagnosis</div>
                    <div class="card-content">{res.get('diagnosis_hypothesis', 'Analysis pending...')}</div>
                </div>
                """, unsafe_allow_html=True)
                
            with col2:
                st.markdown(f"""
                <div class="medical-card advice">
                    <div class="card-label">Recommended Action</div>
                    <div class="card-content">{res.get('medical_advice', 'Consult a doctor.')}</div>
                </div>
                """, unsafe_allow_html=True)
                
            if "abnormal_values" in res and isinstance(res["abnormal_values"], list) and len(res["abnormal_values"]) > 0:
                findings_html = "".join([f"<li style='margin-bottom:5px;'>{val}</li>" for val in res["abnormal_values"]])
                st.markdown(f"""
                <div class="medical-card warning">
                    <div class="card-label" style="color:#d32f2f;">⚠️ Critical Findings</div>
                    <ul style="margin-bottom:0; padding-left:20px;">
                        {findings_html}
                    </ul>
                </div>
                """, unsafe_allow_html=True)
            else:
                st.success("No specific abnormal values detected in the extraction.")

    # --- CHAT SECTION ---
    if st.session_state.chunks:
        st.markdown("---")
        st.subheader("πŸ’¬ Doctor's Companion Chat")
        
        if "messages" not in st.session_state: st.session_state.messages = []
        
        for msg in st.session_state.messages:
            with st.chat_message(msg["role"]): st.write(msg["content"])
            
        if prompt := st.chat_input("Ask specific questions about this patient..."):
            st.session_state.messages.append({"role": "user", "content": prompt})
            with st.chat_message("user"): st.write(prompt)
            
            with st.chat_message("assistant"):
                with st.spinner("Reviewing case notes..."):
                    ans = ask_chatbot(prompt, st.session_state.chunks, client)
                    st.write(ans)
            st.session_state.messages.append({"role": "assistant", "content": ans})