#!/usr/bin/env python3 """ generate_prompts_v8_batch_fixed.py - Uses batch retrieval for Context, QA, and Relationships - Saves in batches with checkpointing - Pads contexts and QA to fixed sizes - Appends metadata at the end """ import os, json, torch, numpy as np from pathlib import Path from tqdm import tqdm from sentence_transformers import SentenceTransformer from concurrent.futures import ThreadPoolExecutor from context_retreiver import retriever as context_retriever from qa_retreiver import search_topk as qa_retreiver from relationships_retreiver import batch_relationships QA_FILE = Path("got_all_qa_final.json") OUT_DIR = Path("prompts_out") CHECKPOINT_FILE = OUT_DIR / "checkpoint.json" SAVE_BATCH_SIZE = 512 EMBED_BATCH_SIZE = 32 # GPU batch size DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" print(f"[INFO] Using device: {DEVICE}") EMBED_MODEL = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=DEVICE) STRUCTURAL_TOKENS = [ "<|CTX_QA|>", "<|/CTX_QA|>", "<|CTX_REL|>", "<|/CTX_REL|>", "<|INSTR|>", "<|/INSTR|>", "<|QUESTION|>", "<|/QUESTION|>", "<|ANSWER|>", "<|/ANSWER|>", "<|QA_SIM_1|>", "<|/QA_SIM_1|>", "<|QA_SIM_2|>", "<|/QA_SIM_2|>", "<|QA_SIM_3|>", "<|/QA_SIM_3|>", "<|QA_SIM_4|>", "<|/QA_SIM_4|>", "<|QA_SIM_5|>", "<|/QA_SIM_5|>" ] def read_checkpoint(): if CHECKPOINT_FILE.exists(): try: return int(json.loads(CHECKPOINT_FILE.read_text())["next_index"]) except: return 0 return 0 def write_checkpoint(idx): OUT_DIR.mkdir(parents=True, exist_ok=True) CHECKPOINT_FILE.write_text(json.dumps({"next_index": idx})) def metadata_to_str(meta): if not meta: return "" return "; ".join(f"{k}={v}" for k,v in meta.items() if isinstance(v,(str,int,float,bool))) def append_metadata_at_end(answer, context1_text, context1_meta): parts=[] if answer: parts.append(answer.strip()) if context1_text: parts.append(f"[Context1: {context1_text.strip()}]") meta_str = metadata_to_str(context1_meta) if meta_str: parts.append(f"(meta: {meta_str})") return " ".join(parts) def build_prompt(ctx_texts, rel_text, sim_qas, question): parts=[] # ctx_texts = [ctx2, ctx3] for ctx in ctx_texts: if ctx: parts.append(f"<|CTX_QA|> {ctx} <|/CTX_QA|>") if rel_text: parts.append(f"<|CTX_REL|> {rel_text} <|/CTX_REL|>") for i in range(5): if i < len(sim_qas): qa = sim_qas[i] parts.append(f"<|QA_SIM_{i+1}|> Q: {qa['question']} A: {qa['answer']} <|/QA_SIM_{i+1}|>") else: parts.append(f"<|QA_SIM_{i+1}|> <|/QA_SIM_{i+1}|>") parts.append("<|INSTR|> Use above contexts to answer concisely. <|/INSTR|>") parts.append(f"<|QUESTION|> {question} <|/QUESTION|>") parts.append("<|ANSWER|>") return "\n\n".join(parts) def retrieve_contexts(questions, top_k=3): """Batch retrieve context texts + metadata""" batch_res = context_retriever.batch_retrieve(questions, top_k=top_k) contexts=[] for res_list in batch_res: ctx_texts = [r["text"] for r in res_list[:top_k]] ctx_metas = [r["metadata"] for r in res_list[:top_k]] # pad to top_k while len(ctx_texts)= total: print("[INFO] Checkpoint beyond dataset length.") return prompts_accum=[] batch_count=start_idx//SAVE_BATCH_SIZE for batch_start in tqdm(range(start_idx, total, EMBED_BATCH_SIZE)): batch_end = min(batch_start + EMBED_BATCH_SIZE, total) batch_items = qas[batch_start:batch_end] questions = [it.get("question") or it.get("q") or it.get("Question") for it in batch_items] orig_answers = [it.get("answer") or it.get("a") or it.get("Answer","") for it in batch_items] # --- retrieve contexts --- contexts = retrieve_contexts(questions, top_k=3) # --- QA & relationships --- sim_qas_list, rel_list = retrieve_qas_and_rels(questions) for i,q in enumerate(questions): if not q: write_checkpoint(batch_start+i+1) continue ctx_texts, ctx_metas = contexts[i] context1, context2, context3 = ctx_texts meta1 = ctx_metas[0] prompt_text = build_prompt([context2, context3], rel_list[i], sim_qas_list[i], q) gold = append_metadata_at_end(orig_answers[i], context1, meta1) obj={ "id": batch_start+i, "question": q, "prompt": prompt_text, "gold_answer": gold, "context1": context1, "retrieved_qas": sim_qas_list[i], "relation_text": rel_list[i] } prompts_accum.append(obj) # --- Save batch --- if len(prompts_accum)>=SAVE_BATCH_SIZE: out_path = OUT_DIR/f"prompts_batch_{batch_count:03d}.json" out_path.write_text(json.dumps(prompts_accum, ensure_ascii=False, indent=2),encoding='utf-8') batch_count+=1 prompts_accum=[] write_checkpoint(batch_start+i+1) # save remaining if prompts_accum: out_path = OUT_DIR/f"prompts_batch_{batch_count:03d}.json" out_path.write_text(json.dumps(prompts_accum, ensure_ascii=False, indent=2)) OUT_DIR.joinpath("special_tokens_used.txt").write_text("\n".join(STRUCTURAL_TOKENS)) print("[DONE] All prompts processed.") if __name__=="__main__": main()