File size: 3,539 Bytes
674666d
 
 
 
5bf3fbe
674666d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from transformers import CLIPProcessor, CLIPModel
import numpy as np
from typing import Union
import PIL
import gradio as gr

class ZeroShotCLF():
  def __init__(self, model_id = "openai/clip-vit-base-patch32"):
    """
    Clasificador de imágenes zero-shot usando CLIP.

    Métodos:
      - set_classes: sirve para especificar las posibles clases. Se debe usar antes de predict o evaluate.
      - predict: sirve para clasificar una o varias imágenes. Se debe usar antes set_classes.
      - evaluate: sirve para evaluar el modelo. Se debe usar antes set_classes.
    """
    self.model_id = model_id
    self.processor = CLIPProcessor.from_pretrained(model_id)
    self.model = CLIPModel.from_pretrained(model_id)

  def set_classes(self, clases: list[str], tipos: Union[str, list[str]] = "image"):
    """
    Establece las posibles clases en las que deberemos clasificar.

    Inputs:
      - clases: lista de strings con todas las clases a considerar.
      - tipos: string o lista de strings con los tipos de imagen. Este tipo sirve para especificar si buscas fotografías, dibujos... Valor por defecto: "image".
      Si la variable es string, se utilizará el mismo tipo para todas las clases, si es una lista, deberá tener la misma longitud que la variable clases.

    Outputs:
      - Lista de textos a usar para cada clase. Se usará el siguiente esquema: "a __ of a __", donde en el primer hueco se rellenará con el tipo y el segundo por la clase.
      Por ejemplo, si el tipo es "photo" y la clase es "dog", el texto será "a photo of a dog".
    """
    if isinstance(tipos, str):
      tipos = [tipos] * len(clases)
    else:
      assert len(tipos) == len(clases), "¡clases y tipos deben tener la misma longitud!"

    self.clases = clases
    self.tipos = tipos
    self.text = list(map(lambda par: "a %s of a %s" % par, zip(tipos, clases)))
    return self.text

  def predict(self, images):
    """
    Predice las clases de las imágenes. Es necesario usar previamente el método set_classes.

    Inputs:
      - images: imagen PIL o lista de imágenes PIL a clasificar.

    Outputs:
      - Nombre de la clase predicha (string) o lista de las clases predichas (lista de strings).
      - Índice de la clase predicha (int) o lista de los índices de las clases predichas (lista de ints).
      - Vector de probabilidades de las clases (array) o matriz de probabilidades de las clases (ndarray de dimensiones número de imágenes por número de clases).
    """
    assert 'text' in dir(self), "¡No se han especificado las clases! Usa set_classes antes que predict."

    inputs = self.processor(text=self.text, images=images, return_tensors="pt", padding=True)
    outputs = self.model(**inputs)

    probs = outputs.logits_per_image.softmax(dim=1)

    ind_clases = probs.argmax(1)

    if isinstance(images, list):
      ind_clases = ind_clases.tolist()
      pred_clases = list(map(lambda ind: self.clases[ind], ind_clases))
      return pred_clases, ind_clases, probs

    ind_clases = int(ind_clases)
    return self.clases[ind_clases], ind_clases, probs.ravel()

clf = ZeroShotCLF()

def clasifica(imagen, clases_, tipos_):
    clases = clases_.split(", ")
    tipos = tipos_.split(", ")
    if len(tipos) == 1:
        tipos = tipos[0]
    textos = clf.set_classes(clases, tipos)
    probs = clf.predict(PIL.Image.fromarray(imagen))[-1][0]
    return {textos[i]: probs[i] for i in range(len(textos))}

gr.Interface(fn=clasifica, inputs=["image", "text", "text"], outputs="text").launch(share=False)