adwod commited on
Commit
f593fb6
·
1 Parent(s): 50214bc

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -0
app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import ViTFeatureExtractor, ViTForImageClassification
4
+
5
+ # Cargar el modelo preentrenado
6
+ model_path = "output/model.bin"
7
+ model = ViTForImageClassification.from_pretrained(model_path)
8
+
9
+ # Cargar el extractor de características
10
+ feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
11
+
12
+ # Función para hacer predicciones en una imagen de entrada
13
+ def predict(image):
14
+ # Preprocesar la imagen
15
+ inputs = feature_extractor(image=image, return_tensors="pt")
16
+ # Hacer predicciones
17
+ outputs = model(**inputs)
18
+ # Obtener las etiquetas predichas
19
+ predicted_labels = torch.argmax(outputs.logits, dim=1)
20
+ # Devolver las etiquetas como una lista de strings
21
+ label_strings = ['inside', 'outside', 'food', 'drink', 'menu']
22
+ return [label_strings[label] for label in predicted_labels]
23
+
24
+ # Interfaz de usuario para cargar una imagen y hacer predicciones
25
+ st.title("ViT Image Classifier")
26
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
27
+ if uploaded_file is not None:
28
+ image = Image.open(uploaded_file)
29
+ st.image(image, caption='Uploaded image.', use_column_width=True)
30
+ predictions = predict(image)
31
+ st.write("Predicted labels:")
32
+ for label in predictions:
33
+ st.write(label)