CLIP_Classifier / app.py
pamunarr's picture
Update app.py
5bf3fbe verified
raw
history blame
3.54 kB
from transformers import CLIPProcessor, CLIPModel
import numpy as np
from typing import Union
import PIL
import gradio as gr
class ZeroShotCLF():
def __init__(self, model_id = "openai/clip-vit-base-patch32"):
"""
Clasificador de imágenes zero-shot usando CLIP.
Métodos:
- set_classes: sirve para especificar las posibles clases. Se debe usar antes de predict o evaluate.
- predict: sirve para clasificar una o varias imágenes. Se debe usar antes set_classes.
- evaluate: sirve para evaluar el modelo. Se debe usar antes set_classes.
"""
self.model_id = model_id
self.processor = CLIPProcessor.from_pretrained(model_id)
self.model = CLIPModel.from_pretrained(model_id)
def set_classes(self, clases: list[str], tipos: Union[str, list[str]] = "image"):
"""
Establece las posibles clases en las que deberemos clasificar.
Inputs:
- clases: lista de strings con todas las clases a considerar.
- tipos: string o lista de strings con los tipos de imagen. Este tipo sirve para especificar si buscas fotografías, dibujos... Valor por defecto: "image".
Si la variable es string, se utilizará el mismo tipo para todas las clases, si es una lista, deberá tener la misma longitud que la variable clases.
Outputs:
- Lista de textos a usar para cada clase. Se usará el siguiente esquema: "a __ of a __", donde en el primer hueco se rellenará con el tipo y el segundo por la clase.
Por ejemplo, si el tipo es "photo" y la clase es "dog", el texto será "a photo of a dog".
"""
if isinstance(tipos, str):
tipos = [tipos] * len(clases)
else:
assert len(tipos) == len(clases), "¡clases y tipos deben tener la misma longitud!"
self.clases = clases
self.tipos = tipos
self.text = list(map(lambda par: "a %s of a %s" % par, zip(tipos, clases)))
return self.text
def predict(self, images):
"""
Predice las clases de las imágenes. Es necesario usar previamente el método set_classes.
Inputs:
- images: imagen PIL o lista de imágenes PIL a clasificar.
Outputs:
- Nombre de la clase predicha (string) o lista de las clases predichas (lista de strings).
- Índice de la clase predicha (int) o lista de los índices de las clases predichas (lista de ints).
- Vector de probabilidades de las clases (array) o matriz de probabilidades de las clases (ndarray de dimensiones número de imágenes por número de clases).
"""
assert 'text' in dir(self), "¡No se han especificado las clases! Usa set_classes antes que predict."
inputs = self.processor(text=self.text, images=images, return_tensors="pt", padding=True)
outputs = self.model(**inputs)
probs = outputs.logits_per_image.softmax(dim=1)
ind_clases = probs.argmax(1)
if isinstance(images, list):
ind_clases = ind_clases.tolist()
pred_clases = list(map(lambda ind: self.clases[ind], ind_clases))
return pred_clases, ind_clases, probs
ind_clases = int(ind_clases)
return self.clases[ind_clases], ind_clases, probs.ravel()
clf = ZeroShotCLF()
def clasifica(imagen, clases_, tipos_):
clases = clases_.split(", ")
tipos = tipos_.split(", ")
if len(tipos) == 1:
tipos = tipos[0]
textos = clf.set_classes(clases, tipos)
probs = clf.predict(PIL.Image.fromarray(imagen))[-1][0]
return {textos[i]: probs[i] for i in range(len(textos))}
gr.Interface(fn=clasifica, inputs=["image", "text", "text"], outputs="text").launch(share=False)