Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import numpy as np | |
| from flask import Flask, request, render_template, jsonify, send_from_directory | |
| from werkzeug.utils import secure_filename | |
| from datetime import datetime | |
| import base64 | |
| from io import BytesIO | |
| from PIL import Image | |
| import SimpleITK as sitk | |
| from skimage.transform import resize | |
| # Suppress TensorFlow warnings | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
| import tensorflow as tf | |
| from tensorflow.keras.models import load_model | |
| app = Flask(__name__) | |
| app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max | |
| app.config['UPLOAD_FOLDER'] = 'uploads' | |
| app.config['ALLOWED_EXTENSIONS'] = {'mha'} # Only MRI .mha files | |
| # Create uploads folder with proper permissions | |
| os.makedirs(app.config['UPLOAD_FOLDER'], mode=0o777, exist_ok=True) | |
| # Load the brain segmentation model | |
| print("Loading Brain Segmentation Model with TensorFlow 2.15...") | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| try: | |
| # Load with TensorFlow 2.15 (Keras 2) - supports 'groups' parameter | |
| model = load_model('brain1.h5', compile=False) | |
| print("✓ Model loaded successfully with TensorFlow 2.15!") | |
| except Exception as e: | |
| print(f"❌ Error loading model: {e}") | |
| print("\n⚠️ If you see 'groups' parameter error:") | |
| print(" Model needs TensorFlow 2.15 (not 2.16+)") | |
| import traceback | |
| traceback.print_exc() | |
| raise | |
| def allowed_file(filename): | |
| """Check if file extension is allowed""" | |
| return '.' in filename and \ | |
| filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS'] | |
| def preprocess_image(image_path): | |
| """Preprocess MHA image for brain segmentation (same as training)""" | |
| try: | |
| # Read MHA file using SimpleITK | |
| img = sitk.ReadImage(image_path) | |
| img = sitk.GetArrayFromImage(img) | |
| print(f"Original MHA shape: {img.shape}") | |
| # Resize to (155, 160, 160) - same as training | |
| img_resized = resize(img, (155, 160, 160), preserve_range=True) | |
| # Select middle slice (same as training uses slice 60-130) | |
| # For single prediction, use slice 95 (middle of range) | |
| middle_slice = 95 | |
| img_slice = img_resized[middle_slice, :, :] # (160, 160) | |
| # Keep original for visualization | |
| original_slice = img_slice.copy() | |
| # Z-score normalization (same as training) | |
| img_normalized = (img_slice - img_slice.mean()) / (img_slice.std() + 1e-8) | |
| img_normalized = img_normalized.astype(np.float32) | |
| print(f"Slice shape: {img_normalized.shape}") | |
| # Add batch and channel dimensions in channels_first format (NCHW) | |
| # Model expects: (batch, channels, height, width) = (None, 1, 160, 160) | |
| img_input = np.expand_dims(img_normalized, axis=0) # (1, 160, 160) | |
| img_input = np.expand_dims(img_input, axis=0) # (1, 1, 160, 160) | |
| print(f"Model input shape: {img_input.shape}") | |
| return img_input, original_slice | |
| except Exception as e: | |
| raise ValueError(f"Failed to read MHA file: {str(e)}") | |
| def postprocess_mask(mask, original_shape): | |
| """Postprocess segmentation mask""" | |
| # Mask comes in channels_first format: (batch, channels, height, width) | |
| # Squeeze to remove batch and channel dimensions | |
| mask = np.squeeze(mask) # (160, 160) | |
| # If mask still has extra dimensions, squeeze again | |
| while len(mask.shape) > 2: | |
| mask = np.squeeze(mask) | |
| # Resize back to original shape | |
| mask_resized = cv2.resize(mask, (original_shape[1], original_shape[0])) | |
| # Threshold | |
| mask_binary = (mask_resized > 0.5).astype(np.uint8) * 255 | |
| return mask_binary | |
| def create_overlay(original, mask): | |
| """Create overlay of mask on original image""" | |
| # Ensure original is RGB | |
| if len(original.shape) == 2: | |
| original_rgb = cv2.cvtColor(original, cv2.COLOR_GRAY2RGB) | |
| else: | |
| original_rgb = original.copy() | |
| # Create colored mask (red for tumor) | |
| colored_mask = np.zeros_like(original_rgb) | |
| colored_mask[:, :, 2] = mask # Red channel | |
| # Blend | |
| overlay = cv2.addWeighted(original_rgb, 0.7, colored_mask, 0.3, 0) | |
| return overlay | |
| def img_to_base64(img_array): | |
| """Convert numpy array to base64 string""" | |
| # Ensure uint8 | |
| if img_array.dtype != np.uint8: | |
| img_array = (img_array * 255).astype(np.uint8) | |
| # Convert to PIL Image | |
| if len(img_array.shape) == 2: | |
| img = Image.fromarray(img_array, mode='L') | |
| else: | |
| img = Image.fromarray(img_array, mode='RGB') | |
| # Save to buffer | |
| buffer = BytesIO() | |
| img.save(buffer, format='PNG') | |
| buffer.seek(0) | |
| # Encode to base64 | |
| img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
| return f"data:image/png;base64,{img_base64}" | |
| def index(): | |
| """Render main page""" | |
| return render_template('index.html') | |
| def predict(): | |
| """Handle image upload and prediction""" | |
| try: | |
| # Check if file was uploaded | |
| if 'file' not in request.files: | |
| return jsonify({'error': 'No file uploaded'}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({'error': 'No file selected'}), 400 | |
| if not allowed_file(file.filename): | |
| return jsonify({'error': 'Invalid file type. Please upload .mha MRI file'}), 400 | |
| # Save uploaded file | |
| timestamp = datetime.now().strftime('%Y%m%d_%Hh%Mm%Ss') | |
| filename = secure_filename(f"{timestamp}_{file.filename}") | |
| filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
| file.save(filepath) | |
| # Preprocess MHA file (returns normalized input and original slice) | |
| img_input, original_slice = preprocess_image(filepath) | |
| # Predict | |
| print("Making prediction...") | |
| prediction = model.predict(img_input, verbose=0) | |
| # Postprocess mask (returns 160x160 binary mask) | |
| mask = postprocess_mask(prediction[0], original_slice.shape) | |
| # Normalize original slice for display (0-255) | |
| original_display = ((original_slice - original_slice.min()) / | |
| (original_slice.max() - original_slice.min() + 1e-8) * 255).astype(np.uint8) | |
| # Create overlay | |
| overlay = create_overlay(original_display, mask) | |
| # Convert to base64 | |
| original_base64 = img_to_base64(original_display) | |
| mask_base64 = img_to_base64(mask) | |
| overlay_base64 = img_to_base64(overlay) | |
| # Calculate statistics | |
| tumor_pixels = np.sum(mask > 127) | |
| total_pixels = mask.shape[0] * mask.shape[1] | |
| tumor_percentage = (tumor_pixels / total_pixels) * 100 | |
| result = { | |
| 'original': original_base64, | |
| 'mask': mask_base64, | |
| 'overlay': overlay_base64, | |
| 'tumor_percentage': float(tumor_percentage), | |
| 'image_size': f"{mask.shape[1]}x{mask.shape[0]}" | |
| } | |
| print(f"✓ Prediction completed: {tumor_percentage:.2f}% tumor detected") | |
| return jsonify(result) | |
| except Exception as e: | |
| print(f"Error during prediction: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return jsonify({'error': str(e)}), 500 | |
| def test_example(): | |
| """Test with example MHA file""" | |
| try: | |
| example_path = 'image/VSD.Brain.XX.O.MR_Flair.35796.mha' | |
| if not os.path.exists(example_path): | |
| return jsonify({'error': 'Example MHA file not found. Please add VSD.Brain.XX.O.MR_Flair.35796.mha to image/ folder'}), 404 | |
| print(f"Testing with example file: {example_path}") | |
| # Preprocess MHA file | |
| img_input, original_slice = preprocess_image(example_path) | |
| # Predict | |
| print("Making prediction on example...") | |
| prediction = model.predict(img_input, verbose=0) | |
| # Postprocess mask | |
| mask = postprocess_mask(prediction[0], original_slice.shape) | |
| # Normalize original slice for display | |
| original_display = ((original_slice - original_slice.min()) / | |
| (original_slice.max() - original_slice.min() + 1e-8) * 255).astype(np.uint8) | |
| # Create overlay | |
| overlay = create_overlay(original_display, mask) | |
| # Convert to base64 | |
| original_base64 = img_to_base64(original_display) | |
| mask_base64 = img_to_base64(mask) | |
| overlay_base64 = img_to_base64(overlay) | |
| # Calculate statistics | |
| tumor_pixels = np.sum(mask > 127) | |
| total_pixels = mask.shape[0] * mask.shape[1] | |
| tumor_percentage = (tumor_pixels / total_pixels) * 100 | |
| result = { | |
| 'original': original_base64, | |
| 'mask': mask_base64, | |
| 'overlay': overlay_base64, | |
| 'tumor_percentage': float(tumor_percentage), | |
| 'image_size': f"{mask.shape[1]}x{mask.shape[0]}" | |
| } | |
| print(f"✓ Example prediction completed: {tumor_percentage:.2f}% tumor detected") | |
| return jsonify(result) | |
| except Exception as e: | |
| print(f"Error during example prediction: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return jsonify({'error': str(e)}), 500 | |
| if __name__ == '__main__': | |
| print("\n" + "="*60) | |
| print("🧠 Brain Tumor Segmentation App") | |
| print("="*60) | |
| print("✓ Model loaded and ready!") | |
| print("✓ Server starting on port 7860...") | |
| print("="*60 + "\n") | |
| app.run(host='0.0.0.0', port=7860, debug=False) | |