ismaeltorres00 commited on
Commit
00f86b7
verified
1 Parent(s): 5e95917

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -31
app.py CHANGED
@@ -1,37 +1,17 @@
1
- import gradio as gr
2
- from transformers import ViTFeatureExtractor, ViTForImageClassification
3
- from PIL import Image
4
- import torch
5
 
6
  def classify_image(image):
7
- # Cargar el feature extractor y el modelo
8
- feature_extractor = ViTFeatureExtractor.from_pretrained("ismaeltorres00/ModeloFinalEuroSat")
9
- model = ViTForImageClassification.from_pretrained("ismaeltorres00/ModeloFinalEuroSat")
 
 
 
10
 
11
- # Preprocesar la imagen
12
  inputs = feature_extractor(images=image, return_tensors="pt")
13
-
14
- # Mover el modelo a GPU si est谩 disponible
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- model.to(device)
17
- inputs = {k: v.to(device) for k, v in inputs.items()}
18
-
19
- # Realizar la predicci贸n
20
- with torch.no_grad():
21
- outputs = model(**inputs)
22
-
23
- # Obtener el label de la predicci贸n
24
  logits = outputs.logits
25
- predicted_class_idx = logits.argmax(-1).item()
26
- return model.config.id2label[predicted_class_idx]
27
-
28
- # Crear la interfaz de Gradio
29
- demo = gr.Interface(
30
- fn=classify_image,
31
- inputs=gr.Image(type="pil"),
32
- outputs="text",
33
- title="Clasificaci贸n de Im谩genes con ViT"
34
- )
35
 
36
- # Lanzar la aplicaci贸n
37
- demo.launch()
 
1
+ from transformers import ViTFeatureExtractor, AutoModelForImageClassification
 
 
 
2
 
3
  def classify_image(image):
4
+ try:
5
+ feature_extractor = ViTFeatureExtractor.from_pretrained("ismaeltorres00/ModeloFinalEuroSat")
6
+ except OSError as e:
7
+ # Manejo del error si el archivo no existe
8
+ print("No se pudo encontrar el archivo preprocessor_config.json. Verifica el repositorio.")
9
+ raise e
10
 
11
+ model = AutoModelForImageClassification.from_pretrained("ismaeltorres00/ModeloFinalEuroSat")
12
  inputs = feature_extractor(images=image, return_tensors="pt")
13
+ outputs = model(**inputs)
 
 
 
 
 
 
 
 
 
 
14
  logits = outputs.logits
 
 
 
 
 
 
 
 
 
 
15
 
16
+ # Aqu铆 procesar铆as los logits para obtener la clasificaci贸n
17
+ return logits.argmax(-1).item()