ZedLow commited on
Commit
4792259
·
verified ·
1 Parent(s): fb6bf03

Create retrieval.py

Browse files
Files changed (1) hide show
  1. rag/retrieval.py +134 -0
rag/retrieval.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import json
3
+ from pathlib import Path
4
+ from typing import List, Tuple
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from rag.config import Settings
10
+ from rag.data import Doc
11
+ from rag.logging_utils import get_logger
12
+
13
+ logger = get_logger(__name__)
14
+
15
+ def last_token_pool(last_hidden_states, attention_mask):
16
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
17
+ if left_padding:
18
+ return last_hidden_states[:, -1]
19
+ sequence_lengths = attention_mask.sum(dim=1) - 1
20
+ batch_size = last_hidden_states.shape[0]
21
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
22
+
23
+ def _fingerprint(docs: List[Doc], settings: Settings) -> str:
24
+ h = hashlib.sha256()
25
+ h.update(settings.embed_model_id.encode("utf-8"))
26
+ h.update(str(settings.embed_max_len).encode("utf-8"))
27
+ for d in docs:
28
+ h.update(d.doc_name.encode("utf-8"))
29
+ h.update(d.company.encode("utf-8"))
30
+ h.update(d.text.encode("utf-8"))
31
+ return h.hexdigest()
32
+
33
+ def ensure_index_dir(settings: Settings):
34
+ Path(settings.index_dir).mkdir(parents=True, exist_ok=True)
35
+
36
+ @torch.no_grad()
37
+ def build_or_load_doc_embeddings(
38
+ docs: List[Doc],
39
+ embed_tokenizer,
40
+ embed_model,
41
+ settings: Settings,
42
+ ) -> Tuple[torch.Tensor, str]:
43
+ """
44
+ Returns (doc_embeddings [N, D] on CPU, fingerprint)
45
+ Caches to data/index/doc_embeds.pt
46
+ """
47
+ ensure_index_dir(settings)
48
+ fp = _fingerprint(docs, settings)
49
+ cache_file = settings.doc_embeds_file()
50
+ meta_file = settings.doc_meta_file()
51
+
52
+ if cache_file.exists() and meta_file.exists():
53
+ try:
54
+ meta = json.loads(meta_file.read_text(encoding="utf-8"))
55
+ if meta.get("fingerprint") == fp:
56
+ logger.info("Loading cached doc embeddings: %s", str(cache_file))
57
+ payload = torch.load(cache_file, map_location="cpu")
58
+ return payload["embeddings"], fp
59
+ except Exception as e:
60
+ logger.warning("Failed to load cache, rebuilding. Reason: %s", e)
61
+
62
+ logger.info("Building doc embeddings cache (%d docs)...", len(docs))
63
+ doc_texts = [d.text for d in docs]
64
+ embs = []
65
+
66
+ for i in range(0, len(doc_texts), settings.embed_batch_size):
67
+ batch = doc_texts[i : i + settings.embed_batch_size]
68
+ d_inputs = embed_tokenizer(
69
+ batch,
70
+ max_length=settings.embed_max_len,
71
+ padding=True,
72
+ truncation=True,
73
+ return_tensors="pt",
74
+ ).to(embed_model.device)
75
+
76
+ d_outputs = embed_model(**d_inputs)
77
+ batch_emb = last_token_pool(d_outputs.last_hidden_state, d_inputs["attention_mask"])
78
+ batch_emb = F.normalize(batch_emb, p=2, dim=1)
79
+ embs.append(batch_emb.detach().to("cpu"))
80
+
81
+ doc_embs = torch.cat(embs, dim=0)
82
+
83
+ torch.save({"embeddings": doc_embs}, cache_file)
84
+ meta_file.write_text(json.dumps({"fingerprint": fp, "n_docs": len(docs)}, indent=2), encoding="utf-8")
85
+ logger.info("Saved embeddings cache: %s", str(cache_file))
86
+ return doc_embs, fp
87
+
88
+ @torch.no_grad()
89
+ def embed_query(query: str, embed_tokenizer, embed_model, settings: Settings) -> torch.Tensor:
90
+ query_text = (
91
+ "Instruct: Given a user query, retrieve relevant passages that answer the query.\n"
92
+ f"Query: {query}"
93
+ )
94
+ q_inputs = embed_tokenizer(
95
+ [query_text],
96
+ max_length=settings.embed_max_len,
97
+ padding=True,
98
+ truncation=True,
99
+ return_tensors="pt",
100
+ ).to(embed_model.device)
101
+
102
+ q_outputs = embed_model(**q_inputs)
103
+ q_emb = last_token_pool(q_outputs.last_hidden_state, q_inputs["attention_mask"])
104
+ q_emb = F.normalize(q_emb, p=2, dim=1)
105
+ return q_emb.detach().to("cpu") # keep retrieval ops on CPU
106
+
107
+ def topk_retrieval(q_emb_cpu: torch.Tensor, doc_embs_cpu: torch.Tensor, k: int) -> List[int]:
108
+ # q_emb: [1, D], doc_embs: [N, D]
109
+ scores = (q_emb_cpu @ doc_embs_cpu.T).squeeze(0)
110
+ k = min(k, scores.shape[0])
111
+ return torch.topk(scores, k=k).indices.tolist()
112
+
113
+ @torch.no_grad()
114
+ def rerank(
115
+ query: str,
116
+ candidate_docs: List[Doc],
117
+ rerank_tokenizer,
118
+ rerank_model,
119
+ settings: Settings,
120
+ k: int,
121
+ ) -> Tuple[List[int], torch.Tensor]:
122
+ pairs = [[query, d.text] for d in candidate_docs]
123
+ r_inputs = rerank_tokenizer(
124
+ pairs,
125
+ padding=True,
126
+ truncation=True,
127
+ return_tensors="pt",
128
+ max_length=settings.rerank_max_len,
129
+ ).to(rerank_model.device)
130
+
131
+ r_scores = rerank_model(**r_inputs, return_dict=True).logits.view(-1).float().detach().to("cpu")
132
+ k = min(k, len(candidate_docs))
133
+ top_idx = torch.topk(r_scores, k=k).indices.tolist()
134
+ return top_idx, r_scores