File size: 4,644 Bytes
c7a6fe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import argparse
import json
from pathlib import Path
from typing import Any, Dict, List

from openai import OpenAI


PROMPT_PATH = Path("/home/mshahidul/readctrl/prompts/support_check_data_generate")
API_FILE = Path("/home/mshahidul/api_new.json")
INPUT_PATH = Path(
    "/home/mshahidul/readctrl/code/text_classifier/data/verified_combined_0-80_clean200.json"
)
OUTPUT_DIR = Path("/home/mshahidul/readctrl/data/extracting_subclaim")
DEFAULT_OUTPUT_FILE = "synthetic_subclaims_first200.json"


def load_openai_client() -> OpenAI:
    with API_FILE.open("r", encoding="utf-8") as f:
        api_keys = json.load(f)
    openai_api_key = api_keys["openai"]
    return OpenAI(api_key=openai_api_key)


def normalize_difficulty(label: str) -> str:
    mapping = {
        "low_health_literacy": "easy",
        "intermediate_health_literacy": "intermediate",
        "proficient_health_literacy": "hard",
    }
    return mapping.get(label, "intermediate")


def clean_json_response(raw: str) -> Dict[str, Any]:
    cleaned = raw.strip().replace("```json", "").replace("```", "").strip()
    return json.loads(cleaned)


def make_prompt(template: str, item: Dict[str, Any]) -> str:
    payload = {
        "passage_id": f"{item.get('doc_id', 'unknown')}_{item.get('label', 'unknown')}",
        "passage": item.get("diff_label_texts", ""),
        "difficulty_label": normalize_difficulty(item.get("label", "")),
    }
    return (
        f"{template}\n\n"
        "Now generate output for this input:\n"
        f"{json.dumps(payload, ensure_ascii=False, indent=2)}\n"
    )


def load_input_data(limit: int) -> List[Dict[str, Any]]:
    with INPUT_PATH.open("r", encoding="utf-8") as f:
        data = json.load(f)
    return data[:limit]


def load_existing(path: Path) -> List[Dict[str, Any]]:
    if not path.exists():
        return []
    with path.open("r", encoding="utf-8") as f:
        return json.load(f)


def save_json(path: Path, data: List[Dict[str, Any]]) -> None:
    with path.open("w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Generate synthetic claim-verification subclaim dataset from diff_label_texts."
    )
    parser.add_argument("--limit", type=int, default=200, help="Number of input items to process.")
    parser.add_argument("--model", type=str, default="gpt-5", help="OpenAI model name.")
    parser.add_argument(
        "--output-file",
        type=str,
        default=DEFAULT_OUTPUT_FILE,
        help="Output filename inside output directory.",
    )
    parser.add_argument(
        "--save-every",
        type=int,
        default=2,
        help="Persist results after every N processed items.",
    )
    args = parser.parse_args()

    with PROMPT_PATH.open("r", encoding="utf-8") as f:
        prompt_template = f.read().strip()

    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    output_path = OUTPUT_DIR / args.output_file

    data = load_input_data(limit=args.limit)
    results = load_existing(output_path)
    done_keys = {item.get("source_key") for item in results}

    client = load_openai_client()

    for idx, item in enumerate(data):
        source_key = f"{item.get('doc_id')}_{item.get('label')}_{idx}"
        if source_key in done_keys:
            continue

        prompt = make_prompt(prompt_template, item)
        try:
            response = client.chat.completions.create(
                model=args.model,
                messages=[
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": prompt},
                ],
            )
            content = response.choices[0].message.content or ""
            generated = clean_json_response(content)
        except Exception as e:  # noqa: BLE001
            generated = {"error": str(e), "raw_response": response.choices[0].message.content if "response" in locals() else ""}

        results.append(
            {
                "source_key": source_key,
                "doc_id": item.get("doc_id"),
                "source_label": item.get("label"),
                "difficulty_label": normalize_difficulty(item.get("label", "")),
                "generated": generated,
            }
        )
        done_keys.add(source_key)

        if len(results) % args.save_every == 0:
            save_json(output_path, results)
            print(f"Saved {len(results)} rows to {output_path}")

    save_json(output_path, results)
    print(f"Done. Saved {len(results)} rows to {output_path}")


if __name__ == "__main__":
    main()