File size: 6,439 Bytes
dff5c6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
#!/usr/bin/env python3
"""
generate_prompts_v8_batch_fixed.py
- Uses batch retrieval for Context, QA, and Relationships
- Saves in batches with checkpointing
- Pads contexts and QA to fixed sizes
- Appends metadata at the end
"""
import os, json, torch, numpy as np
from pathlib import Path
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from concurrent.futures import ThreadPoolExecutor
from context_retreiver import retriever as context_retriever
from qa_retreiver import search_topk as qa_retreiver
from relationships_retreiver import batch_relationships
QA_FILE = Path("got_all_qa_final.json")
OUT_DIR = Path("prompts_out")
CHECKPOINT_FILE = OUT_DIR / "checkpoint.json"
SAVE_BATCH_SIZE = 512
EMBED_BATCH_SIZE = 32 # GPU batch size
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"[INFO] Using device: {DEVICE}")
EMBED_MODEL = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=DEVICE)
STRUCTURAL_TOKENS = [
"<|CTX_QA|>", "<|/CTX_QA|>",
"<|CTX_REL|>", "<|/CTX_REL|>",
"<|INSTR|>", "<|/INSTR|>",
"<|QUESTION|>", "<|/QUESTION|>",
"<|ANSWER|>", "<|/ANSWER|>",
"<|QA_SIM_1|>", "<|/QA_SIM_1|>",
"<|QA_SIM_2|>", "<|/QA_SIM_2|>",
"<|QA_SIM_3|>", "<|/QA_SIM_3|>",
"<|QA_SIM_4|>", "<|/QA_SIM_4|>",
"<|QA_SIM_5|>", "<|/QA_SIM_5|>"
]
def read_checkpoint():
if CHECKPOINT_FILE.exists():
try:
return int(json.loads(CHECKPOINT_FILE.read_text())["next_index"])
except:
return 0
return 0
def write_checkpoint(idx):
OUT_DIR.mkdir(parents=True, exist_ok=True)
CHECKPOINT_FILE.write_text(json.dumps({"next_index": idx}))
def metadata_to_str(meta):
if not meta: return ""
return "; ".join(f"{k}={v}" for k,v in meta.items() if isinstance(v,(str,int,float,bool)))
def append_metadata_at_end(answer, context1_text, context1_meta):
parts=[]
if answer: parts.append(answer.strip())
if context1_text: parts.append(f"[Context1: {context1_text.strip()}]")
meta_str = metadata_to_str(context1_meta)
if meta_str: parts.append(f"(meta: {meta_str})")
return " ".join(parts)
def build_prompt(ctx_texts, rel_text, sim_qas, question):
parts=[]
# ctx_texts = [ctx2, ctx3]
for ctx in ctx_texts:
if ctx: parts.append(f"<|CTX_QA|> {ctx} <|/CTX_QA|>")
if rel_text: parts.append(f"<|CTX_REL|> {rel_text} <|/CTX_REL|>")
for i in range(5):
if i < len(sim_qas):
qa = sim_qas[i]
parts.append(f"<|QA_SIM_{i+1}|> Q: {qa['question']} A: {qa['answer']} <|/QA_SIM_{i+1}|>")
else:
parts.append(f"<|QA_SIM_{i+1}|> <|/QA_SIM_{i+1}|>")
parts.append("<|INSTR|> Use above contexts to answer concisely. <|/INSTR|>")
parts.append(f"<|QUESTION|> {question} <|/QUESTION|>")
parts.append("<|ANSWER|>")
return "\n\n".join(parts)
def retrieve_contexts(questions, top_k=3):
"""Batch retrieve context texts + metadata"""
batch_res = context_retriever.batch_retrieve(questions, top_k=top_k)
contexts=[]
for res_list in batch_res:
ctx_texts = [r["text"] for r in res_list[:top_k]]
ctx_metas = [r["metadata"] for r in res_list[:top_k]]
# pad to top_k
while len(ctx_texts)<top_k: ctx_texts.append(""); ctx_metas.append({})
contexts.append((ctx_texts, ctx_metas))
return contexts
def retrieve_qas_and_rels(questions, max_workers=20):
"""Threaded retrieval of QA and relationships"""
sim_qas_list=[]
rel_list=[]
with ThreadPoolExecutor(max_workers=max_workers) as ex:
sim_qas_list = list(ex.map(lambda q: qa_retreiver([q], k=5), questions))
rel_list = list(ex.map(lambda q: batch_relationships([q], top_k=1)[0], questions))
return sim_qas_list, rel_list
def main():
OUT_DIR.mkdir(parents=True, exist_ok=True)
with open(QA_FILE,'r',encoding='utf-8') as f:
qas = json.load(f)
total = len(qas)
start_idx = read_checkpoint()
if start_idx >= total:
print("[INFO] Checkpoint beyond dataset length.")
return
prompts_accum=[]
batch_count=start_idx//SAVE_BATCH_SIZE
for batch_start in tqdm(range(start_idx, total, EMBED_BATCH_SIZE)):
batch_end = min(batch_start + EMBED_BATCH_SIZE, total)
batch_items = qas[batch_start:batch_end]
questions = [it.get("question") or it.get("q") or it.get("Question") for it in batch_items]
orig_answers = [it.get("answer") or it.get("a") or it.get("Answer","") for it in batch_items]
# --- retrieve contexts ---
contexts = retrieve_contexts(questions, top_k=3)
# --- QA & relationships ---
sim_qas_list, rel_list = retrieve_qas_and_rels(questions)
for i,q in enumerate(questions):
if not q:
write_checkpoint(batch_start+i+1)
continue
ctx_texts, ctx_metas = contexts[i]
context1, context2, context3 = ctx_texts
meta1 = ctx_metas[0]
prompt_text = build_prompt([context2, context3], rel_list[i], sim_qas_list[i], q)
gold = append_metadata_at_end(orig_answers[i], context1, meta1)
obj={
"id": batch_start+i,
"question": q,
"prompt": prompt_text,
"gold_answer": gold,
"context1": context1,
"retrieved_qas": sim_qas_list[i],
"relation_text": rel_list[i]
}
prompts_accum.append(obj)
# --- Save batch ---
if len(prompts_accum)>=SAVE_BATCH_SIZE:
out_path = OUT_DIR/f"prompts_batch_{batch_count:03d}.json"
out_path.write_text(json.dumps(prompts_accum, ensure_ascii=False, indent=2),encoding='utf-8')
batch_count+=1
prompts_accum=[]
write_checkpoint(batch_start+i+1)
# save remaining
if prompts_accum:
out_path = OUT_DIR/f"prompts_batch_{batch_count:03d}.json"
out_path.write_text(json.dumps(prompts_accum, ensure_ascii=False, indent=2))
OUT_DIR.joinpath("special_tokens_used.txt").write_text("\n".join(STRUCTURAL_TOKENS))
print("[DONE] All prompts processed.")
if __name__=="__main__":
main()
|