Upload 5 files
Browse files- context_retreiver.py +75 -0
- full_rag.zip +3 -0
- prompter.py +170 -0
- qa_retreiver.py +68 -0
- relationships_retreiver.py +49 -0
context_retreiver.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# context_retriever.py
|
| 2 |
+
import os, re, json, pickle, logging, numpy as np, faiss
|
| 3 |
+
from tqdm.notebook import tqdm
|
| 4 |
+
from sentence_transformers import SentenceTransformer
|
| 5 |
+
from langchain_community.retrievers import BM25Retriever
|
| 6 |
+
from langchain.docstore.document import Document
|
| 7 |
+
|
| 8 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
WORK = "context"
|
| 12 |
+
JSONL = f"{WORK}/rag_documents.jsonl"
|
| 13 |
+
FAISS_INDEX = f"{WORK}/faiss_ivf.index"
|
| 14 |
+
BM25_PICKLE = f"{WORK}/bm25_retriever.pkl"
|
| 15 |
+
|
| 16 |
+
logger.info("Loading all RAG documents...")
|
| 17 |
+
with open(JSONL, encoding='utf-8') as f:
|
| 18 |
+
ALL_DOCS = [json.loads(line) for line in f]
|
| 19 |
+
|
| 20 |
+
LINE_TO_TEXT = {i: doc["text"] for i, doc in enumerate(ALL_DOCS)}
|
| 21 |
+
LINE_TO_META = {i: doc["metadata"] for i, doc in enumerate(ALL_DOCS)}
|
| 22 |
+
|
| 23 |
+
class HybridRetriever:
|
| 24 |
+
def __init__(self):
|
| 25 |
+
# FAISS CPU
|
| 26 |
+
self.faiss_index = faiss.read_index(FAISS_INDEX)
|
| 27 |
+
logger.info(f"FAISS loaded ({self.faiss_index.ntotal:,} vectors)")
|
| 28 |
+
|
| 29 |
+
# SentenceTransformer (GPU if available)
|
| 30 |
+
self.model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2",
|
| 31 |
+
device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu")
|
| 32 |
+
|
| 33 |
+
# BM25
|
| 34 |
+
if os.path.exists(BM25_PICKLE):
|
| 35 |
+
self.bm25 = pickle.load(open(BM25_PICKLE, "rb"))
|
| 36 |
+
logger.info("BM25 loaded")
|
| 37 |
+
else:
|
| 38 |
+
logger.info("Building BM25...")
|
| 39 |
+
docs = [Document(page_content=re.sub(r"^Filename:.*\nFullPath:.*\n\n", "",
|
| 40 |
+
doc["text"], flags=re.M),
|
| 41 |
+
metadata=doc["metadata"]) for doc in ALL_DOCS]
|
| 42 |
+
self.bm25 = BM25Retriever.from_documents(docs)
|
| 43 |
+
self.bm25.k = 30
|
| 44 |
+
pickle.dump(self.bm25, open(BM25_PICKLE, "wb"))
|
| 45 |
+
logger.info("BM25 built and saved")
|
| 46 |
+
|
| 47 |
+
def batch_retrieve(self, queries, top_k=3, faiss_k=10, bm25_k=3):
|
| 48 |
+
qvecs = self.model.encode(queries, show_progress_bar=False, normalize_embeddings=True).astype("float32")
|
| 49 |
+
D, I = self.faiss_index.search(qvecs, faiss_k)
|
| 50 |
+
|
| 51 |
+
batch_results = []
|
| 52 |
+
for qi, (scores, indices) in enumerate(zip(D, I)):
|
| 53 |
+
results = []
|
| 54 |
+
seen = set()
|
| 55 |
+
for score, idx in zip(scores, indices):
|
| 56 |
+
if idx == -1 or idx in seen: continue
|
| 57 |
+
results.append({"score": float(score), "text": LINE_TO_TEXT[idx],
|
| 58 |
+
"metadata": LINE_TO_META[idx], "source": "FAISS"})
|
| 59 |
+
seen.add(idx)
|
| 60 |
+
if len(results) >= top_k: break
|
| 61 |
+
|
| 62 |
+
# BM25
|
| 63 |
+
bm25_docs = self.bm25.invoke(queries[qi])
|
| 64 |
+
for doc in bm25_docs[:bm25_k]:
|
| 65 |
+
ln = doc.metadata.get("line_no")
|
| 66 |
+
if ln in seen: continue
|
| 67 |
+
results.append({"score": 0.0, "text": LINE_TO_TEXT.get(ln, ""),
|
| 68 |
+
"metadata": LINE_TO_META.get(ln, doc.metadata), "source": "BM25"})
|
| 69 |
+
seen.add(ln)
|
| 70 |
+
if len(results) >= top_k: break
|
| 71 |
+
batch_results.append(results)
|
| 72 |
+
return batch_results
|
| 73 |
+
|
| 74 |
+
# Singleton retriever
|
| 75 |
+
retriever = HybridRetriever()
|
full_rag.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e9c92278e3df812534acaa211928b76a888453c81cfbe6b70bdea5d5cb330c61
|
| 3 |
+
size 1597083267
|
prompter.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
generate_prompts_v8_batch_fixed.py
|
| 4 |
+
|
| 5 |
+
- Uses batch retrieval for Context, QA, and Relationships
|
| 6 |
+
- Saves in batches with checkpointing
|
| 7 |
+
- Pads contexts and QA to fixed sizes
|
| 8 |
+
- Appends metadata at the end
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os, json, torch, numpy as np
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from sentence_transformers import SentenceTransformer
|
| 15 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 16 |
+
|
| 17 |
+
from context_retreiver import retriever as context_retriever
|
| 18 |
+
from qa_retreiver import search_topk as qa_retreiver
|
| 19 |
+
from relationships_retreiver import batch_relationships
|
| 20 |
+
|
| 21 |
+
QA_FILE = Path("got_all_qa_final.json")
|
| 22 |
+
OUT_DIR = Path("prompts_out")
|
| 23 |
+
CHECKPOINT_FILE = OUT_DIR / "checkpoint.json"
|
| 24 |
+
SAVE_BATCH_SIZE = 512
|
| 25 |
+
EMBED_BATCH_SIZE = 32 # GPU batch size
|
| 26 |
+
|
| 27 |
+
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 28 |
+
print(f"[INFO] Using device: {DEVICE}")
|
| 29 |
+
|
| 30 |
+
EMBED_MODEL = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=DEVICE)
|
| 31 |
+
|
| 32 |
+
STRUCTURAL_TOKENS = [
|
| 33 |
+
"<|CTX_QA|>", "<|/CTX_QA|>",
|
| 34 |
+
"<|CTX_REL|>", "<|/CTX_REL|>",
|
| 35 |
+
"<|INSTR|>", "<|/INSTR|>",
|
| 36 |
+
"<|QUESTION|>", "<|/QUESTION|>",
|
| 37 |
+
"<|ANSWER|>", "<|/ANSWER|>",
|
| 38 |
+
"<|QA_SIM_1|>", "<|/QA_SIM_1|>",
|
| 39 |
+
"<|QA_SIM_2|>", "<|/QA_SIM_2|>",
|
| 40 |
+
"<|QA_SIM_3|>", "<|/QA_SIM_3|>",
|
| 41 |
+
"<|QA_SIM_4|>", "<|/QA_SIM_4|>",
|
| 42 |
+
"<|QA_SIM_5|>", "<|/QA_SIM_5|>"
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
def read_checkpoint():
|
| 46 |
+
if CHECKPOINT_FILE.exists():
|
| 47 |
+
try:
|
| 48 |
+
return int(json.loads(CHECKPOINT_FILE.read_text())["next_index"])
|
| 49 |
+
except:
|
| 50 |
+
return 0
|
| 51 |
+
return 0
|
| 52 |
+
|
| 53 |
+
def write_checkpoint(idx):
|
| 54 |
+
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 55 |
+
CHECKPOINT_FILE.write_text(json.dumps({"next_index": idx}))
|
| 56 |
+
|
| 57 |
+
def metadata_to_str(meta):
|
| 58 |
+
if not meta: return ""
|
| 59 |
+
return "; ".join(f"{k}={v}" for k,v in meta.items() if isinstance(v,(str,int,float,bool)))
|
| 60 |
+
|
| 61 |
+
def append_metadata_at_end(answer, context1_text, context1_meta):
|
| 62 |
+
parts=[]
|
| 63 |
+
if answer: parts.append(answer.strip())
|
| 64 |
+
if context1_text: parts.append(f"[Context1: {context1_text.strip()}]")
|
| 65 |
+
meta_str = metadata_to_str(context1_meta)
|
| 66 |
+
if meta_str: parts.append(f"(meta: {meta_str})")
|
| 67 |
+
return " ".join(parts)
|
| 68 |
+
|
| 69 |
+
def build_prompt(ctx_texts, rel_text, sim_qas, question):
|
| 70 |
+
parts=[]
|
| 71 |
+
# ctx_texts = [ctx2, ctx3]
|
| 72 |
+
for ctx in ctx_texts:
|
| 73 |
+
if ctx: parts.append(f"<|CTX_QA|> {ctx} <|/CTX_QA|>")
|
| 74 |
+
if rel_text: parts.append(f"<|CTX_REL|> {rel_text} <|/CTX_REL|>")
|
| 75 |
+
for i in range(5):
|
| 76 |
+
if i < len(sim_qas):
|
| 77 |
+
qa = sim_qas[i]
|
| 78 |
+
parts.append(f"<|QA_SIM_{i+1}|> Q: {qa['question']} A: {qa['answer']} <|/QA_SIM_{i+1}|>")
|
| 79 |
+
else:
|
| 80 |
+
parts.append(f"<|QA_SIM_{i+1}|> <|/QA_SIM_{i+1}|>")
|
| 81 |
+
parts.append("<|INSTR|> Use above contexts to answer concisely. <|/INSTR|>")
|
| 82 |
+
parts.append(f"<|QUESTION|> {question} <|/QUESTION|>")
|
| 83 |
+
parts.append("<|ANSWER|>")
|
| 84 |
+
return "\n\n".join(parts)
|
| 85 |
+
|
| 86 |
+
def retrieve_contexts(questions, top_k=3):
|
| 87 |
+
"""Batch retrieve context texts + metadata"""
|
| 88 |
+
batch_res = context_retriever.batch_retrieve(questions, top_k=top_k)
|
| 89 |
+
contexts=[]
|
| 90 |
+
for res_list in batch_res:
|
| 91 |
+
ctx_texts = [r["text"] for r in res_list[:top_k]]
|
| 92 |
+
ctx_metas = [r["metadata"] for r in res_list[:top_k]]
|
| 93 |
+
# pad to top_k
|
| 94 |
+
while len(ctx_texts)<top_k: ctx_texts.append(""); ctx_metas.append({})
|
| 95 |
+
contexts.append((ctx_texts, ctx_metas))
|
| 96 |
+
return contexts
|
| 97 |
+
|
| 98 |
+
def retrieve_qas_and_rels(questions, max_workers=20):
|
| 99 |
+
"""Threaded retrieval of QA and relationships"""
|
| 100 |
+
sim_qas_list=[]
|
| 101 |
+
rel_list=[]
|
| 102 |
+
with ThreadPoolExecutor(max_workers=max_workers) as ex:
|
| 103 |
+
sim_qas_list = list(ex.map(lambda q: qa_retreiver([q], k=5), questions))
|
| 104 |
+
rel_list = list(ex.map(lambda q: batch_relationships([q], top_k=1)[0], questions))
|
| 105 |
+
return sim_qas_list, rel_list
|
| 106 |
+
|
| 107 |
+
def main():
|
| 108 |
+
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 109 |
+
with open(QA_FILE,'r',encoding='utf-8') as f:
|
| 110 |
+
qas = json.load(f)
|
| 111 |
+
total = len(qas)
|
| 112 |
+
start_idx = read_checkpoint()
|
| 113 |
+
if start_idx >= total:
|
| 114 |
+
print("[INFO] Checkpoint beyond dataset length.")
|
| 115 |
+
return
|
| 116 |
+
|
| 117 |
+
prompts_accum=[]
|
| 118 |
+
batch_count=start_idx//SAVE_BATCH_SIZE
|
| 119 |
+
|
| 120 |
+
for batch_start in tqdm(range(start_idx, total, EMBED_BATCH_SIZE)):
|
| 121 |
+
batch_end = min(batch_start + EMBED_BATCH_SIZE, total)
|
| 122 |
+
batch_items = qas[batch_start:batch_end]
|
| 123 |
+
questions = [it.get("question") or it.get("q") or it.get("Question") for it in batch_items]
|
| 124 |
+
orig_answers = [it.get("answer") or it.get("a") or it.get("Answer","") for it in batch_items]
|
| 125 |
+
|
| 126 |
+
# --- retrieve contexts ---
|
| 127 |
+
contexts = retrieve_contexts(questions, top_k=3)
|
| 128 |
+
# --- QA & relationships ---
|
| 129 |
+
sim_qas_list, rel_list = retrieve_qas_and_rels(questions)
|
| 130 |
+
|
| 131 |
+
for i,q in enumerate(questions):
|
| 132 |
+
if not q:
|
| 133 |
+
write_checkpoint(batch_start+i+1)
|
| 134 |
+
continue
|
| 135 |
+
ctx_texts, ctx_metas = contexts[i]
|
| 136 |
+
context1, context2, context3 = ctx_texts
|
| 137 |
+
meta1 = ctx_metas[0]
|
| 138 |
+
prompt_text = build_prompt([context2, context3], rel_list[i], sim_qas_list[i], q)
|
| 139 |
+
gold = append_metadata_at_end(orig_answers[i], context1, meta1)
|
| 140 |
+
|
| 141 |
+
obj={
|
| 142 |
+
"id": batch_start+i,
|
| 143 |
+
"question": q,
|
| 144 |
+
"prompt": prompt_text,
|
| 145 |
+
"gold_answer": gold,
|
| 146 |
+
"context1": context1,
|
| 147 |
+
"retrieved_qas": sim_qas_list[i],
|
| 148 |
+
"relation_text": rel_list[i]
|
| 149 |
+
}
|
| 150 |
+
prompts_accum.append(obj)
|
| 151 |
+
|
| 152 |
+
# --- Save batch ---
|
| 153 |
+
if len(prompts_accum)>=SAVE_BATCH_SIZE:
|
| 154 |
+
out_path = OUT_DIR/f"prompts_batch_{batch_count:03d}.json"
|
| 155 |
+
out_path.write_text(json.dumps(prompts_accum, ensure_ascii=False, indent=2),encoding='utf-8')
|
| 156 |
+
batch_count+=1
|
| 157 |
+
prompts_accum=[]
|
| 158 |
+
|
| 159 |
+
write_checkpoint(batch_start+i+1)
|
| 160 |
+
|
| 161 |
+
# save remaining
|
| 162 |
+
if prompts_accum:
|
| 163 |
+
out_path = OUT_DIR/f"prompts_batch_{batch_count:03d}.json"
|
| 164 |
+
out_path.write_text(json.dumps(prompts_accum, ensure_ascii=False, indent=2))
|
| 165 |
+
|
| 166 |
+
OUT_DIR.joinpath("special_tokens_used.txt").write_text("\n".join(STRUCTURAL_TOKENS))
|
| 167 |
+
print("[DONE] All prompts processed.")
|
| 168 |
+
|
| 169 |
+
if __name__=="__main__":
|
| 170 |
+
main()
|
qa_retreiver.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# qa_retriever.py
|
| 2 |
+
import os, pickle, faiss
|
| 3 |
+
from sentence_transformers import SentenceTransformer
|
| 4 |
+
from typing import List, Dict, Any, Optional
|
| 5 |
+
|
| 6 |
+
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
| 7 |
+
CLEAN_JSON = "qa_pairs/asoiaf_qa_clean.json"
|
| 8 |
+
INDEX_FILE = "qa_pairs/faiss_index.index"
|
| 9 |
+
QA_DATA_FILE = "qa_pairs/qa_data.pkl"
|
| 10 |
+
|
| 11 |
+
EMBED_MODEL: Optional[SentenceTransformer] = None
|
| 12 |
+
INDEX = None
|
| 13 |
+
QA_PAIRS: List[Dict[str, Any]] = []
|
| 14 |
+
|
| 15 |
+
def _load_embed_model():
|
| 16 |
+
global EMBED_MODEL
|
| 17 |
+
if EMBED_MODEL is None:
|
| 18 |
+
EMBED_MODEL = SentenceTransformer(MODEL_NAME,
|
| 19 |
+
device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu")
|
| 20 |
+
return EMBED_MODEL
|
| 21 |
+
|
| 22 |
+
def build_or_load_index():
|
| 23 |
+
global INDEX, QA_PAIRS
|
| 24 |
+
if INDEX and QA_PAIRS: return INDEX, QA_PAIRS, EMBED_MODEL
|
| 25 |
+
|
| 26 |
+
INDEX = faiss.read_index(INDEX_FILE)
|
| 27 |
+
with open(QA_DATA_FILE, "rb") as f:
|
| 28 |
+
QA_PAIRS = pickle.load(f)
|
| 29 |
+
_load_embed_model()
|
| 30 |
+
return INDEX, QA_PAIRS, EMBED_MODEL
|
| 31 |
+
def search_topk(query: str, index=None, qa_pairs=None, model=None, k: int = 5):
|
| 32 |
+
"""
|
| 33 |
+
Returns up to `k` similar Q&A entries as list of dicts.
|
| 34 |
+
Handles single query string input safely.
|
| 35 |
+
"""
|
| 36 |
+
if not isinstance(query, list):
|
| 37 |
+
query_list = [query]
|
| 38 |
+
else:
|
| 39 |
+
query_list = query
|
| 40 |
+
|
| 41 |
+
if model is None:
|
| 42 |
+
model = _load_embed_model()
|
| 43 |
+
if index is None or qa_pairs is None:
|
| 44 |
+
index, qa_pairs, model = build_or_load_index()
|
| 45 |
+
|
| 46 |
+
q_vecs = model.encode(query_list, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False).astype("float32")
|
| 47 |
+
|
| 48 |
+
results = []
|
| 49 |
+
for q_vec in q_vecs:
|
| 50 |
+
scores, indices = index.search(q_vec[None, :], k*3)
|
| 51 |
+
seen = set()
|
| 52 |
+
q_results = []
|
| 53 |
+
for score, idx in zip(scores[0], indices[0]):
|
| 54 |
+
if len(q_results) >= k: break
|
| 55 |
+
if idx < 0 or idx >= len(qa_pairs): continue
|
| 56 |
+
q_text = qa_pairs[idx].get("question", "")
|
| 57 |
+
if q_text in seen: continue
|
| 58 |
+
seen.add(q_text)
|
| 59 |
+
raw_ans = qa_pairs[idx].get("answer", "")
|
| 60 |
+
clean_ans = raw_ans.split("\n\nReference:")[0].strip()
|
| 61 |
+
q_results.append({
|
| 62 |
+
"similarity": float(score),
|
| 63 |
+
"question": q_text,
|
| 64 |
+
"answer": clean_ans
|
| 65 |
+
})
|
| 66 |
+
results.append(q_results)
|
| 67 |
+
|
| 68 |
+
return results[0] if len(results) == 1 else results
|
relationships_retreiver.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# relationship_retriever.py
|
| 2 |
+
import os, pickle, logging
|
| 3 |
+
import faiss
|
| 4 |
+
from sentence_transformers import SentenceTransformer
|
| 5 |
+
|
| 6 |
+
logging.basicConfig(level=logging.INFO)
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
RELATIONS = "relations"
|
| 10 |
+
REL_INDEX = f"{RELATIONS}/got_rels.faiss"
|
| 11 |
+
REL_DATA = f"{RELATIONS}/got_rels_meta.pkl"
|
| 12 |
+
|
| 13 |
+
logger.info("Loading relationship FAISS index...")
|
| 14 |
+
rel_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2",
|
| 15 |
+
device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu")
|
| 16 |
+
rel_index = faiss.read_index(REL_INDEX)
|
| 17 |
+
with open(REL_DATA, "rb") as f:
|
| 18 |
+
rel_data = pickle.load(f)
|
| 19 |
+
name_map = rel_data["name_map"]
|
| 20 |
+
|
| 21 |
+
def batch_relationships(questions, top_k=3):
|
| 22 |
+
batch_results = []
|
| 23 |
+
for q in questions:
|
| 24 |
+
q_upper = q.upper()
|
| 25 |
+
candidates = []
|
| 26 |
+
for variant in name_map.keys():
|
| 27 |
+
if len(variant) < 3: continue
|
| 28 |
+
if variant in q_upper or variant.replace(" ","") in q_upper.replace(" ",""):
|
| 29 |
+
candidates.append(name_map[variant])
|
| 30 |
+
candidates = list(dict.fromkeys(candidates))[:2]
|
| 31 |
+
if not candidates:
|
| 32 |
+
batch_results.append(["No known character relationships found"])
|
| 33 |
+
continue
|
| 34 |
+
|
| 35 |
+
query = f"Relationships of {' and '.join(candidates)} in Game of Thrones books"
|
| 36 |
+
q_vec = rel_model.encode([query], normalize_embeddings=True, show_progress_bar=False).astype("float32")
|
| 37 |
+
D, I = rel_index.search(q_vec, top_k*2)
|
| 38 |
+
results = []
|
| 39 |
+
seen = set()
|
| 40 |
+
for idx in I[0]:
|
| 41 |
+
if idx == -1: continue
|
| 42 |
+
sent = rel_data["sentences"][idx]
|
| 43 |
+
char = rel_data["metadata"][idx]["display_name"]
|
| 44 |
+
if char not in seen:
|
| 45 |
+
results.append(sent)
|
| 46 |
+
seen.add(char)
|
| 47 |
+
if len(results) >= top_k: break
|
| 48 |
+
batch_results.append(results if results else ["No confirmed relationships found"])
|
| 49 |
+
return batch_results
|