7gonzalodm commited on
Commit
cbcaa4f
·
verified ·
1 Parent(s): 7c8fd5e

Upload team.py

Browse files
Files changed (1) hide show
  1. team.py +112 -0
team.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Generator, Iterable, List, TypeVar
2
+
3
+ import numpy as np
4
+ import supervision as sv
5
+ import torch
6
+ import umap
7
+ from sklearn.cluster import KMeans
8
+ from tqdm import tqdm
9
+ from transformers import AutoProcessor, SiglipVisionModel
10
+
11
+ V = TypeVar("V")
12
+
13
+ SIGLIP_MODEL_PATH = 'google/siglip-base-patch16-224'
14
+
15
+
16
+ def create_batches(
17
+ sequence: Iterable[V], batch_size: int
18
+ ) -> Generator[List[V], None, None]:
19
+ """
20
+ Generate batches from a sequence with a specified batch size.
21
+
22
+ Args:
23
+ sequence (Iterable[V]): The input sequence to be batched.
24
+ batch_size (int): The size of each batch.
25
+
26
+ Yields:
27
+ Generator[List[V], None, None]: A generator yielding batches of the input
28
+ sequence.
29
+ """
30
+ batch_size = max(batch_size, 1)
31
+ current_batch = []
32
+ for element in sequence:
33
+ if len(current_batch) == batch_size:
34
+ yield current_batch
35
+ current_batch = []
36
+ current_batch.append(element)
37
+ if current_batch:
38
+ yield current_batch
39
+
40
+
41
+ class TeamClassifier:
42
+ """
43
+ A classifier that uses a pre-trained SiglipVisionModel for feature extraction,
44
+ UMAP for dimensionality reduction, and KMeans for clustering.
45
+ """
46
+ def __init__(self, device: str = 'cpu', batch_size: int = 32):
47
+ """
48
+ Initialize the TeamClassifier with device and batch size.
49
+
50
+ Args:
51
+ device (str): The device to run the model on ('cpu' or 'cuda').
52
+ batch_size (int): The batch size for processing images.
53
+ """
54
+ self.device = device
55
+ self.batch_size = batch_size
56
+ self.features_model = SiglipVisionModel.from_pretrained(
57
+ SIGLIP_MODEL_PATH).to(device)
58
+ self.processor = AutoProcessor.from_pretrained(SIGLIP_MODEL_PATH)
59
+ self.reducer = umap.UMAP(n_components=3)
60
+ self.cluster_model = KMeans(n_clusters=2)
61
+
62
+ def extract_features(self, crops: List[np.ndarray]) -> np.ndarray:
63
+ """
64
+ Extract features from a list of image crops using the pre-trained
65
+ SiglipVisionModel.
66
+
67
+ Args:
68
+ crops (List[np.ndarray]): List of image crops.
69
+
70
+ Returns:
71
+ np.ndarray: Extracted features as a numpy array.
72
+ """
73
+ crops = [sv.cv2_to_pillow(crop) for crop in crops]
74
+ batches = create_batches(crops, self.batch_size)
75
+ data = []
76
+ with torch.no_grad():
77
+ for batch in tqdm(batches, desc='Embedding extraction'):
78
+ inputs = self.processor(
79
+ images=batch, return_tensors="pt").to(self.device)
80
+ outputs = self.features_model(**inputs)
81
+ embeddings = torch.mean(outputs.last_hidden_state, dim=1).cpu().numpy()
82
+ data.append(embeddings)
83
+
84
+ return np.concatenate(data)
85
+
86
+ def fit(self, crops: List[np.ndarray]) -> None:
87
+ """
88
+ Fit the classifier model on a list of image crops.
89
+
90
+ Args:
91
+ crops (List[np.ndarray]): List of image crops.
92
+ """
93
+ data = self.extract_features(crops)
94
+ projections = self.reducer.fit_transform(data)
95
+ self.cluster_model.fit(projections)
96
+
97
+ def predict(self, crops: List[np.ndarray]) -> np.ndarray:
98
+ """
99
+ Predict the cluster labels for a list of image crops.
100
+
101
+ Args:
102
+ crops (List[np.ndarray]): List of image crops.
103
+
104
+ Returns:
105
+ np.ndarray: Predicted cluster labels.
106
+ """
107
+ if len(crops) == 0:
108
+ return np.array([])
109
+
110
+ data = self.extract_features(crops)
111
+ projections = self.reducer.transform(data)
112
+ return self.cluster_model.predict(projections)