Authentica / text_embedder.py
MAS-AI-0000's picture
Update text_embedder.py
3ec9229 verified
"""Text β†’ detection-ready embedding.
Loads the DETree ``TextEmbeddingModel`` and exposes ``get_text_embedding``,
which tokenises a string, runs it through the model, and returns a single
L2-normalised embedding vector ready to be passed to ``detect_embedding``.
The layer extracted defaults to -1 (the last hidden layer), matching the
default used in ``detector.py`` when building the KNN index. Override
``layer`` if your database was built with a different layer.
Usage::
from Apps.text_embedder import get_text_embedding
from Apps.detector import detect_embedding
emb = get_text_embedding("Was this written by a human?")
result = detect_embedding(emb)
# {"predicted_class": "Human"|"Ai", "confidence": 0.93}
"""
from __future__ import annotations
import os
import sys
from typing import Optional
import logging
import numpy as np
import torch
import torch.nn.functional as F
from pathlib import Path
from huggingface_hub import snapshot_download
log = logging.getLogger("text_embedder")
logging.basicConfig(level=logging.INFO, format="%(levelname)s [%(name)s] %(message)s")
# ---------------------------------------------------------------------------
# Make the local 'detree' package importable
# ---------------------------------------------------------------------------
_current_dir = os.path.dirname(os.path.abspath(__file__))
if _current_dir not in sys.path:
sys.path.append(_current_dir)
try:
from detree.model.text_embedding import TextEmbeddingModel
log.info("TextEmbeddingModel imported successfully.")
except ImportError as _e:
log.error(f"Could not import TextEmbeddingModel: {_e}")
TextEmbeddingModel = None
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
MAX_LENGTH = 512
POOLING = "max" # must match what was used during database construction
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# hugging face
REPO_ID = "MAS-AI-0000/Authentica"
TEXT_SUBFOLDER = "Lib/Models/Text" # where config.json/model.safetensors live in the repo
EMBEDDING_FILE = "priori1_center10k.pt"
_TEXT_DIR = None
log.info(f"[config] device={DEVICE!r} max_length={MAX_LENGTH} pooling={POOLING!r}")
try:
# download a local snapshot of just the Text folder and point _TEXT_DIR at it
print(f"Downloading/Checking model from {REPO_ID}...")
_snapshot_dir = snapshot_download(
repo_id=REPO_ID,
allow_patterns=[f"{TEXT_SUBFOLDER}/*"]
)
_TEXT_DIR = os.path.join(_snapshot_dir, TEXT_SUBFOLDER)
print(f"Model directory set to: {_TEXT_DIR}")
except Exception as e:
print(f"Error downloading model from Hugging Face: {e}")
# ---------------------------------------------------------------------------
# Module-level initialisation
# ---------------------------------------------------------------------------
_model: Optional[object] = None
_tokenizer: Optional[object] = None
def _init() -> None:
global _model, _tokenizer
log.info("_init: starting TextEmbedder initialisation.")
if TextEmbeddingModel is None:
log.error("_init: TextEmbeddingModel is None β€” check import error above. Embedding disabled.")
return
if not os.path.exists(_TEXT_DIR):
log.error(f"_init: model directory not found at {_TEXT_DIR!r} β€” embedding disabled.")
return
log.info(f"_init: loading TextEmbeddingModel from {_TEXT_DIR!r} on device={DEVICE!r} ...")
try:
_model = TextEmbeddingModel(
_TEXT_DIR,
output_hidden_states=True,
infer=True,
use_pooling=POOLING,
).to(DEVICE)
_model.eval()
_tokenizer = _model.tokenizer
log.info(f"_init: model loaded OK. tokenizer type={type(_tokenizer).__name__!r}")
log.info(f"_init: model device={next(_model.parameters()).device}")
except Exception as exc:
log.exception(f"_init: error loading model: {exc}")
_init()
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
@torch.no_grad()
def get_text_embedding(
text: str,
*,
layer: int = -1, # which hidden-state layer to use (-1 = last)
max_length: int = MAX_LENGTH,
) -> np.ndarray:
"""Return a (1, embedding_dim) float32 numpy array for the given text.
The embedding is L2-normalised and projected into the same space as the
DETree database so it can be passed directly to ``detect_embedding``.
Args:
text: The input string to embed.
layer: Hidden-state layer index. -1 selects the last layer,
matching the default used when building the database.
max_length: Tokenisation truncation length.
Returns:
``np.ndarray`` of shape ``(1, embedding_dim)`` and dtype float32.
"""
if _model is None or _tokenizer is None:
log.error("get_text_embedding: model or tokenizer is None β€” returning zeros. Check _init logs.")
return np.zeros((1, 1), dtype=np.float32)
log.info(f"get_text_embedding: input text length={len(text)} chars, layer={layer}")
try:
encoded = _tokenizer(
[text],
return_tensors="pt",
max_length=max_length,
padding="max_length",
truncation=True,
)
log.info(f"get_text_embedding: tokenised keys={list(encoded.keys())} "
f"input_ids shape={encoded['input_ids'].shape}")
encoded = {k: v.to(DEVICE) for k, v in encoded.items()}
# Shape returned by model with hidden_states=True: (batch, num_layers, dim)
embeddings = _model(encoded, hidden_states=True)
log.info(f"get_text_embedding: raw embeddings shape={tuple(embeddings.shape)}")
embeddings = F.normalize(embeddings, dim=-1) # normalise feature dim
# embeddings: (1, num_layers, dim) β†’ select layer β†’ (1, dim)
selected = embeddings[:, layer, :] # supports negative indexing
log.info(f"get_text_embedding: selected layer={layer} output shape={tuple(selected.shape)} "
f"norm={selected.norm(dim=-1).item():.4f}")
except Exception as exc:
log.exception(f"get_text_embedding: failed during inference: {exc}")
return np.zeros((1, 1), dtype=np.float32)
return selected.cpu().numpy().astype(np.float32)
@torch.no_grad()
def get_text_embeddings_batch(
texts: list[str],
*,
layer: int = -1,
max_length: int = MAX_LENGTH,
batch_size: int = 8,
) -> np.ndarray:
"""Return an (N, embedding_dim) float32 array for a list of strings.
Args:
texts: List of input strings.
layer: Hidden-state layer index (-1 = last).
max_length: Tokenisation truncation length.
batch_size: Number of strings to encode per forward pass.
Returns:
``np.ndarray`` of shape ``(N, embedding_dim)`` and dtype float32.
"""
if _model is None or _tokenizer is None:
return np.zeros((len(texts), 1), dtype=np.float32)
all_embeddings: list[np.ndarray] = []
for i in range(0, len(texts), batch_size):
batch = [str(t) for t in texts[i : i + batch_size]]
encoded = _tokenizer(
batch,
return_tensors="pt",
max_length=max_length,
padding="max_length",
truncation=True,
)
encoded = {k: v.to(DEVICE) for k, v in encoded.items()}
embeddings = _model(encoded, hidden_states=True)
embeddings = F.normalize(embeddings, dim=-1) # (B, num_layers, dim)
selected = embeddings[:, layer, :] # (B, dim)
all_embeddings.append(selected.cpu().numpy().astype(np.float32))
return np.concatenate(all_embeddings, axis=0) if all_embeddings else np.zeros((0, 1), dtype=np.float32)