pavan10504's picture
changed a few files to handle the PIL
6f66313
import gradio as gr
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.preprocessing import image
import numpy as np
from PIL import Image
import logging
import os
# Set environment variables to suppress TensorFlow warnings
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# Configure logging
logging.basicConfig(level=logging.INFO)
# Blood group classes
blood_groups = ['A+', 'A-', 'AB+', 'AB-', 'B+', 'B-', 'O+', 'O-']
# Load model
def load_ml_model():
try:
logging.info("Loading model...")
model = load_model("model_blood_group_detection_resnet.h5")
logging.info("Model loaded successfully!")
return model
except Exception as e:
logging.error(f"Error loading model: {e}")
return None
# Initialize model
model = load_ml_model()
def predict_blood_group(input_image):
"""
Predict blood group from uploaded image
"""
try:
if model is None:
logging.error("Model is not loaded")
return "Error: Model not loaded", "0.0%"
logging.info(f"Processing image: {type(input_image)}")
# Preprocess the image
if isinstance(input_image, str):
logging.info(f"Loading image from path: {input_image}")
img = Image.open(input_image)
else:
logging.info("Using PIL Image object")
img = input_image
# Log original image info
logging.info(f"Original image size: {img.size}, mode: {img.mode}")
# Resize and convert to RGB if needed
img = img.convert('RGB')
img = img.resize((256, 256))
logging.info(f"Processed image size: {img.size}, mode: {img.mode}")
# Convert to array and preprocess
img_array = image.img_to_array(img)
logging.info(f"Image array shape: {img_array.shape}")
img_array = np.expand_dims(img_array, axis=0)
logging.info(f"Expanded array shape: {img_array.shape}")
img_array = preprocess_input(img_array)
logging.info("Image preprocessing completed")
# Make prediction
logging.info("Running model prediction...")
prediction = model.predict(img_array, verbose=0)
logging.info(f"Prediction shape: {prediction.shape}")
predicted_class = np.argmax(prediction, axis=1)[0]
confidence = float(prediction[0][predicted_class])
predicted_group = blood_groups[predicted_class]
logging.info(f"Prediction: {predicted_group} with confidence: {confidence:.2%}")
return predicted_group, f"{confidence:.1%}"
except Exception as e:
logging.error(f"Prediction error: {str(e)}", exc_info=True)
return f"Error: {str(e)}", "0.0%"
def gradio_predict(image_path):
"""
Gradio wrapper for prediction - handles file paths
"""
try:
logging.info(f"Gradio predict called with image path: {image_path}")
if image_path is None:
logging.warning("No image path provided")
return "Please upload an image", "No confidence"
# Verify file exists
if not os.path.exists(image_path):
logging.error(f"Image file not found: {image_path}")
return "Error: Image file not found", "0.0%"
predicted_group, confidence = predict_blood_group(image_path)
if predicted_group.startswith("Error"):
logging.error(f"Prediction failed: {predicted_group}")
return predicted_group, "N/A"
logging.info(f"Gradio returning: {predicted_group}, {confidence}")
return predicted_group, confidence
except Exception as e:
logging.error(f"Gradio wrapper error: {str(e)}", exc_info=True)
return f"Gradio Error: {str(e)}", "0.0%"
# Create Gradio interface
iface = gr.Interface(
fn=gradio_predict,
inputs=gr.Image(type="filepath", label="Upload Blood Sample Image"),
outputs=[
gr.Textbox(label="Predicted Blood Group"),
gr.Textbox(label="Confidence")
],
title="Blood Group Prediction Model",
description="Upload an image of a blood sample to predict the blood group",
examples=None,
allow_flagging="never"
)
# For API access
def predict_api(image):
"""
API endpoint that returns JSON response
"""
predicted_group, confidence = predict_blood_group(image)
return {
"blood_group": predicted_group,
"confidence": confidence
}
# Launch the app
if __name__ == "__main__":
# Launch with public link for API access and verbose error reporting
iface.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_api=True,
show_error=True, # Enable verbose error reporting
debug=True # Enable debug mode
)