mkjaramillo commited on
Commit
2332cfc
·
1 Parent(s): 55d7d00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -2,12 +2,14 @@ import gradio as gr
2
  import torch
3
  import torchvision.transforms as transforms
4
  from PIL import Image
5
-
6
-
7
 
8
  # Nombre del modelo en el repositorio de Hugging Face
9
  model_name = "mkjaramillo/cancer2"
10
- model = torch.load("model_name")
 
 
 
11
 
12
  # Transformación de la imagen
13
  image_transform = transforms.Compose([
@@ -16,7 +18,6 @@ image_transform = transforms.Compose([
16
 
17
  ])
18
 
19
- # Función para realizar la predicción
20
  def classify_image(image):
21
  # Cargar la imagen
22
  image = Image.fromarray(image)
@@ -28,10 +29,10 @@ def classify_image(image):
28
  outputs = model(image)
29
 
30
  # Obtener las predicciones
31
- predictions = torch.argmax(outputs, dim=1)
32
 
33
  # Obtener la etiqueta de la predicción
34
- label = predictions.item()
35
 
36
  # Retornar la etiqueta de la predicción
37
  return label
 
2
  import torch
3
  import torchvision.transforms as transforms
4
  from PIL import Image
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
6
 
7
  # Nombre del modelo en el repositorio de Hugging Face
8
  model_name = "mkjaramillo/cancer2"
9
+
10
+ # Cargar el tokenizer y el modelo desde el repositorio de Hugging Face
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
13
 
14
  # Transformación de la imagen
15
  image_transform = transforms.Compose([
 
18
 
19
  ])
20
 
 
21
  def classify_image(image):
22
  # Cargar la imagen
23
  image = Image.fromarray(image)
 
29
  outputs = model(image)
30
 
31
  # Obtener las predicciones
32
+ predictions = torch.argmax(outputs.logits, dim=1)
33
 
34
  # Obtener la etiqueta de la predicción
35
+ label = tokenizer.decode(predictions.item())
36
 
37
  # Retornar la etiqueta de la predicción
38
  return label