Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from PIL import Image | |
| import numpy as np | |
| from tensorflow.keras.preprocessing import image as keras_image | |
| from tensorflow.keras.applications.resnet50 import preprocess_input as resnet_preprocess_input | |
| from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mobilenet_preprocess_input | |
| from tensorflow.keras.models import load_model | |
| # Load your trained models | |
| resnet_model = load_model('/home/user/app/resnet50_model.keras') | |
| mobilenet_model = load_model('/home/user/app/mobilenetv2_model.keras') | |
| def predict_anime_character(img, model_type): | |
| try: | |
| # Convert the image to RGB format | |
| img = Image.fromarray(img.astype('uint8'), 'RGB') | |
| # Resize the image to the required input size of the model | |
| img = img.resize((224, 224)) | |
| # Convert the image to an array | |
| img_array = keras_image.img_to_array(img) | |
| # Expand dimensions to match the model's input shape | |
| img_array = np.expand_dims(img_array, axis=0) | |
| if model_type == 'ResNet50': | |
| # Preprocess the input as expected by ResNet50 | |
| img_array = resnet_preprocess_input(img_array) | |
| # Predict using the ResNet50 model | |
| prediction = resnet_model.predict(img_array) | |
| elif model_type == 'MobileNetV2': | |
| # Preprocess the input as expected by MobileNetV2 | |
| img_array = mobilenet_preprocess_input(img_array) | |
| # Predict using the MobileNetV2 model | |
| prediction = mobilenet_model.predict(img_array) | |
| else: | |
| return {"error": "Invalid model type selected"} | |
| # Debugging: print the prediction shape | |
| print(f"Prediction shape: {prediction.shape}") | |
| # Define the classes | |
| classes = ['Goku', 'Killua', 'Naruto', 'Ruffy', 'Sasuke'] | |
| # Check prediction shape | |
| if prediction.shape[1] == len(classes): # Ensure the prediction matches the number of classes | |
| # Return the prediction as a dictionary with class probabilities | |
| return {classes[i]: float(prediction[0][i]) for i in range(len(classes))} | |
| else: | |
| return {"error": f"Unexpected prediction shape: {prediction.shape}"} | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # Custom CSS for Dark Mode and Elegant Design | |
| custom_css = """ | |
| body {background-color: #121212; color: #e0e0e0; font-family: 'Arial', sans-serif;} | |
| h1 {color: #ff5722;} | |
| label {color: #ff9800;} | |
| input[type=radio] {accent-color: #ff5722;} | |
| button:hover {background-color: #e64a19;} | |
| .footer {display: none !important;} | |
| """ | |
| # Define the Gradio interface | |
| interface = gr.Interface( | |
| fn=predict_anime_character, | |
| inputs=[ | |
| gr.Image(type="numpy", label="Upload an image of an anime character"), | |
| gr.Radio(['ResNet50', 'MobileNetV2'], label="Choose Model") | |
| ], | |
| outputs=gr.Label(num_top_classes=5, label="Prediction"), | |
| title="Anime Character Classifier", | |
| description="Upload an image of an anime character and the classifier will predict its character.", | |
| css=custom_css # Apply the custom CSS | |
| ) | |
| # Launch the interface | |
| interface.launch() |