Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import keras_hub | |
| from PIL import Image | |
| import numpy as np | |
| classification_models = { | |
| "ResNet18": "resnet_18_imagenet", | |
| "ResNet50": "resnet_50_imagenet", | |
| # "ViT-B16-224": "vit_base_patch16_224_imagenet", | |
| # "ViT-B16-384": "vit_base_patch16_384_imagenet", | |
| # "ViT-L16-224": "vit_large_patch16_224_imagenet" | |
| } | |
| def load_preprocessor(model_name): | |
| return keras_hub.models.ImageClassifierPreprocessor.from_preset(model_name) | |
| def load_model(model_name): | |
| """Load a pre-trained model from KerasHub.""" | |
| return keras_hub.models.ImageClassifier.from_preset(model_name) | |
| def upload_image(): | |
| """Common function for uploading an image.""" | |
| uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file: | |
| image = Image.open(uploaded_file) | |
| return np.expand_dims(np.array(image).astype("float32"), axis=0) | |
| return None | |
| def vision_page(): | |
| st.header("Vision Models") | |
| st.write("Explore Vision Models including Image Classification, Object Detection, and Segmentation.") | |
| # Tabs for different vision tasks | |
| tab1, tab2, tab3 = st.tabs(["Image Classification", "Object Detection", "Segmentation"]) | |
| with tab1: | |
| st.subheader("Image Classification") | |
| model_name = st.selectbox("Choose a pre-trained model:", list(classification_models.keys())) | |
| preprocessor = load_preprocessor(classification_models[model_name]) | |
| model = load_model(classification_models[model_name]) | |
| image = upload_image() | |
| if image is not None: | |
| preprocessed_image = preprocessor(image) | |
| raw_predictions = model(preprocessed_image) | |
| predictions = keras_hub.utils.decode_imagenet_predictions(raw_predictions) | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| st.image(image[0].astype("uint8"), caption="Uploaded Image", use_container_width=True) | |
| with col2: | |
| st.write("##### Top Predictions:") | |
| for idx, (class_name, score) in enumerate(predictions[0]): | |
| st.write(f"{idx + 1}: {class_name}") | |
| with tab2: | |
| st.subheader("Object Detection") | |
| st.write("Object Detection functionality is under development.") | |
| with tab3: | |
| st.subheader("Segmentation") | |
| st.write("Segmentation functionality is under development.") | |