Spaces:
Running
Running
| """High-level detector interface for running DETree inference.""" | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import List, Optional, Sequence | |
| import numpy as np | |
| import torch | |
| from torch.nn import functional as F | |
| from torch.utils.data import DataLoader, Dataset | |
| from tqdm import tqdm | |
| from detree.model.text_embedding import TextEmbeddingModel | |
| from detree.utils.index import Indexer | |
| __all__ = ["Detector", "Prediction"] | |
| def _to_numpy(value) -> np.ndarray: | |
| if isinstance(value, np.ndarray): | |
| return value | |
| if torch.is_tensor(value): | |
| return value.detach().cpu().numpy() | |
| return np.asarray(value) | |
| def _load_database(path: Path): | |
| data = torch.load(path, map_location="cpu") | |
| embeddings = data["embeddings"] | |
| labels = data["labels"] | |
| ids = data["ids"] | |
| classes = data["classes"] | |
| if not isinstance(embeddings, dict): | |
| raise ValueError("Expected embeddings to be a dict keyed by layer index") | |
| return embeddings, labels, ids, classes | |
| class TextDataset(Dataset): | |
| def __init__(self, texts: Sequence[str]): | |
| self._texts = [str(text) for text in texts] | |
| def __len__(self) -> int: | |
| return len(self._texts) | |
| def __getitem__(self, idx: int): | |
| return self._texts[idx], idx | |
| class Prediction: | |
| text: str | |
| probability_ai: float | |
| probability_human: float | |
| label: str | |
| class Detector: | |
| """Wraps model + database logic for kNN predictions.""" | |
| def __init__( | |
| self, | |
| database_path: Path, | |
| model_name_or_path: str, | |
| *, | |
| pooling: str = "max", | |
| max_length: int = 512, | |
| batch_size: int = 8, | |
| num_workers: int = 0, | |
| top_k: int = 10, | |
| threshold: float = 0.97, | |
| layer: Optional[int] = None, | |
| device: Optional[str] = None, | |
| ) -> None: | |
| self.database_path = database_path | |
| self.model_name_or_path = model_name_or_path | |
| self.pooling = pooling | |
| self.max_length = max_length | |
| self.batch_size = batch_size | |
| self.num_workers = num_workers | |
| self.top_k = top_k | |
| if not 0.0 <= threshold <= 1.0: | |
| raise ValueError( | |
| "threshold must be a probability between 0 and 1 (inclusive)." | |
| ) | |
| self.threshold = threshold | |
| self.device = torch.device( | |
| device if device else ("cuda" if torch.cuda.is_available() else "cpu") | |
| ) | |
| embeddings, labels, ids, classes = _load_database(database_path) | |
| self.classes = list(classes) | |
| self.human_index = None | |
| if "human" in self.classes: | |
| self.human_index = self.classes.index("human") | |
| self._raw_labels = labels | |
| self._raw_ids = ids | |
| self.layer_embeddings = { | |
| int(layer): tensor.float() for layer, tensor in embeddings.items() | |
| } | |
| if isinstance(labels, dict): | |
| self.layer_labels = {int(layer): tensor for layer, tensor in labels.items()} | |
| else: | |
| self.layer_labels = None | |
| if isinstance(ids, dict): | |
| self.layer_ids = {int(layer): tensor for layer, tensor in ids.items()} | |
| else: | |
| self.layer_ids = None | |
| self.available_layers = sorted(self.layer_embeddings.keys()) | |
| if not self.available_layers: | |
| raise ValueError("No layers found in the embedding database") | |
| requested_layer = layer if layer is not None else self.available_layers[-1] | |
| if requested_layer not in self.available_layers: | |
| raise ValueError(f"Requested layer {layer} not present in database") | |
| self.model = TextEmbeddingModel( | |
| model_name_or_path, | |
| output_hidden_states=True, | |
| infer=True, | |
| use_pooling=self.pooling, | |
| ).to(self.device) | |
| self.model.eval() | |
| self.tokenizer = self.model.tokenizer | |
| if self.human_index is None: | |
| raise ValueError( | |
| "Database must include a 'human' entry in its classes list to compute probabilities." | |
| ) | |
| self._configure_layer(requested_layer) | |
| def _configure_layer(self, layer: int) -> None: | |
| if layer not in self.layer_embeddings: | |
| raise ValueError(f"Requested layer {layer} not present in database") | |
| layer_embeddings = self.layer_embeddings[layer] | |
| self.embedding_dim = layer_embeddings.shape[-1] | |
| if self.layer_labels is not None: | |
| layer_labels = self.layer_labels[layer] | |
| else: | |
| # Fall back to shared labels tensor when per-layer labels are unavailable. | |
| layer_labels = self._raw_labels | |
| if self.layer_ids is not None: | |
| layer_ids = self.layer_ids[layer] | |
| else: | |
| layer_ids = self._raw_ids | |
| train_embeddings = _to_numpy(layer_embeddings) | |
| train_labels = _to_numpy(layer_labels).astype(np.int64) | |
| train_ids = _to_numpy(layer_ids).astype(np.int64) | |
| self.index = Indexer(self.embedding_dim) | |
| label_dict = {} | |
| for idx, label in zip(train_ids.tolist(), train_labels.tolist()): | |
| label_dict[int(idx)] = 1 if int(label) == int(self.human_index) else 0 | |
| self.index.label_dict = label_dict | |
| self.index.index_data(train_ids.tolist(), train_embeddings.astype(np.float32)) | |
| self.layer = layer | |
| def set_layer(self, layer: int) -> None: | |
| """Switch the active database layer used for inference.""" | |
| if layer == self.layer: | |
| return | |
| self._configure_layer(layer) | |
| def get_available_layers(self) -> List[int]: | |
| return list(self.available_layers) | |
| def _encode(self, texts: Sequence[str]) -> np.ndarray: | |
| dataset = TextDataset(texts) | |
| if len(dataset) == 0: | |
| return np.zeros((0, self.embedding_dim), dtype=np.float32) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=self.batch_size, | |
| num_workers=self.num_workers, | |
| shuffle=False, | |
| collate_fn=lambda batch: tuple(zip(*batch)), | |
| ) | |
| all_embeddings: List[torch.Tensor] = [] | |
| all_indices: List[int] = [] | |
| for texts_batch, indices_batch in tqdm( | |
| dataloader, desc="Encoding", leave=False | |
| ): | |
| encoded_batch = self.tokenizer.batch_encode_plus( | |
| list(texts_batch), | |
| return_tensors="pt", | |
| max_length=self.max_length, | |
| padding="max_length", | |
| truncation=True, | |
| ) | |
| encoded_batch = {k: v.to(self.device) for k, v in encoded_batch.items()} | |
| embeddings = self.model(encoded_batch, hidden_states=True) | |
| embeddings = F.normalize(embeddings, dim=-1) | |
| all_embeddings.append(embeddings.cpu()) | |
| all_indices.extend(indices_batch) | |
| stacked = torch.cat(all_embeddings, dim=0) if all_embeddings else torch.empty(0) | |
| if stacked.numel() == 0: | |
| return np.zeros((0, self.embedding_dim), dtype=np.float32) | |
| order = torch.tensor(all_indices, dtype=torch.long) | |
| if order.numel() != stacked.shape[0]: | |
| raise RuntimeError("Index and embedding counts do not match.") | |
| sorted_indices = torch.argsort(order) | |
| stacked = stacked[sorted_indices] | |
| stacked = stacked.permute(1, 0, 2) | |
| selected_layer = stacked[self.layer] | |
| return selected_layer.numpy().astype(np.float32) | |
| def predict(self, texts: Sequence[str]) -> List[Prediction]: | |
| texts_list = [str(text) for text in texts] | |
| embeddings = self._encode(texts_list) | |
| if embeddings.shape[0] == 0: | |
| return [] | |
| results = self.index.search_knn( | |
| embeddings, | |
| self.top_k, | |
| index_batch_size=max(1, min(self.top_k, 128)), | |
| ) | |
| predictions: List[Prediction] = [] | |
| for text, (_ids, scores, labels) in zip(texts_list, results): | |
| scores_tensor = torch.from_numpy(np.asarray(scores)) | |
| weights = torch.softmax(scores_tensor, dim=0) | |
| label_tensor = torch.tensor(labels, dtype=torch.float32) | |
| probability_human = float(torch.dot(weights, label_tensor).item()) | |
| probability_human = max(0.0, min(1.0, probability_human)) | |
| probability_ai = float(max(0.0, min(1.0, 1.0 - probability_human))) | |
| label = "Human" if probability_human >= self.threshold else "AI" | |
| predictions.append( | |
| Prediction( | |
| text=text, | |
| probability_ai=probability_ai, | |
| probability_human=probability_human, | |
| label=label, | |
| ) | |
| ) | |
| return predictions | |