Nouur123's picture
Upload 7 files
ec36337 verified
from flask import Flask, request, jsonify
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import img_to_array
import numpy as np
from PIL import Image
import io
import logging
import json
import sys
# === CONFIGURATION ===
MODEL_PATH = "model_resnet152v2.keras" # Absolute Docker path
CLASS_INDICES_PATH = "class_indices.json"
IMAGE_SIZE = 224
# === Set up logging ===
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
app = Flask(__name__)
# === LOAD MODEL ===
try:
model = load_model(MODEL_PATH)
logger.info("✅ Model loaded successfully.")
except Exception as e:
logger.error(f"❌ Failed to load model: {e}")
sys.exit(1)
# === LOAD CLASS LABELS ===
try:
with open(CLASS_INDICES_PATH, "r", encoding="utf-8") as f:
class_indices = json.load(f)
idx_to_label = {int(v): k for k, v in class_indices.items()}
except Exception as e:
logger.error(f"❌ Failed to load class indices: {e}")
sys.exit(1)
def predict_image_bytes(img_bytes):
try:
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
img = img.resize((IMAGE_SIZE, IMAGE_SIZE))
img_array = img_to_array(img) / 255.0
img_array = np.expand_dims(img_array, axis=0)
prediction = model.predict(img_array)
top_index = np.argmax(prediction[0])
label = idx_to_label[top_index]
return label.replace('_', ' ').lower() # Return lowercase to match Spring's expectation
except Exception as e:
logger.error(f"Prediction error: {e}")
return "error"
@app.route('/predict', methods=['POST'])
def predict():
if 'file' not in request.files:
return jsonify({'error': 'No file part'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No selected file'}), 400
try:
img_bytes = file.read()
prediction = predict_image_bytes(img_bytes)
logger.info(f"Prediction result: {prediction}")
return prediction
except Exception as e:
logger.error(f"Error processing image: {e}")
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5003)