Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch | |
| from transformers import ViTFeatureExtractor, ViTForImageClassification, ViTConfig | |
| from huggingface_hub import hf_hub_url, cached_download | |
| # URL del archivo de configuraci贸n en el espacio de Hugging Face | |
| config_url = hf_hub_url(filename="config.json", repo_id="adwod/Streamlite_ViT_2000") | |
| # Descargar el archivo de configuraci贸n y cargarlo en una instancia de ViTConfig | |
| config = ViTConfig.from_json_dict(cached_download(config_url)) | |
| # URL del archivo de pesos del modelo en el espacio de Hugging Face | |
| model_path = hf_hub_url(filename="pytorch_model.bin", repo_id="adwod/Streamlite_ViT_2000") | |
| # Descargar el archivo de pesos del modelo y cargarlo en el modelo | |
| model = ViTForImageClassification(config) | |
| model.load_state_dict(torch.load(cached_download(model_path))) | |
| # Cargar el extractor de caracter铆sticas | |
| feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224') | |
| # Funci贸n para hacer predicciones en una imagen de entrada | |
| def predict(image): | |
| # Preprocesar la imagen | |
| inputs = feature_extractor(image=image, return_tensors="pt") | |
| # Hacer predicciones | |
| outputs = model(**inputs) | |
| # Obtener las etiquetas predichas | |
| predicted_labels = torch.argmax(outputs.logits, dim=1) | |
| # Devolver las etiquetas como una lista de strings | |
| label_strings = ['inside', 'outside', 'food', 'drink', 'menu'] | |
| return [label_strings[label] for label in predicted_labels] | |
| # Interfaz de usuario para cargar una imagen y hacer predicciones | |
| st.title("ViT Image Classifier") | |
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file) | |
| st.image(image, caption='Uploaded image.', use_column_width=True) | |
| predictions = predict(image) | |
| st.write("Predicted labels:") | |
| for label in predictions: | |
| st.write(label) | |