MAS-AI-0000 commited on
Commit
1736cee
·
verified ·
1 Parent(s): 1b51bf6

Upload detector.py

Browse files
Files changed (1) hide show
  1. detree/inference/detector.py +251 -0
detree/inference/detector.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """High-level detector interface for running DETree inference."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import List, Optional, Sequence
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torch.nn import functional as F
12
+ from torch.utils.data import DataLoader, Dataset
13
+ from tqdm import tqdm
14
+
15
+ from detree.model.text_embedding import TextEmbeddingModel
16
+ from detree.utils.index import Indexer
17
+
18
+ __all__ = ["Detector", "Prediction"]
19
+
20
+
21
+ def _to_numpy(value) -> np.ndarray:
22
+ if isinstance(value, np.ndarray):
23
+ return value
24
+ if torch.is_tensor(value):
25
+ return value.detach().cpu().numpy()
26
+ return np.asarray(value)
27
+
28
+
29
+ def _load_database(path: Path):
30
+ data = torch.load(path, map_location="cpu")
31
+ embeddings = data["embeddings"]
32
+ labels = data["labels"]
33
+ ids = data["ids"]
34
+ classes = data["classes"]
35
+ if not isinstance(embeddings, dict):
36
+ raise ValueError("Expected embeddings to be a dict keyed by layer index")
37
+ return embeddings, labels, ids, classes
38
+
39
+
40
+ class TextDataset(Dataset):
41
+ def __init__(self, texts: Sequence[str]):
42
+ self._texts = [str(text) for text in texts]
43
+
44
+ def __len__(self) -> int:
45
+ return len(self._texts)
46
+
47
+ def __getitem__(self, idx: int):
48
+ return self._texts[idx], idx
49
+
50
+
51
+ @dataclass
52
+ class Prediction:
53
+ text: str
54
+ probability_ai: float
55
+ probability_human: float
56
+ label: str
57
+
58
+
59
+ class Detector:
60
+ """Wraps model + database logic for kNN predictions."""
61
+
62
+ def __init__(
63
+ self,
64
+ database_path: Path,
65
+ model_name_or_path: str,
66
+ *,
67
+ pooling: str = "max",
68
+ max_length: int = 512,
69
+ batch_size: int = 8,
70
+ num_workers: int = 0,
71
+ top_k: int = 10,
72
+ threshold: float = 0.97,
73
+ layer: Optional[int] = None,
74
+ device: Optional[str] = None,
75
+ ) -> None:
76
+ self.database_path = database_path
77
+ self.model_name_or_path = model_name_or_path
78
+ self.pooling = pooling
79
+ self.max_length = max_length
80
+ self.batch_size = batch_size
81
+ self.num_workers = num_workers
82
+ self.top_k = top_k
83
+ if not 0.0 <= threshold <= 1.0:
84
+ raise ValueError(
85
+ "threshold must be a probability between 0 and 1 (inclusive)."
86
+ )
87
+ self.threshold = threshold
88
+ self.device = torch.device(
89
+ device if device else ("cuda" if torch.cuda.is_available() else "cpu")
90
+ )
91
+
92
+ embeddings, labels, ids, classes = _load_database(database_path)
93
+ self.classes = list(classes)
94
+ self.human_index = None
95
+ if "human" in self.classes:
96
+ self.human_index = self.classes.index("human")
97
+
98
+ self._raw_labels = labels
99
+ self._raw_ids = ids
100
+
101
+ self.layer_embeddings = {
102
+ int(layer): tensor.float() for layer, tensor in embeddings.items()
103
+ }
104
+
105
+ if isinstance(labels, dict):
106
+ self.layer_labels = {int(layer): tensor for layer, tensor in labels.items()}
107
+ else:
108
+ self.layer_labels = None
109
+ if isinstance(ids, dict):
110
+ self.layer_ids = {int(layer): tensor for layer, tensor in ids.items()}
111
+ else:
112
+ self.layer_ids = None
113
+
114
+ self.available_layers = sorted(self.layer_embeddings.keys())
115
+ if not self.available_layers:
116
+ raise ValueError("No layers found in the embedding database")
117
+ requested_layer = layer if layer is not None else self.available_layers[-1]
118
+ if requested_layer not in self.available_layers:
119
+ raise ValueError(f"Requested layer {layer} not present in database")
120
+
121
+ self.model = TextEmbeddingModel(
122
+ model_name_or_path,
123
+ output_hidden_states=True,
124
+ infer=True,
125
+ use_pooling=self.pooling,
126
+ ).to(self.device)
127
+ self.model.eval()
128
+ self.tokenizer = self.model.tokenizer
129
+
130
+ if self.human_index is None:
131
+ raise ValueError(
132
+ "Database must include a 'human' entry in its classes list to compute probabilities."
133
+ )
134
+
135
+ self._configure_layer(requested_layer)
136
+
137
+ def _configure_layer(self, layer: int) -> None:
138
+ if layer not in self.layer_embeddings:
139
+ raise ValueError(f"Requested layer {layer} not present in database")
140
+
141
+ layer_embeddings = self.layer_embeddings[layer]
142
+ self.embedding_dim = layer_embeddings.shape[-1]
143
+
144
+ if self.layer_labels is not None:
145
+ layer_labels = self.layer_labels[layer]
146
+ else:
147
+ # Fall back to shared labels tensor when per-layer labels are unavailable.
148
+ layer_labels = self._raw_labels
149
+
150
+ if self.layer_ids is not None:
151
+ layer_ids = self.layer_ids[layer]
152
+ else:
153
+ layer_ids = self._raw_ids
154
+
155
+ train_embeddings = _to_numpy(layer_embeddings)
156
+ train_labels = _to_numpy(layer_labels).astype(np.int64)
157
+ train_ids = _to_numpy(layer_ids).astype(np.int64)
158
+
159
+ self.index = Indexer(self.embedding_dim)
160
+ label_dict = {}
161
+ for idx, label in zip(train_ids.tolist(), train_labels.tolist()):
162
+ label_dict[int(idx)] = 1 if int(label) == int(self.human_index) else 0
163
+ self.index.label_dict = label_dict
164
+ self.index.index_data(train_ids.tolist(), train_embeddings.astype(np.float32))
165
+
166
+ self.layer = layer
167
+
168
+ def set_layer(self, layer: int) -> None:
169
+ """Switch the active database layer used for inference."""
170
+ if layer == self.layer:
171
+ return
172
+ self._configure_layer(layer)
173
+
174
+ def get_available_layers(self) -> List[int]:
175
+ return list(self.available_layers)
176
+
177
+ @torch.no_grad()
178
+ def _encode(self, texts: Sequence[str]) -> np.ndarray:
179
+ dataset = TextDataset(texts)
180
+ if len(dataset) == 0:
181
+ return np.zeros((0, self.embedding_dim), dtype=np.float32)
182
+
183
+ dataloader = DataLoader(
184
+ dataset,
185
+ batch_size=self.batch_size,
186
+ num_workers=self.num_workers,
187
+ shuffle=False,
188
+ collate_fn=lambda batch: tuple(zip(*batch)),
189
+ )
190
+
191
+ all_embeddings: List[torch.Tensor] = []
192
+ all_indices: List[int] = []
193
+ for texts_batch, indices_batch in tqdm(
194
+ dataloader, desc="Encoding", leave=False
195
+ ):
196
+ encoded_batch = self.tokenizer.batch_encode_plus(
197
+ list(texts_batch),
198
+ return_tensors="pt",
199
+ max_length=self.max_length,
200
+ padding="max_length",
201
+ truncation=True,
202
+ )
203
+ encoded_batch = {k: v.to(self.device) for k, v in encoded_batch.items()}
204
+ embeddings = self.model(encoded_batch, hidden_states=True)
205
+ embeddings = F.normalize(embeddings, dim=-1)
206
+ all_embeddings.append(embeddings.cpu())
207
+ all_indices.extend(indices_batch)
208
+
209
+ stacked = torch.cat(all_embeddings, dim=0) if all_embeddings else torch.empty(0)
210
+ if stacked.numel() == 0:
211
+ return np.zeros((0, self.embedding_dim), dtype=np.float32)
212
+ order = torch.tensor(all_indices, dtype=torch.long)
213
+ if order.numel() != stacked.shape[0]:
214
+ raise RuntimeError("Index and embedding counts do not match.")
215
+ sorted_indices = torch.argsort(order)
216
+ stacked = stacked[sorted_indices]
217
+ stacked = stacked.permute(1, 0, 2)
218
+ selected_layer = stacked[self.layer]
219
+ return selected_layer.numpy().astype(np.float32)
220
+
221
+ def predict(self, texts: Sequence[str]) -> List[Prediction]:
222
+ texts_list = [str(text) for text in texts]
223
+ embeddings = self._encode(texts_list)
224
+ if embeddings.shape[0] == 0:
225
+ return []
226
+
227
+ results = self.index.search_knn(
228
+ embeddings,
229
+ self.top_k,
230
+ index_batch_size=max(1, min(self.top_k, 128)),
231
+ )
232
+
233
+ predictions: List[Prediction] = []
234
+ for text, (_ids, scores, labels) in zip(texts_list, results):
235
+ scores_tensor = torch.from_numpy(np.asarray(scores))
236
+ weights = torch.softmax(scores_tensor, dim=0)
237
+ label_tensor = torch.tensor(labels, dtype=torch.float32)
238
+ probability_human = float(torch.dot(weights, label_tensor).item())
239
+ probability_human = max(0.0, min(1.0, probability_human))
240
+ probability_ai = float(max(0.0, min(1.0, 1.0 - probability_human)))
241
+ label = "Human" if probability_human >= self.threshold else "AI"
242
+ predictions.append(
243
+ Prediction(
244
+ text=text,
245
+ probability_ai=probability_ai,
246
+ probability_human=probability_human,
247
+ label=label,
248
+ )
249
+ )
250
+ return predictions
251
+