File size: 8,037 Bytes
cae2130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ec9229
cae2130
 
 
 
 
 
3ec9229
 
 
 
 
cae2130
 
 
 
 
 
 
 
 
3ec9229
cae2130
3ec9229
cae2130
 
 
 
 
 
 
 
 
 
 
 
 
 
3ec9229
 
cae2130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ec9229
 
cae2130
3ec9229
cae2130
 
 
3ec9229
cae2130
 
3ec9229
cae2130
 
 
 
 
 
 
 
 
3ec9229
 
cae2130
3ec9229
cae2130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ec9229
cae2130
 
3ec9229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cae2130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0fafa4
cae2130
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
"""Text β†’ detection-ready embedding.

Loads the DETree ``TextEmbeddingModel`` and exposes ``get_text_embedding``,
which tokenises a string, runs it through the model, and returns a single
L2-normalised embedding vector ready to be passed to ``detect_embedding``.

The layer extracted defaults to -1 (the last hidden layer), matching the
default used in ``detector.py`` when building the KNN index.  Override
``layer`` if your database was built with a different layer.

Usage::

    from Apps.text_embedder import get_text_embedding
    from Apps.detector import detect_embedding

    emb    = get_text_embedding("Was this written by a human?")
    result = detect_embedding(emb)
    # {"predicted_class": "Human"|"Ai", "confidence": 0.93}
"""

from __future__ import annotations

import os
import sys
from typing import Optional
import logging
import numpy as np
import torch
import torch.nn.functional as F
from pathlib import Path
from huggingface_hub import snapshot_download



log = logging.getLogger("text_embedder")
logging.basicConfig(level=logging.INFO, format="%(levelname)s [%(name)s] %(message)s")

# ---------------------------------------------------------------------------
# Make the local 'detree' package importable
# ---------------------------------------------------------------------------
_current_dir = os.path.dirname(os.path.abspath(__file__))
if _current_dir not in sys.path:
    sys.path.append(_current_dir)

try:
    from detree.model.text_embedding import TextEmbeddingModel
    log.info("TextEmbeddingModel imported successfully.")
except ImportError as _e:
    log.error(f"Could not import TextEmbeddingModel: {_e}")
    TextEmbeddingModel = None

# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
MAX_LENGTH = 512
POOLING    = "max"    # must match what was used during database construction
DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"

# hugging face 
REPO_ID = "MAS-AI-0000/Authentica"
TEXT_SUBFOLDER = "Lib/Models/Text"   # where config.json/model.safetensors live in the repo
EMBEDDING_FILE = "priori1_center10k.pt" 
_TEXT_DIR = None
log.info(f"[config] device={DEVICE!r}  max_length={MAX_LENGTH}  pooling={POOLING!r}")


try:
    # download a local snapshot of just the Text folder and point _TEXT_DIR at it
    print(f"Downloading/Checking model from {REPO_ID}...")
    _snapshot_dir = snapshot_download(
        repo_id=REPO_ID,
        allow_patterns=[f"{TEXT_SUBFOLDER}/*"]
    )
    _TEXT_DIR = os.path.join(_snapshot_dir, TEXT_SUBFOLDER)
    print(f"Model directory set to: {_TEXT_DIR}")
except Exception as e:
    print(f"Error downloading model from Hugging Face: {e}")



# ---------------------------------------------------------------------------
# Module-level initialisation
# ---------------------------------------------------------------------------

_model:     Optional[object] = None
_tokenizer: Optional[object] = None

def _init() -> None:
    global _model, _tokenizer

    log.info("_init: starting TextEmbedder initialisation.")

    if TextEmbeddingModel is None:
        log.error("_init: TextEmbeddingModel is None β€” check import error above. Embedding disabled.")
        return

    if not os.path.exists(_TEXT_DIR):
        log.error(f"_init: model directory not found at {_TEXT_DIR!r} β€” embedding disabled.")
        return

    log.info(f"_init: loading TextEmbeddingModel from {_TEXT_DIR!r} on device={DEVICE!r} ...")
    try:
        _model = TextEmbeddingModel(
            _TEXT_DIR,
            output_hidden_states=True,
            infer=True,
            use_pooling=POOLING,
        ).to(DEVICE)
        _model.eval()
        _tokenizer = _model.tokenizer
        log.info(f"_init: model loaded OK. tokenizer type={type(_tokenizer).__name__!r}")
        log.info(f"_init: model device={next(_model.parameters()).device}")
    except Exception as exc:
        log.exception(f"_init: error loading model: {exc}")


