Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import CLIPProcessor, CLIPModel
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import Union
|
| 4 |
+
import PIL
|
| 5 |
+
|
| 6 |
+
class ZeroShotCLF():
|
| 7 |
+
def __init__(self, model_id = "openai/clip-vit-base-patch32"):
|
| 8 |
+
"""
|
| 9 |
+
Clasificador de imágenes zero-shot usando CLIP.
|
| 10 |
+
|
| 11 |
+
Métodos:
|
| 12 |
+
- set_classes: sirve para especificar las posibles clases. Se debe usar antes de predict o evaluate.
|
| 13 |
+
- predict: sirve para clasificar una o varias imágenes. Se debe usar antes set_classes.
|
| 14 |
+
- evaluate: sirve para evaluar el modelo. Se debe usar antes set_classes.
|
| 15 |
+
"""
|
| 16 |
+
self.model_id = model_id
|
| 17 |
+
self.processor = CLIPProcessor.from_pretrained(model_id)
|
| 18 |
+
self.model = CLIPModel.from_pretrained(model_id)
|
| 19 |
+
|
| 20 |
+
def set_classes(self, clases: list[str], tipos: Union[str, list[str]] = "image"):
|
| 21 |
+
"""
|
| 22 |
+
Establece las posibles clases en las que deberemos clasificar.
|
| 23 |
+
|
| 24 |
+
Inputs:
|
| 25 |
+
- clases: lista de strings con todas las clases a considerar.
|
| 26 |
+
- tipos: string o lista de strings con los tipos de imagen. Este tipo sirve para especificar si buscas fotografías, dibujos... Valor por defecto: "image".
|
| 27 |
+
Si la variable es string, se utilizará el mismo tipo para todas las clases, si es una lista, deberá tener la misma longitud que la variable clases.
|
| 28 |
+
|
| 29 |
+
Outputs:
|
| 30 |
+
- Lista de textos a usar para cada clase. Se usará el siguiente esquema: "a __ of a __", donde en el primer hueco se rellenará con el tipo y el segundo por la clase.
|
| 31 |
+
Por ejemplo, si el tipo es "photo" y la clase es "dog", el texto será "a photo of a dog".
|
| 32 |
+
"""
|
| 33 |
+
if isinstance(tipos, str):
|
| 34 |
+
tipos = [tipos] * len(clases)
|
| 35 |
+
else:
|
| 36 |
+
assert len(tipos) == len(clases), "¡clases y tipos deben tener la misma longitud!"
|
| 37 |
+
|
| 38 |
+
self.clases = clases
|
| 39 |
+
self.tipos = tipos
|
| 40 |
+
self.text = list(map(lambda par: "a %s of a %s" % par, zip(tipos, clases)))
|
| 41 |
+
return self.text
|
| 42 |
+
|
| 43 |
+
def predict(self, images):
|
| 44 |
+
"""
|
| 45 |
+
Predice las clases de las imágenes. Es necesario usar previamente el método set_classes.
|
| 46 |
+
|
| 47 |
+
Inputs:
|
| 48 |
+
- images: imagen PIL o lista de imágenes PIL a clasificar.
|
| 49 |
+
|
| 50 |
+
Outputs:
|
| 51 |
+
- Nombre de la clase predicha (string) o lista de las clases predichas (lista de strings).
|
| 52 |
+
- Índice de la clase predicha (int) o lista de los índices de las clases predichas (lista de ints).
|
| 53 |
+
- Vector de probabilidades de las clases (array) o matriz de probabilidades de las clases (ndarray de dimensiones número de imágenes por número de clases).
|
| 54 |
+
"""
|
| 55 |
+
assert 'text' in dir(self), "¡No se han especificado las clases! Usa set_classes antes que predict."
|
| 56 |
+
|
| 57 |
+
inputs = self.processor(text=self.text, images=images, return_tensors="pt", padding=True)
|
| 58 |
+
outputs = self.model(**inputs)
|
| 59 |
+
|
| 60 |
+
probs = outputs.logits_per_image.softmax(dim=1)
|
| 61 |
+
|
| 62 |
+
ind_clases = probs.argmax(1)
|
| 63 |
+
|
| 64 |
+
if isinstance(images, list):
|
| 65 |
+
ind_clases = ind_clases.tolist()
|
| 66 |
+
pred_clases = list(map(lambda ind: self.clases[ind], ind_clases))
|
| 67 |
+
return pred_clases, ind_clases, probs
|
| 68 |
+
|
| 69 |
+
ind_clases = int(ind_clases)
|
| 70 |
+
return self.clases[ind_clases], ind_clases, probs.ravel()
|
| 71 |
+
|
| 72 |
+
clf = ZeroShotCLF()
|
| 73 |
+
|
| 74 |
+
def clasifica(imagen, clases_, tipos_):
|
| 75 |
+
clases = clases_.split(", ")
|
| 76 |
+
tipos = tipos_.split(", ")
|
| 77 |
+
if len(tipos) == 1:
|
| 78 |
+
tipos = tipos[0]
|
| 79 |
+
textos = clf.set_classes(clases, tipos)
|
| 80 |
+
probs = clf.predict(PIL.Image.fromarray(imagen))[-1][0]
|
| 81 |
+
return {textos[i]: probs[i] for i in range(len(textos))}
|
| 82 |
+
|
| 83 |
+
gr.Interface(fn=clasifica, inputs=["image", "text", "text"], outputs="text").launch(share=False)
|