mkjaramillo commited on
Commit
7d47552
·
1 Parent(s): c645b81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -17
app.py CHANGED
@@ -1,29 +1,47 @@
1
  import gradio as gr
2
- import requests
3
- import json
 
 
4
 
5
- # Función para realizar la predicción
6
- def classify_image(image):
7
- # URL de la API de inferencia de Hugging Face
8
- api_url = "https://api-inference.huggingface.co/models/mkjaramillo/cancer"
9
 
10
- # Preparar la imagen para la API de inferencia
11
- files = {'file': image}
 
12
 
13
- # Realizar la solicitud de predicción a la API de inferencia
14
- response = requests.post(api_url, files=files)
 
 
 
 
15
 
16
- # Obtener el resultado de la predicción
17
- result = json.loads(response.content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # Retornar el resultado de la predicción
20
- return result[0]["label"]
21
 
22
  # Configurar la interfaz de Gradio
23
- iface = gr.Interface(fn=classify_image,
24
- inputs="image",
25
  outputs="text",
26
- capture_session=True,
27
  title="Clasificador de Imágenes")
28
 
29
  # Ejecutar la interfaz de Gradio
 
1
  import gradio as gr
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
 
7
+ # Nombre del modelo en el repositorio de Hugging Face
8
+ model_name = "mkjaramillo/cancer"
 
 
9
 
10
+ # Cargar el tokenizer y el modelo desde el repositorio de Hugging Face
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
13
 
14
+ # Transformación de la imagen
15
+ image_transform = transforms.Compose([
16
+ transforms.Resize((50, 50)),
17
+ transforms.ToTensor(),
18
+
19
+ ])
20
 
21
+ # Función para realizar la predicción
22
+ def classify_image(image):
23
+ # Cargar la imagen
24
+ image = Image.fromarray(image)
25
+
26
+ # Preprocesar la imagen
27
+ image = image_transform(image).unsqueeze(0)
28
+
29
+ # Realizar la inferencia con el modelo
30
+ outputs = model(image)
31
+
32
+ # Obtener las predicciones
33
+ predictions = torch.argmax(outputs.logits, dim=1)
34
+
35
+ # Obtener la etiqueta de la predicción
36
+ label = tokenizer.decode(predictions.item())
37
 
38
+ # Retornar la etiqueta de la predicción
39
+ return label
40
 
41
  # Configurar la interfaz de Gradio
42
+ iface = gr.Interface(fn=classify_image,
43
+ inputs=gr.inputs.Image(label="Imagen de entrada"),
44
  outputs="text",
 
45
  title="Clasificador de Imágenes")
46
 
47
  # Ejecutar la interfaz de Gradio