import streamlit as st from transformers import ViTFeatureExtractor, ViTForImageClassification from PIL import Image import torch # Load model and feature extractor @st.cache_resource def load_model(): model_name = "google/vit-base-patch16-224" model = ViTForImageClassification.from_pretrained(model_name) feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) return model, feature_extractor model, feature_extractor = load_model() st.title("Animal Recognition App 🐾") st.write("Upload an image to detect the animal.") uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption="Uploaded Image.", use_column_width=True) st.write("Classifying...") # Preprocess Image inputs = feature_extractor(images=image, return_tensors="pt") # Predict with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() labels = model.config.id2label predicted_label = labels[predicted_class_idx] st.success(f"Prediction: {predicted_label}")