mkjaramillo commited on
Commit
d4ee5e3
·
1 Parent(s): c76a77a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -7
app.py CHANGED
@@ -2,15 +2,12 @@ import gradio as gr
2
  import torch
3
  import torchvision.transforms as transforms
4
  from PIL import Image
5
- from transformers import ViTForImageClassification, AutoTokenizer
6
 
7
 
8
  # Nombre del modelo en el repositorio de Hugging Face
9
  model_name = "mkjaramillo/cancer"
10
-
11
- # Cargar el tokenizer y el modelo desde el repositorio de Hugging Face
12
- tokenizer = AutoTokenizer.from_pretrained(model_name)
13
- model = ViTForImageClassification.from_pretrained(model_name)
14
 
15
  # Transformación de la imagen
16
  image_transform = transforms.Compose([
@@ -31,10 +28,10 @@ def classify_image(image):
31
  outputs = model(image)
32
 
33
  # Obtener las predicciones
34
- predictions = torch.argmax(outputs.logits, dim=1)
35
 
36
  # Obtener la etiqueta de la predicción
37
- label = tokenizer.decode(predictions.item())
38
 
39
  # Retornar la etiqueta de la predicción
40
  return label
 
2
  import torch
3
  import torchvision.transforms as transforms
4
  from PIL import Image
5
+
6
 
7
 
8
  # Nombre del modelo en el repositorio de Hugging Face
9
  model_name = "mkjaramillo/cancer"
10
+ model = torch.load("model_name")
 
 
 
11
 
12
  # Transformación de la imagen
13
  image_transform = transforms.Compose([
 
28
  outputs = model(image)
29
 
30
  # Obtener las predicciones
31
+ predictions = torch.argmax(outputs, dim=1)
32
 
33
  # Obtener la etiqueta de la predicción
34
+ label = predictions.item()
35
 
36
  # Retornar la etiqueta de la predicción
37
  return label