""" SecureBERT+ embedder — extracted from MurshidUIPipeline.ipynb (cell 15). Produces a 768-dim float32 embedding for a text paragraph. Also provides build_text_for_embedding (cell 12). Original file is NOT modified. """ from __future__ import annotations import numpy as np from lxml import etree try: import torch from transformers import AutoModel, AutoTokenizer _TORCH_OK = True except (ImportError, OSError): _TORCH_OK = False from app.config import settings def _norm_spaces(s: str) -> str: return " ".join((s or "").split()).strip() def _strip_end_punct(s: str) -> str: return (s or "").rstrip(". ").strip() def build_text_for_embedding(clean_rule: str, summary: str) -> str: """Combine LLM summary with rule description — cell 12 of notebook.""" rule_elem = etree.fromstring(clean_rule.strip()) raw_desc = rule_elem.findtext("description") or "" description = _norm_spaces(raw_desc) summary = _norm_spaces(summary) description = _norm_spaces(description) if not summary and not description: return "" if summary and not description: return summary if description and not summary: return description s0 = _strip_end_punct(summary).lower() d0 = _strip_end_punct(description).lower() if s0 == d0: return _strip_end_punct(summary) + "." return f"{_strip_end_punct(summary)}. {_strip_end_punct(description)}." class SecureBERTEmbedder: """Mean-pooling embedder using ehsanaghaei/SecureBERT_Plus — cell 15.""" MAX_LEN = 512 BATCH_CHUNKS = 8 def __init__(self, model_id: str | None = None, device: str | None = None): if not _TORCH_OK: raise RuntimeError("torch/transformers not available — SecureBERTEmbedder cannot be initialised.") mid = model_id or settings.embed_model_id self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False self.tokenizer = AutoTokenizer.from_pretrained(mid, use_fast=True) self.model = AutoModel.from_pretrained(mid).to(self.device) self.model.eval() self.cls_id = self.tokenizer.cls_token_id self.sep_id = self.tokenizer.sep_token_id self.pad_id = ( self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.sep_id ) def _chunk_text(self, text: str) -> list[list[int]]: token_ids = self.tokenizer.encode(text, add_special_tokens=False) chunk_size = self.MAX_LEN - 2 chunks = [] for i in range(0, len(token_ids), chunk_size): piece = token_ids[i : i + chunk_size] chunks.append([self.cls_id] + piece + [self.sep_id]) return chunks def embed_text(self, text: str) -> np.ndarray: chunks = self._chunk_text(text) all_embs: list[np.ndarray] = [] for i in range(0, len(chunks), self.BATCH_CHUNKS): batch = chunks[i : i + self.BATCH_CHUNKS] max_len = max(len(x) for x in batch) input_ids, masks = [], [] for x in batch: pad = max_len - len(x) input_ids.append(x + [self.pad_id] * pad) masks.append([1] * len(x) + [0] * pad) ids_t = torch.tensor(input_ids).to(self.device) mask_t = torch.tensor(masks).to(self.device) with torch.no_grad(): out = self.model(input_ids=ids_t, attention_mask=mask_t) tok_emb = out.last_hidden_state mask_exp = mask_t.unsqueeze(-1).expand(tok_emb.size()).float() summed = torch.sum(tok_emb * mask_exp, dim=1) denom = torch.clamp(mask_exp.sum(dim=1), min=1e-9) mean_pooled = summed / denom all_embs.append(mean_pooled.cpu().numpy()) all_embs_np = np.vstack(all_embs) para_emb = all_embs_np.mean(axis=0) para_emb /= np.linalg.norm(para_emb) + 1e-12 return para_emb.astype(np.float32)