|
|
|
|
|
"""
|
|
|
generate_prompts_v8_batch_fixed.py
|
|
|
|
|
|
- Uses batch retrieval for Context, QA, and Relationships
|
|
|
- Saves in batches with checkpointing
|
|
|
- Pads contexts and QA to fixed sizes
|
|
|
- Appends metadata at the end
|
|
|
"""
|
|
|
|
|
|
import os, json, torch, numpy as np
|
|
|
from pathlib import Path
|
|
|
from tqdm import tqdm
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
|
|
from context_retreiver import retriever as context_retriever
|
|
|
from qa_retreiver import search_topk as qa_retreiver
|
|
|
from relationships_retreiver import batch_relationships
|
|
|
|
|
|
QA_FILE = Path("got_all_qa_final.json")
|
|
|
OUT_DIR = Path("prompts_out")
|
|
|
CHECKPOINT_FILE = OUT_DIR / "checkpoint.json"
|
|
|
SAVE_BATCH_SIZE = 512
|
|
|
EMBED_BATCH_SIZE = 32
|
|
|
|
|
|
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
|
print(f"[INFO] Using device: {DEVICE}")
|
|
|
|
|
|
EMBED_MODEL = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=DEVICE)
|
|
|
|
|
|
STRUCTURAL_TOKENS = [
|
|
|
"<|CTX_QA|>", "<|/CTX_QA|>",
|
|
|
"<|CTX_REL|>", "<|/CTX_REL|>",
|
|
|
"<|INSTR|>", "<|/INSTR|>",
|
|
|
"<|QUESTION|>", "<|/QUESTION|>",
|
|
|
"<|ANSWER|>", "<|/ANSWER|>",
|
|
|
"<|QA_SIM_1|>", "<|/QA_SIM_1|>",
|
|
|
"<|QA_SIM_2|>", "<|/QA_SIM_2|>",
|
|
|
"<|QA_SIM_3|>", "<|/QA_SIM_3|>",
|
|
|
"<|QA_SIM_4|>", "<|/QA_SIM_4|>",
|
|
|
"<|QA_SIM_5|>", "<|/QA_SIM_5|>"
|
|
|
]
|
|
|
|
|
|
def read_checkpoint():
|
|
|
if CHECKPOINT_FILE.exists():
|
|
|
try:
|
|
|
return int(json.loads(CHECKPOINT_FILE.read_text())["next_index"])
|
|
|
except:
|
|
|
return 0
|
|
|
return 0
|
|
|
|
|
|
def write_checkpoint(idx):
|
|
|
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
CHECKPOINT_FILE.write_text(json.dumps({"next_index": idx}))
|
|
|
|
|
|
def metadata_to_str(meta):
|
|
|
if not meta: return ""
|
|
|
return "; ".join(f"{k}={v}" for k,v in meta.items() if isinstance(v,(str,int,float,bool)))
|
|
|
|
|
|
def append_metadata_at_end(answer, context1_text, context1_meta):
|
|
|
parts=[]
|
|
|
if answer: parts.append(answer.strip())
|
|
|
if context1_text: parts.append(f"[Context1: {context1_text.strip()}]")
|
|
|
meta_str = metadata_to_str(context1_meta)
|
|
|
if meta_str: parts.append(f"(meta: {meta_str})")
|
|
|
return " ".join(parts)
|
|
|
|
|
|
def build_prompt(ctx_texts, rel_text, sim_qas, question):
|
|
|
parts=[]
|
|
|
|
|
|
for ctx in ctx_texts:
|
|
|
if ctx: parts.append(f"<|CTX_QA|> {ctx} <|/CTX_QA|>")
|
|
|
if rel_text: parts.append(f"<|CTX_REL|> {rel_text} <|/CTX_REL|>")
|
|
|
for i in range(5):
|
|
|
if i < len(sim_qas):
|
|
|
qa = sim_qas[i]
|
|
|
parts.append(f"<|QA_SIM_{i+1}|> Q: {qa['question']} A: {qa['answer']} <|/QA_SIM_{i+1}|>")
|
|
|
else:
|
|
|
parts.append(f"<|QA_SIM_{i+1}|> <|/QA_SIM_{i+1}|>")
|
|
|
parts.append("<|INSTR|> Use above contexts to answer concisely. <|/INSTR|>")
|
|
|
parts.append(f"<|QUESTION|> {question} <|/QUESTION|>")
|
|
|
parts.append("<|ANSWER|>")
|
|
|
return "\n\n".join(parts)
|
|
|
|
|
|
def retrieve_contexts(questions, top_k=3):
|
|
|
"""Batch retrieve context texts + metadata"""
|
|
|
batch_res = context_retriever.batch_retrieve(questions, top_k=top_k)
|
|
|
contexts=[]
|
|
|
for res_list in batch_res:
|
|
|
ctx_texts = [r["text"] for r in res_list[:top_k]]
|
|
|
ctx_metas = [r["metadata"] for r in res_list[:top_k]]
|
|
|
|
|
|
while len(ctx_texts)<top_k: ctx_texts.append(""); ctx_metas.append({})
|
|
|
contexts.append((ctx_texts, ctx_metas))
|
|
|
return contexts
|
|
|
|
|
|
def retrieve_qas_and_rels(questions, max_workers=20):
|
|
|
"""Threaded retrieval of QA and relationships"""
|
|
|
sim_qas_list=[]
|
|
|
rel_list=[]
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as ex:
|
|
|
sim_qas_list = list(ex.map(lambda q: qa_retreiver([q], k=5), questions))
|
|
|
rel_list = list(ex.map(lambda q: batch_relationships([q], top_k=1)[0], questions))
|
|
|
return sim_qas_list, rel_list
|
|
|
|
|
|
def main():
|
|
|
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
with open(QA_FILE,'r',encoding='utf-8') as f:
|
|
|
qas = json.load(f)
|
|
|
total = len(qas)
|
|
|
start_idx = read_checkpoint()
|
|
|
if start_idx >= total:
|
|
|
print("[INFO] Checkpoint beyond dataset length.")
|
|
|
return
|
|
|
|
|
|
prompts_accum=[]
|
|
|
batch_count=start_idx//SAVE_BATCH_SIZE
|
|
|
|
|
|
for batch_start in tqdm(range(start_idx, total, EMBED_BATCH_SIZE)):
|
|
|
batch_end = min(batch_start + EMBED_BATCH_SIZE, total)
|
|
|
batch_items = qas[batch_start:batch_end]
|
|
|
questions = [it.get("question") or it.get("q") or it.get("Question") for it in batch_items]
|
|
|
orig_answers = [it.get("answer") or it.get("a") or it.get("Answer","") for it in batch_items]
|
|
|
|
|
|
|
|
|
contexts = retrieve_contexts(questions, top_k=3)
|
|
|
|
|
|
sim_qas_list, rel_list = retrieve_qas_and_rels(questions)
|
|
|
|
|
|
for i,q in enumerate(questions):
|
|
|
if not q:
|
|
|
write_checkpoint(batch_start+i+1)
|
|
|
continue
|
|
|
ctx_texts, ctx_metas = contexts[i]
|
|
|
context1, context2, context3 = ctx_texts
|
|
|
meta1 = ctx_metas[0]
|
|
|
prompt_text = build_prompt([context2, context3], rel_list[i], sim_qas_list[i], q)
|
|
|
gold = append_metadata_at_end(orig_answers[i], context1, meta1)
|
|
|
|
|
|
obj={
|
|
|
"id": batch_start+i,
|
|
|
"question": q,
|
|
|
"prompt": prompt_text,
|
|
|
"gold_answer": gold,
|
|
|
"context1": context1,
|
|
|
"retrieved_qas": sim_qas_list[i],
|
|
|
"relation_text": rel_list[i]
|
|
|
}
|
|
|
prompts_accum.append(obj)
|
|
|
|
|
|
|
|
|
if len(prompts_accum)>=SAVE_BATCH_SIZE:
|
|
|
out_path = OUT_DIR/f"prompts_batch_{batch_count:03d}.json"
|
|
|
out_path.write_text(json.dumps(prompts_accum, ensure_ascii=False, indent=2),encoding='utf-8')
|
|
|
batch_count+=1
|
|
|
prompts_accum=[]
|
|
|
|
|
|
write_checkpoint(batch_start+i+1)
|
|
|
|
|
|
|
|
|
if prompts_accum:
|
|
|
out_path = OUT_DIR/f"prompts_batch_{batch_count:03d}.json"
|
|
|
out_path.write_text(json.dumps(prompts_accum, ensure_ascii=False, indent=2))
|
|
|
|
|
|
OUT_DIR.joinpath("special_tokens_used.txt").write_text("\n".join(STRUCTURAL_TOKENS))
|
|
|
print("[DONE] All prompts processed.")
|
|
|
|
|
|
if __name__=="__main__":
|
|
|
main()
|
|
|
|