got_retreivers / prompter.py
hash-map's picture
Upload 5 files
dff5c6e verified
#!/usr/bin/env python3
"""
generate_prompts_v8_batch_fixed.py
- Uses batch retrieval for Context, QA, and Relationships
- Saves in batches with checkpointing
- Pads contexts and QA to fixed sizes
- Appends metadata at the end
"""
import os, json, torch, numpy as np
from pathlib import Path
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from concurrent.futures import ThreadPoolExecutor
from context_retreiver import retriever as context_retriever
from qa_retreiver import search_topk as qa_retreiver
from relationships_retreiver import batch_relationships
QA_FILE = Path("got_all_qa_final.json")
OUT_DIR = Path("prompts_out")
CHECKPOINT_FILE = OUT_DIR / "checkpoint.json"
SAVE_BATCH_SIZE = 512
EMBED_BATCH_SIZE = 32 # GPU batch size
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"[INFO] Using device: {DEVICE}")
EMBED_MODEL = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=DEVICE)
STRUCTURAL_TOKENS = [
"<|CTX_QA|>", "<|/CTX_QA|>",
"<|CTX_REL|>", "<|/CTX_REL|>",
"<|INSTR|>", "<|/INSTR|>",
"<|QUESTION|>", "<|/QUESTION|>",
"<|ANSWER|>", "<|/ANSWER|>",
"<|QA_SIM_1|>", "<|/QA_SIM_1|>",
"<|QA_SIM_2|>", "<|/QA_SIM_2|>",
"<|QA_SIM_3|>", "<|/QA_SIM_3|>",
"<|QA_SIM_4|>", "<|/QA_SIM_4|>",
"<|QA_SIM_5|>", "<|/QA_SIM_5|>"
]
def read_checkpoint():
if CHECKPOINT_FILE.exists():
try:
return int(json.loads(CHECKPOINT_FILE.read_text())["next_index"])
except:
return 0
return 0
def write_checkpoint(idx):
OUT_DIR.mkdir(parents=True, exist_ok=True)
CHECKPOINT_FILE.write_text(json.dumps({"next_index": idx}))
def metadata_to_str(meta):
if not meta: return ""
return "; ".join(f"{k}={v}" for k,v in meta.items() if isinstance(v,(str,int,float,bool)))
def append_metadata_at_end(answer, context1_text, context1_meta):
parts=[]
if answer: parts.append(answer.strip())
if context1_text: parts.append(f"[Context1: {context1_text.strip()}]")
meta_str = metadata_to_str(context1_meta)
if meta_str: parts.append(f"(meta: {meta_str})")
return " ".join(parts)
def build_prompt(ctx_texts, rel_text, sim_qas, question):
parts=[]
# ctx_texts = [ctx2, ctx3]
for ctx in ctx_texts:
if ctx: parts.append(f"<|CTX_QA|> {ctx} <|/CTX_QA|>")
if rel_text: parts.append(f"<|CTX_REL|> {rel_text} <|/CTX_REL|>")
for i in range(5):
if i < len(sim_qas):
qa = sim_qas[i]
parts.append(f"<|QA_SIM_{i+1}|> Q: {qa['question']} A: {qa['answer']} <|/QA_SIM_{i+1}|>")
else:
parts.append(f"<|QA_SIM_{i+1}|> <|/QA_SIM_{i+1}|>")
parts.append("<|INSTR|> Use above contexts to answer concisely. <|/INSTR|>")
parts.append(f"<|QUESTION|> {question} <|/QUESTION|>")
parts.append("<|ANSWER|>")
return "\n\n".join(parts)
def retrieve_contexts(questions, top_k=3):
"""Batch retrieve context texts + metadata"""
batch_res = context_retriever.batch_retrieve(questions, top_k=top_k)
contexts=[]
for res_list in batch_res:
ctx_texts = [r["text"] for r in res_list[:top_k]]
ctx_metas = [r["metadata"] for r in res_list[:top_k]]
# pad to top_k
while len(ctx_texts)<top_k: ctx_texts.append(""); ctx_metas.append({})
contexts.append((ctx_texts, ctx_metas))
return contexts
def retrieve_qas_and_rels(questions, max_workers=20):
"""Threaded retrieval of QA and relationships"""
sim_qas_list=[]
rel_list=[]
with ThreadPoolExecutor(max_workers=max_workers) as ex:
sim_qas_list = list(ex.map(lambda q: qa_retreiver([q], k=5), questions))
rel_list = list(ex.map(lambda q: batch_relationships([q], top_k=1)[0], questions))
return sim_qas_list, rel_list
def main():
OUT_DIR.mkdir(parents=True, exist_ok=True)
with open(QA_FILE,'r',encoding='utf-8') as f:
qas = json.load(f)
total = len(qas)
start_idx = read_checkpoint()
if start_idx >= total:
print("[INFO] Checkpoint beyond dataset length.")
return
prompts_accum=[]
batch_count=start_idx//SAVE_BATCH_SIZE
for batch_start in tqdm(range(start_idx, total, EMBED_BATCH_SIZE)):
batch_end = min(batch_start + EMBED_BATCH_SIZE, total)
batch_items = qas[batch_start:batch_end]
questions = [it.get("question") or it.get("q") or it.get("Question") for it in batch_items]
orig_answers = [it.get("answer") or it.get("a") or it.get("Answer","") for it in batch_items]
# --- retrieve contexts ---
contexts = retrieve_contexts(questions, top_k=3)
# --- QA & relationships ---
sim_qas_list, rel_list = retrieve_qas_and_rels(questions)
for i,q in enumerate(questions):
if not q:
write_checkpoint(batch_start+i+1)
continue
ctx_texts, ctx_metas = contexts[i]
context1, context2, context3 = ctx_texts
meta1 = ctx_metas[0]
prompt_text = build_prompt([context2, context3], rel_list[i], sim_qas_list[i], q)
gold = append_metadata_at_end(orig_answers[i], context1, meta1)
obj={
"id": batch_start+i,
"question": q,
"prompt": prompt_text,
"gold_answer": gold,
"context1": context1,
"retrieved_qas": sim_qas_list[i],
"relation_text": rel_list[i]
}
prompts_accum.append(obj)
# --- Save batch ---
if len(prompts_accum)>=SAVE_BATCH_SIZE:
out_path = OUT_DIR/f"prompts_batch_{batch_count:03d}.json"
out_path.write_text(json.dumps(prompts_accum, ensure_ascii=False, indent=2),encoding='utf-8')
batch_count+=1
prompts_accum=[]
write_checkpoint(batch_start+i+1)
# save remaining
if prompts_accum:
out_path = OUT_DIR/f"prompts_batch_{batch_count:03d}.json"
out_path.write_text(json.dumps(prompts_accum, ensure_ascii=False, indent=2))
OUT_DIR.joinpath("special_tokens_used.txt").write_text("\n".join(STRUCTURAL_TOKENS))
print("[DONE] All prompts processed.")
if __name__=="__main__":
main()