MAS-AI-0000 commited on
Commit
cae2130
·
verified ·
1 Parent(s): 025084b

Update text_embedder.py

Browse files
Files changed (1) hide show
  1. text_embedder.py +196 -179
text_embedder.py CHANGED
@@ -1,179 +1,196 @@
1
- """Text → detection-ready embedding.
2
-
3
- Loads the DETree ``TextEmbeddingModel`` and exposes ``get_text_embedding``,
4
- which tokenises a string, runs it through the model, and returns a single
5
- L2-normalised embedding vector ready to be passed to ``detect_embedding``.
6
-
7
- The layer extracted defaults to -1 (the last hidden layer), matching the
8
- default used in ``detector.py`` when building the KNN index. Override
9
- ``layer`` if your database was built with a different layer.
10
-
11
- Usage::
12
-
13
- from Apps.text_embedder import get_text_embedding
14
- from Apps.detector import detect_embedding
15
-
16
- emb = get_text_embedding("Was this written by a human?")
17
- result = detect_embedding(emb)
18
- # {"predicted_class": "Human"|"Ai", "confidence": 0.93}
19
- """
20
-
21
- from __future__ import annotations
22
-
23
- import os
24
- import sys
25
- from typing import Optional
26
-
27
- import numpy as np
28
- import torch
29
- import torch.nn.functional as F
30
-
31
- # ---------------------------------------------------------------------------
32
- # Make the local 'detree' package importable
33
- # ---------------------------------------------------------------------------
34
- _current_dir = os.path.dirname(os.path.abspath(__file__))
35
- if _current_dir not in sys.path:
36
- sys.path.append(_current_dir)
37
-
38
- try:
39
- from detree.model.text_embedding import TextEmbeddingModel
40
- except ImportError as _e:
41
- print(f"Warning: could not import TextEmbeddingModel ({_e}). Text embedding will return zeros.")
42
- TextEmbeddingModel = None
43
-
44
- # ---------------------------------------------------------------------------
45
- # Paths
46
- # ---------------------------------------------------------------------------
47
- _BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
48
- _TEXT_DIR = os.path.join(_BASE_DIR, "Lib", "Models", "Text")
49
-
50
- # ---------------------------------------------------------------------------
51
- # Configuration
52
- # ---------------------------------------------------------------------------
53
- MAX_LENGTH = 512
54
- POOLING = "max" # must match what was used during database construction
55
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
56
-
57
- # ---------------------------------------------------------------------------
58
- # Module-level initialisation
59
- # ---------------------------------------------------------------------------
60
-
61
- _model: Optional[object] = None
62
- _tokenizer: Optional[object] = None
63
-
64
-
65
- def _init() -> None:
66
- global _model, _tokenizer
67
-
68
- if TextEmbeddingModel is None:
69
- print("TextEmbedder: TextEmbeddingModel unavailable — embedding disabled.")
70
- return
71
-
72
- if not os.path.exists(_TEXT_DIR):
73
- print(f"TextEmbedder: model directory not found at {_TEXT_DIR!r} — embedding disabled.")
74
- return
75
-
76
- try:
77
- _model = TextEmbeddingModel(
78
- _TEXT_DIR,
79
- output_hidden_states=True,
80
- infer=True,
81
- use_pooling=POOLING,
82
- ).to(DEVICE)
83
- _model.eval()
84
- _tokenizer = _model.tokenizer
85
- print(f"TextEmbedder: model loaded from {_TEXT_DIR!r}")
86
- except Exception as exc:
87
- print(f"TextEmbedder: error loading model: {exc}")
88
-
89
-
90
- _init()
91
-
92
-
93
- # ---------------------------------------------------------------------------
94
- # Public API
95
- # ---------------------------------------------------------------------------
96
-
97
- @torch.no_grad()
98
- def get_text_embedding(
99
- text: str,
100
- *,
101
- layer: int = -1, # which hidden-state layer to use (-1 = last)
102
- max_length: int = MAX_LENGTH,
103
- ) -> np.ndarray:
104
- """Return a (1, embedding_dim) float32 numpy array for the given text.
105
-
106
- The embedding is L2-normalised and projected into the same space as the
107
- DETree database so it can be passed directly to ``detect_embedding``.
108
-
109
- Args:
110
- text: The input string to embed.
111
- layer: Hidden-state layer index. -1 selects the last layer,
112
- matching the default used when building the database.
113
- max_length: Tokenisation truncation length.
114
-
115
- Returns:
116
- ``np.ndarray`` of shape ``(1, embedding_dim)`` and dtype float32.
117
- """
118
- if _model is None or _tokenizer is None:
119
- return np.zeros((1, 1), dtype=np.float32)
120
-
121
- encoded = _tokenizer.batch_encode_plus(
122
- [text],
123
- return_tensors="pt",
124
- max_length=max_length,
125
- padding="max_length",
126
- truncation=True,
127
- )
128
- encoded = {k: v.to(DEVICE) for k, v in encoded.items()}
129
-
130
- # Shape returned by model with hidden_states=True: (batch, num_layers, dim)
131
- embeddings = _model(encoded, hidden_states=True)
132
- embeddings = F.normalize(embeddings, dim=-1) # normalise feature dim
133
-
134
- # embeddings: (1, num_layers, dim) → select layer → (1, dim)
135
- selected = embeddings[:, layer, :] # supports negative indexing
136
-
137
- return selected.cpu().numpy().astype(np.float32)
138
-
139
-
140
- @torch.no_grad()
141
- def get_text_embeddings_batch(
142
- texts: list[str],
143
- *,
144
- layer: int = -1,
145
- max_length: int = MAX_LENGTH,
146
- batch_size: int = 8,
147
- ) -> np.ndarray:
148
- """Return an (N, embedding_dim) float32 array for a list of strings.
149
-
150
- Args:
151
- texts: List of input strings.
152
- layer: Hidden-state layer index (-1 = last).
153
- max_length: Tokenisation truncation length.
154
- batch_size: Number of strings to encode per forward pass.
155
-
156
- Returns:
157
- ``np.ndarray`` of shape ``(N, embedding_dim)`` and dtype float32.
158
- """
159
- if _model is None or _tokenizer is None:
160
- return np.zeros((len(texts), 1), dtype=np.float32)
161
-
162
- all_embeddings: list[np.ndarray] = []
163
- for i in range(0, len(texts), batch_size):
164
- batch = [str(t) for t in texts[i : i + batch_size]]
165
- encoded = _tokenizer.batch_encode_plus(
166
- batch,
167
- return_tensors="pt",
168
- max_length=max_length,
169
- padding="max_length",
170
- truncation=True,
171
- )
172
- encoded = {k: v.to(DEVICE) for k, v in encoded.items()}
173
-
174
- embeddings = _model(encoded, hidden_states=True)
175
- embeddings = F.normalize(embeddings, dim=-1) # (B, num_layers, dim)
176
- selected = embeddings[:, layer, :] # (B, dim)
177
- all_embeddings.append(selected.cpu().numpy().astype(np.float32))
178
-
179
- return np.concatenate(all_embeddings, axis=0) if all_embeddings else np.zeros((0, 1), dtype=np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Text → detection-ready embedding.
2
+
3
+ Loads the DETree ``TextEmbeddingModel`` and exposes ``get_text_embedding``,
4
+ which tokenises a string, runs it through the model, and returns a single
5
+ L2-normalised embedding vector ready to be passed to ``detect_embedding``.
6
+
7
+ The layer extracted defaults to -1 (the last hidden layer), matching the
8
+ default used in ``detector.py`` when building the KNN index. Override
9
+ ``layer`` if your database was built with a different layer.
10
+
11
+ Usage::
12
+
13
+ from Apps.text_embedder import get_text_embedding
14
+ from Apps.detector import detect_embedding
15
+
16
+ emb = get_text_embedding("Was this written by a human?")
17
+ result = detect_embedding(emb)
18
+ # {"predicted_class": "Human"|"Ai", "confidence": 0.93}
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import os
24
+ import sys
25
+ from typing import Optional
26
+
27
+ import numpy as np
28
+ import torch
29
+ import torch.nn.functional as F
30
+ from pathlib import Path
31
+ from huggingface_hub import snapshot_download
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Make the local 'detree' package importable
35
+ # ---------------------------------------------------------------------------
36
+ _current_dir = os.path.dirname(os.path.abspath(__file__))
37
+ if _current_dir not in sys.path:
38
+ sys.path.append(_current_dir)
39
+
40
+ try:
41
+ from detree.model.text_embedding import TextEmbeddingModel
42
+ except ImportError as _e:
43
+ print(f"Warning: could not import TextEmbeddingModel ({_e}). Text embedding will return zeros.")
44
+ TextEmbeddingModel = None
45
+
46
+
47
+ # ---------------------------------------------------------------------------
48
+ # Configuration
49
+ # ---------------------------------------------------------------------------
50
+ MAX_LENGTH = 512
51
+ POOLING = "max" # must match what was used during database construction
52
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
53
+
54
+ # hugging face
55
+ REPO_ID = "MAS-AI-0000/Authentica"
56
+ TEXT_SUBFOLDER = "Lib/Models/Text" # where config.json/model.safetensors live in the repo
57
+ EMBEDDING_FILE = "priori1_center10k.pt"
58
+ _TEXT_DIR = None
59
+
60
+ try:
61
+ # download a local snapshot of just the Text folder and point _TEXT_DIR at it
62
+ print(f"Downloading/Checking model from {REPO_ID}...")
63
+ _snapshot_dir = snapshot_download(
64
+ repo_id=REPO_ID,
65
+ allow_patterns=[f"{TEXT_SUBFOLDER}/*"]
66
+ )
67
+ _TEXT_DIR = os.path.join(_snapshot_dir, TEXT_SUBFOLDER)
68
+ print(f"Model directory set to: {_TEXT_DIR}")
69
+ except Exception as e:
70
+ print(f"Error downloading model from Hugging Face: {e}")
71
+
72
+
73
+
74
+ # ---------------------------------------------------------------------------
75
+ # Module-level initialisation
76
+ # ---------------------------------------------------------------------------
77
+
78
+ _model: Optional[object] = None
79
+ _tokenizer: Optional[object] = None
80
+
81
+
82
+ def _init() -> None:
83
+ global _model, _tokenizer
84
+
85
+ if TextEmbeddingModel is None:
86
+ print("TextEmbedder: TextEmbeddingModel unavailable — embedding disabled.")
87
+ return
88
+
89
+ if not os.path.exists(_TEXT_DIR):
90
+ print(f"TextEmbedder: model directory not found at {_TEXT_DIR!r} — embedding disabled.")
91
+ return
92
+
93
+ try:
94
+ _model = TextEmbeddingModel(
95
+ _TEXT_DIR,
96
+ output_hidden_states=True,
97
+ infer=True,
98
+ use_pooling=POOLING,
99
+ ).to(DEVICE)
100
+ _model.eval()
101
+ _tokenizer = _model.tokenizer
102
+ print(f"TextEmbedder: model loaded from {_TEXT_DIR!r}")
103
+ except Exception as exc:
104
+ print(f"TextEmbedder: error loading model: {exc}")
105
+
106
+
107
+ _init()
108
+
109
+
110
+ # ---------------------------------------------------------------------------
111
+ # Public API
112
+ # ---------------------------------------------------------------------------
113
+
114
+ @torch.no_grad()
115
+ def get_text_embedding(
116
+ text: str,
117
+ *,
118
+ layer: int = -1, # which hidden-state layer to use (-1 = last)
119
+ max_length: int = MAX_LENGTH,
120
+ ) -> np.ndarray:
121
+ """Return a (1, embedding_dim) float32 numpy array for the given text.
122
+
123
+ The embedding is L2-normalised and projected into the same space as the
124
+ DETree database so it can be passed directly to ``detect_embedding``.
125
+
126
+ Args:
127
+ text: The input string to embed.
128
+ layer: Hidden-state layer index. -1 selects the last layer,
129
+ matching the default used when building the database.
130
+ max_length: Tokenisation truncation length.
131
+
132
+ Returns:
133
+ ``np.ndarray`` of shape ``(1, embedding_dim)`` and dtype float32.
134
+ """
135
+ if _model is None or _tokenizer is None:
136
+ return np.zeros((1, 1), dtype=np.float32)
137
+
138
+ encoded = _tokenizer.batch_encode_plus(
139
+ [text],
140
+ return_tensors="pt",
141
+ max_length=max_length,
142
+ padding="max_length",
143
+ truncation=True,
144
+ )
145
+ encoded = {k: v.to(DEVICE) for k, v in encoded.items()}
146
+
147
+ # Shape returned by model with hidden_states=True: (batch, num_layers, dim)
148
+ embeddings = _model(encoded, hidden_states=True)
149
+ embeddings = F.normalize(embeddings, dim=-1) # normalise feature dim
150
+
151
+ # embeddings: (1, num_layers, dim) → select layer → (1, dim)
152
+ selected = embeddings[:, layer, :] # supports negative indexing
153
+
154
+ return selected.cpu().numpy().astype(np.float32)
155
+
156
+
157
+ @torch.no_grad()
158
+ def get_text_embeddings_batch(
159
+ texts: list[str],
160
+ *,
161
+ layer: int = -1,
162
+ max_length: int = MAX_LENGTH,
163
+ batch_size: int = 8,
164
+ ) -> np.ndarray:
165
+ """Return an (N, embedding_dim) float32 array for a list of strings.
166
+
167
+ Args:
168
+ texts: List of input strings.
169
+ layer: Hidden-state layer index (-1 = last).
170
+ max_length: Tokenisation truncation length.
171
+ batch_size: Number of strings to encode per forward pass.
172
+
173
+ Returns:
174
+ ``np.ndarray`` of shape ``(N, embedding_dim)`` and dtype float32.
175
+ """
176
+ if _model is None or _tokenizer is None:
177
+ return np.zeros((len(texts), 1), dtype=np.float32)
178
+
179
+ all_embeddings: list[np.ndarray] = []
180
+ for i in range(0, len(texts), batch_size):
181
+ batch = [str(t) for t in texts[i : i + batch_size]]
182
+ encoded = _tokenizer.batch_encode_plus(
183
+ batch,
184
+ return_tensors="pt",
185
+ max_length=max_length,
186
+ padding="max_length",
187
+ truncation=True,
188
+ )
189
+ encoded = {k: v.to(DEVICE) for k, v in encoded.items()}
190
+
191
+ embeddings = _model(encoded, hidden_states=True)
192
+ embeddings = F.normalize(embeddings, dim=-1) # (B, num_layers, dim)
193
+ selected = embeddings[:, layer, :] # (B, dim)
194
+ all_embeddings.append(selected.cpu().numpy().astype(np.float32))
195
+
196
+ return np.concatenate(all_embeddings, axis=0) if all_embeddings else np.zeros((0, 1), dtype=np.float32)