Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import json | |
| import base64 | |
| import uuid | |
| from datetime import datetime | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models | |
| from transformers import AutoModel | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| from flask import Flask, render_template, request, jsonify, send_from_directory | |
| app = Flask(__name__) | |
| app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max upload | |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| TARGET_LABELS = ["Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Pleural Effusion"] | |
| LABEL_INFO = { | |
| "Atelectasis": { | |
| "description": "Partial or complete collapse of the lung or a section of the lung.", | |
| "icon": "" | |
| }, | |
| "Cardiomegaly": { | |
| "description": "Enlargement of the heart, often indicating heart disease.", | |
| "icon": "" | |
| }, | |
| "Consolidation": { | |
| "description": "Region of lung tissue filled with liquid instead of air.", | |
| "icon": "" | |
| }, | |
| "Edema": { | |
| "description": "Excess fluid in the lungs, often due to heart failure.", | |
| "icon": "" | |
| }, | |
| "Pleural Effusion": { | |
| "description": "Buildup of fluid between the lung and chest wall.", | |
| "icon": "" | |
| } | |
| } | |
| ENSEMBLE_WEIGHT_RD = 0.60 | |
| ENSEMBLE_WEIGHT_DN = 0.40 | |
| MODEL_DIR = os.path.join(os.path.dirname(__file__), 'models') | |
| SAMPLES_DIR = os.path.join(os.path.dirname(__file__), 'samples') | |
| # In-memory store for analysis results (for report generation) | |
| analysis_store = {} | |
| # --- Model Definitions --- | |
| class RADDINOClassifier(nn.Module): | |
| def __init__(self, num_classes=5, dropout=0.3): | |
| super().__init__() | |
| from transformers import AutoConfig | |
| config = AutoConfig.from_pretrained("microsoft/rad-dino") | |
| self.backbone = AutoModel.from_config(config) | |
| self.hidden_dim = self.backbone.config.hidden_size | |
| self.classifier = nn.Sequential( | |
| nn.LayerNorm(self.hidden_dim), | |
| nn.Dropout(dropout), | |
| nn.Linear(self.hidden_dim, 256), | |
| nn.GELU(), | |
| nn.Dropout(dropout / 2), | |
| nn.Linear(256, num_classes) | |
| ) | |
| def forward(self, x): | |
| features = self.backbone(x).last_hidden_state[:, 0] | |
| return self.classifier(features) | |
| class DenseNetClassifier(nn.Module): | |
| def __init__(self, num_classes=5, dropout=0.4): | |
| super().__init__() | |
| self.backbone = models.densenet121(weights=None) | |
| nf = self.backbone.classifier.in_features | |
| self.backbone.classifier = nn.Sequential( | |
| nn.Dropout(dropout), | |
| nn.Linear(nf, 256), | |
| nn.ReLU(), | |
| nn.Dropout(dropout / 2), | |
| nn.Linear(256, num_classes) | |
| ) | |
| def forward(self, x): | |
| return self.backbone(x) | |
| # --- Grad-CAM --- | |
| class GradCAM: | |
| def __init__(self, model): | |
| self.model = model | |
| self.gradients = None | |
| self.activations = None | |
| target_layer = model.backbone.features.denseblock4 | |
| target_layer.register_forward_hook(self._forward_hook) | |
| target_layer.register_full_backward_hook(self._backward_hook) | |
| def _forward_hook(self, module, input, output): | |
| self.activations = output.detach() | |
| def _backward_hook(self, module, grad_input, grad_output): | |
| self.gradients = grad_output[0].detach() | |
| def generate(self, input_tensor, class_idx=None): | |
| self.model.eval() | |
| input_tensor.requires_grad_(True) | |
| output = self.model(input_tensor) | |
| if class_idx is None: | |
| class_idx = output.sigmoid().mean(dim=0).argmax().item() | |
| self.model.zero_grad() | |
| target = output[0, class_idx] | |
| target.backward() | |
| gradients = self.gradients[0] | |
| activations = self.activations[0] | |
| weights = gradients.mean(dim=(1, 2), keepdim=True) | |
| cam = (weights * activations).sum(dim=0) | |
| cam = torch.relu(cam) | |
| cam = cam - cam.min() | |
| if cam.max() > 0: | |
| cam = cam / cam.max() | |
| return cam.cpu().numpy() | |
| # --- Image Preprocessing --- | |
| def get_transform(size): | |
| return A.Compose([ | |
| A.Resize(size, size), | |
| A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ToTensorV2() | |
| ]) | |
| transform_384 = get_transform(384) | |
| transform_320 = get_transform(320) | |
| def preprocess_image(image_bytes, transform): | |
| img = Image.open(io.BytesIO(image_bytes)).convert('RGB') | |
| img_np = np.array(img) | |
| augmented = transform(image=img_np) | |
| tensor = augmented['image'].unsqueeze(0).to(DEVICE) | |
| return tensor, img_np | |
| # --- DICOM Support --- | |
| def read_dicom_as_bytes(file_bytes): | |
| """Convert DICOM file bytes to standard image bytes.""" | |
| try: | |
| import pydicom | |
| ds = pydicom.dcmread(io.BytesIO(file_bytes)) | |
| pixel_array = ds.pixel_array | |
| # Normalize to 0-255 | |
| arr = pixel_array.astype(float) | |
| if arr.max() != arr.min(): | |
| arr = (arr - arr.min()) / (arr.max() - arr.min()) * 255 | |
| arr = arr.astype(np.uint8) | |
| # Handle MONOCHROME1 (inverted) | |
| if hasattr(ds, 'PhotometricInterpretation'): | |
| if ds.PhotometricInterpretation == 'MONOCHROME1': | |
| arr = 255 - arr | |
| img = Image.fromarray(arr).convert('RGB') | |
| buffer = io.BytesIO() | |
| img.save(buffer, format='PNG') | |
| return buffer.getvalue() | |
| except Exception as e: | |
| raise ValueError(f"Failed to read DICOM file: {str(e)}") | |
| # --- Heatmap Generation --- | |
| def create_heatmap_overlay(original_img, cam, alpha=0.4): | |
| import cv2 | |
| h, w = original_img.shape[:2] | |
| cam_resized = cv2.resize(cam, (w, h)) | |
| heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET) | |
| heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) | |
| overlay = np.float32(heatmap) * alpha + np.float32(original_img) * (1 - alpha) | |
| overlay = np.clip(overlay, 0, 255).astype(np.uint8) | |
| img_pil = Image.fromarray(overlay) | |
| buffer = io.BytesIO() | |
| img_pil.save(buffer, format='PNG') | |
| return base64.b64encode(buffer.getvalue()).decode('utf-8') | |
| def image_to_base64(img_np): | |
| img_pil = Image.fromarray(img_np) | |
| buffer = io.BytesIO() | |
| img_pil.save(buffer, format='PNG') | |
| return base64.b64encode(buffer.getvalue()).decode('utf-8') | |
| # --- Load Models --- | |
| models_loaded = False | |
| def load_models(): | |
| global rd_model, dn_model, grad_cam, models_loaded | |
| if models_loaded: | |
| return | |
| models_loaded = True | |
| rd_path = os.path.join(MODEL_DIR, 'rad_dino_best.pth') | |
| dn_path = os.path.join(MODEL_DIR, 'densenet_best.pth') | |
| if os.path.exists(dn_path): | |
| print("Loading DenseNet121...") | |
| dn_model = DenseNetClassifier(num_classes=5, dropout=0.4) | |
| state = torch.load(dn_path, map_location='cpu', weights_only=True) | |
| dn_model.load_state_dict(state) | |
| dn_model.to(DEVICE).eval() | |
| grad_cam = GradCAM(dn_model) | |
| print("[OK] DenseNet121 loaded") | |
| else: | |
| print(f"[WARN] DenseNet weights not found at {dn_path}") | |
| if os.path.exists(rd_path): | |
| print("Loading RAD-DINO...") | |
| rd_model = RADDINOClassifier(num_classes=5, dropout=0.3) | |
| state = torch.load(rd_path, map_location='cpu', weights_only=True) | |
| rd_model.load_state_dict(state) | |
| rd_model.to(DEVICE).eval() | |
| print("[OK] RAD-DINO loaded") | |
| else: | |
| print(f"[WARN] RAD-DINO weights not found at {rd_path}") | |
| # --- Core prediction logic --- | |
| def run_prediction(image_bytes): | |
| """Run ensemble prediction and return results dict.""" | |
| heatmaps = {} | |
| # DenseNet prediction + Grad-CAM | |
| dn_probs = None | |
| if dn_model is not None: | |
| tensor_320, img_np = preprocess_image(image_bytes, transform_320) | |
| with torch.no_grad(): | |
| logits = dn_model(tensor_320) | |
| dn_probs = torch.sigmoid(logits).cpu().numpy()[0] | |
| for i, label in enumerate(TARGET_LABELS): | |
| tensor_for_cam, _ = preprocess_image(image_bytes, transform_320) | |
| cam = grad_cam.generate(tensor_for_cam, class_idx=i) | |
| heatmaps[label] = create_heatmap_overlay(img_np, cam, alpha=0.45) | |
| # RAD-DINO prediction | |
| rd_probs = None | |
| if rd_model is not None: | |
| tensor_384, img_np = preprocess_image(image_bytes, transform_384) | |
| with torch.no_grad(): | |
| logits = rd_model(tensor_384) | |
| rd_probs = torch.sigmoid(logits).cpu().numpy()[0] | |
| # Ensemble | |
| if rd_probs is not None and dn_probs is not None: | |
| ensemble_probs = ENSEMBLE_WEIGHT_RD * rd_probs + ENSEMBLE_WEIGHT_DN * dn_probs | |
| elif rd_probs is not None: | |
| ensemble_probs = rd_probs | |
| elif dn_probs is not None: | |
| ensemble_probs = dn_probs | |
| else: | |
| return None | |
| original_b64 = image_to_base64(img_np) | |
| results = [] | |
| for i, label in enumerate(TARGET_LABELS): | |
| prob = float(ensemble_probs[i]) | |
| risk = 'high' if prob > 0.6 else ('medium' if prob > 0.3 else 'low') | |
| results.append({ | |
| 'label': label, | |
| 'probability': round(prob * 100, 1), | |
| 'risk': risk, | |
| 'description': LABEL_INFO[label]['description'], | |
| 'icon': LABEL_INFO[label]['icon'], | |
| 'heatmap': heatmaps.get(label, ''), | |
| 'rd_prob': round(float(rd_probs[i]) * 100, 1) if rd_probs is not None else None, | |
| 'dn_prob': round(float(dn_probs[i]) * 100, 1) if dn_probs is not None else None, | |
| }) | |
| results.sort(key=lambda x: x['probability'], reverse=True) | |
| return { | |
| 'success': True, | |
| 'results': results, | |
| 'original_image': original_b64, | |
| 'models_used': { | |
| 'rad_dino': rd_probs is not None, | |
| 'densenet': dn_probs is not None, | |
| 'ensemble': rd_probs is not None and dn_probs is not None, | |
| } | |
| } | |
| # --- Routes --- | |
| def index(): | |
| return render_template('index.html') | |
| def analyze(): | |
| return render_template('analyze.html') | |
| def login(): | |
| return render_template('login.html') | |
| def register(): | |
| return render_template('register.html') | |
| def about(): | |
| return render_template('about.html') | |
| def history(): | |
| return render_template('history.html') | |
| def compare(): | |
| return render_template('compare.html') | |
| def report(analysis_id): | |
| data = analysis_store.get(analysis_id) | |
| if not data: | |
| return render_template('report.html', error=True) | |
| return render_template('report.html', error=False, data=json.dumps(data)) | |
| def samples_page(): | |
| return render_template('analyze.html', show_samples=True) | |
| # --- API Endpoints --- | |
| def predict(): | |
| load_models() | |
| if 'file' not in request.files: | |
| return jsonify({'error': 'No file uploaded'}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({'error': 'No file selected'}), 400 | |
| allowed = {'png', 'jpg', 'jpeg', 'bmp', 'dcm', 'dicom'} | |
| ext = file.filename.rsplit('.', 1)[-1].lower() if '.' in file.filename else '' | |
| if ext not in allowed: | |
| return jsonify({'error': f'File type .{ext} not supported'}), 400 | |
| image_bytes = file.read() | |
| # DICOM handling | |
| if ext in ('dcm', 'dicom'): | |
| try: | |
| image_bytes = read_dicom_as_bytes(image_bytes) | |
| except ValueError as e: | |
| return jsonify({'error': str(e)}), 400 | |
| result = run_prediction(image_bytes) | |
| if result is None: | |
| return jsonify({'error': 'No models loaded'}), 500 | |
| # Store result for report generation | |
| analysis_id = str(uuid.uuid4())[:8] | |
| result['analysis_id'] = analysis_id | |
| result['timestamp'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S') | |
| result['filename'] = file.filename | |
| analysis_store[analysis_id] = result | |
| # Keep only last 50 analyses in memory | |
| if len(analysis_store) > 50: | |
| oldest_key = next(iter(analysis_store)) | |
| del analysis_store[oldest_key] | |
| return jsonify(result) | |
| def api_samples(): | |
| """List available sample X-ray images.""" | |
| samples = [] | |
| if os.path.exists(SAMPLES_DIR): | |
| for f in os.listdir(SAMPLES_DIR): | |
| if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')): | |
| name = os.path.splitext(f)[0].replace('_', ' ').replace('-', ' ').title() | |
| samples.append({ | |
| 'filename': f, | |
| 'name': name, | |
| 'url': f'/samples/{f}' | |
| }) | |
| return jsonify(samples) | |
| def serve_sample(filename): | |
| return send_from_directory(SAMPLES_DIR, filename) | |
| def health(): | |
| return jsonify({ | |
| 'status': 'ok', | |
| 'models': { | |
| 'rad_dino': rd_model is not None, | |
| 'densenet': dn_model is not None, | |
| }, | |
| 'device': str(DEVICE) | |
| }) | |
| if __name__ == '__main__': | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| os.makedirs(SAMPLES_DIR, exist_ok=True) | |
| os.makedirs('uploads', exist_ok=True) | |
| app.run(debug=False, host='0.0.0.0', port=int(os.environ.get('PORT', 7860))) | |