ChestXpert / app.py
31puneet's picture
Fix lazy model loading and remove HF Hub download
19519bf
import os
import io
import json
import base64
import uuid
from datetime import datetime
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torchvision import models
from transformers import AutoModel
import albumentations as A
from albumentations.pytorch import ToTensorV2
from flask import Flask, render_template, request, jsonify, send_from_directory
app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max upload
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TARGET_LABELS = ["Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Pleural Effusion"]
LABEL_INFO = {
"Atelectasis": {
"description": "Partial or complete collapse of the lung or a section of the lung.",
"icon": ""
},
"Cardiomegaly": {
"description": "Enlargement of the heart, often indicating heart disease.",
"icon": ""
},
"Consolidation": {
"description": "Region of lung tissue filled with liquid instead of air.",
"icon": ""
},
"Edema": {
"description": "Excess fluid in the lungs, often due to heart failure.",
"icon": ""
},
"Pleural Effusion": {
"description": "Buildup of fluid between the lung and chest wall.",
"icon": ""
}
}
ENSEMBLE_WEIGHT_RD = 0.60
ENSEMBLE_WEIGHT_DN = 0.40
MODEL_DIR = os.path.join(os.path.dirname(__file__), 'models')
SAMPLES_DIR = os.path.join(os.path.dirname(__file__), 'samples')
# In-memory store for analysis results (for report generation)
analysis_store = {}
# --- Model Definitions ---
class RADDINOClassifier(nn.Module):
def __init__(self, num_classes=5, dropout=0.3):
super().__init__()
from transformers import AutoConfig
config = AutoConfig.from_pretrained("microsoft/rad-dino")
self.backbone = AutoModel.from_config(config)
self.hidden_dim = self.backbone.config.hidden_size
self.classifier = nn.Sequential(
nn.LayerNorm(self.hidden_dim),
nn.Dropout(dropout),
nn.Linear(self.hidden_dim, 256),
nn.GELU(),
nn.Dropout(dropout / 2),
nn.Linear(256, num_classes)
)
def forward(self, x):
features = self.backbone(x).last_hidden_state[:, 0]
return self.classifier(features)
class DenseNetClassifier(nn.Module):
def __init__(self, num_classes=5, dropout=0.4):
super().__init__()
self.backbone = models.densenet121(weights=None)
nf = self.backbone.classifier.in_features
self.backbone.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(nf, 256),
nn.ReLU(),
nn.Dropout(dropout / 2),
nn.Linear(256, num_classes)
)
def forward(self, x):
return self.backbone(x)
# --- Grad-CAM ---
class GradCAM:
def __init__(self, model):
self.model = model
self.gradients = None
self.activations = None
target_layer = model.backbone.features.denseblock4
target_layer.register_forward_hook(self._forward_hook)
target_layer.register_full_backward_hook(self._backward_hook)
def _forward_hook(self, module, input, output):
self.activations = output.detach()
def _backward_hook(self, module, grad_input, grad_output):
self.gradients = grad_output[0].detach()
def generate(self, input_tensor, class_idx=None):
self.model.eval()
input_tensor.requires_grad_(True)
output = self.model(input_tensor)
if class_idx is None:
class_idx = output.sigmoid().mean(dim=0).argmax().item()
self.model.zero_grad()
target = output[0, class_idx]
target.backward()
gradients = self.gradients[0]
activations = self.activations[0]
weights = gradients.mean(dim=(1, 2), keepdim=True)
cam = (weights * activations).sum(dim=0)
cam = torch.relu(cam)
cam = cam - cam.min()
if cam.max() > 0:
cam = cam / cam.max()
return cam.cpu().numpy()
# --- Image Preprocessing ---
def get_transform(size):
return A.Compose([
A.Resize(size, size),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
transform_384 = get_transform(384)
transform_320 = get_transform(320)
def preprocess_image(image_bytes, transform):
img = Image.open(io.BytesIO(image_bytes)).convert('RGB')
img_np = np.array(img)
augmented = transform(image=img_np)
tensor = augmented['image'].unsqueeze(0).to(DEVICE)
return tensor, img_np
# --- DICOM Support ---
def read_dicom_as_bytes(file_bytes):
"""Convert DICOM file bytes to standard image bytes."""
try:
import pydicom
ds = pydicom.dcmread(io.BytesIO(file_bytes))
pixel_array = ds.pixel_array
# Normalize to 0-255
arr = pixel_array.astype(float)
if arr.max() != arr.min():
arr = (arr - arr.min()) / (arr.max() - arr.min()) * 255
arr = arr.astype(np.uint8)
# Handle MONOCHROME1 (inverted)
if hasattr(ds, 'PhotometricInterpretation'):
if ds.PhotometricInterpretation == 'MONOCHROME1':
arr = 255 - arr
img = Image.fromarray(arr).convert('RGB')
buffer = io.BytesIO()
img.save(buffer, format='PNG')
return buffer.getvalue()
except Exception as e:
raise ValueError(f"Failed to read DICOM file: {str(e)}")
# --- Heatmap Generation ---
def create_heatmap_overlay(original_img, cam, alpha=0.4):
import cv2
h, w = original_img.shape[:2]
cam_resized = cv2.resize(cam, (w, h))
heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
overlay = np.float32(heatmap) * alpha + np.float32(original_img) * (1 - alpha)
overlay = np.clip(overlay, 0, 255).astype(np.uint8)
img_pil = Image.fromarray(overlay)
buffer = io.BytesIO()
img_pil.save(buffer, format='PNG')
return base64.b64encode(buffer.getvalue()).decode('utf-8')
def image_to_base64(img_np):
img_pil = Image.fromarray(img_np)
buffer = io.BytesIO()
img_pil.save(buffer, format='PNG')
return base64.b64encode(buffer.getvalue()).decode('utf-8')
# --- Load Models ---
models_loaded = False
def load_models():
global rd_model, dn_model, grad_cam, models_loaded
if models_loaded:
return
models_loaded = True
rd_path = os.path.join(MODEL_DIR, 'rad_dino_best.pth')
dn_path = os.path.join(MODEL_DIR, 'densenet_best.pth')
if os.path.exists(dn_path):
print("Loading DenseNet121...")
dn_model = DenseNetClassifier(num_classes=5, dropout=0.4)
state = torch.load(dn_path, map_location='cpu', weights_only=True)
dn_model.load_state_dict(state)
dn_model.to(DEVICE).eval()
grad_cam = GradCAM(dn_model)
print("[OK] DenseNet121 loaded")
else:
print(f"[WARN] DenseNet weights not found at {dn_path}")
if os.path.exists(rd_path):
print("Loading RAD-DINO...")
rd_model = RADDINOClassifier(num_classes=5, dropout=0.3)
state = torch.load(rd_path, map_location='cpu', weights_only=True)
rd_model.load_state_dict(state)
rd_model.to(DEVICE).eval()
print("[OK] RAD-DINO loaded")
else:
print(f"[WARN] RAD-DINO weights not found at {rd_path}")
# --- Core prediction logic ---
def run_prediction(image_bytes):
"""Run ensemble prediction and return results dict."""
heatmaps = {}
# DenseNet prediction + Grad-CAM
dn_probs = None
if dn_model is not None:
tensor_320, img_np = preprocess_image(image_bytes, transform_320)
with torch.no_grad():
logits = dn_model(tensor_320)
dn_probs = torch.sigmoid(logits).cpu().numpy()[0]
for i, label in enumerate(TARGET_LABELS):
tensor_for_cam, _ = preprocess_image(image_bytes, transform_320)
cam = grad_cam.generate(tensor_for_cam, class_idx=i)
heatmaps[label] = create_heatmap_overlay(img_np, cam, alpha=0.45)
# RAD-DINO prediction
rd_probs = None
if rd_model is not None:
tensor_384, img_np = preprocess_image(image_bytes, transform_384)
with torch.no_grad():
logits = rd_model(tensor_384)
rd_probs = torch.sigmoid(logits).cpu().numpy()[0]
# Ensemble
if rd_probs is not None and dn_probs is not None:
ensemble_probs = ENSEMBLE_WEIGHT_RD * rd_probs + ENSEMBLE_WEIGHT_DN * dn_probs
elif rd_probs is not None:
ensemble_probs = rd_probs
elif dn_probs is not None:
ensemble_probs = dn_probs
else:
return None
original_b64 = image_to_base64(img_np)
results = []
for i, label in enumerate(TARGET_LABELS):
prob = float(ensemble_probs[i])
risk = 'high' if prob > 0.6 else ('medium' if prob > 0.3 else 'low')
results.append({
'label': label,
'probability': round(prob * 100, 1),
'risk': risk,
'description': LABEL_INFO[label]['description'],
'icon': LABEL_INFO[label]['icon'],
'heatmap': heatmaps.get(label, ''),
'rd_prob': round(float(rd_probs[i]) * 100, 1) if rd_probs is not None else None,
'dn_prob': round(float(dn_probs[i]) * 100, 1) if dn_probs is not None else None,
})
results.sort(key=lambda x: x['probability'], reverse=True)
return {
'success': True,
'results': results,
'original_image': original_b64,
'models_used': {
'rad_dino': rd_probs is not None,
'densenet': dn_probs is not None,
'ensemble': rd_probs is not None and dn_probs is not None,
}
}
# --- Routes ---
@app.route('/')
def index():
return render_template('index.html')
@app.route('/analyze')
def analyze():
return render_template('analyze.html')
@app.route('/login')
def login():
return render_template('login.html')
@app.route('/register')
def register():
return render_template('register.html')
@app.route('/about')
def about():
return render_template('about.html')
@app.route('/history')
def history():
return render_template('history.html')
@app.route('/compare')
def compare():
return render_template('compare.html')
@app.route('/report/<analysis_id>')
def report(analysis_id):
data = analysis_store.get(analysis_id)
if not data:
return render_template('report.html', error=True)
return render_template('report.html', error=False, data=json.dumps(data))
@app.route('/samples')
def samples_page():
return render_template('analyze.html', show_samples=True)
# --- API Endpoints ---
@app.route('/predict', methods=['POST'])
def predict():
load_models()
if 'file' not in request.files:
return jsonify({'error': 'No file uploaded'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No file selected'}), 400
allowed = {'png', 'jpg', 'jpeg', 'bmp', 'dcm', 'dicom'}
ext = file.filename.rsplit('.', 1)[-1].lower() if '.' in file.filename else ''
if ext not in allowed:
return jsonify({'error': f'File type .{ext} not supported'}), 400
image_bytes = file.read()
# DICOM handling
if ext in ('dcm', 'dicom'):
try:
image_bytes = read_dicom_as_bytes(image_bytes)
except ValueError as e:
return jsonify({'error': str(e)}), 400
result = run_prediction(image_bytes)
if result is None:
return jsonify({'error': 'No models loaded'}), 500
# Store result for report generation
analysis_id = str(uuid.uuid4())[:8]
result['analysis_id'] = analysis_id
result['timestamp'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
result['filename'] = file.filename
analysis_store[analysis_id] = result
# Keep only last 50 analyses in memory
if len(analysis_store) > 50:
oldest_key = next(iter(analysis_store))
del analysis_store[oldest_key]
return jsonify(result)
@app.route('/api/samples')
def api_samples():
"""List available sample X-ray images."""
samples = []
if os.path.exists(SAMPLES_DIR):
for f in os.listdir(SAMPLES_DIR):
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
name = os.path.splitext(f)[0].replace('_', ' ').replace('-', ' ').title()
samples.append({
'filename': f,
'name': name,
'url': f'/samples/{f}'
})
return jsonify(samples)
@app.route('/samples/<path:filename>')
def serve_sample(filename):
return send_from_directory(SAMPLES_DIR, filename)
@app.route('/health')
def health():
return jsonify({
'status': 'ok',
'models': {
'rad_dino': rd_model is not None,
'densenet': dn_model is not None,
},
'device': str(DEVICE)
})
if __name__ == '__main__':
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(SAMPLES_DIR, exist_ok=True)
os.makedirs('uploads', exist_ok=True)
app.run(debug=False, host='0.0.0.0', port=int(os.environ.get('PORT', 7860)))