wardrobe-ai / src /utils /classifier.py
elalber2000's picture
first commit
59830d4 verified
from PIL import Image
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from utils.utils import encode_images, encode_texts
class ClothesClassifier():
def __init__(
self,
model_name: str,
minimum_similarity: float = 0.20,
minimum_margin: float = 0.015,
):
self.model_name = model_name
self.model = SentenceTransformer(
self.model_name,
device="cpu",
)
self.category_prompts = {
"dress": [
"a photo of a dress",
"a one-piece garment covering the upper and lower body",
"a casual or formal dress worn by a person",
],
"pants": [
"a photo of pants or trousers",
"a lower-body garment with two trouser legs",
"jeans, chinos, trousers, or sweatpants",
],
"top": [
"a photo of a t-shirt, shirt, blouse, or tank top",
"a lightweight garment worn on the upper body",
"a shirt that is not a sweater, hoodie, jacket, or coat",
],
"sweater_hoodie": [
"a photo of a sweater or hoodie",
"a knitted sweater, sweatshirt, or hooded sweatshirt",
"a warm pullover garment worn on the upper body",
],
"jacket_coat": [
"a photo of a jacket or coat",
"an outerwear garment worn over other clothes",
"a blazer, jacket, raincoat, or winter coat",
],
"shoes": [
"a photo of footwear",
"a pair of shoes, sneakers, boots, sandals, or heels",
"something worn on the feet",
],
"accessories": [
"a photo of a fashion accessory",
"a bag, belt, hat, scarf, jewelry, watch, or sunglasses",
"an accessory worn with an outfit",
],
"skirt": [
"a photo of a skirt",
"a lower-body garment that hangs from the waist without separate trouser legs",
"a mini skirt, midi skirt, maxi skirt, or pleated skirt",
],
}
self.labels, self.prototypes = self.build_prototypes()
self.minimum_similarity = minimum_similarity
self.minimum_margin = minimum_margin
def build_prototypes(
self,
) -> tuple[list[str], np.ndarray]:
labels = []
prototypes = []
for label, prompts in self.category_prompts.items():
prompt_embeddings = encode_texts(self.model, prompts)
prototype = prompt_embeddings.mean(axis=0)
prototype /= np.linalg.norm(prototype)
labels.append(label)
prototypes.append(prototype)
return labels, np.stack(prototypes)
def classify(
self,
images: list[Image.Image] | Image.Image,
) -> pd.DataFrame:
if not isinstance(images, list):
images = [images]
image_embeddings = encode_images(self.model, images)
similarities = image_embeddings @ self.prototypes.T
best_indices = similarities.argmax(axis=1)
sorted_scores = np.sort(similarities, axis=1)
best_scores = sorted_scores[:, -1]
second_best_scores = sorted_scores[:, -2]
margins = best_scores - second_best_scores
predictions = []
for best_index, best_score, margin in zip(
best_indices,
best_scores,
margins,
):
if best_score < self.minimum_similarity or margin < self.minimum_margin:
predictions.append("other")
else:
predictions.append(self.labels[best_index])
result = pd.DataFrame(
similarities,
columns=self.labels,
)
result["prediction"] = predictions
result["best_similarity"] = best_scores
result["margin"] = margins
return result