_init()


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

@torch.no_grad()
def get_text_embedding(
    text: str,
    *,
    layer:      int = -1,           # which hidden-state layer to use (-1 = last)
    max_length: int = MAX_LENGTH,
) -> np.ndarray:
    """Return a (1, embedding_dim) float32 numpy array for the given text.

    The embedding is L2-normalised and projected into the same space as the
    DETree database so it can be passed directly to ``detect_embedding``.

    Args:
        text:       The input string to embed.
        layer:      Hidden-state layer index.  -1 selects the last layer,
                    matching the default used when building the database.
        max_length: Tokenisation truncation length.

    Returns:
        ``np.ndarray`` of shape ``(1, embedding_dim)`` and dtype float32.
    """
    if _model is None or _tokenizer is None:
        log.error("get_text_embedding: model or tokenizer is None β€” returning zeros. Check _init logs.")
        return np.zeros((1, 1), dtype=np.float32)

    log.info(f"get_text_embedding: input text length={len(text)} chars, layer={layer}")
    try:
        encoded = _tokenizer(
            [text],
            return_tensors="pt",
            max_length=max_length,
            padding="max_length",
            truncation=True,
        )
        log.info(f"get_text_embedding: tokenised keys={list(encoded.keys())}  "
                 f"input_ids shape={encoded['input_ids'].shape}")
        encoded = {k: v.to(DEVICE) for k, v in encoded.items()}

        # Shape returned by model with hidden_states=True: (batch, num_layers, dim)
        embeddings = _model(encoded, hidden_states=True)
        log.info(f"get_text_embedding: raw embeddings shape={tuple(embeddings.shape)}")
        embeddings = F.normalize(embeddings, dim=-1)         # normalise feature dim

        # embeddings: (1, num_layers, dim)  β†’  select layer  β†’  (1, dim)
        selected = embeddings[:, layer, :]                   # supports negative indexing
        log.info(f"get_text_embedding: selected layer={layer}  output shape={tuple(selected.shape)}  "
                 f"norm={selected.norm(dim=-1).item():.4f}")
    except Exception as exc:
        log.exception(f"get_text_embedding: failed during inference: {exc}")
        return np.zeros((1, 1), dtype=np.float32)

    return selected.cpu().numpy().astype(np.float32)


@torch.no_grad()
def get_text_embeddings_batch(
    texts: list[str],
    *,
    layer:      int = -1,
    max_length: int = MAX_LENGTH,
    batch_size: int = 8,
) -> np.ndarray:
    """Return an (N, embedding_dim) float32 array for a list of strings.

    Args:
        texts:      List of input strings.
        layer:      Hidden-state layer index (-1 = last).
        max_length: Tokenisation truncation length.
        batch_size: Number of strings to encode per forward pass.

    Returns:
        ``np.ndarray`` of shape ``(N, embedding_dim)`` and dtype float32.
    """
    if _model is None or _tokenizer is None:
        return np.zeros((len(texts), 1), dtype=np.float32)

    all_embeddings: list[np.ndarray] = []
    for i in range(0, len(texts), batch_size):
        batch = [str(t) for t in texts[i : i + batch_size]]
        encoded = _tokenizer(
            batch,
            return_tensors="pt",
            max_length=max_length,
            padding="max_length",
            truncation=True,
        )
        encoded = {k: v.to(DEVICE) for k, v in encoded.items()}

        embeddings = _model(encoded, hidden_states=True)
        embeddings = F.normalize(embeddings, dim=-1)         # (B, num_layers, dim)
        selected   = embeddings[:, layer, :]                 # (B, dim)
        all_embeddings.append(selected.cpu().numpy().astype(np.float32))

    return np.concatenate(all_embeddings, axis=0) if all_embeddings else np.zeros((0, 1), dtype=np.float32)