pamunarr commited on
Commit
674666d
·
verified ·
1 Parent(s): 6235eed

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPProcessor, CLIPModel
2
+ import numpy as np
3
+ from typing import Union
4
+ import PIL
5
+
6
+ class ZeroShotCLF():
7
+ def __init__(self, model_id = "openai/clip-vit-base-patch32"):
8
+ """
9
+ Clasificador de imágenes zero-shot usando CLIP.
10
+
11
+ Métodos:
12
+ - set_classes: sirve para especificar las posibles clases. Se debe usar antes de predict o evaluate.
13
+ - predict: sirve para clasificar una o varias imágenes. Se debe usar antes set_classes.
14
+ - evaluate: sirve para evaluar el modelo. Se debe usar antes set_classes.
15
+ """
16
+ self.model_id = model_id
17
+ self.processor = CLIPProcessor.from_pretrained(model_id)
18
+ self.model = CLIPModel.from_pretrained(model_id)
19
+
20
+ def set_classes(self, clases: list[str], tipos: Union[str, list[str]] = "image"):
21
+ """
22
+ Establece las posibles clases en las que deberemos clasificar.
23
+
24
+ Inputs:
25
+ - clases: lista de strings con todas las clases a considerar.
26
+ - tipos: string o lista de strings con los tipos de imagen. Este tipo sirve para especificar si buscas fotografías, dibujos... Valor por defecto: "image".
27
+ Si la variable es string, se utilizará el mismo tipo para todas las clases, si es una lista, deberá tener la misma longitud que la variable clases.
28
+
29
+ Outputs:
30
+ - Lista de textos a usar para cada clase. Se usará el siguiente esquema: "a __ of a __", donde en el primer hueco se rellenará con el tipo y el segundo por la clase.
31
+ Por ejemplo, si el tipo es "photo" y la clase es "dog", el texto será "a photo of a dog".
32
+ """
33
+ if isinstance(tipos, str):
34
+ tipos = [tipos] * len(clases)
35
+ else:
36
+ assert len(tipos) == len(clases), "¡clases y tipos deben tener la misma longitud!"
37
+
38
+ self.clases = clases
39
+ self.tipos = tipos
40
+ self.text = list(map(lambda par: "a %s of a %s" % par, zip(tipos, clases)))
41
+ return self.text
42
+
43
+ def predict(self, images):
44
+ """
45
+ Predice las clases de las imágenes. Es necesario usar previamente el método set_classes.
46
+
47
+ Inputs:
48
+ - images: imagen PIL o lista de imágenes PIL a clasificar.
49
+
50
+ Outputs:
51
+ - Nombre de la clase predicha (string) o lista de las clases predichas (lista de strings).
52
+ - Índice de la clase predicha (int) o lista de los índices de las clases predichas (lista de ints).
53
+ - Vector de probabilidades de las clases (array) o matriz de probabilidades de las clases (ndarray de dimensiones número de imágenes por número de clases).
54
+ """
55
+ assert 'text' in dir(self), "¡No se han especificado las clases! Usa set_classes antes que predict."
56
+
57
+ inputs = self.processor(text=self.text, images=images, return_tensors="pt", padding=True)
58
+ outputs = self.model(**inputs)
59
+
60
+ probs = outputs.logits_per_image.softmax(dim=1)
61
+
62
+ ind_clases = probs.argmax(1)
63
+
64
+ if isinstance(images, list):
65
+ ind_clases = ind_clases.tolist()
66
+ pred_clases = list(map(lambda ind: self.clases[ind], ind_clases))
67
+ return pred_clases, ind_clases, probs
68
+
69
+ ind_clases = int(ind_clases)
70
+ return self.clases[ind_clases], ind_clases, probs.ravel()
71
+
72
+ clf = ZeroShotCLF()
73
+
74
+ def clasifica(imagen, clases_, tipos_):
75
+ clases = clases_.split(", ")
76
+ tipos = tipos_.split(", ")
77
+ if len(tipos) == 1:
78
+ tipos = tipos[0]
79
+ textos = clf.set_classes(clases, tipos)
80
+ probs = clf.predict(PIL.Image.fromarray(imagen))[-1][0]
81
+ return {textos[i]: probs[i] for i in range(len(textos))}
82
+
83
+ gr.Interface(fn=clasifica, inputs=["image", "text", "text"], outputs="text").launch(share=False)