import streamlit as st import pandas as pd import numpy as np import nibabel as nib import cv2 import matplotlib.pyplot as plt import seaborn as sns from pathlib import Path import PIL.Image as Image import torch import torch.nn as nn import tempfile from transformers import pipeline from fpdf import FPDF import io # ========================================== # CONFIGURATION & THEME # ========================================== PROJECT_ROOT = Path(__file__).parent MASTER_CSV = PROJECT_ROOT / "data" / "metadata" / "master_dataset.csv" CHECKPOINT_PATH = PROJECT_ROOT / "models" / "checkpoints" / "best_custom_cnn.pth" st.set_page_config(page_title="NeuroVision | Structural Analytics", layout="wide", initial_sidebar_state="expanded") # Research-Grade Theme (UI Preserved) st.markdown(""" """, unsafe_allow_html=True) # JavaScript injection to stop iframe height oscillation import streamlit.components.v1 as components components.html(""" """, height=0) # ========================================== # AI ARCHITECTURE # ========================================== class CustomLightCNN(nn.Module): def __init__(self): super(CustomLightCNN, self).__init__() self.features = nn.Sequential( nn.Conv2d(1, 8, kernel_size=3, padding=1), nn.BatchNorm2d(8), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(8, 16, kernel_size=3, padding=1), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2) ) self.classifier = nn.Sequential( nn.Flatten(), nn.Linear(32 * 28 * 28, 64), nn.ReLU(), nn.Dropout(0.3), nn.Linear(64, 2) ) self.gradients = None def activations_hook(self, grad): self.gradients = grad def forward(self, x): x = self.features[0:10](x) h = x.register_hook(self.activations_hook) x = self.features[10:12](x) x = self.classifier(x) return x def get_activations_gradient(self): return self.gradients def get_activations(self, x): return self.features[0:10](x) @st.cache_resource def load_assets(): device = torch.device("cpu") model = CustomLightCNN() if CHECKPOINT_PATH.exists(): model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=device)) model.eval() return model # ========================================== # ANALYTICS CORE (DEEP STRUCTURAL) # ========================================== def get_age_norm(age): if age < 20: expected_mean = 0.90 else: expected_mean = 0.88 - (0.002 * (age - 20)) return max(expected_mean, 0.65), 0.02 def generate_evidence_summary(metrics, issues, age_desc): """ Evidence-Grounded Dynamic Interpretation. """ summary = [] summary.append(f"### **Structural Analysis Portfolio | Baseline: {age_desc}**") # 1. Structural Observations if issues['tumor_risk'] or issues['asymmetry']: overview = f"Analysis demonstrates focal structural irregularity associated with concentrated saliency activation. Hemispheric balance ({metrics['symmetry']}/100) is moderately to significantly disrupted." elif metrics['health'] > 80: overview = "Structural morphology demonstrates relative preservation with stable hemispheric organization." else: overview = f"Diffuse volumetric variance detected ({metrics['health']}/100) consistent with normative aging patterns." summary.append(f"**Structural Overview:** {overview}") # 2. Localized Morphology findings = [] if issues['tumor_risk']: findings.append("Localized mass-like structural deviation identified; anomalous signal concentration detected in focal regions.") if issues['asymmetry']: findings.append("Significant hemispheric imbalance observed; morphological contours show unilateral deviation.") if not findings: findings.append("Morphology remains broadly consistent with stable structural patterns.") summary.append(f"**Morphological Findings:** {' '.join(findings)}") # 3. Confidence Note summary.append(f"\n*Analysis Certainty: {issues['conf_label']} based on saliency-guided structural reasoning.*") return "\n".join(summary) def perform_structural_analysis(img_slice, model, age, mmse, cdr): # --- 1. BRAIN MASKING (Skull Suppression) --- _, mask = cv2.threshold(img_slice.astype(np.uint8), 50, 255, cv2.THRESH_BINARY) kernel = np.ones((5,5), np.uint8) mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) masked_img = cv2.bitwise_and(img_slice, img_slice, mask=mask) # --- 2. REGIONAL QUADRANT SCAN --- h, w = masked_img.shape mid_h, mid_w = h // 2, w // 2 # Compare Left-Front to Right-Front lf = masked_img[:mid_h, :mid_w]; rf = cv2.flip(masked_img[:mid_h, mid_w:], 1) rf = cv2.resize(rf, (lf.shape[1], lf.shape[0])) quad_diff = np.mean(np.abs(lf.astype(float) - rf.astype(float))) # Symmetry Score (Scaled for realism) symmetry_score = int(max(10, min(95, 100 - (quad_diff * 1.5)))) asymmetry_flag = True if quad_diff > 35 else False # --- 3. FOCAL ANOMALY (Mass Detection) --- # Detect bright 'islands' within brain tissue _, bright_mask = cv2.threshold(masked_img.astype(np.uint8), 200, 255, cv2.THRESH_BINARY) num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(bright_mask) # Check for large bright masses (>1% of image) large_islands = [stats[i, cv2.CC_STAT_AREA] for i in range(1, num_labels) if stats[i, cv2.CC_STAT_AREA] > (h*w*0.005)] tumor_flag = True if len(large_islands) > 0 else False # --- 4. AI MODEL INFERENCE --- img_resized = cv2.resize(masked_img, (224, 224)) img_tensor = torch.tensor(img_resized).float().unsqueeze(0).unsqueeze(0) / 255.0 model.zero_grad() outputs = model(img_tensor) probs = torch.softmax(outputs, dim=1) atrophy_idx = probs[0, 1].item() # --- 5. COHERENT SCORING --- target_mean, target_std = get_age_norm(age) current_nwbv = 0.85 - (atrophy_idx * 0.15) z_score = (current_nwbv - target_mean) / target_std pres_base = 100 * (1 / (1 + np.exp(-z_score))) # Weighting: 90% MRI, 10% Metadata meta_ref = 0 if mmse is not None: meta_ref += (mmse/30 * 5) if cdr is not None: meta_ref += (5 - cdr*2.5) # Integrity drops if anomaly or asymmetry exists structural_integrity = int(pres_base * 0.9 + meta_ref) if tumor_flag or asymmetry_flag: structural_integrity = int(structural_integrity * 0.6) structural_integrity = max(min(structural_integrity, 95), 5) morph_consistency = int(structural_integrity * 0.5 + symmetry_score * 0.5) # Severity Logic if tumor_flag or asymmetry_flag or z_score < -2.0: severity = "Elevated" elif z_score < -1.0: severity = "Mild" else: severity = "Stable" # --- 6. FOCAL SALIENCY (Skull-Masked) --- outputs[0, 1].backward() gradients = model.get_activations_gradient() pooled_gradients = torch.mean(gradients, dim=[0, 2, 3]) activations = model.get_activations(img_tensor).detach() for i in range(32): activations[:, i, :, :] *= pooled_gradients[i] heatmap = torch.mean(activations, dim=1).squeeze().numpy() heatmap = np.maximum(heatmap, 0) # Focus only on top activations INSIDE brain thresh = np.percentile(heatmap, 88) heatmap[heatmap < thresh] = 0 heatmap /= (np.max(heatmap) + 1e-8) # Metadata for summary age_desc = f"~{int(age//5 * 5)}s Pattern" conf_label = "Moderate" if (tumor_flag or asymmetry_flag) else "Preliminary" issues = {'tumor_risk': tumor_flag, 'asymmetry': asymmetry_flag, 'balance': symmetry_score, 'conf_label': conf_label, 'confidence': 72} metrics = {'health': structural_integrity, 'symmetry': symmetry_score, 'severity': severity} summary = generate_evidence_summary(metrics, issues, age_desc) # Regions Focus regions = [] if asymmetry_flag: regions.append("Hemispheric imbalance zone") if tumor_flag: regions.append("Local signal irregularity focus") if not regions: regions.append("Diffuse structural background") return { "health": structural_integrity, "integrity": morph_consistency, "symmetry": symmetry_score, "severity": severity, "confidence": conf_label, "contributions": {"Asymmetry": quad_diff/100, "Voxel Distribution": atrophy_idx, "Metadata Context": meta_ref/10}, "layman": summary, "regions": ", ".join(regions), "heatmap": heatmap, "nwbv": current_nwbv, "sym_mse": quad_diff, "intensity_peak": np.max(masked_img), "atrophy_idx": atrophy_idx, "z_score": z_score } # ========================================== # MAIN APP # ========================================== def main(): st.sidebar.title("🧬 NeuroVision") st.sidebar.markdown("---") model = load_assets() df = pd.read_csv(MASTER_CSV) if MASTER_CSV.exists() else pd.DataFrame() menu = st.sidebar.radio("Navigation", ["⚡ Structural Inference", "🔬 Population Insights"]) if menu == "⚡ Structural Inference": st.title("⚡ Structural Neuro-Analytics") with st.sidebar: st.header("Context Controls") age_input = st.slider("Reference Age", 1, 120, 72) has_meta = st.checkbox("Include Cognitive Reference", value=False) mmse = st.slider("MMSE Baseline", 0, 30, 27) if has_meta else None cdr = st.selectbox("CDR Assessment", [0, 0.5, 1, 2], index=0) if has_meta else 0 uploaded_file = st.file_uploader("Upload MRI Volume / Slice", type=["nii.gz", "nii", "png", "jpg"]) if uploaded_file: with tempfile.NamedTemporaryFile(delete=False, suffix=".nii.gz") as tmp: tmp.write(uploaded_file.getvalue()); tmp_path = tmp.name if uploaded_file.name.endswith(('.nii', '.nii.gz')): vol = nib.load(tmp_path).get_fdata() img_slice = vol[:, :, vol.shape[2]//2] else: img_slice = cv2.imread(tmp_path, cv2.IMREAD_GRAYSCALE) with st.spinner("Executing Structural Analytics Pipeline..."): results = perform_structural_analysis(img_slice, model, age_input, mmse, cdr) # --- Visuals --- c1, c2 = st.columns(2) with c1: st.markdown("### **Structural Anatomy**") # Normalize base image for display disp_img = cv2.normalize(img_slice, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) st.image(disp_img, use_container_width=True, channels="GRAY") with c2: st.markdown("### **Structural Attention Areas**") heatmap_resized = cv2.resize(results['heatmap'], (img_slice.shape[1], img_slice.shape[0])) # Convert heatmap to magma color map heatmap_norm = cv2.normalize(heatmap_resized, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) heatmap_color = cv2.applyColorMap(heatmap_norm, cv2.COLORMAP_MAGMA) # Blend with original image disp_img_bgr = cv2.cvtColor(disp_img, cv2.COLOR_GRAY2BGR) # Only apply color where heatmap is active to keep background clean mask = heatmap_norm > 15 blended = disp_img_bgr.copy() alpha = 0.55 blended[mask] = cv2.addWeighted(heatmap_color, alpha, disp_img_bgr, 1 - alpha, 0)[mask] # Convert to RGB for Streamlit blended_rgb = cv2.cvtColor(blended, cv2.COLOR_BGR2RGB) st.image(blended_rgb, use_container_width=True) # --- Metrics --- st.markdown("### **Quantitative Analytics**") m1, m2, m3, m4 = st.columns(4) m1.metric("Structural Integrity", f"{results['health']} / 100") m2.metric("Morphology Consistency", f"{results['integrity']} / 100") m3.metric("Tissue Symmetry", f"{results['symmetry']} / 100") m4.metric("Severity Index", results['severity']) # --- Analysis --- st.markdown("---") ex1, ex2 = st.columns([1, 1.2]) with ex1: st.subheader("💡 Evidence Analysis") st.bar_chart(pd.DataFrame(list(results['contributions'].items()), columns=['F', 'I']).set_index('F')) st.info(f"**Attention Regions:** {results['regions']}") st.caption(f"**Confidence Profile:** {results['confidence']}") with ex2: st.subheader("📝 Clinical Interpretation") st.info(results['layman']) if st.button("Generate Portfolio Report"): st.success("Analysis Ready.") # Pipeline Trace Table with st.expander("🔍 Pipeline Trace & Technical Data"): st.markdown("### **Technical Analytics Portfolio**") trace_data = { "Metric": ["nWBV", "Symmetry MSE", "Intensity Peak", "Atrophy Probability", "Peer Deviation (Z)"], "Value": [f"{results['nwbv']:.3f}", f"{results['sym_mse']:.1f}", f"{results['intensity_peak']:.1f}", f"{results['atrophy_idx']:.3f}", f"{results['z_score']:.2f}"], "Status": ["Extracted", "Region-Scanned", "Island-Detected", "Computed", "Calibrated"] } st.table(pd.DataFrame(trace_data)) st.markdown("---") st.markdown("### **Biomarker Glossary**") st.write("**Structural Integrity:** Brain volume preservation relative to skull size.") st.write("**Tissue Symmetry:** Hemispheric voxel comparison (MSE).") st.markdown('

Disclaimer: Research-use only. Not a medical device. Clinical correlation required.

', unsafe_allow_html=True) elif menu == "🔬 Population Insights": st.title("🔬 Population Morphometry") if not df.empty: fig, ax = plt.subplots(figsize=(10, 5)) sns.regplot(data=df, x='age', y='nWBV', ax=ax, scatter_kws={'alpha':0.2}, line_kws={'color':'red'}) st.pyplot(fig) plt.close(fig) if __name__ == "__main__": main()