from typing import Generator, Iterable, List, TypeVar import numpy as np import supervision as sv import torch import umap from sklearn.cluster import KMeans from tqdm import tqdm from transformers import AutoProcessor, SiglipVisionModel V = TypeVar("V") SIGLIP_MODEL_PATH = 'google/siglip-base-patch16-224' def create_batches( sequence: Iterable[V], batch_size: int ) -> Generator[List[V], None, None]: """ Generate batches from a sequence with a specified batch size. Args: sequence (Iterable[V]): The input sequence to be batched. batch_size (int): The size of each batch. Yields: Generator[List[V], None, None]: A generator yielding batches of the input sequence. """ batch_size = max(batch_size, 1) current_batch = [] for element in sequence: if len(current_batch) == batch_size: yield current_batch current_batch = [] current_batch.append(element) if current_batch: yield current_batch class TeamClassifier: """ A classifier that uses a pre-trained SiglipVisionModel for feature extraction, UMAP for dimensionality reduction, and KMeans for clustering. """ def __init__(self, device: str = 'cpu', batch_size: int = 32): """ Initialize the TeamClassifier with device and batch size. Args: device (str): The device to run the model on ('cpu' or 'cuda'). batch_size (int): The batch size for processing images. """ self.device = device self.batch_size = batch_size self.features_model = SiglipVisionModel.from_pretrained( SIGLIP_MODEL_PATH).to(device) self.processor = AutoProcessor.from_pretrained(SIGLIP_MODEL_PATH) self.reducer = umap.UMAP(n_components=3) self.cluster_model = KMeans(n_clusters=2) def extract_features(self, crops: List[np.ndarray]) -> np.ndarray: """ Extract features from a list of image crops using the pre-trained SiglipVisionModel. Args: crops (List[np.ndarray]): List of image crops. Returns: np.ndarray: Extracted features as a numpy array. """ crops = [sv.cv2_to_pillow(crop) for crop in crops] batches = create_batches(crops, self.batch_size) data = [] with torch.no_grad(): for batch in tqdm(batches, desc='Embedding extraction'): inputs = self.processor( images=batch, return_tensors="pt").to(self.device) outputs = self.features_model(**inputs) embeddings = torch.mean(outputs.last_hidden_state, dim=1).cpu().numpy() data.append(embeddings) return np.concatenate(data) def fit(self, crops: List[np.ndarray]) -> None: """ Fit the classifier model on a list of image crops. Args: crops (List[np.ndarray]): List of image crops. """ data = self.extract_features(crops) projections = self.reducer.fit_transform(data) self.cluster_model.fit(projections) def predict(self, crops: List[np.ndarray]) -> np.ndarray: """ Predict the cluster labels for a list of image crops. Args: crops (List[np.ndarray]): List of image crops. Returns: np.ndarray: Predicted cluster labels. """ if len(crops) == 0: return np.array([]) data = self.extract_features(crops) projections = self.reducer.transform(data) return self.cluster_model.predict(projections)