readctrl / code /data_creation /generate_subclaim_synthetic_dataset.py
shahidul034's picture
Add files using upload-large-folder tool
c7a6fe6 verified
import argparse
import json
from pathlib import Path
from typing import Any, Dict, List
from openai import OpenAI
PROMPT_PATH = Path("/home/mshahidul/readctrl/prompts/support_check_data_generate")
API_FILE = Path("/home/mshahidul/api_new.json")
INPUT_PATH = Path(
"/home/mshahidul/readctrl/code/text_classifier/data/verified_combined_0-80_clean200.json"
)
OUTPUT_DIR = Path("/home/mshahidul/readctrl/data/extracting_subclaim")
DEFAULT_OUTPUT_FILE = "synthetic_subclaims_first200.json"
def load_openai_client() -> OpenAI:
with API_FILE.open("r", encoding="utf-8") as f:
api_keys = json.load(f)
openai_api_key = api_keys["openai"]
return OpenAI(api_key=openai_api_key)
def normalize_difficulty(label: str) -> str:
mapping = {
"low_health_literacy": "easy",
"intermediate_health_literacy": "intermediate",
"proficient_health_literacy": "hard",
}
return mapping.get(label, "intermediate")
def clean_json_response(raw: str) -> Dict[str, Any]:
cleaned = raw.strip().replace("```json", "").replace("```", "").strip()
return json.loads(cleaned)
def make_prompt(template: str, item: Dict[str, Any]) -> str:
payload = {
"passage_id": f"{item.get('doc_id', 'unknown')}_{item.get('label', 'unknown')}",
"passage": item.get("diff_label_texts", ""),
"difficulty_label": normalize_difficulty(item.get("label", "")),
}
return (
f"{template}\n\n"
"Now generate output for this input:\n"
f"{json.dumps(payload, ensure_ascii=False, indent=2)}\n"
)
def load_input_data(limit: int) -> List[Dict[str, Any]]:
with INPUT_PATH.open("r", encoding="utf-8") as f:
data = json.load(f)
return data[:limit]
def load_existing(path: Path) -> List[Dict[str, Any]]:
if not path.exists():
return []
with path.open("r", encoding="utf-8") as f:
return json.load(f)
def save_json(path: Path, data: List[Dict[str, Any]]) -> None:
with path.open("w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def main() -> None:
parser = argparse.ArgumentParser(
description="Generate synthetic claim-verification subclaim dataset from diff_label_texts."
)
parser.add_argument("--limit", type=int, default=200, help="Number of input items to process.")
parser.add_argument("--model", type=str, default="gpt-5", help="OpenAI model name.")
parser.add_argument(
"--output-file",
type=str,
default=DEFAULT_OUTPUT_FILE,
help="Output filename inside output directory.",
)
parser.add_argument(
"--save-every",
type=int,
default=2,
help="Persist results after every N processed items.",
)
args = parser.parse_args()
with PROMPT_PATH.open("r", encoding="utf-8") as f:
prompt_template = f.read().strip()
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
output_path = OUTPUT_DIR / args.output_file
data = load_input_data(limit=args.limit)
results = load_existing(output_path)
done_keys = {item.get("source_key") for item in results}
client = load_openai_client()
for idx, item in enumerate(data):
source_key = f"{item.get('doc_id')}_{item.get('label')}_{idx}"
if source_key in done_keys:
continue
prompt = make_prompt(prompt_template, item)
try:
response = client.chat.completions.create(
model=args.model,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
],
)
content = response.choices[0].message.content or ""
generated = clean_json_response(content)
except Exception as e: # noqa: BLE001
generated = {"error": str(e), "raw_response": response.choices[0].message.content if "response" in locals() else ""}
results.append(
{
"source_key": source_key,
"doc_id": item.get("doc_id"),
"source_label": item.get("label"),
"difficulty_label": normalize_difficulty(item.get("label", "")),
"generated": generated,
}
)
done_keys.add(source_key)
if len(results) % args.save_every == 0:
save_json(output_path, results)
print(f"Saved {len(results)} rows to {output_path}")
save_json(output_path, results)
print(f"Done. Saved {len(results)} rows to {output_path}")
if __name__ == "__main__":
main()