murshid / murshid_backend /app /ml /embedder.py
devorbit's picture
Initial deployment - secrets removed
26e1c2e
"""
SecureBERT+ embedder — extracted from MurshidUIPipeline.ipynb (cell 15).
Produces a 768-dim float32 embedding for a text paragraph.
Also provides build_text_for_embedding (cell 12).
Original file is NOT modified.
"""
from __future__ import annotations
import numpy as np
from lxml import etree
try:
import torch
from transformers import AutoModel, AutoTokenizer
_TORCH_OK = True
except (ImportError, OSError):
_TORCH_OK = False
from app.config import settings
def _norm_spaces(s: str) -> str:
return " ".join((s or "").split()).strip()
def _strip_end_punct(s: str) -> str:
return (s or "").rstrip(". ").strip()
def build_text_for_embedding(clean_rule: str, summary: str) -> str:
"""Combine LLM summary with rule description — cell 12 of notebook."""
rule_elem = etree.fromstring(clean_rule.strip())
raw_desc = rule_elem.findtext("description") or ""
description = _norm_spaces(raw_desc)
summary = _norm_spaces(summary)
description = _norm_spaces(description)
if not summary and not description:
return ""
if summary and not description:
return summary
if description and not summary:
return description
s0 = _strip_end_punct(summary).lower()
d0 = _strip_end_punct(description).lower()
if s0 == d0:
return _strip_end_punct(summary) + "."
return f"{_strip_end_punct(summary)}. {_strip_end_punct(description)}."
class SecureBERTEmbedder:
"""Mean-pooling embedder using ehsanaghaei/SecureBERT_Plus — cell 15."""
MAX_LEN = 512
BATCH_CHUNKS = 8
def __init__(self, model_id: str | None = None, device: str | None = None):
if not _TORCH_OK:
raise RuntimeError("torch/transformers not available — SecureBERTEmbedder cannot be initialised.")
mid = model_id or settings.embed_model_id
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
self.tokenizer = AutoTokenizer.from_pretrained(mid, use_fast=True)
self.model = AutoModel.from_pretrained(mid).to(self.device)
self.model.eval()
self.cls_id = self.tokenizer.cls_token_id
self.sep_id = self.tokenizer.sep_token_id
self.pad_id = (
self.tokenizer.pad_token_id
if self.tokenizer.pad_token_id is not None
else self.sep_id
)
def _chunk_text(self, text: str) -> list[list[int]]:
token_ids = self.tokenizer.encode(text, add_special_tokens=False)
chunk_size = self.MAX_LEN - 2
chunks = []
for i in range(0, len(token_ids), chunk_size):
piece = token_ids[i : i + chunk_size]
chunks.append([self.cls_id] + piece + [self.sep_id])
return chunks
def embed_text(self, text: str) -> np.ndarray:
chunks = self._chunk_text(text)
all_embs: list[np.ndarray] = []
for i in range(0, len(chunks), self.BATCH_CHUNKS):
batch = chunks[i : i + self.BATCH_CHUNKS]
max_len = max(len(x) for x in batch)
input_ids, masks = [], []
for x in batch:
pad = max_len - len(x)
input_ids.append(x + [self.pad_id] * pad)
masks.append([1] * len(x) + [0] * pad)
ids_t = torch.tensor(input_ids).to(self.device)
mask_t = torch.tensor(masks).to(self.device)
with torch.no_grad():
out = self.model(input_ids=ids_t, attention_mask=mask_t)
tok_emb = out.last_hidden_state
mask_exp = mask_t.unsqueeze(-1).expand(tok_emb.size()).float()
summed = torch.sum(tok_emb * mask_exp, dim=1)
denom = torch.clamp(mask_exp.sum(dim=1), min=1e-9)
mean_pooled = summed / denom
all_embs.append(mean_pooled.cpu().numpy())
all_embs_np = np.vstack(all_embs)
para_emb = all_embs_np.mean(axis=0)
para_emb /= np.linalg.norm(para_emb) + 1e-12
return para_emb.astype(np.float32)