Spaces:
Runtime error
Runtime error
File size: 3,547 Bytes
674666d 5bf3fbe 674666d 89bd32e 674666d 5687849 674666d 89bd32e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
from transformers import CLIPProcessor, CLIPModel
import numpy as np
from typing import Union
import PIL
import gradio as gr
class ZeroShotCLF():
def __init__(self, model_id = "openai/clip-vit-base-patch32"):
"""
Clasificador de imágenes zero-shot usando CLIP.
Métodos:
- set_classes: sirve para especificar las posibles clases. Se debe usar antes de predict o evaluate.
- predict: sirve para clasificar una o varias imágenes. Se debe usar antes set_classes.
- evaluate: sirve para evaluar el modelo. Se debe usar antes set_classes.
"""
self.model_id = model_id
self.processor = CLIPProcessor.from_pretrained(model_id)
self.model = CLIPModel.from_pretrained(model_id)
def set_classes(self, clases: list[str], tipos: Union[str, list[str]] = "image"):
"""
Establece las posibles clases en las que deberemos clasificar.
Inputs:
- clases: lista de strings con todas las clases a considerar.
- tipos: string o lista de strings con los tipos de imagen. Este tipo sirve para especificar si buscas fotografías, dibujos... Valor por defecto: "image".
Si la variable es string, se utilizará el mismo tipo para todas las clases, si es una lista, deberá tener la misma longitud que la variable clases.
Outputs:
- Lista de textos a usar para cada clase. Se usará el siguiente esquema: "a __ of a __", donde en el primer hueco se rellenará con el tipo y el segundo por la clase.
Por ejemplo, si el tipo es "photo" y la clase es "dog", el texto será "a photo of a dog".
"""
if isinstance(tipos, str):
tipos = [tipos] * len(clases)
else:
assert len(tipos) == len(clases), "¡clases y tipos deben tener la misma longitud!"
self.clases = clases
self.tipos = tipos
self.text = list(map(lambda par: "a %s of a %s" % par, zip(tipos, clases)))
return self.text
def predict(self, images):
"""
Predice las clases de las imágenes. Es necesario usar previamente el método set_classes.
Inputs:
- images: imagen PIL o lista de imágenes PIL a clasificar.
Outputs:
- Nombre de la clase predicha (string) o lista de las clases predichas (lista de strings).
- Índice de la clase predicha (int) o lista de los índices de las clases predichas (lista de ints).
- Vector de probabilidades de las clases (array) o matriz de probabilidades de las clases (ndarray de dimensiones número de imágenes por número de clases).
"""
assert 'text' in dir(self), "¡No se han especificado las clases! Usa set_classes antes que predict."
inputs = self.processor(text=self.text, images=images, return_tensors="pt", padding=True)
outputs = self.model(**inputs)
probs = outputs.logits_per_image.softmax(dim=1)
ind_clases = probs.argmax(1)
if isinstance(images, list):
ind_clases = ind_clases.tolist()
pred_clases = list(map(lambda ind: self.clases[ind], ind_clases))
return pred_clases, ind_clases, probs
ind_clases = int(ind_clases)
return self.clases[ind_clases], ind_clases, probs.ravel()
clf = ZeroShotCLF()
def clasifica(Image, Types, Classes):
clases = Classes.split(", ")
tipos = Types.split(", ")
if len(tipos) == 1:
tipos = tipos[0]
textos = clf.set_classes(clases, tipos)
probs = clf.predict(PIL.Image.fromarray(Image))[-1]
return {textos[i]: probs[i] for i in range(len(textos))}
gr.Interface(fn=clasifica, inputs=["image", "text", "text"], outputs=gr.components.Label()).launch(share=False) |