Spaces:
Build error
Build error
| import streamlit as st | |
| from transformers import ViTForImageClassification, ViTFeatureExtractor | |
| from PIL import Image | |
| import torch | |
| import matplotlib.pyplot as plt | |
| # Define the repository ID | |
| repo_id = "Hammad712/5-Flower-Types-Classification-VIT-Model" | |
| # Load the model and feature extractor | |
| model = ViTForImageClassification.from_pretrained(repo_id) | |
| feature_extractor = ViTFeatureExtractor.from_pretrained(repo_id) | |
| # Define the class names dictionary | |
| class_names = {0: 'Lilly', 1: 'Lotus', 2: 'Orchid', 3: 'Sunflower', 4: 'Tulip'} | |
| # Define the inference function | |
| def predict(image): | |
| inputs = feature_extractor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist() | |
| predicted_class_idx = logits.argmax(-1).item() | |
| predicted_class_name = class_names[predicted_class_idx] | |
| return probabilities, predicted_class_name | |
| # Streamlit app | |
| st.title("Flower Type Classification") | |
| st.write("Upload an image of a flower to classify its type.") | |
| # Upload image | |
| uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| # Display the uploaded image | |
| image = Image.open(uploaded_file).convert("RGB") | |
| st.image(image, caption='Uploaded Image.', use_column_width=True) | |
| # Predict the class of the image | |
| probabilities, predicted_class = predict(image) | |
| # Display the probabilities in a bar chart | |
| fig, ax = plt.subplots() | |
| ax.bar(class_names.values(), probabilities) | |
| ax.set_ylabel('Probability') | |
| ax.set_xlabel('Class') | |
| ax.set_title('Class Probabilities') | |
| st.pyplot(fig) | |
| # Display the predicted class | |
| st.write(f"Predicted class: **{predicted_class}**") | |