import gradio as gr from PIL import Image import numpy as np from tensorflow.keras.preprocessing import image as keras_image from tensorflow.keras.applications.resnet50 import preprocess_input as resnet_preprocess_input from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mobilenet_preprocess_input from tensorflow.keras.models import load_model # Load your trained models resnet_model = load_model('/home/user/app/resnet50_model.keras') mobilenet_model = load_model('/home/user/app/mobilenetv2_model.keras') def predict_anime_character(img, model_type): try: # Convert the image to RGB format img = Image.fromarray(img.astype('uint8'), 'RGB') # Resize the image to the required input size of the model img = img.resize((224, 224)) # Convert the image to an array img_array = keras_image.img_to_array(img) # Expand dimensions to match the model's input shape img_array = np.expand_dims(img_array, axis=0) if model_type == 'ResNet50': # Preprocess the input as expected by ResNet50 img_array = resnet_preprocess_input(img_array) # Predict using the ResNet50 model prediction = resnet_model.predict(img_array) elif model_type == 'MobileNetV2': # Preprocess the input as expected by MobileNetV2 img_array = mobilenet_preprocess_input(img_array) # Predict using the MobileNetV2 model prediction = mobilenet_model.predict(img_array) else: return {"error": "Invalid model type selected"} # Debugging: print the prediction shape print(f"Prediction shape: {prediction.shape}") # Define the classes classes = ['Goku', 'Killua', 'Naruto', 'Ruffy', 'Sasuke'] # Check prediction shape if prediction.shape[1] == len(classes): # Ensure the prediction matches the number of classes # Return the prediction as a dictionary with class probabilities return {classes[i]: float(prediction[0][i]) for i in range(len(classes))} else: return {"error": f"Unexpected prediction shape: {prediction.shape}"} except Exception as e: return {"error": str(e)} # Custom CSS for Dark Mode and Elegant Design custom_css = """ body {background-color: #121212; color: #e0e0e0; font-family: 'Arial', sans-serif;} h1 {color: #ff5722;} label {color: #ff9800;} input[type=radio] {accent-color: #ff5722;} button:hover {background-color: #e64a19;} .footer {display: none !important;} """ # Define the Gradio interface interface = gr.Interface( fn=predict_anime_character, inputs=[ gr.Image(type="numpy", label="Upload an image of an anime character"), gr.Radio(['ResNet50', 'MobileNetV2'], label="Choose Model") ], outputs=gr.Label(num_top_classes=5, label="Prediction"), title="Anime Character Classifier", description="Upload an image of an anime character and the classifier will predict its character.", css=custom_css # Apply the custom CSS ) # Launch the interface interface.launch()