File size: 4,644 Bytes
c7a6fe6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | import argparse
import json
from pathlib import Path
from typing import Any, Dict, List
from openai import OpenAI
PROMPT_PATH = Path("/home/mshahidul/readctrl/prompts/support_check_data_generate")
API_FILE = Path("/home/mshahidul/api_new.json")
INPUT_PATH = Path(
"/home/mshahidul/readctrl/code/text_classifier/data/verified_combined_0-80_clean200.json"
)
OUTPUT_DIR = Path("/home/mshahidul/readctrl/data/extracting_subclaim")
DEFAULT_OUTPUT_FILE = "synthetic_subclaims_first200.json"
def load_openai_client() -> OpenAI:
with API_FILE.open("r", encoding="utf-8") as f:
api_keys = json.load(f)
openai_api_key = api_keys["openai"]
return OpenAI(api_key=openai_api_key)
def normalize_difficulty(label: str) -> str:
mapping = {
"low_health_literacy": "easy",
"intermediate_health_literacy": "intermediate",
"proficient_health_literacy": "hard",
}
return mapping.get(label, "intermediate")
def clean_json_response(raw: str) -> Dict[str, Any]:
cleaned = raw.strip().replace("```json", "").replace("```", "").strip()
return json.loads(cleaned)
def make_prompt(template: str, item: Dict[str, Any]) -> str:
payload = {
"passage_id": f"{item.get('doc_id', 'unknown')}_{item.get('label', 'unknown')}",
"passage": item.get("diff_label_texts", ""),
"difficulty_label": normalize_difficulty(item.get("label", "")),
}
return (
f"{template}\n\n"
"Now generate output for this input:\n"
f"{json.dumps(payload, ensure_ascii=False, indent=2)}\n"
)
def load_input_data(limit: int) -> List[Dict[str, Any]]:
with INPUT_PATH.open("r", encoding="utf-8") as f:
data = json.load(f)
return data[:limit]
def load_existing(path: Path) -> List[Dict[str, Any]]:
if not path.exists():
return []
with path.open("r", encoding="utf-8") as f:
return json.load(f)
def save_json(path: Path, data: List[Dict[str, Any]]) -> None:
with path.open("w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def main() -> None:
parser = argparse.ArgumentParser(
description="Generate synthetic claim-verification subclaim dataset from diff_label_texts."
)
parser.add_argument("--limit", type=int, default=200, help="Number of input items to process.")
parser.add_argument("--model", type=str, default="gpt-5", help="OpenAI model name.")
parser.add_argument(
"--output-file",
type=str,
default=DEFAULT_OUTPUT_FILE,
help="Output filename inside output directory.",
)
parser.add_argument(
"--save-every",
type=int,
default=2,
help="Persist results after every N processed items.",
)
args = parser.parse_args()
with PROMPT_PATH.open("r", encoding="utf-8") as f:
prompt_template = f.read().strip()
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
output_path = OUTPUT_DIR / args.output_file
data = load_input_data(limit=args.limit)
results = load_existing(output_path)
done_keys = {item.get("source_key") for item in results}
client = load_openai_client()
for idx, item in enumerate(data):
source_key = f"{item.get('doc_id')}_{item.get('label')}_{idx}"
if source_key in done_keys:
continue
prompt = make_prompt(prompt_template, item)
try:
response = client.chat.completions.create(
model=args.model,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
],
)
content = response.choices[0].message.content or ""
generated = clean_json_response(content)
except Exception as e: # noqa: BLE001
generated = {"error": str(e), "raw_response": response.choices[0].message.content if "response" in locals() else ""}
results.append(
{
"source_key": source_key,
"doc_id": item.get("doc_id"),
"source_label": item.get("label"),
"difficulty_label": normalize_difficulty(item.get("label", "")),
"generated": generated,
}
)
done_keys.add(source_key)
if len(results) % args.save_every == 0:
save_json(output_path, results)
print(f"Saved {len(results)} rows to {output_path}")
save_json(output_path, results)
print(f"Done. Saved {len(results)} rows to {output_path}")
if __name__ == "__main__":
main()
|