Terence9's picture
Upload 4 files
eac85ba verified
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import numpy as np
import json
import os
from PIL import Image
import io
# Load class names
try:
with open("models/class_names.json", "r") as f:
CLASS_NAMES = json.load(f)
except FileNotFoundError:
print("Warning: class_names.json not found. Make sure to train the model first.")
CLASS_NAMES = []
def load_and_prepare_image(img_file):
"""
Load and preprocess an image file for prediction.
Handles both file paths and file-like objects.
"""
try:
if isinstance(img_file, str):
# If img_file is a path
img = load_img(img_file, target_size=(256, 256))
else:
# If img_file is a file-like object (e.g., from Streamlit upload)
img = Image.open(img_file)
img = img.resize((256, 256))
# Convert to array and normalize
img_array = img_to_array(img)
img_array = img_array / 255.0
return np.expand_dims(img_array, axis=0)
except Exception as e:
raise Exception(f"Error processing image: {str(e)}")
def predict(model_path, img_file):
"""
Make a prediction using the loaded model.
Returns the predicted class and confidence score.
"""
try:
# Check if model exists
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found at {model_path}")
# Load model
model = load_model(model_path)
# Prepare image
image = load_and_prepare_image(img_file)
# Make prediction
preds = model.predict(image, verbose=0)
# Get predicted class and confidence
predicted_class_idx = np.argmax(preds[0])
confidence = preds[0][predicted_class_idx]
# Get class name
if CLASS_NAMES and predicted_class_idx < len(CLASS_NAMES):
predicted_class = CLASS_NAMES[predicted_class_idx]
else:
predicted_class = f"Class_{predicted_class_idx}"
return predicted_class, confidence
except Exception as e:
raise Exception(f"Prediction error: {str(e)}")
def get_class_probabilities(model_path, img_file):
"""
Get probability distribution across all classes.
Returns a dictionary of class names and their probabilities.
"""
try:
model = load_model(model_path)
image = load_and_prepare_image(img_file)
preds = model.predict(image, verbose=0)[0]
probabilities = {}
for idx, prob in enumerate(preds):
class_name = CLASS_NAMES[idx] if CLASS_NAMES and idx < len(CLASS_NAMES) else f"Class_{idx}"
probabilities[class_name] = float(prob)
return probabilities
except Exception as e:
raise Exception(f"Error getting class probabilities: {str(e)}")