File size: 3,860 Bytes
29fdac9
 
 
 
 
 
 
 
 
 
 
d76ef9a
 
 
29fdac9
 
 
 
 
 
d76ef9a
 
 
29fdac9
 
 
 
 
 
d76ef9a
 
 
29fdac9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d76ef9a
 
 
29fdac9
 
 
 
 
 
 
 
 
 
 
 
 
 
d76ef9a
 
 
29fdac9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
from dataclasses import dataclass
from typing import List, Optional

import numpy as np
import torch
from sentence_transformers import CrossEncoder, SentenceTransformer, util


@dataclass
class RunbookDoc:
    """
    Представляет один Markdown-ранбук локальной БЗ.
    """
    path: str
    title: str
    content: str


class RunbookRetriever:
    """
    Отвечает за загрузку локальной базы знаний и поиск по ней.
    """
    def __init__(
        self,
        kb_dir: str = "kb",
        embed_model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
        reranker_name: Optional[str] = "cross-encoder/ms-marco-MiniLM-L-6-v2",
    ):
        """
        Загружает все ранбуки и подготавливает модели (эмбеддер + опциональный reranker).
        """
        self.kb_dir = kb_dir
        # Force CPU to avoid CUDA capability mismatches in WSL/GPUs.
        self.device = torch.device("cpu")
        self.embed_model = SentenceTransformer(embed_model_name, device=self.device)
        self.reranker: Optional[CrossEncoder] = None
        if reranker_name:
            try:
                self.reranker = CrossEncoder(reranker_name, device=self.device)
            except Exception:
                self.reranker = None
        self.docs = self._load_docs()
        if self.docs:
            self.doc_embeddings = self.embed_model.encode(
                [doc.content for doc in self.docs],
                convert_to_tensor=True,
                device=self.device,
            )
        else:
            self.doc_embeddings = None

    def _load_docs(self) -> List[RunbookDoc]:
        """
        Читает Markdown-файлы из kb_dir и превращает их в список RunbookDoc.
        """
        docs: List[RunbookDoc] = []
        if not os.path.isdir(self.kb_dir):
            return docs
        for fname in os.listdir(self.kb_dir):
            if not fname.endswith(".md"):
                continue
            path = os.path.join(self.kb_dir, fname)
            with open(path, "r", encoding="utf-8") as f:
                content = f.read()
            title = content.splitlines()[0].lstrip("# ").strip() if content else fname
            docs.append(RunbookDoc(path=path, title=title, content=content))
        return docs

    def search(self, query: str, top_k: int = 3):
        """
        Находит топ-k релевантных ранбуков по косинусному сходству (и reranker'у, если доступен).
        """
        if not self.docs or self.doc_embeddings is None:
            return []
        query_emb = self.embed_model.encode(query, convert_to_tensor=True, device=self.device)
        scores = util.cos_sim(query_emb, self.doc_embeddings)[0]
        top_results = np.argsort(-scores.cpu().numpy())[: top_k * 4]
        candidates = [
            {"doc": self.docs[idx], "score": float(scores[idx])} for idx in top_results
        ]
        if self.reranker:
            pairs = [[query, c["doc"].content] for c in candidates]
            rerank_scores = self.reranker.predict(pairs)
            for cand, rscore in zip(candidates, rerank_scores):
                cand["rerank_score"] = float(rscore)
            candidates = sorted(candidates, key=lambda x: x.get("rerank_score", x["score"]), reverse=True)
        else:
            candidates = sorted(candidates, key=lambda x: x["score"], reverse=True)
        return [
            {
                "title": cand["doc"].title,
                "score": cand.get("rerank_score", cand["score"]),
                "path": cand["doc"].path,
                "excerpt": cand["doc"].content[:500],
            }
            for cand in candidates[:top_k]
        ]