adwod's picture
Update app.py
1dcd51d
raw
history blame
1.89 kB
import streamlit as st
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification, ViTConfig
from huggingface_hub import hf_hub_url, cached_download
# URL del archivo de configuraci贸n en el espacio de Hugging Face
config_url = hf_hub_url(filename="config.json", repo_id="adwod/Streamlite_ViT_2000")
# Descargar el archivo de configuraci贸n y cargarlo en una instancia de ViTConfig
config = ViTConfig.from_json_dict(cached_download(config_url))
# URL del archivo de pesos del modelo en el espacio de Hugging Face
model_path = hf_hub_url(filename="pytorch_model.bin", repo_id="adwod/Streamlite_ViT_2000")
# Descargar el archivo de pesos del modelo y cargarlo en el modelo
model = ViTForImageClassification(config)
model.load_state_dict(torch.load(cached_download(model_path)))
# Cargar el extractor de caracter铆sticas
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
# Funci贸n para hacer predicciones en una imagen de entrada
def predict(image):
# Preprocesar la imagen
inputs = feature_extractor(image=image, return_tensors="pt")
# Hacer predicciones
outputs = model(**inputs)
# Obtener las etiquetas predichas
predicted_labels = torch.argmax(outputs.logits, dim=1)
# Devolver las etiquetas como una lista de strings
label_strings = ['inside', 'outside', 'food', 'drink', 'menu']
return [label_strings[label] for label in predicted_labels]
# Interfaz de usuario para cargar una imagen y hacer predicciones
st.title("ViT Image Classifier")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded image.', use_column_width=True)
predictions = predict(image)
st.write("Predicted labels:")
for label in predictions:
st.write(label)