Spaces:
Runtime error
Runtime error
| from transformers import CLIPProcessor, CLIPModel | |
| import numpy as np | |
| from typing import Union | |
| import PIL | |
| import gradio as gr | |
| class ZeroShotCLF(): | |
| def __init__(self, model_id = "openai/clip-vit-base-patch32"): | |
| """ | |
| Clasificador de imágenes zero-shot usando CLIP. | |
| Métodos: | |
| - set_classes: sirve para especificar las posibles clases. Se debe usar antes de predict o evaluate. | |
| - predict: sirve para clasificar una o varias imágenes. Se debe usar antes set_classes. | |
| - evaluate: sirve para evaluar el modelo. Se debe usar antes set_classes. | |
| """ | |
| self.model_id = model_id | |
| self.processor = CLIPProcessor.from_pretrained(model_id) | |
| self.model = CLIPModel.from_pretrained(model_id) | |
| def set_classes(self, clases: list[str], tipos: Union[str, list[str]] = "image"): | |
| """ | |
| Establece las posibles clases en las que deberemos clasificar. | |
| Inputs: | |
| - clases: lista de strings con todas las clases a considerar. | |
| - tipos: string o lista de strings con los tipos de imagen. Este tipo sirve para especificar si buscas fotografías, dibujos... Valor por defecto: "image". | |
| Si la variable es string, se utilizará el mismo tipo para todas las clases, si es una lista, deberá tener la misma longitud que la variable clases. | |
| Outputs: | |
| - Lista de textos a usar para cada clase. Se usará el siguiente esquema: "a __ of a __", donde en el primer hueco se rellenará con el tipo y el segundo por la clase. | |
| Por ejemplo, si el tipo es "photo" y la clase es "dog", el texto será "a photo of a dog". | |
| """ | |
| if isinstance(tipos, str): | |
| tipos = [tipos] * len(clases) | |
| else: | |
| assert len(tipos) == len(clases), "¡clases y tipos deben tener la misma longitud!" | |
| self.clases = clases | |
| self.tipos = tipos | |
| self.text = list(map(lambda par: "a %s of a %s" % par, zip(tipos, clases))) | |
| return self.text | |
| def predict(self, images): | |
| """ | |
| Predice las clases de las imágenes. Es necesario usar previamente el método set_classes. | |
| Inputs: | |
| - images: imagen PIL o lista de imágenes PIL a clasificar. | |
| Outputs: | |
| - Nombre de la clase predicha (string) o lista de las clases predichas (lista de strings). | |
| - Índice de la clase predicha (int) o lista de los índices de las clases predichas (lista de ints). | |
| - Vector de probabilidades de las clases (array) o matriz de probabilidades de las clases (ndarray de dimensiones número de imágenes por número de clases). | |
| """ | |
| assert 'text' in dir(self), "¡No se han especificado las clases! Usa set_classes antes que predict." | |
| inputs = self.processor(text=self.text, images=images, return_tensors="pt", padding=True) | |
| outputs = self.model(**inputs) | |
| probs = outputs.logits_per_image.softmax(dim=1) | |
| ind_clases = probs.argmax(1) | |
| if isinstance(images, list): | |
| ind_clases = ind_clases.tolist() | |
| pred_clases = list(map(lambda ind: self.clases[ind], ind_clases)) | |
| return pred_clases, ind_clases, probs | |
| ind_clases = int(ind_clases) | |
| return self.clases[ind_clases], ind_clases, probs.ravel() | |
| clf = ZeroShotCLF() | |
| def clasifica(Image, Types, Classes): | |
| clases = Classes.split(", ") | |
| tipos = Types.split(", ") | |
| if len(tipos) == 1: | |
| tipos = tipos[0] | |
| textos = clf.set_classes(clases, tipos) | |
| probs = clf.predict(PIL.Image.fromarray(Image))[-1] | |
| return {textos[i]: probs[i] for i in range(len(textos))} | |
| gr.Interface(fn=clasifica, | |
| inputs=["image", "text", "text"], | |
| outputs=gr.components.Label(), | |
| examples=[ | |
| ['caballo.jpg', "photo", "tiger, zebra, horse, bear"], | |
| ['drawing.jpg', "picture, drawing, photo", "sunset, sunset, sunset"] | |
| ], | |
| title="Zero-Shot Classifier with CLIP", | |
| description=""" | |
| This is a zero-shot image classifier model that uses CLIP model from Open-AI. | |
| Select an image to classify and write the types and classes that the model should consider. | |
| Types and classes must be separated by ", ". If only one type is wrote, all the classes are | |
| matched with this type. | |
| Please, check the examples to understand how this space works. | |
| """).launch(share=False) |