File size: 4,285 Bytes
674666d
 
 
 
5bf3fbe
674666d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89bd32e
 
 
674666d
 
 
5687849
674666d
 
4768faa
 
 
 
 
 
a2e4494
 
4e67fd1
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from transformers import CLIPProcessor, CLIPModel
import numpy as np
from typing import Union
import PIL
import gradio as gr

class ZeroShotCLF():
  def __init__(self, model_id = "openai/clip-vit-base-patch32"):
    """
    Clasificador de imágenes zero-shot usando CLIP.

    Métodos:
      - set_classes: sirve para especificar las posibles clases. Se debe usar antes de predict o evaluate.
      - predict: sirve para clasificar una o varias imágenes. Se debe usar antes set_classes.
      - evaluate: sirve para evaluar el modelo. Se debe usar antes set_classes.
    """
    self.model_id = model_id
    self.processor = CLIPProcessor.from_pretrained(model_id)
    self.model = CLIPModel.from_pretrained(model_id)

  def set_classes(self, clases: list[str], tipos: Union[str, list[str]] = "image"):
    """
    Establece las posibles clases en las que deberemos clasificar.

    Inputs:
      - clases: lista de strings con todas las clases a considerar.
      - tipos: string o lista de strings con los tipos de imagen. Este tipo sirve para especificar si buscas fotografías, dibujos... Valor por defecto: "image".
      Si la variable es string, se utilizará el mismo tipo para todas las clases, si es una lista, deberá tener la misma longitud que la variable clases.

    Outputs:
      - Lista de textos a usar para cada clase. Se usará el siguiente esquema: "a __ of a __", donde en el primer hueco se rellenará con el tipo y el segundo por la clase.
      Por ejemplo, si el tipo es "photo" y la clase es "dog", el texto será "a photo of a dog".
    """
    if isinstance(tipos, str):
      tipos = [tipos] * len(clases)
    else:
      assert len(tipos) == len(clases), "¡clases y tipos deben tener la misma longitud!"

    self.clases = clases
    self.tipos = tipos
    self.text = list(map(lambda par: "a %s of a %s" % par, zip(tipos, clases)))
    return self.text

  def predict(self, images):
    """
    Predice las clases de las imágenes. Es necesario usar previamente el método set_classes.

    Inputs:
      - images: imagen PIL o lista de imágenes PIL a clasificar.

    Outputs:
      - Nombre de la clase predicha (string) o lista de las clases predichas (lista de strings).
      - Índice de la clase predicha (int) o lista de los índices de las clases predichas (lista de ints).
      - Vector de probabilidades de las clases (array) o matriz de probabilidades de las clases (ndarray de dimensiones número de imágenes por número de clases).
    """
    assert 'text' in dir(self), "¡No se han especificado las clases! Usa set_classes antes que predict."

    inputs = self.processor(text=self.text, images=images, return_tensors="pt", padding=True)
    outputs = self.model(**inputs)

    probs = outputs.logits_per_image.softmax(dim=1)

    ind_clases = probs.argmax(1)

    if isinstance(images, list):
      ind_clases = ind_clases.tolist()
      pred_clases = list(map(lambda ind: self.clases[ind], ind_clases))
      return pred_clases, ind_clases, probs

    ind_clases = int(ind_clases)
    return self.clases[ind_clases], ind_clases, probs.ravel()

clf = ZeroShotCLF()

def clasifica(Image, Types, Classes):
    clases = Classes.split(", ")
    tipos = Types.split(", ")
    if len(tipos) == 1:
        tipos = tipos[0]
    textos = clf.set_classes(clases, tipos)
    probs = clf.predict(PIL.Image.fromarray(Image))[-1]
    return {textos[i]: probs[i] for i in range(len(textos))}

gr.Interface(fn=clasifica,
             inputs=["image", "text", "text"],
             outputs=gr.components.Label(),
             examples=[
                 ['caballo.jpg', "photo", "tiger, zebra, horse, bear"],
                 ['drawing.jpg', "picture, drawing, photo", "sunset, sunset, sunset"]
             ],
             title="Zero-Shot Classifier with CLIP",
             description="""
             This is a zero-shot image classifier model that uses CLIP model from Open-AI.
             Select an image to classify and write the types and classes that the model should consider.
             Types and classes must be separated by ", ". If only one type is wrote, all the classes are
             matched with this type.
             Please, check the examples to understand how this space works.
             """).launch(share=False)