| from PIL import Image |
| import pandas as pd |
| import numpy as np |
| from sentence_transformers import SentenceTransformer |
|
|
| from utils.utils import encode_images, encode_texts |
|
|
| class ClothesClassifier(): |
| def __init__( |
| self, |
| model_name: str, |
| minimum_similarity: float = 0.20, |
| minimum_margin: float = 0.015, |
| ): |
| self.model_name = model_name |
| self.model = SentenceTransformer( |
| self.model_name, |
| device="cpu", |
| ) |
| self.category_prompts = { |
| "dress": [ |
| "a photo of a dress", |
| "a one-piece garment covering the upper and lower body", |
| "a casual or formal dress worn by a person", |
| ], |
| "pants": [ |
| "a photo of pants or trousers", |
| "a lower-body garment with two trouser legs", |
| "jeans, chinos, trousers, or sweatpants", |
| ], |
| "top": [ |
| "a photo of a t-shirt, shirt, blouse, or tank top", |
| "a lightweight garment worn on the upper body", |
| "a shirt that is not a sweater, hoodie, jacket, or coat", |
| ], |
| "sweater_hoodie": [ |
| "a photo of a sweater or hoodie", |
| "a knitted sweater, sweatshirt, or hooded sweatshirt", |
| "a warm pullover garment worn on the upper body", |
| ], |
| "jacket_coat": [ |
| "a photo of a jacket or coat", |
| "an outerwear garment worn over other clothes", |
| "a blazer, jacket, raincoat, or winter coat", |
| ], |
| "shoes": [ |
| "a photo of footwear", |
| "a pair of shoes, sneakers, boots, sandals, or heels", |
| "something worn on the feet", |
| ], |
| "accessories": [ |
| "a photo of a fashion accessory", |
| "a bag, belt, hat, scarf, jewelry, watch, or sunglasses", |
| "an accessory worn with an outfit", |
| ], |
| "skirt": [ |
| "a photo of a skirt", |
| "a lower-body garment that hangs from the waist without separate trouser legs", |
| "a mini skirt, midi skirt, maxi skirt, or pleated skirt", |
| ], |
| } |
| self.labels, self.prototypes = self.build_prototypes() |
| self.minimum_similarity = minimum_similarity |
| self.minimum_margin = minimum_margin |
|
|
|
|
| def build_prototypes( |
| self, |
| ) -> tuple[list[str], np.ndarray]: |
| labels = [] |
| prototypes = [] |
|
|
| for label, prompts in self.category_prompts.items(): |
| prompt_embeddings = encode_texts(self.model, prompts) |
|
|
| prototype = prompt_embeddings.mean(axis=0) |
| prototype /= np.linalg.norm(prototype) |
|
|
| labels.append(label) |
| prototypes.append(prototype) |
|
|
| return labels, np.stack(prototypes) |
|
|
|
|
| def classify( |
| self, |
| images: list[Image.Image] | Image.Image, |
| ) -> pd.DataFrame: |
| if not isinstance(images, list): |
| images = [images] |
| image_embeddings = encode_images(self.model, images) |
| similarities = image_embeddings @ self.prototypes.T |
|
|
| best_indices = similarities.argmax(axis=1) |
| sorted_scores = np.sort(similarities, axis=1) |
|
|
| best_scores = sorted_scores[:, -1] |
| second_best_scores = sorted_scores[:, -2] |
| margins = best_scores - second_best_scores |
|
|
| predictions = [] |
|
|
| for best_index, best_score, margin in zip( |
| best_indices, |
| best_scores, |
| margins, |
| ): |
| if best_score < self.minimum_similarity or margin < self.minimum_margin: |
| predictions.append("other") |
| else: |
| predictions.append(self.labels[best_index]) |
|
|
| result = pd.DataFrame( |
| similarities, |
| columns=self.labels, |
| ) |
|
|
| result["prediction"] = predictions |
| result["best_similarity"] = best_scores |
| result["margin"] = margins |
|
|
| return result |