from transformers import CLIPProcessor, CLIPModel import numpy as np from typing import Union import PIL class ZeroShotCLF(): def __init__(self, model_id = "openai/clip-vit-base-patch32"): """ Clasificador de imágenes zero-shot usando CLIP. Métodos: - set_classes: sirve para especificar las posibles clases. Se debe usar antes de predict o evaluate. - predict: sirve para clasificar una o varias imágenes. Se debe usar antes set_classes. - evaluate: sirve para evaluar el modelo. Se debe usar antes set_classes. """ self.model_id = model_id self.processor = CLIPProcessor.from_pretrained(model_id) self.model = CLIPModel.from_pretrained(model_id) def set_classes(self, clases: list[str], tipos: Union[str, list[str]] = "image"): """ Establece las posibles clases en las que deberemos clasificar. Inputs: - clases: lista de strings con todas las clases a considerar. - tipos: string o lista de strings con los tipos de imagen. Este tipo sirve para especificar si buscas fotografías, dibujos... Valor por defecto: "image". Si la variable es string, se utilizará el mismo tipo para todas las clases, si es una lista, deberá tener la misma longitud que la variable clases. Outputs: - Lista de textos a usar para cada clase. Se usará el siguiente esquema: "a __ of a __", donde en el primer hueco se rellenará con el tipo y el segundo por la clase. Por ejemplo, si el tipo es "photo" y la clase es "dog", el texto será "a photo of a dog". """ if isinstance(tipos, str): tipos = [tipos] * len(clases) else: assert len(tipos) == len(clases), "¡clases y tipos deben tener la misma longitud!" self.clases = clases self.tipos = tipos self.text = list(map(lambda par: "a %s of a %s" % par, zip(tipos, clases))) return self.text def predict(self, images): """ Predice las clases de las imágenes. Es necesario usar previamente el método set_classes. Inputs: - images: imagen PIL o lista de imágenes PIL a clasificar. Outputs: - Nombre de la clase predicha (string) o lista de las clases predichas (lista de strings). - Índice de la clase predicha (int) o lista de los índices de las clases predichas (lista de ints). - Vector de probabilidades de las clases (array) o matriz de probabilidades de las clases (ndarray de dimensiones número de imágenes por número de clases). """ assert 'text' in dir(self), "¡No se han especificado las clases! Usa set_classes antes que predict." inputs = self.processor(text=self.text, images=images, return_tensors="pt", padding=True) outputs = self.model(**inputs) probs = outputs.logits_per_image.softmax(dim=1) ind_clases = probs.argmax(1) if isinstance(images, list): ind_clases = ind_clases.tolist() pred_clases = list(map(lambda ind: self.clases[ind], ind_clases)) return pred_clases, ind_clases, probs ind_clases = int(ind_clases) return self.clases[ind_clases], ind_clases, probs.ravel() clf = ZeroShotCLF() def clasifica(imagen, clases_, tipos_): clases = clases_.split(", ") tipos = tipos_.split(", ") if len(tipos) == 1: tipos = tipos[0] textos = clf.set_classes(clases, tipos) probs = clf.predict(PIL.Image.fromarray(imagen))[-1][0] return {textos[i]: probs[i] for i in range(len(textos))} gr.Interface(fn=clasifica, inputs=["image", "text", "text"], outputs="text").launch(share=False)