File size: 4,121 Bytes
26e1c2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
SecureBERT+ embedder — extracted from MurshidUIPipeline.ipynb (cell 15).
Produces a 768-dim float32 embedding for a text paragraph.
Also provides build_text_for_embedding (cell 12).
Original file is NOT modified.
"""

from __future__ import annotations

import numpy as np
from lxml import etree

try:
    import torch
    from transformers import AutoModel, AutoTokenizer
    _TORCH_OK = True
except (ImportError, OSError):
    _TORCH_OK = False

from app.config import settings


def _norm_spaces(s: str) -> str:
    return " ".join((s or "").split()).strip()


def _strip_end_punct(s: str) -> str:
    return (s or "").rstrip(". ").strip()


def build_text_for_embedding(clean_rule: str, summary: str) -> str:
    """Combine LLM summary with rule description — cell 12 of notebook."""
    rule_elem = etree.fromstring(clean_rule.strip())
    raw_desc = rule_elem.findtext("description") or ""
    description = _norm_spaces(raw_desc)
    summary = _norm_spaces(summary)
    description = _norm_spaces(description)

    if not summary and not description:
        return ""
    if summary and not description:
        return summary
    if description and not summary:
        return description

    s0 = _strip_end_punct(summary).lower()
    d0 = _strip_end_punct(description).lower()

    if s0 == d0:
        return _strip_end_punct(summary) + "."
    return f"{_strip_end_punct(summary)}. {_strip_end_punct(description)}."


class SecureBERTEmbedder:
    """Mean-pooling embedder using ehsanaghaei/SecureBERT_Plus — cell 15."""

    MAX_LEN = 512
    BATCH_CHUNKS = 8

    def __init__(self, model_id: str | None = None, device: str | None = None):
        if not _TORCH_OK:
            raise RuntimeError("torch/transformers not available — SecureBERTEmbedder cannot be initialised.")
        mid = model_id or settings.embed_model_id
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        self.tokenizer = AutoTokenizer.from_pretrained(mid, use_fast=True)
        self.model = AutoModel.from_pretrained(mid).to(self.device)
        self.model.eval()
        self.cls_id = self.tokenizer.cls_token_id
        self.sep_id = self.tokenizer.sep_token_id
        self.pad_id = (
            self.tokenizer.pad_token_id
            if self.tokenizer.pad_token_id is not None
            else self.sep_id
        )

    def _chunk_text(self, text: str) -> list[list[int]]:
        token_ids = self.tokenizer.encode(text, add_special_tokens=False)
        chunk_size = self.MAX_LEN - 2
        chunks = []
        for i in range(0, len(token_ids), chunk_size):
            piece = token_ids[i : i + chunk_size]
            chunks.append([self.cls_id] + piece + [self.sep_id])
        return chunks

    def embed_text(self, text: str) -> np.ndarray:
        chunks = self._chunk_text(text)
        all_embs: list[np.ndarray] = []

        for i in range(0, len(chunks), self.BATCH_CHUNKS):
            batch = chunks[i : i + self.BATCH_CHUNKS]
            max_len = max(len(x) for x in batch)
            input_ids, masks = [], []
            for x in batch:
                pad = max_len - len(x)
                input_ids.append(x + [self.pad_id] * pad)
                masks.append([1] * len(x) + [0] * pad)

            ids_t = torch.tensor(input_ids).to(self.device)
            mask_t = torch.tensor(masks).to(self.device)

            with torch.no_grad():
                out = self.model(input_ids=ids_t, attention_mask=mask_t)
                tok_emb = out.last_hidden_state
                mask_exp = mask_t.unsqueeze(-1).expand(tok_emb.size()).float()
                summed = torch.sum(tok_emb * mask_exp, dim=1)
                denom = torch.clamp(mask_exp.sum(dim=1), min=1e-9)
                mean_pooled = summed / denom

            all_embs.append(mean_pooled.cpu().numpy())

        all_embs_np = np.vstack(all_embs)
        para_emb = all_embs_np.mean(axis=0)
        para_emb /= np.linalg.norm(para_emb) + 1e-12
        return para_emb.astype(np.float32)