koesan's picture
Update app.py
616a31a verified
import os
import cv2
import numpy as np
from flask import Flask, request, render_template, jsonify, send_from_directory
from werkzeug.utils import secure_filename
from datetime import datetime
import base64
from io import BytesIO
from PIL import Image
import SimpleITK as sitk
from skimage.transform import resize
# Suppress TensorFlow warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow.keras.models import load_model
app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max
app.config['UPLOAD_FOLDER'] = 'uploads'
app.config['ALLOWED_EXTENSIONS'] = {'mha'} # Only MRI .mha files
# Create uploads folder with proper permissions
os.makedirs(app.config['UPLOAD_FOLDER'], mode=0o777, exist_ok=True)
# Load the brain segmentation model
print("Loading Brain Segmentation Model with TensorFlow 2.15...")
import warnings
warnings.filterwarnings('ignore')
try:
# Load with TensorFlow 2.15 (Keras 2) - supports 'groups' parameter
model = load_model('brain1.h5', compile=False)
print("✓ Model loaded successfully with TensorFlow 2.15!")
except Exception as e:
print(f"❌ Error loading model: {e}")
print("\n⚠️ If you see 'groups' parameter error:")
print(" Model needs TensorFlow 2.15 (not 2.16+)")
import traceback
traceback.print_exc()
raise
def allowed_file(filename):
"""Check if file extension is allowed"""
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']
def preprocess_image(image_path):
"""Preprocess MHA image for brain segmentation (same as training)"""
try:
# Read MHA file using SimpleITK
img = sitk.ReadImage(image_path)
img = sitk.GetArrayFromImage(img)
print(f"Original MHA shape: {img.shape}")
# Resize to (155, 160, 160) - same as training
img_resized = resize(img, (155, 160, 160), preserve_range=True)
# Select middle slice (same as training uses slice 60-130)
# For single prediction, use slice 95 (middle of range)
middle_slice = 95
img_slice = img_resized[middle_slice, :, :] # (160, 160)
# Keep original for visualization
original_slice = img_slice.copy()
# Z-score normalization (same as training)
img_normalized = (img_slice - img_slice.mean()) / (img_slice.std() + 1e-8)
img_normalized = img_normalized.astype(np.float32)
print(f"Slice shape: {img_normalized.shape}")
# Add batch and channel dimensions in channels_first format (NCHW)
# Model expects: (batch, channels, height, width) = (None, 1, 160, 160)
img_input = np.expand_dims(img_normalized, axis=0) # (1, 160, 160)
img_input = np.expand_dims(img_input, axis=0) # (1, 1, 160, 160)
print(f"Model input shape: {img_input.shape}")
return img_input, original_slice
except Exception as e:
raise ValueError(f"Failed to read MHA file: {str(e)}")
def postprocess_mask(mask, original_shape):
"""Postprocess segmentation mask"""
# Mask comes in channels_first format: (batch, channels, height, width)
# Squeeze to remove batch and channel dimensions
mask = np.squeeze(mask) # (160, 160)
# If mask still has extra dimensions, squeeze again
while len(mask.shape) > 2:
mask = np.squeeze(mask)
# Resize back to original shape
mask_resized = cv2.resize(mask, (original_shape[1], original_shape[0]))
# Threshold
mask_binary = (mask_resized > 0.5).astype(np.uint8) * 255
return mask_binary
def create_overlay(original, mask):
"""Create overlay of mask on original image"""
# Ensure original is RGB
if len(original.shape) == 2:
original_rgb = cv2.cvtColor(original, cv2.COLOR_GRAY2RGB)
else:
original_rgb = original.copy()
# Create colored mask (red for tumor)
colored_mask = np.zeros_like(original_rgb)
colored_mask[:, :, 2] = mask # Red channel
# Blend
overlay = cv2.addWeighted(original_rgb, 0.7, colored_mask, 0.3, 0)
return overlay
def img_to_base64(img_array):
"""Convert numpy array to base64 string"""
# Ensure uint8
if img_array.dtype != np.uint8:
img_array = (img_array * 255).astype(np.uint8)
# Convert to PIL Image
if len(img_array.shape) == 2:
img = Image.fromarray(img_array, mode='L')
else:
img = Image.fromarray(img_array, mode='RGB')
# Save to buffer
buffer = BytesIO()
img.save(buffer, format='PNG')
buffer.seek(0)
# Encode to base64
img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
return f"data:image/png;base64,{img_base64}"
@app.route('/')
def index():
"""Render main page"""
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
"""Handle image upload and prediction"""
try:
# Check if file was uploaded
if 'file' not in request.files:
return jsonify({'error': 'No file uploaded'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No file selected'}), 400
if not allowed_file(file.filename):
return jsonify({'error': 'Invalid file type. Please upload .mha MRI file'}), 400
# Save uploaded file
timestamp = datetime.now().strftime('%Y%m%d_%Hh%Mm%Ss')
filename = secure_filename(f"{timestamp}_{file.filename}")
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
# Preprocess MHA file (returns normalized input and original slice)
img_input, original_slice = preprocess_image(filepath)
# Predict
print("Making prediction...")
prediction = model.predict(img_input, verbose=0)
# Postprocess mask (returns 160x160 binary mask)
mask = postprocess_mask(prediction[0], original_slice.shape)
# Normalize original slice for display (0-255)
original_display = ((original_slice - original_slice.min()) /
(original_slice.max() - original_slice.min() + 1e-8) * 255).astype(np.uint8)
# Create overlay
overlay = create_overlay(original_display, mask)
# Convert to base64
original_base64 = img_to_base64(original_display)
mask_base64 = img_to_base64(mask)
overlay_base64 = img_to_base64(overlay)
# Calculate statistics
tumor_pixels = np.sum(mask > 127)
total_pixels = mask.shape[0] * mask.shape[1]
tumor_percentage = (tumor_pixels / total_pixels) * 100
result = {
'original': original_base64,
'mask': mask_base64,
'overlay': overlay_base64,
'tumor_percentage': float(tumor_percentage),
'image_size': f"{mask.shape[1]}x{mask.shape[0]}"
}
print(f"✓ Prediction completed: {tumor_percentage:.2f}% tumor detected")
return jsonify(result)
except Exception as e:
print(f"Error during prediction: {e}")
import traceback
traceback.print_exc()
return jsonify({'error': str(e)}), 500
@app.route('/test-example', methods=['POST'])
def test_example():
"""Test with example MHA file"""
try:
example_path = 'image/VSD.Brain.XX.O.MR_Flair.35796.mha'
if not os.path.exists(example_path):
return jsonify({'error': 'Example MHA file not found. Please add VSD.Brain.XX.O.MR_Flair.35796.mha to image/ folder'}), 404
print(f"Testing with example file: {example_path}")
# Preprocess MHA file
img_input, original_slice = preprocess_image(example_path)
# Predict
print("Making prediction on example...")
prediction = model.predict(img_input, verbose=0)
# Postprocess mask
mask = postprocess_mask(prediction[0], original_slice.shape)
# Normalize original slice for display
original_display = ((original_slice - original_slice.min()) /
(original_slice.max() - original_slice.min() + 1e-8) * 255).astype(np.uint8)
# Create overlay
overlay = create_overlay(original_display, mask)
# Convert to base64
original_base64 = img_to_base64(original_display)
mask_base64 = img_to_base64(mask)
overlay_base64 = img_to_base64(overlay)
# Calculate statistics
tumor_pixels = np.sum(mask > 127)
total_pixels = mask.shape[0] * mask.shape[1]
tumor_percentage = (tumor_pixels / total_pixels) * 100
result = {
'original': original_base64,
'mask': mask_base64,
'overlay': overlay_base64,
'tumor_percentage': float(tumor_percentage),
'image_size': f"{mask.shape[1]}x{mask.shape[0]}"
}
print(f"✓ Example prediction completed: {tumor_percentage:.2f}% tumor detected")
return jsonify(result)
except Exception as e:
print(f"Error during example prediction: {e}")
import traceback
traceback.print_exc()
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
print("\n" + "="*60)
print("🧠 Brain Tumor Segmentation App")
print("="*60)
print("✓ Model loaded and ready!")
print("✓ Server starting on port 7860...")
print("="*60 + "\n")
app.run(host='0.0.0.0', port=7860, debug=False)