koesan's picture
Update app.py
5e5b44c verified
raw
history blame
5.03 kB
import os
import cv2
import numpy as np
from flask import Flask, request, render_template, jsonify
from werkzeug.utils import secure_filename
import tensorflow as tf
from tensorflow.keras.models import load_model
app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
app.config['UPLOAD_FOLDER'] = 'uploads'
# Create uploads folder if it doesn't exist
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
# Load the model with compatibility handling
print("Loading model...")
import warnings
warnings.filterwarnings('ignore')
# Set TensorFlow to use less memory
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
# Try loading with h5py directly for better compatibility
import h5py
# Custom InputLayer to handle old 'batch_shape' parameter
from tensorflow.python.keras.engine.input_layer import InputLayer as OriginalInputLayer
class CustomInputLayer(OriginalInputLayer):
def __init__(self, **kwargs):
# Convert old 'batch_shape' to new 'input_shape'
if 'batch_shape' in kwargs:
batch_shape = kwargs.pop('batch_shape')
if batch_shape and len(batch_shape) > 1:
kwargs['input_shape'] = batch_shape[1:]
# Remove any other problematic params
kwargs.pop('batch_input_shape', None)
super(CustomInputLayer, self).__init__(**kwargs)
# Register custom layer
custom_objects = {'InputLayer': CustomInputLayer}
# Load model with custom objects
try:
model = tf.keras.models.load_model('cancer_model.h5',
custom_objects=custom_objects,
compile=False)
print("✓ Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
print("\n⚠️ Model file is incompatible with TensorFlow 2.4.1")
print("Your model appears to be from TensorFlow 1.x or very early 2.x")
raise
def resize_with_padding(img, target_size):
"""Resize image while maintaining aspect ratio and add padding"""
height, width = img.shape[:2]
target_width, target_height = target_size
# Calculate scaling factor
scale = min(target_width / width, target_height / height)
new_width = int(width * scale)
new_height = int(height * scale)
# Resize image
resized_image = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_AREA)
# Calculate padding
pad_width = target_width - new_width
pad_height = target_height - new_height
top = pad_height // 2
bottom = pad_height - top
left = pad_width // 2
right = pad_width - left
# Add black padding
padded_image = cv2.copyMakeBorder(resized_image, top, bottom, left, right,
cv2.BORDER_CONSTANT, value=[0, 0, 0])
return padded_image
def left_or_right(img):
"""Normalize left/right breast orientation"""
height, width = img.shape[:2]
left_half = img[:, :width // 2]
right_half = img[:, width // 2:]
left_intensity = np.sum(left_half)
right_intensity = np.sum(right_half)
return img if left_intensity > right_intensity else cv2.flip(img, 1)
def predict_image(image_path):
"""Make prediction on uploaded image"""
# Read image
img = cv2.imread(image_path)
# Preprocess
img = resize_with_padding(img, (256, 256))
img = left_or_right(img)
img = img / 255.0
img = np.expand_dims(img, axis=0)
# Predict
prediction_prob = model.predict(img, verbose=0)
predicted_class = 1 if prediction_prob[0][0] > 0.5 else 0
confidence = float(prediction_prob[0][0] if predicted_class == 1 else 1 - prediction_prob[0][0])
result = {
'class': 'Malignant' if predicted_class == 1 else 'Benign',
'confidence': confidence * 100,
'malignant_prob': float(prediction_prob[0][0]) * 100,
'benign_prob': (1 - float(prediction_prob[0][0])) * 100
}
return result
@app.route('/')
def index():
"""Render main page"""
return render_template('index.html')
@app.route('/predict', methods=['POST'])
def predict():
"""Handle prediction request"""
if 'file' not in request.files:
return jsonify({'error': 'No file uploaded'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No file selected'}), 400
if file:
# Save uploaded file
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
try:
# Make prediction
result = predict_image(filepath)
# Clean up
os.remove(filepath)
return jsonify(result)
except Exception as e:
if os.path.exists(filepath):
os.remove(filepath)
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860, debug=False)