anime_searcher / app.py
premrmuk's picture
Update app.py
e80fabf verified
import gradio as gr
from PIL import Image
import numpy as np
from tensorflow.keras.preprocessing import image as keras_image
from tensorflow.keras.applications.resnet50 import preprocess_input as resnet_preprocess_input
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mobilenet_preprocess_input
from tensorflow.keras.models import load_model
# Load your trained models
resnet_model = load_model('/home/user/app/resnet50_model.keras')
mobilenet_model = load_model('/home/user/app/mobilenetv2_model.keras')
def predict_anime_character(img, model_type):
try:
# Convert the image to RGB format
img = Image.fromarray(img.astype('uint8'), 'RGB')
# Resize the image to the required input size of the model
img = img.resize((224, 224))
# Convert the image to an array
img_array = keras_image.img_to_array(img)
# Expand dimensions to match the model's input shape
img_array = np.expand_dims(img_array, axis=0)
if model_type == 'ResNet50':
# Preprocess the input as expected by ResNet50
img_array = resnet_preprocess_input(img_array)
# Predict using the ResNet50 model
prediction = resnet_model.predict(img_array)
elif model_type == 'MobileNetV2':
# Preprocess the input as expected by MobileNetV2
img_array = mobilenet_preprocess_input(img_array)
# Predict using the MobileNetV2 model
prediction = mobilenet_model.predict(img_array)
else:
return {"error": "Invalid model type selected"}
# Debugging: print the prediction shape
print(f"Prediction shape: {prediction.shape}")
# Define the classes
classes = ['Goku', 'Killua', 'Naruto', 'Ruffy', 'Sasuke']
# Check prediction shape
if prediction.shape[1] == len(classes): # Ensure the prediction matches the number of classes
# Return the prediction as a dictionary with class probabilities
return {classes[i]: float(prediction[0][i]) for i in range(len(classes))}
else:
return {"error": f"Unexpected prediction shape: {prediction.shape}"}
except Exception as e:
return {"error": str(e)}
# Custom CSS for Dark Mode and Elegant Design
custom_css = """
body {background-color: #121212; color: #e0e0e0; font-family: 'Arial', sans-serif;}
h1 {color: #ff5722;}
label {color: #ff9800;}
input[type=radio] {accent-color: #ff5722;}
button:hover {background-color: #e64a19;}
.footer {display: none !important;}
"""
# Define the Gradio interface
interface = gr.Interface(
fn=predict_anime_character,
inputs=[
gr.Image(type="numpy", label="Upload an image of an anime character"),
gr.Radio(['ResNet50', 'MobileNetV2'], label="Choose Model")
],
outputs=gr.Label(num_top_classes=5, label="Prediction"),
title="Anime Character Classifier",
description="Upload an image of an anime character and the classifier will predict its character.",
css=custom_css # Apply the custom CSS
)
# Launch the interface
interface.launch()