amanneo's picture
Updated app.py
6f474e3
import torch
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import streamlit as st
from PIL import Image
model_id = f'amanneo/vit-base-patch16-224-finetuned-flower'
labels = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
def classify_image(image):
model = AutoModelForImageClassification.from_pretrained(model_id)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
inp = feature_extractor(image, return_tensors='pt')
outp = model(**inp)
pred = torch.nn.functional.softmax(outp.logits, dim=-1)
preds = pred[0].cpu().detach().numpy()
confidence = {label: float(preds[i]) for i, label in enumerate(labels)}
return confidence
file_name = st.file_uploader("Upload flower image")
if file_name is not None:
col1,col2 = st.columns(2)
image = Image.open(file_name)
col1.image(image,use_column_width=True)
predictions = classify_image(image)
col2.header("Probabilities")
for l,p in predictions.items():
col2.subheader("{} : {}".format(l,p))