File size: 1,803 Bytes
59830d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image
from sentence_transformers import SentenceTransformer
from utils.classifier import ClothesClassifier
from utils.utils import SRC_PATH, encode_images, encode_texts, load_image_paths
from config import cnf


def load_images(paths: list[Path]) -> list[Image.Image]:
    return [Image.open(path).convert("RGB") for path in paths]


def classify_categories(
    image_embeddings: np.ndarray,
    labels: list[str],
    prototypes: np.ndarray,
    minimum_similarity: float = 0.20,
    minimum_margin: float = 0.015,
) -> pd.DataFrame:
    similarities = image_embeddings @ prototypes.T

    best_indices = similarities.argmax(axis=1)
    sorted_scores = np.sort(similarities, axis=1)

    best_scores = sorted_scores[:, -1]
    second_best_scores = sorted_scores[:, -2]
    margins = best_scores - second_best_scores

    predictions = []

    for best_index, best_score, margin in zip(
        best_indices,
        best_scores,
        margins,
    ):
        if best_score < minimum_similarity or margin < minimum_margin:
            predictions.append("other")
        else:
            predictions.append(labels[best_index])

    result = pd.DataFrame(
        similarities,
        columns=labels,
    )

    result["prediction"] = predictions
    result["best_similarity"] = best_scores
    result["margin"] = margins

    return result


if __name__ == "__main__":
    data_dir = SRC_PATH / "data" / "images"

    image_paths = load_image_paths(data_dir)
    images = load_images(image_paths)

    classifier = ClothesClassifier(
        model_name=cnf.emb_model_name,
        minimum_similarity=cnf.minimum_similarity,
        minimum_margin=cnf.minimum_margin,
    )

    print(classifier.classify(images))