adwod commited on
Commit
9ab3705
1 Parent(s): e8f8664

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import ViTFeatureExtractor, ViTForImageClassification, ViTConfig
4
+ from huggingface_hub import hf_hub_url, cached_download
5
+
6
+ # URL del archivo de configuraci贸n en el espacio de Hugging Face
7
+ config_url = hf_hub_url(space_name="adwod", filename="config.json", repo_id="Streamlite_ViT_2000")
8
+ # Descargar el archivo de configuraci贸n y cargarlo en una instancia de ViTConfig
9
+ config = ViTConfig.from_json_dict(cached_download(config_url))
10
+
11
+ # URL del archivo de pesos del modelo en el espacio de Hugging Face
12
+ model_path = hf_hub_url(space_name="adwod", filename="pytorch_model.bin", repo_id="Streamlite_ViT_2000")
13
+ # Descargar el archivo de pesos del modelo y cargarlo en el modelo
14
+ model.load_state_dict(torch.load(cached_download(model_path)))
15
+ model = ViTForImageClassification(config)
16
+ model.load_state_dict(torch.load(model_path))
17
+
18
+ # Cargar el extractor de caracter铆sticas
19
+ feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
20
+
21
+ # Funci贸n para hacer predicciones en una imagen de entrada
22
+ def predict(image):
23
+ # Preprocesar la imagen
24
+ inputs = feature_extractor(image=image, return_tensors="pt")
25
+ # Hacer predicciones
26
+ outputs = model(**inputs)
27
+ # Obtener las etiquetas predichas
28
+ predicted_labels = torch.argmax(outputs.logits, dim=1)
29
+ # Devolver las etiquetas como una lista de strings
30
+ label_strings = ['inside', 'outside', 'food', 'drink', 'menu']
31
+ return [label_strings[label] for label in predicted_labels]
32
+
33
+ # Interfaz de usuario para cargar una imagen y hacer predicciones
34
+ st.title("ViT Image Classifier")
35
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
36
+ if uploaded_file is not None:
37
+ image = Image.open(uploaded_file)
38
+ st.image(image, caption='Uploaded image.', use_column_width=True)
39
+ predictions = predict(image)
40
+ st.write("Predicted labels:")
41
+ for label in predictions:
42
+ st.write(label)
43
+