Spaces:
Sleeping
Sleeping
File size: 3,039 Bytes
eac85ba | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import numpy as np
import json
import os
from PIL import Image
import io
# Load class names
try:
with open("models/class_names.json", "r") as f:
CLASS_NAMES = json.load(f)
except FileNotFoundError:
print("Warning: class_names.json not found. Make sure to train the model first.")
CLASS_NAMES = []
def load_and_prepare_image(img_file):
"""
Load and preprocess an image file for prediction.
Handles both file paths and file-like objects.
"""
try:
if isinstance(img_file, str):
# If img_file is a path
img = load_img(img_file, target_size=(256, 256))
else:
# If img_file is a file-like object (e.g., from Streamlit upload)
img = Image.open(img_file)
img = img.resize((256, 256))
# Convert to array and normalize
img_array = img_to_array(img)
img_array = img_array / 255.0
return np.expand_dims(img_array, axis=0)
except Exception as e:
raise Exception(f"Error processing image: {str(e)}")
def predict(model_path, img_file):
"""
Make a prediction using the loaded model.
Returns the predicted class and confidence score.
"""
try:
# Check if model exists
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found at {model_path}")
# Load model
model = load_model(model_path)
# Prepare image
image = load_and_prepare_image(img_file)
# Make prediction
preds = model.predict(image, verbose=0)
# Get predicted class and confidence
predicted_class_idx = np.argmax(preds[0])
confidence = preds[0][predicted_class_idx]
# Get class name
if CLASS_NAMES and predicted_class_idx < len(CLASS_NAMES):
predicted_class = CLASS_NAMES[predicted_class_idx]
else:
predicted_class = f"Class_{predicted_class_idx}"
return predicted_class, confidence
except Exception as e:
raise Exception(f"Prediction error: {str(e)}")
def get_class_probabilities(model_path, img_file):
"""
Get probability distribution across all classes.
Returns a dictionary of class names and their probabilities.
"""
try:
model = load_model(model_path)
image = load_and_prepare_image(img_file)
preds = model.predict(image, verbose=0)[0]
probabilities = {}
for idx, prob in enumerate(preds):
class_name = CLASS_NAMES[idx] if CLASS_NAMES and idx < len(CLASS_NAMES) else f"Class_{idx}"
probabilities[class_name] = float(prob)
return probabilities
except Exception as e:
raise Exception(f"Error getting class probabilities: {str(e)}")
|