hash-map commited on
Commit
dff5c6e
·
verified ·
1 Parent(s): 835c0b3

Upload 5 files

Browse files
context_retreiver.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # context_retriever.py
2
+ import os, re, json, pickle, logging, numpy as np, faiss
3
+ from tqdm.notebook import tqdm
4
+ from sentence_transformers import SentenceTransformer
5
+ from langchain_community.retrievers import BM25Retriever
6
+ from langchain.docstore.document import Document
7
+
8
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
9
+ logger = logging.getLogger(__name__)
10
+
11
+ WORK = "context"
12
+ JSONL = f"{WORK}/rag_documents.jsonl"
13
+ FAISS_INDEX = f"{WORK}/faiss_ivf.index"
14
+ BM25_PICKLE = f"{WORK}/bm25_retriever.pkl"
15
+
16
+ logger.info("Loading all RAG documents...")
17
+ with open(JSONL, encoding='utf-8') as f:
18
+ ALL_DOCS = [json.loads(line) for line in f]
19
+
20
+ LINE_TO_TEXT = {i: doc["text"] for i, doc in enumerate(ALL_DOCS)}
21
+ LINE_TO_META = {i: doc["metadata"] for i, doc in enumerate(ALL_DOCS)}
22
+
23
+ class HybridRetriever:
24
+ def __init__(self):
25
+ # FAISS CPU
26
+ self.faiss_index = faiss.read_index(FAISS_INDEX)
27
+ logger.info(f"FAISS loaded ({self.faiss_index.ntotal:,} vectors)")
28
+
29
+ # SentenceTransformer (GPU if available)
30
+ self.model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2",
31
+ device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu")
32
+
33
+ # BM25
34
+ if os.path.exists(BM25_PICKLE):
35
+ self.bm25 = pickle.load(open(BM25_PICKLE, "rb"))
36
+ logger.info("BM25 loaded")
37
+ else:
38
+ logger.info("Building BM25...")
39
+ docs = [Document(page_content=re.sub(r"^Filename:.*\nFullPath:.*\n\n", "",
40
+ doc["text"], flags=re.M),
41
+ metadata=doc["metadata"]) for doc in ALL_DOCS]
42
+ self.bm25 = BM25Retriever.from_documents(docs)
43
+ self.bm25.k = 30
44
+ pickle.dump(self.bm25, open(BM25_PICKLE, "wb"))
45
+ logger.info("BM25 built and saved")
46
+
47
+ def batch_retrieve(self, queries, top_k=3, faiss_k=10, bm25_k=3):
48
+ qvecs = self.model.encode(queries, show_progress_bar=False, normalize_embeddings=True).astype("float32")
49
+ D, I = self.faiss_index.search(qvecs, faiss_k)
50
+
51
+ batch_results = []
52
+ for qi, (scores, indices) in enumerate(zip(D, I)):
53
+ results = []
54
+ seen = set()
55
+ for score, idx in zip(scores, indices):
56
+ if idx == -1 or idx in seen: continue
57
+ results.append({"score": float(score), "text": LINE_TO_TEXT[idx],
58
+ "metadata": LINE_TO_META[idx], "source": "FAISS"})
59
+ seen.add(idx)
60
+ if len(results) >= top_k: break
61
+
62
+ # BM25
63
+ bm25_docs = self.bm25.invoke(queries[qi])
64
+ for doc in bm25_docs[:bm25_k]:
65
+ ln = doc.metadata.get("line_no")
66
+ if ln in seen: continue
67
+ results.append({"score": 0.0, "text": LINE_TO_TEXT.get(ln, ""),
68
+ "metadata": LINE_TO_META.get(ln, doc.metadata), "source": "BM25"})
69
+ seen.add(ln)
70
+ if len(results) >= top_k: break
71
+ batch_results.append(results)
72
+ return batch_results
73
+
74
+ # Singleton retriever
75
+ retriever = HybridRetriever()
full_rag.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9c92278e3df812534acaa211928b76a888453c81cfbe6b70bdea5d5cb330c61
3
+ size 1597083267
prompter.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ generate_prompts_v8_batch_fixed.py
4
+
5
+ - Uses batch retrieval for Context, QA, and Relationships
6
+ - Saves in batches with checkpointing
7
+ - Pads contexts and QA to fixed sizes
8
+ - Appends metadata at the end
9
+ """
10
+
11
+ import os, json, torch, numpy as np
12
+ from pathlib import Path
13
+ from tqdm import tqdm
14
+ from sentence_transformers import SentenceTransformer
15
+ from concurrent.futures import ThreadPoolExecutor
16
+
17
+ from context_retreiver import retriever as context_retriever
18
+ from qa_retreiver import search_topk as qa_retreiver
19
+ from relationships_retreiver import batch_relationships
20
+
21
+ QA_FILE = Path("got_all_qa_final.json")
22
+ OUT_DIR = Path("prompts_out")
23
+ CHECKPOINT_FILE = OUT_DIR / "checkpoint.json"
24
+ SAVE_BATCH_SIZE = 512
25
+ EMBED_BATCH_SIZE = 32 # GPU batch size
26
+
27
+ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
28
+ print(f"[INFO] Using device: {DEVICE}")
29
+
30
+ EMBED_MODEL = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=DEVICE)
31
+
32
+ STRUCTURAL_TOKENS = [
33
+ "<|CTX_QA|>", "<|/CTX_QA|>",
34
+ "<|CTX_REL|>", "<|/CTX_REL|>",
35
+ "<|INSTR|>", "<|/INSTR|>",
36
+ "<|QUESTION|>", "<|/QUESTION|>",
37
+ "<|ANSWER|>", "<|/ANSWER|>",
38
+ "<|QA_SIM_1|>", "<|/QA_SIM_1|>",
39
+ "<|QA_SIM_2|>", "<|/QA_SIM_2|>",
40
+ "<|QA_SIM_3|>", "<|/QA_SIM_3|>",
41
+ "<|QA_SIM_4|>", "<|/QA_SIM_4|>",
42
+ "<|QA_SIM_5|>", "<|/QA_SIM_5|>"
43
+ ]
44
+
45
+ def read_checkpoint():
46
+ if CHECKPOINT_FILE.exists():
47
+ try:
48
+ return int(json.loads(CHECKPOINT_FILE.read_text())["next_index"])
49
+ except:
50
+ return 0
51
+ return 0
52
+
53
+ def write_checkpoint(idx):
54
+ OUT_DIR.mkdir(parents=True, exist_ok=True)
55
+ CHECKPOINT_FILE.write_text(json.dumps({"next_index": idx}))
56
+
57
+ def metadata_to_str(meta):
58
+ if not meta: return ""
59
+ return "; ".join(f"{k}={v}" for k,v in meta.items() if isinstance(v,(str,int,float,bool)))
60
+
61
+ def append_metadata_at_end(answer, context1_text, context1_meta):
62
+ parts=[]
63
+ if answer: parts.append(answer.strip())
64
+ if context1_text: parts.append(f"[Context1: {context1_text.strip()}]")
65
+ meta_str = metadata_to_str(context1_meta)
66
+ if meta_str: parts.append(f"(meta: {meta_str})")
67
+ return " ".join(parts)
68
+
69
+ def build_prompt(ctx_texts, rel_text, sim_qas, question):
70
+ parts=[]
71
+ # ctx_texts = [ctx2, ctx3]
72
+ for ctx in ctx_texts:
73
+ if ctx: parts.append(f"<|CTX_QA|> {ctx} <|/CTX_QA|>")
74
+ if rel_text: parts.append(f"<|CTX_REL|> {rel_text} <|/CTX_REL|>")
75
+ for i in range(5):
76
+ if i < len(sim_qas):
77
+ qa = sim_qas[i]
78
+ parts.append(f"<|QA_SIM_{i+1}|> Q: {qa['question']} A: {qa['answer']} <|/QA_SIM_{i+1}|>")
79
+ else:
80
+ parts.append(f"<|QA_SIM_{i+1}|> <|/QA_SIM_{i+1}|>")
81
+ parts.append("<|INSTR|> Use above contexts to answer concisely. <|/INSTR|>")
82
+ parts.append(f"<|QUESTION|> {question} <|/QUESTION|>")
83
+ parts.append("<|ANSWER|>")
84
+ return "\n\n".join(parts)
85
+
86
+ def retrieve_contexts(questions, top_k=3):
87
+ """Batch retrieve context texts + metadata"""
88
+ batch_res = context_retriever.batch_retrieve(questions, top_k=top_k)
89
+ contexts=[]
90
+ for res_list in batch_res:
91
+ ctx_texts = [r["text"] for r in res_list[:top_k]]
92
+ ctx_metas = [r["metadata"] for r in res_list[:top_k]]
93
+ # pad to top_k
94
+ while len(ctx_texts)<top_k: ctx_texts.append(""); ctx_metas.append({})
95
+ contexts.append((ctx_texts, ctx_metas))
96
+ return contexts
97
+
98
+ def retrieve_qas_and_rels(questions, max_workers=20):
99
+ """Threaded retrieval of QA and relationships"""
100
+ sim_qas_list=[]
101
+ rel_list=[]
102
+ with ThreadPoolExecutor(max_workers=max_workers) as ex:
103
+ sim_qas_list = list(ex.map(lambda q: qa_retreiver([q], k=5), questions))
104
+ rel_list = list(ex.map(lambda q: batch_relationships([q], top_k=1)[0], questions))
105
+ return sim_qas_list, rel_list
106
+
107
+ def main():
108
+ OUT_DIR.mkdir(parents=True, exist_ok=True)
109
+ with open(QA_FILE,'r',encoding='utf-8') as f:
110
+ qas = json.load(f)
111
+ total = len(qas)
112
+ start_idx = read_checkpoint()
113
+ if start_idx >= total:
114
+ print("[INFO] Checkpoint beyond dataset length.")
115
+ return
116
+
117
+ prompts_accum=[]
118
+ batch_count=start_idx//SAVE_BATCH_SIZE
119
+
120
+ for batch_start in tqdm(range(start_idx, total, EMBED_BATCH_SIZE)):
121
+ batch_end = min(batch_start + EMBED_BATCH_SIZE, total)
122
+ batch_items = qas[batch_start:batch_end]
123
+ questions = [it.get("question") or it.get("q") or it.get("Question") for it in batch_items]
124
+ orig_answers = [it.get("answer") or it.get("a") or it.get("Answer","") for it in batch_items]
125
+
126
+ # --- retrieve contexts ---
127
+ contexts = retrieve_contexts(questions, top_k=3)
128
+ # --- QA & relationships ---
129
+ sim_qas_list, rel_list = retrieve_qas_and_rels(questions)
130
+
131
+ for i,q in enumerate(questions):
132
+ if not q:
133
+ write_checkpoint(batch_start+i+1)
134
+ continue
135
+ ctx_texts, ctx_metas = contexts[i]
136
+ context1, context2, context3 = ctx_texts
137
+ meta1 = ctx_metas[0]
138
+ prompt_text = build_prompt([context2, context3], rel_list[i], sim_qas_list[i], q)
139
+ gold = append_metadata_at_end(orig_answers[i], context1, meta1)
140
+
141
+ obj={
142
+ "id": batch_start+i,
143
+ "question": q,
144
+ "prompt": prompt_text,
145
+ "gold_answer": gold,
146
+ "context1": context1,
147
+ "retrieved_qas": sim_qas_list[i],
148
+ "relation_text": rel_list[i]
149
+ }
150
+ prompts_accum.append(obj)
151
+
152
+ # --- Save batch ---
153
+ if len(prompts_accum)>=SAVE_BATCH_SIZE:
154
+ out_path = OUT_DIR/f"prompts_batch_{batch_count:03d}.json"
155
+ out_path.write_text(json.dumps(prompts_accum, ensure_ascii=False, indent=2),encoding='utf-8')
156
+ batch_count+=1
157
+ prompts_accum=[]
158
+
159
+ write_checkpoint(batch_start+i+1)
160
+
161
+ # save remaining
162
+ if prompts_accum:
163
+ out_path = OUT_DIR/f"prompts_batch_{batch_count:03d}.json"
164
+ out_path.write_text(json.dumps(prompts_accum, ensure_ascii=False, indent=2))
165
+
166
+ OUT_DIR.joinpath("special_tokens_used.txt").write_text("\n".join(STRUCTURAL_TOKENS))
167
+ print("[DONE] All prompts processed.")
168
+
169
+ if __name__=="__main__":
170
+ main()
qa_retreiver.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # qa_retriever.py
2
+ import os, pickle, faiss
3
+ from sentence_transformers import SentenceTransformer
4
+ from typing import List, Dict, Any, Optional
5
+
6
+ MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
7
+ CLEAN_JSON = "qa_pairs/asoiaf_qa_clean.json"
8
+ INDEX_FILE = "qa_pairs/faiss_index.index"
9
+ QA_DATA_FILE = "qa_pairs/qa_data.pkl"
10
+
11
+ EMBED_MODEL: Optional[SentenceTransformer] = None
12
+ INDEX = None
13
+ QA_PAIRS: List[Dict[str, Any]] = []
14
+
15
+ def _load_embed_model():
16
+ global EMBED_MODEL
17
+ if EMBED_MODEL is None:
18
+ EMBED_MODEL = SentenceTransformer(MODEL_NAME,
19
+ device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu")
20
+ return EMBED_MODEL
21
+
22
+ def build_or_load_index():
23
+ global INDEX, QA_PAIRS
24
+ if INDEX and QA_PAIRS: return INDEX, QA_PAIRS, EMBED_MODEL
25
+
26
+ INDEX = faiss.read_index(INDEX_FILE)
27
+ with open(QA_DATA_FILE, "rb") as f:
28
+ QA_PAIRS = pickle.load(f)
29
+ _load_embed_model()
30
+ return INDEX, QA_PAIRS, EMBED_MODEL
31
+ def search_topk(query: str, index=None, qa_pairs=None, model=None, k: int = 5):
32
+ """
33
+ Returns up to `k` similar Q&A entries as list of dicts.
34
+ Handles single query string input safely.
35
+ """
36
+ if not isinstance(query, list):
37
+ query_list = [query]
38
+ else:
39
+ query_list = query
40
+
41
+ if model is None:
42
+ model = _load_embed_model()
43
+ if index is None or qa_pairs is None:
44
+ index, qa_pairs, model = build_or_load_index()
45
+
46
+ q_vecs = model.encode(query_list, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False).astype("float32")
47
+
48
+ results = []
49
+ for q_vec in q_vecs:
50
+ scores, indices = index.search(q_vec[None, :], k*3)
51
+ seen = set()
52
+ q_results = []
53
+ for score, idx in zip(scores[0], indices[0]):
54
+ if len(q_results) >= k: break
55
+ if idx < 0 or idx >= len(qa_pairs): continue
56
+ q_text = qa_pairs[idx].get("question", "")
57
+ if q_text in seen: continue
58
+ seen.add(q_text)
59
+ raw_ans = qa_pairs[idx].get("answer", "")
60
+ clean_ans = raw_ans.split("\n\nReference:")[0].strip()
61
+ q_results.append({
62
+ "similarity": float(score),
63
+ "question": q_text,
64
+ "answer": clean_ans
65
+ })
66
+ results.append(q_results)
67
+
68
+ return results[0] if len(results) == 1 else results
relationships_retreiver.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # relationship_retriever.py
2
+ import os, pickle, logging
3
+ import faiss
4
+ from sentence_transformers import SentenceTransformer
5
+
6
+ logging.basicConfig(level=logging.INFO)
7
+ logger = logging.getLogger(__name__)
8
+
9
+ RELATIONS = "relations"
10
+ REL_INDEX = f"{RELATIONS}/got_rels.faiss"
11
+ REL_DATA = f"{RELATIONS}/got_rels_meta.pkl"
12
+
13
+ logger.info("Loading relationship FAISS index...")
14
+ rel_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2",
15
+ device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu")
16
+ rel_index = faiss.read_index(REL_INDEX)
17
+ with open(REL_DATA, "rb") as f:
18
+ rel_data = pickle.load(f)
19
+ name_map = rel_data["name_map"]
20
+
21
+ def batch_relationships(questions, top_k=3):
22
+ batch_results = []
23
+ for q in questions:
24
+ q_upper = q.upper()
25
+ candidates = []
26
+ for variant in name_map.keys():
27
+ if len(variant) < 3: continue
28
+ if variant in q_upper or variant.replace(" ","") in q_upper.replace(" ",""):
29
+ candidates.append(name_map[variant])
30
+ candidates = list(dict.fromkeys(candidates))[:2]
31
+ if not candidates:
32
+ batch_results.append(["No known character relationships found"])
33
+ continue
34
+
35
+ query = f"Relationships of {' and '.join(candidates)} in Game of Thrones books"
36
+ q_vec = rel_model.encode([query], normalize_embeddings=True, show_progress_bar=False).astype("float32")
37
+ D, I = rel_index.search(q_vec, top_k*2)
38
+ results = []
39
+ seen = set()
40
+ for idx in I[0]:
41
+ if idx == -1: continue
42
+ sent = rel_data["sentences"][idx]
43
+ char = rel_data["metadata"][idx]["display_name"]
44
+ if char not in seen:
45
+ results.append(sent)
46
+ seen.add(char)
47
+ if len(results) >= top_k: break
48
+ batch_results.append(results if results else ["No confirmed relationships found"])
49
+ return batch_results