NeuroVision-App / app.py
Prasannata's picture
Upload app.py with huggingface_hub
94a95da verified
import streamlit as st
import pandas as pd
import numpy as np
import nibabel as nib
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import PIL.Image as Image
import torch
import torch.nn as nn
import tempfile
from transformers import pipeline
from fpdf import FPDF
import io
# ==========================================
# CONFIGURATION & THEME
# ==========================================
PROJECT_ROOT = Path(__file__).parent
MASTER_CSV = PROJECT_ROOT / "data" / "metadata" / "master_dataset.csv"
CHECKPOINT_PATH = PROJECT_ROOT / "models" / "checkpoints" / "best_custom_cnn.pth"
st.set_page_config(page_title="NeuroVision | Structural Analytics", layout="wide", initial_sidebar_state="expanded")
# Research-Grade Theme (UI Preserved)
st.markdown("""
<style>
/* === CORE DARK THEME === */
html, body, .stApp, [data-testid="stAppViewContainer"] {
background-color: #0f1216 !important;
color: #e1e4e8 !important;
}
.main, [data-testid="stMain"] {
background-color: #0f1216 !important;
color: #e1e4e8 !important;
}
/* === SIDEBAR === */
[data-testid="stSidebar"], [data-testid="stSidebar"] > div {
background-color: #161b22 !important;
color: #e1e4e8 !important;
}
[data-testid="stSidebar"] .stRadio label,
[data-testid="stSidebar"] .stCheckbox label,
[data-testid="stSidebar"] .stSlider label,
[data-testid="stSidebar"] h1, [data-testid="stSidebar"] h2,
[data-testid="stSidebar"] h3, [data-testid="stSidebar"] p,
[data-testid="stSidebar"] span, [data-testid="stSidebar"] div {
color: #e1e4e8 !important;
}
/* === METRIC CARDS === */
[data-testid="stMetric"], [data-testid="metric-container"] {
background-color: #1c2128 !important;
padding: 20px !important;
border-radius: 12px !important;
border: 1px solid #30363d !important;
min-height: 80px;
}
[data-testid="stMetric"] label, [data-testid="stMetricLabel"] {
color: #8b949e !important;
}
[data-testid="stMetricValue"], [data-testid="stMetric"] [data-testid="stMetricValue"] {
color: #e1e4e8 !important;
}
/* === HEADINGS & TEXT === */
h1, h2, h3 { color: #58a6ff !important; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; }
p, span, label, div { color: #e1e4e8; }
.disclaimer { font-size: 0.8rem; color: #8b949e; font-style: italic; margin-top: 30px; }
/* === CONTAINERS & BOXES === */
.findings-box { background-color: #161b22; border-radius: 10px; padding: 20px; border: 1px solid #30363d; margin-top: 20px; }
[data-testid="stExpander"] { background-color: #161b22 !important; border: 1px solid #30363d !important; border-radius: 8px; }
[data-testid="stExpander"] summary { color: #e1e4e8 !important; }
.stAlert, [data-testid="stAlert"] { background-color: #1c2128 !important; border-color: #30363d !important; color: #e1e4e8 !important; }
/* === INPUTS === */
.stSelectbox, .stSlider, .stFileUploader {
color: #e1e4e8 !important;
}
[data-testid="stFileUploader"] { background-color: #161b22 !important; border-color: #30363d !important; border-radius: 8px; }
.stButton > button { background-color: #238636 !important; color: #ffffff !important; border: none !important; border-radius: 6px; }
.stButton > button:hover { background-color: #2ea043 !important; }
/* === TABLES === */
.stTable, table, thead, tbody, th, td { background-color: #161b22 !important; color: #e1e4e8 !important; border-color: #30363d !important; }
/* === CHARTS === */
.vega-embed { width: 100% !important; }
/* === ANTI-SHAKE === */
html, body { overflow-y: scroll !important; overflow-x: hidden !important; }
.stApp { min-height: 100vh; }
.block-container { max-width: 100% !important; padding-top: 1rem !important; }
[data-testid="stImage"] img { display: block; width: 100%; height: auto; }
button[title="View fullscreen"] { display: none !important; }
</style>
""", unsafe_allow_html=True)
# JavaScript injection to stop iframe height oscillation
import streamlit.components.v1 as components
components.html("""
<script>
// Override Streamlit's iframe resizer to prevent height oscillation
(function() {
let lastHeight = 0;
const orig = window.parent.postMessage;
window.parent.postMessage = function(msg, origin) {
if (typeof msg === 'object' && msg !== null && msg.type === 'streamlit:setFrameHeight') {
const newH = msg.height || 0;
// Only allow height INCREASES, never decreases (prevents oscillation)
if (newH < lastHeight) return;
lastHeight = newH;
}
return orig.call(window.parent, msg, origin);
};
})();
</script>
""", height=0)
# ==========================================
# AI ARCHITECTURE
# ==========================================
class CustomLightCNN(nn.Module):
def __init__(self):
super(CustomLightCNN, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=3, padding=1),
nn.BatchNorm2d(8), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(8, 16, kernel_size=3, padding=1),
nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2)
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(32 * 28 * 28, 64),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(64, 2)
)
self.gradients = None
def activations_hook(self, grad):
self.gradients = grad
def forward(self, x):
x = self.features[0:10](x)
h = x.register_hook(self.activations_hook)
x = self.features[10:12](x)
x = self.classifier(x)
return x
def get_activations_gradient(self):
return self.gradients
def get_activations(self, x):
return self.features[0:10](x)
@st.cache_resource
def load_assets():
device = torch.device("cpu")
model = CustomLightCNN()
if CHECKPOINT_PATH.exists():
model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=device))
model.eval()
return model
# ==========================================
# ANALYTICS CORE (DEEP STRUCTURAL)
# ==========================================
def get_age_norm(age):
if age < 20: expected_mean = 0.90
else: expected_mean = 0.88 - (0.002 * (age - 20))
return max(expected_mean, 0.65), 0.02
def generate_evidence_summary(metrics, issues, age_desc):
"""
Evidence-Grounded Dynamic Interpretation.
"""
summary = []
summary.append(f"### **Structural Analysis Portfolio | Baseline: {age_desc}**")
# 1. Structural Observations
if issues['tumor_risk'] or issues['asymmetry']:
overview = f"Analysis demonstrates focal structural irregularity associated with concentrated saliency activation. Hemispheric balance ({metrics['symmetry']}/100) is moderately to significantly disrupted."
elif metrics['health'] > 80:
overview = "Structural morphology demonstrates relative preservation with stable hemispheric organization."
else:
overview = f"Diffuse volumetric variance detected ({metrics['health']}/100) consistent with normative aging patterns."
summary.append(f"**Structural Overview:** {overview}")
# 2. Localized Morphology
findings = []
if issues['tumor_risk']:
findings.append("Localized mass-like structural deviation identified; anomalous signal concentration detected in focal regions.")
if issues['asymmetry']:
findings.append("Significant hemispheric imbalance observed; morphological contours show unilateral deviation.")
if not findings:
findings.append("Morphology remains broadly consistent with stable structural patterns.")
summary.append(f"**Morphological Findings:** {' '.join(findings)}")
# 3. Confidence Note
summary.append(f"\n*Analysis Certainty: {issues['conf_label']} based on saliency-guided structural reasoning.*")
return "\n".join(summary)
def perform_structural_analysis(img_slice, model, age, mmse, cdr):
# --- 1. BRAIN MASKING (Skull Suppression) ---
_, mask = cv2.threshold(img_slice.astype(np.uint8), 50, 255, cv2.THRESH_BINARY)
kernel = np.ones((5,5), np.uint8)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
masked_img = cv2.bitwise_and(img_slice, img_slice, mask=mask)
# --- 2. REGIONAL QUADRANT SCAN ---
h, w = masked_img.shape
mid_h, mid_w = h // 2, w // 2
# Compare Left-Front to Right-Front
lf = masked_img[:mid_h, :mid_w]; rf = cv2.flip(masked_img[:mid_h, mid_w:], 1)
rf = cv2.resize(rf, (lf.shape[1], lf.shape[0]))
quad_diff = np.mean(np.abs(lf.astype(float) - rf.astype(float)))
# Symmetry Score (Scaled for realism)
symmetry_score = int(max(10, min(95, 100 - (quad_diff * 1.5))))
asymmetry_flag = True if quad_diff > 35 else False
# --- 3. FOCAL ANOMALY (Mass Detection) ---
# Detect bright 'islands' within brain tissue
_, bright_mask = cv2.threshold(masked_img.astype(np.uint8), 200, 255, cv2.THRESH_BINARY)
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(bright_mask)
# Check for large bright masses (>1% of image)
large_islands = [stats[i, cv2.CC_STAT_AREA] for i in range(1, num_labels) if stats[i, cv2.CC_STAT_AREA] > (h*w*0.005)]
tumor_flag = True if len(large_islands) > 0 else False
# --- 4. AI MODEL INFERENCE ---
img_resized = cv2.resize(masked_img, (224, 224))
img_tensor = torch.tensor(img_resized).float().unsqueeze(0).unsqueeze(0) / 255.0
model.zero_grad()
outputs = model(img_tensor)
probs = torch.softmax(outputs, dim=1)
atrophy_idx = probs[0, 1].item()
# --- 5. COHERENT SCORING ---
target_mean, target_std = get_age_norm(age)
current_nwbv = 0.85 - (atrophy_idx * 0.15)
z_score = (current_nwbv - target_mean) / target_std
pres_base = 100 * (1 / (1 + np.exp(-z_score)))
# Weighting: 90% MRI, 10% Metadata
meta_ref = 0
if mmse is not None: meta_ref += (mmse/30 * 5)
if cdr is not None: meta_ref += (5 - cdr*2.5)
# Integrity drops if anomaly or asymmetry exists
structural_integrity = int(pres_base * 0.9 + meta_ref)
if tumor_flag or asymmetry_flag:
structural_integrity = int(structural_integrity * 0.6)
structural_integrity = max(min(structural_integrity, 95), 5)
morph_consistency = int(structural_integrity * 0.5 + symmetry_score * 0.5)
# Severity Logic
if tumor_flag or asymmetry_flag or z_score < -2.0: severity = "Elevated"
elif z_score < -1.0: severity = "Mild"
else: severity = "Stable"
# --- 6. FOCAL SALIENCY (Skull-Masked) ---
outputs[0, 1].backward()
gradients = model.get_activations_gradient()
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
activations = model.get_activations(img_tensor).detach()
for i in range(32):
activations[:, i, :, :] *= pooled_gradients[i]
heatmap = torch.mean(activations, dim=1).squeeze().numpy()
heatmap = np.maximum(heatmap, 0)
# Focus only on top activations INSIDE brain
thresh = np.percentile(heatmap, 88)
heatmap[heatmap < thresh] = 0
heatmap /= (np.max(heatmap) + 1e-8)
# Metadata for summary
age_desc = f"~{int(age//5 * 5)}s Pattern"
conf_label = "Moderate" if (tumor_flag or asymmetry_flag) else "Preliminary"
issues = {'tumor_risk': tumor_flag, 'asymmetry': asymmetry_flag, 'balance': symmetry_score, 'conf_label': conf_label, 'confidence': 72}
metrics = {'health': structural_integrity, 'symmetry': symmetry_score, 'severity': severity}
summary = generate_evidence_summary(metrics, issues, age_desc)
# Regions Focus
regions = []
if asymmetry_flag: regions.append("Hemispheric imbalance zone")
if tumor_flag: regions.append("Local signal irregularity focus")
if not regions: regions.append("Diffuse structural background")
return {
"health": structural_integrity,
"integrity": morph_consistency,
"symmetry": symmetry_score,
"severity": severity,
"confidence": conf_label,
"contributions": {"Asymmetry": quad_diff/100, "Voxel Distribution": atrophy_idx, "Metadata Context": meta_ref/10},
"layman": summary,
"regions": ", ".join(regions),
"heatmap": heatmap,
"nwbv": current_nwbv,
"sym_mse": quad_diff,
"intensity_peak": np.max(masked_img),
"atrophy_idx": atrophy_idx,
"z_score": z_score
}
# ==========================================
# MAIN APP
# ==========================================
def main():
st.sidebar.title("🧬 NeuroVision")
st.sidebar.markdown("---")
model = load_assets()
df = pd.read_csv(MASTER_CSV) if MASTER_CSV.exists() else pd.DataFrame()
menu = st.sidebar.radio("Navigation", ["⚑ Structural Inference", "πŸ”¬ Population Insights"])
if menu == "⚑ Structural Inference":
st.title("⚑ Structural Neuro-Analytics")
with st.sidebar:
st.header("Context Controls")
age_input = st.slider("Reference Age", 1, 120, 72)
has_meta = st.checkbox("Include Cognitive Reference", value=False)
mmse = st.slider("MMSE Baseline", 0, 30, 27) if has_meta else None
cdr = st.selectbox("CDR Assessment", [0, 0.5, 1, 2], index=0) if has_meta else 0
uploaded_file = st.file_uploader("Upload MRI Volume / Slice", type=["nii.gz", "nii", "png", "jpg"])
if uploaded_file:
with tempfile.NamedTemporaryFile(delete=False, suffix=".nii.gz") as tmp:
tmp.write(uploaded_file.getvalue()); tmp_path = tmp.name
if uploaded_file.name.endswith(('.nii', '.nii.gz')):
vol = nib.load(tmp_path).get_fdata()
img_slice = vol[:, :, vol.shape[2]//2]
else:
img_slice = cv2.imread(tmp_path, cv2.IMREAD_GRAYSCALE)
with st.spinner("Executing Structural Analytics Pipeline..."):
results = perform_structural_analysis(img_slice, model, age_input, mmse, cdr)
# --- Visuals ---
c1, c2 = st.columns(2)
with c1:
st.markdown("### **Structural Anatomy**")
# Normalize base image for display
disp_img = cv2.normalize(img_slice, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
st.image(disp_img, use_container_width=True, channels="GRAY")
with c2:
st.markdown("### **Structural Attention Areas**")
heatmap_resized = cv2.resize(results['heatmap'], (img_slice.shape[1], img_slice.shape[0]))
# Convert heatmap to magma color map
heatmap_norm = cv2.normalize(heatmap_resized, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
heatmap_color = cv2.applyColorMap(heatmap_norm, cv2.COLORMAP_MAGMA)
# Blend with original image
disp_img_bgr = cv2.cvtColor(disp_img, cv2.COLOR_GRAY2BGR)
# Only apply color where heatmap is active to keep background clean
mask = heatmap_norm > 15
blended = disp_img_bgr.copy()
alpha = 0.55
blended[mask] = cv2.addWeighted(heatmap_color, alpha, disp_img_bgr, 1 - alpha, 0)[mask]
# Convert to RGB for Streamlit
blended_rgb = cv2.cvtColor(blended, cv2.COLOR_BGR2RGB)
st.image(blended_rgb, use_container_width=True)
# --- Metrics ---
st.markdown("### **Quantitative Analytics**")
m1, m2, m3, m4 = st.columns(4)
m1.metric("Structural Integrity", f"{results['health']} / 100")
m2.metric("Morphology Consistency", f"{results['integrity']} / 100")
m3.metric("Tissue Symmetry", f"{results['symmetry']} / 100")
m4.metric("Severity Index", results['severity'])
# --- Analysis ---
st.markdown("---")
ex1, ex2 = st.columns([1, 1.2])
with ex1:
st.subheader("πŸ’‘ Evidence Analysis")
st.bar_chart(pd.DataFrame(list(results['contributions'].items()), columns=['F', 'I']).set_index('F'))
st.info(f"**Attention Regions:** {results['regions']}")
st.caption(f"**Confidence Profile:** {results['confidence']}")
with ex2:
st.subheader("πŸ“ Clinical Interpretation")
st.info(results['layman'])
if st.button("Generate Portfolio Report"):
st.success("Analysis Ready.")
# Pipeline Trace Table
with st.expander("πŸ” Pipeline Trace & Technical Data"):
st.markdown("### **Technical Analytics Portfolio**")
trace_data = {
"Metric": ["nWBV", "Symmetry MSE", "Intensity Peak", "Atrophy Probability", "Peer Deviation (Z)"],
"Value": [f"{results['nwbv']:.3f}", f"{results['sym_mse']:.1f}", f"{results['intensity_peak']:.1f}", f"{results['atrophy_idx']:.3f}", f"{results['z_score']:.2f}"],
"Status": ["Extracted", "Region-Scanned", "Island-Detected", "Computed", "Calibrated"]
}
st.table(pd.DataFrame(trace_data))
st.markdown("---")
st.markdown("### **Biomarker Glossary**")
st.write("**Structural Integrity:** Brain volume preservation relative to skull size.")
st.write("**Tissue Symmetry:** Hemispheric voxel comparison (MSE).")
st.markdown('<p class="disclaimer">Disclaimer: Research-use only. Not a medical device. Clinical correlation required.</p>', unsafe_allow_html=True)
elif menu == "πŸ”¬ Population Insights":
st.title("πŸ”¬ Population Morphometry")
if not df.empty:
fig, ax = plt.subplots(figsize=(10, 5))
sns.regplot(data=df, x='age', y='nWBV', ax=ax, scatter_kws={'alpha':0.2}, line_kws={'color':'red'})
st.pyplot(fig)
plt.close(fig)
if __name__ == "__main__":
main()