Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import tensorflow as tf | |
| import numpy as np | |
| from PIL import Image | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| # Load the trained model | |
| model_path = "pokemon_model_fahrnphi_transferlearning.keras" | |
| model = tf.keras.models.load_model(model_path) | |
| # Define the core prediction function | |
| def predict_pokemon(image): | |
| # Preprocess image | |
| image = image.resize((150, 150)) # Resize the image to 150x150 | |
| image = image.convert('RGB') # Ensure image has 3 channels | |
| image = np.array(image) | |
| image = np.expand_dims(image, axis=0) # Add batch dimension | |
| # Predict | |
| prediction = model.predict(image) | |
| # Apply softmax to get probabilities for each class | |
| probabilities = tf.nn.softmax(prediction, axis=1) | |
| # Map probabilities to Pokemon classes | |
| class_names = ['Charizard', 'Lapras', 'Machamp'] | |
| probabilities_dict = {pokemon_class: round(float(probability), 2) for pokemon_class, probability in zip(class_names, probabilities.numpy()[0])} | |
| return probabilities_dict | |
| # Streamlit interface | |
| st.title("Pokemon Classifier") | |
| st.write("A simple MLP classification model for image classification using a pretrained model.") | |
| # Upload image | |
| uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "png"]) | |
| if uploaded_image is not None: | |
| image = Image.open(uploaded_image) | |
| st.image(image, caption='Uploaded Image.', use_column_width=True) | |
| st.write("") | |
| st.write("Classifying...") | |
| predictions = predict_pokemon(image) | |
| # Display predictions as a DataFrame | |
| st.write("### Prediction Probabilities") | |
| df = pd.DataFrame(predictions.items(), columns=["Pokemon", "Probability"]) | |
| st.dataframe(df) | |
| # Display predictions as a pie chart | |
| st.write("### Prediction Chart") | |
| fig, ax = plt.subplots() | |
| ax.pie(df["Probability"], labels=df["Pokemon"], autopct='%1.1f%%', colors=plt.cm.Paired.colors) | |
| ax.set_title('Prediction Probabilities') | |
| st.pyplot(fig) | |