Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| from PIL import Image | |
| from keras.models import load_model | |
| from keras.saving import register_keras_serializable | |
| import tensorflow as tf | |
| def f1_score(y_true, y_pred): | |
| def recall(y_true, y_pred): | |
| true_positives = tf.keras.backend.sum(tf.keras.backend.round(tf.keras.backend.clip(y_true * y_pred, 0, 1))) | |
| possible_positives = tf.keras.backend.sum(tf.keras.backend.round(tf.keras.backend.clip(y_true, 0, 1))) | |
| recall = true_positives / (possible_positives + tf.keras.backend.epsilon()) | |
| return recall | |
| def precision(y_true, y_pred): | |
| true_positives = tf.keras.backend.sum(tf.keras.backend.round(tf.keras.backend.clip(y_true * y_pred, 0, 1))) | |
| predicted_positives = tf.keras.backend.sum(tf.keras.backend.round(tf.keras.backend.clip(y_pred, 0, 1))) | |
| precision = true_positives / (predicted_positives + tf.keras.backend.epsilon()) | |
| return precision | |
| precision = precision(y_true, y_pred) | |
| recall = recall(y_true, y_pred) | |
| return 2 * ((precision * recall) / (precision + recall + tf.keras.backend.epsilon())) | |
| # st.title("Origami Model") | |
| st.markdown( | |
| "<h1 style='color: #522258;'>Origami style prediction</h1>", | |
| unsafe_allow_html=True | |
| ) | |
| st.write("This application shows which origamist your folding style is the most similar.") | |
| # Load your pre-trained model | |
| def load_keras_model(): | |
| custom_objects = {'f1_score': f1_score} | |
| model = load_model('resnet_50_ver2.keras', custom_objects=custom_objects) | |
| return model | |
| try: | |
| classifier = load_keras_model() | |
| except Exception as e: | |
| st.error(f"Error loading the model: {e}") | |
| origamists = ['Beth Johnson', 'Chen Xiao', 'Choi Ju Young', 'Eric Joisel', 'Gen Hagiwara', 'Giang Dinh', 'Hideo Komatsu', 'Hojyo Takashi', 'Kaede Nakamura', 'Kamiya Satoshi', 'Katsuta Kyohei', 'Kei Watanabe', 'Kota Imai', 'Robert J Lang', 'Shuki Kato', 'Tran Trung Hieu'] | |
| if 'images' not in st.session_state: | |
| st.session_state['images'] = [] | |
| if 'predictions' not in st.session_state: | |
| st.session_state['predictions'] = [] | |
| left_column, right_column = st.columns([2,1]) | |
| with right_column: | |
| # Upload an image | |
| image_file = st.file_uploader("Upload an image", type=['jpg', 'jpeg', 'png']) | |
| st.markdown( | |
| "<h2 style='color: #C63C51;'>Combined Prediction:</h2>", | |
| unsafe_allow_html=True | |
| ) | |
| with left_column: | |
| st.markdown( | |
| "<h2 style='color: #C63C51;'>Predictions for each image:</h2>", | |
| unsafe_allow_html=True | |
| ) | |
| if image_file: | |
| image = Image.open(image_file) | |
| # Preprocess the image to fit the model input | |
| processed_image = image.resize((224, 224)) | |
| processed_image = np.array(processed_image) | |
| if processed_image.shape[2] == 4: # Check if image has an alpha channel | |
| processed_image = processed_image[..., :3] # Drop the alpha channel | |
| processed_image = processed_image / 255.0 # Normalize the image | |
| st.session_state['images'].append(image) | |
| processed_image = np.expand_dims(processed_image, axis=0) # Add batch dimension | |
| # Predict | |
| prediction = classifier.predict(processed_image)[0] | |
| st.session_state['predictions'].append(prediction) | |
| if st.session_state['predictions']: | |
| combined_prediction = np.mean(st.session_state['predictions'], axis=0) | |
| normalized_prediction = combined_prediction / np.sum(combined_prediction) | |
| top_3_indices = np.argsort(normalized_prediction)[-3:][::-1] | |
| with right_column: | |
| for i in top_3_indices: | |
| label = origamists[i] | |
| confidence = normalized_prediction[i] | |
| st.write(f"{label}: {confidence:.2f}") | |
| with left_column: | |
| for idx, (img, pred) in enumerate(zip(st.session_state['images'], st.session_state['predictions'])): | |
| col1, col2 = st.columns([1, 2]) # Adjust column widths for better alignment | |
| with col1: | |
| st.image(img, caption=f"Image {idx + 1}", use_column_width=True) | |
| with col2: | |
| # st.write(f"Image {idx + 1} Predictions:") | |
| st.markdown( | |
| f"<h4 style='color: #D95F59;'>Image {idx + 1} Predictions:</h4>", | |
| unsafe_allow_html=True | |
| ) | |
| top_3_indices = np.argsort(pred)[-3:][::-1] | |
| for i in top_3_indices: | |
| label = origamists[i] | |
| confidence = pred[i] | |
| st.write(f"{label}, confidence: {confidence:.2f}") | |
| st.markdown(""" | |
| <style> | |
| .css-1kyxreq.e1fqkh3o1 { | |
| display: flex; | |
| align-items: center; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |