team-classifier / team.py
7gonzalodm's picture
Upload team.py
cbcaa4f verified
from typing import Generator, Iterable, List, TypeVar
import numpy as np
import supervision as sv
import torch
import umap
from sklearn.cluster import KMeans
from tqdm import tqdm
from transformers import AutoProcessor, SiglipVisionModel
V = TypeVar("V")
SIGLIP_MODEL_PATH = 'google/siglip-base-patch16-224'
def create_batches(
sequence: Iterable[V], batch_size: int
) -> Generator[List[V], None, None]:
"""
Generate batches from a sequence with a specified batch size.
Args:
sequence (Iterable[V]): The input sequence to be batched.
batch_size (int): The size of each batch.
Yields:
Generator[List[V], None, None]: A generator yielding batches of the input
sequence.
"""
batch_size = max(batch_size, 1)
current_batch = []
for element in sequence:
if len(current_batch) == batch_size:
yield current_batch
current_batch = []
current_batch.append(element)
if current_batch:
yield current_batch
class TeamClassifier:
"""
A classifier that uses a pre-trained SiglipVisionModel for feature extraction,
UMAP for dimensionality reduction, and KMeans for clustering.
"""
def __init__(self, device: str = 'cpu', batch_size: int = 32):
"""
Initialize the TeamClassifier with device and batch size.
Args:
device (str): The device to run the model on ('cpu' or 'cuda').
batch_size (int): The batch size for processing images.
"""
self.device = device
self.batch_size = batch_size
self.features_model = SiglipVisionModel.from_pretrained(
SIGLIP_MODEL_PATH).to(device)
self.processor = AutoProcessor.from_pretrained(SIGLIP_MODEL_PATH)
self.reducer = umap.UMAP(n_components=3)
self.cluster_model = KMeans(n_clusters=2)
def extract_features(self, crops: List[np.ndarray]) -> np.ndarray:
"""
Extract features from a list of image crops using the pre-trained
SiglipVisionModel.
Args:
crops (List[np.ndarray]): List of image crops.
Returns:
np.ndarray: Extracted features as a numpy array.
"""
crops = [sv.cv2_to_pillow(crop) for crop in crops]
batches = create_batches(crops, self.batch_size)
data = []
with torch.no_grad():
for batch in tqdm(batches, desc='Embedding extraction'):
inputs = self.processor(
images=batch, return_tensors="pt").to(self.device)
outputs = self.features_model(**inputs)
embeddings = torch.mean(outputs.last_hidden_state, dim=1).cpu().numpy()
data.append(embeddings)
return np.concatenate(data)
def fit(self, crops: List[np.ndarray]) -> None:
"""
Fit the classifier model on a list of image crops.
Args:
crops (List[np.ndarray]): List of image crops.
"""
data = self.extract_features(crops)
projections = self.reducer.fit_transform(data)
self.cluster_model.fit(projections)
def predict(self, crops: List[np.ndarray]) -> np.ndarray:
"""
Predict the cluster labels for a list of image crops.
Args:
crops (List[np.ndarray]): List of image crops.
Returns:
np.ndarray: Predicted cluster labels.
"""
if len(crops) == 0:
return np.array([])
data = self.extract_features(crops)
projections = self.reducer.transform(data)
return self.cluster_model.predict(projections)