OlejnikM's picture
Create app.py
5966cb5 verified
import tensorflow as tf
import numpy as np
import gradio as gr
import json
from PIL import Image, ImageDraw, ImageFont
import os
def load_model(model_path):
# Load your TensorFlow model
return tf.keras.models.load_model(model_path)
def preprocess_image(image, target_size):
# Resize and normalize the image
image = image.resize(target_size)
image = np.array(image) / 255.0 # Normalize to [0, 1]
image = np.expand_dims(image, axis=0) # Add batch dimension
return image
def predict_top_10(image, model, class_indices):
# Preprocess the image
input_image = preprocess_image(image, target_size=(224, 224))
# Get predictions
predictions = model.predict(input_image)[0] # Assuming single image batch
# print("Model predictions:", predictions) # Debugging line
# Get top 10 predictions
top_indices = np.argsort(predictions)[::-1][:10] # Sort descending, take top 10
top_probs = predictions[top_indices]
top_classes = [class_indices[i] for i in top_indices]
output = [(top_classes[i], top_probs[i]) for i in range(10)]
print("Top 10 predictions:", output) # Debugging line
return output
def visualize_predictions(predictions, class_image_dir):
print("Received predictions:", predictions) # Debugging to check what's coming in
output_images = []
# Load the font once outside the loop to avoid redundant loading
font = ImageFont.load_default() # or ImageFont.truetype("./arial.ttf", 20)
for class_name, prob in predictions:
class_image_path = os.path.join(class_image_dir, f"{class_name}.jpg")
try:
class_image = Image.open(class_image_path).convert("RGB")
# Resize image if necessary
class_image = class_image.resize((300, 300))
draw = ImageDraw.Draw(class_image)
text = f"{class_name}: {prob:.2%}"
draw.text((10, 10), text, fill="black", font=font)
output_images.append(class_image)
except FileNotFoundError:
blank_image = Image.new("RGB", (300, 300), "black")
draw = ImageDraw.Draw(blank_image)
draw.text((10, 10), f"{class_name}: {prob:.2%} (Image not found)", fill="black", font=font)
output_images.append(blank_image)
except Exception as e:
print(f"Error processing {class_name}: {e}")
blank_image = Image.new("RGB", (300, 300), "black")
draw = ImageDraw.Draw(blank_image)
draw.text((10, 10), f"{class_name}: {prob:.2%} (Error: {str(e)})", fill="black", font=font)
output_images.append(blank_image)
return output_images
# Load model and class indices
model = load_model("owned_coins.keras") # Update with your model's file path
with open("class_indices.json", "r") as f:
class_indices = json.load(f)
# Reverse the dictionary
class_indices = {v: k for k, v in class_indices.items()}
# Specify the directory containing class images
class_image_dir = "downloaded_images" # Update with the actual path to your image directory
def gradio_predict(image):
# Get top 10 predictions
predictions = predict_top_10(image, model, class_indices)
print("Predictions:", predictions) # Debugging line
# Visualize predictions with class images
result_images = visualize_predictions(predictions, class_image_dir)
return result_images
# Create Gradio interface
interface = gr.Interface(
fn=gradio_predict,
inputs=gr.Image(type="pil", label="Upload an Image"),
outputs=gr.Gallery(label="Top 10 Predictions"),
title="Multiclass Image Classifier",
description="Upload an image to see the top 10 predictions with class images and probabilities."
)
# Launch locally and push to Hugging Face
if __name__ == "__main__":
interface.launch(share=True)
# Uncomment the following line to push to Hugging Face (requires Hugging Face credentials)
# interface.launch(share=True)