| | import argparse |
| | import json |
| | import os |
| | import time |
| | import urllib.error |
| | import urllib.request |
| | from datetime import datetime |
| | from typing import Any, Dict, List, Optional |
| |
|
| | from tqdm import tqdm |
| |
|
| |
|
| | api_file = "/home/mshahidul/api_new.json" |
| | with open(api_file, "r", encoding="utf-8") as f: |
| | api_keys = json.load(f) |
| |
|
| | DEFAULT_API_BASE = "https://api.openai.com/v1" |
| | DEFAULT_INPUT_PATH = ( |
| | "/home/mshahidul/readctrl/data/annotators_validate_data_(20_80)/combine/" |
| | "verified_combined_0-80.json" |
| | ) |
| | DEFAULT_OUTPUT_DIR = "/home/mshahidul/readctrl/code/rl_inference/test_result" |
| | DEFAULT_PROMPT_LOW_PATH = ( |
| | "/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt_low" |
| | ) |
| | DEFAULT_PROMPT_INTERMEDIATE_PATH = ( |
| | "/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt_intermediate" |
| | ) |
| | DEFAULT_PROMPT_PROFICIENT_PATH = ( |
| | "/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt_proficient" |
| | ) |
| | DEFAULT_MODELS = "gpt-5-mini,gpt-5-nano" |
| |
|
| | VALID_LABELS = { |
| | "low_health_literacy", |
| | "intermediate_health_literacy", |
| | "proficient_health_literacy", |
| | } |
| |
|
| |
|
| | def parse_args() -> argparse.Namespace: |
| | parser = argparse.ArgumentParser( |
| | description=( |
| | "Generate outputs with gpt-5-mini and gpt-5-nano using " |
| | "verified_combined dataset and literacy-level prompts." |
| | ) |
| | ) |
| | parser.add_argument("--api-base", default=os.environ.get("OPENAI_API_BASE", DEFAULT_API_BASE)) |
| | parser.add_argument( |
| | "--api-key", |
| | default=os.environ.get("OPENAI_API_KEY", api_keys["openai"]), |
| | ) |
| | parser.add_argument("--models", default=DEFAULT_MODELS, help="Comma-separated model list.") |
| | parser.add_argument("--input-path", default=DEFAULT_INPUT_PATH) |
| | parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR) |
| | parser.add_argument("--prompt-low-path", default=DEFAULT_PROMPT_LOW_PATH) |
| | parser.add_argument( |
| | "--prompt-intermediate-path", |
| | default=DEFAULT_PROMPT_INTERMEDIATE_PATH, |
| | ) |
| | parser.add_argument( |
| | "--prompt-proficient-path", |
| | default=DEFAULT_PROMPT_PROFICIENT_PATH, |
| | ) |
| | parser.add_argument( |
| | "--max-samples", |
| | type=int, |
| | default=-1, |
| | help="Use -1 for all rows.", |
| | ) |
| | parser.add_argument("--temperature", type=float, default=0.0) |
| | parser.add_argument("--timeout-seconds", type=int, default=120) |
| | parser.add_argument("--max-retries", type=int, default=2) |
| | parser.add_argument("--retry-wait-seconds", type=float, default=2.0) |
| | return parser.parse_args() |
| |
|
| |
|
| | def check_api_base(api_base: str, api_key: str, timeout_seconds: int) -> None: |
| | models_url = api_base.rstrip("/") + "/models" |
| | req = urllib.request.Request(models_url, method="GET") |
| | if api_key: |
| | req.add_header("Authorization", f"Bearer {api_key}") |
| | try: |
| | with urllib.request.urlopen(req, timeout=timeout_seconds) as resp: |
| | if resp.status >= 400: |
| | raise RuntimeError( |
| | f"Endpoint reachable but unhealthy: {models_url} (status={resp.status})" |
| | ) |
| | except urllib.error.URLError as exc: |
| | raise ConnectionError( |
| | "Cannot reach OpenAI-compatible endpoint. " |
| | f"api_base={api_base}. Check network/API base/API key." |
| | ) from exc |
| |
|
| |
|
| | def load_prompt_templates(args: argparse.Namespace) -> Dict[str, str]: |
| | prompt_path_by_label = { |
| | "low_health_literacy": args.prompt_low_path, |
| | "intermediate_health_literacy": args.prompt_intermediate_path, |
| | "proficient_health_literacy": args.prompt_proficient_path, |
| | } |
| | templates: Dict[str, str] = {} |
| | for label, path in prompt_path_by_label.items(): |
| | if not os.path.exists(path): |
| | raise FileNotFoundError(f"Prompt file not found: {path}") |
| | with open(path, "r", encoding="utf-8") as f: |
| | templates[label] = f.read() |
| | return templates |
| |
|
| |
|
| | def infer_source_lang(fulltext: str) -> str: |
| | if fulltext and any("a" <= ch.lower() <= "z" for ch in fulltext): |
| | return "English" |
| | return "Unknown" |
| |
|
| |
|
| | def build_prompt(template: str, fulltext: str, summary: str, source_lang: str) -> str: |
| | return ( |
| | template.replace("{source_lang}", source_lang) |
| | .replace("{gold_summary}", summary) |
| | .replace("{full_text}", fulltext) |
| | ) |
| |
|
| |
|
| | def load_verified_rows(path: str) -> List[Dict[str, Any]]: |
| | if not os.path.exists(path): |
| | raise FileNotFoundError(f"Input file not found: {path}") |
| | with open(path, "r", encoding="utf-8") as f: |
| | parsed = json.load(f) |
| | if not isinstance(parsed, list): |
| | raise ValueError(f"Expected top-level JSON array in {path}") |
| | return [row for row in parsed if isinstance(row, dict)] |
| |
|
| |
|
| | def parse_models(models_arg: str) -> List[str]: |
| | models = [m.strip() for m in models_arg.split(",") if m.strip()] |
| | if not models: |
| | raise ValueError("No models provided. Example: --models gpt-5-mini,gpt-5-nano") |
| | return models |
| |
|
| |
|
| | def _clean_json_block(text: str) -> str: |
| | cleaned = text.strip() |
| | if "```json" in cleaned: |
| | cleaned = cleaned.split("```json", 1)[1].split("```", 1)[0].strip() |
| | elif "```" in cleaned: |
| | cleaned = cleaned.split("```", 1)[1].split("```", 1)[0].strip() |
| | return cleaned |
| |
|
| |
|
| | def extract_generated_text(raw_response: str, expected_label: str) -> str: |
| | cleaned = _clean_json_block(raw_response) |
| | try: |
| | parsed = json.loads(cleaned) |
| | except json.JSONDecodeError: |
| | return raw_response.strip() |
| |
|
| | if isinstance(parsed, dict): |
| | value = parsed.get(expected_label) |
| | if isinstance(value, str) and value.strip(): |
| | return value.strip() |
| | return raw_response.strip() |
| |
|
| |
|
| | def call_chat_completion( |
| | *, |
| | api_base: str, |
| | api_key: str, |
| | model: str, |
| | prompt: str, |
| | temperature: float, |
| | timeout_seconds: int, |
| | max_retries: int, |
| | retry_wait_seconds: float, |
| | ) -> str: |
| | url = api_base.rstrip("/") + "/chat/completions" |
| | payload = { |
| | "model": model, |
| | "messages": [{"role": "user", "content": prompt}], |
| | } |
| | data = json.dumps(payload).encode("utf-8") |
| |
|
| | last_error: Optional[Exception] = None |
| | for attempt in range(max_retries + 1): |
| | req = urllib.request.Request(url, data=data, method="POST") |
| | req.add_header("Content-Type", "application/json") |
| | if api_key: |
| | req.add_header("Authorization", f"Bearer {api_key}") |
| | try: |
| | with urllib.request.urlopen(req, timeout=timeout_seconds) as resp: |
| | body = resp.read().decode("utf-8") |
| | parsed = json.loads(body) |
| | return str(parsed["choices"][0]["message"]["content"]).strip() |
| | except urllib.error.HTTPError as exc: |
| | retriable = exc.code in (408, 409, 429, 500, 502, 503, 504) |
| | last_error = exc |
| | if attempt < max_retries and retriable: |
| | time.sleep(retry_wait_seconds) |
| | continue |
| | raise |
| | except (urllib.error.URLError, KeyError, IndexError, json.JSONDecodeError) as exc: |
| | last_error = exc |
| | if attempt < max_retries: |
| | time.sleep(retry_wait_seconds) |
| | continue |
| | raise |
| |
|
| | if last_error: |
| | raise last_error |
| | raise RuntimeError("Unknown error during chat completion call.") |
| |
|
| |
|
| | def main() -> None: |
| | args = parse_args() |
| | if not args.api_key: |
| | raise ValueError("Missing API key. Set OPENAI_API_KEY or pass --api-key.") |
| |
|
| | for path in ( |
| | args.prompt_low_path, |
| | args.prompt_intermediate_path, |
| | args.prompt_proficient_path, |
| | ): |
| | if not os.path.exists(path): |
| | raise FileNotFoundError(f"Prompt file not found: {path}") |
| |
|
| | check_api_base(args.api_base, args.api_key, args.timeout_seconds) |
| | models = parse_models(args.models) |
| | templates = load_prompt_templates(args) |
| | rows = load_verified_rows(args.input_path) |
| |
|
| | parsed_items: List[Dict[str, Any]] = [] |
| | for idx, row in enumerate(rows): |
| | gold_label = str(row.get("label", "")).strip() |
| | fulltext = str(row.get("fulltext", "")).strip() |
| | summary = str(row.get("summary", "")).strip() |
| | if gold_label not in VALID_LABELS: |
| | continue |
| | if not fulltext or not summary: |
| | continue |
| | source_lang = infer_source_lang(fulltext) |
| | prompt = build_prompt( |
| | template=templates[gold_label], |
| | fulltext=fulltext, |
| | summary=summary, |
| | source_lang=source_lang, |
| | ) |
| | parsed_items.append( |
| | { |
| | "row_index": idx, |
| | "doc_id": row.get("doc_id"), |
| | "gold_label": gold_label, |
| | "source_lang": source_lang, |
| | "prompt": prompt, |
| | } |
| | ) |
| |
|
| | if args.max_samples > 0: |
| | parsed_items = parsed_items[: args.max_samples] |
| | if not parsed_items: |
| | raise RuntimeError("No valid rows found in input file.") |
| |
|
| | ts = datetime.now().strftime("%Y%m%d_%H%M%S") |
| | os.makedirs(args.output_dir, exist_ok=True) |
| | summary_path = os.path.join(args.output_dir, f"gpt5_inference_summary_{ts}.json") |
| | combined_path = os.path.join(args.output_dir, f"gpt5_inference_all_{ts}.jsonl") |
| |
|
| | combined_records: List[Dict[str, Any]] = [] |
| | model_stats: Dict[str, Dict[str, Any]] = {} |
| |
|
| | for model in models: |
| | model_slug = model.replace("/", "_") |
| | model_output_path = os.path.join( |
| | args.output_dir, f"gpt5_inference_{model_slug}_{ts}.jsonl" |
| | ) |
| | success_count = 0 |
| | error_count = 0 |
| |
|
| | with open(model_output_path, "w", encoding="utf-8") as f_model: |
| | total = len(parsed_items) |
| | progress_iter = tqdm( |
| | parsed_items, |
| | total=total, |
| | desc=f"{model}", |
| | unit="item", |
| | ) |
| | for item in progress_iter: |
| |
|
| | record: Dict[str, Any] = { |
| | "model": model, |
| | "row_index": item["row_index"], |
| | "doc_id": item.get("doc_id"), |
| | "gold_label": item["gold_label"], |
| | "source_lang": item["source_lang"], |
| | "prompt": item["prompt"], |
| | } |
| | try: |
| | raw_response = call_chat_completion( |
| | api_base=args.api_base, |
| | api_key=args.api_key, |
| | model=model, |
| | prompt=item["prompt"], |
| | temperature=args.temperature, |
| | timeout_seconds=args.timeout_seconds, |
| | max_retries=args.max_retries, |
| | retry_wait_seconds=args.retry_wait_seconds, |
| | ) |
| | generated_text = extract_generated_text(raw_response, item["gold_label"]) |
| | record["prediction"] = raw_response |
| | record["generated_text"] = generated_text |
| | record["error"] = "" |
| | success_count += 1 |
| | except Exception as exc: |
| | record["prediction"] = "" |
| | record["generated_text"] = "" |
| | record["error"] = f"{type(exc).__name__}: {exc}" |
| | error_count += 1 |
| |
|
| | f_model.write(json.dumps(record, ensure_ascii=False) + "\n") |
| | combined_records.append(record) |
| |
|
| | model_stats[model] = { |
| | "output_path": model_output_path, |
| | "total_rows": len(parsed_items), |
| | "success_count": success_count, |
| | "error_count": error_count, |
| | } |
| | print(f"[DONE] {model} output: {model_output_path}") |
| |
|
| | with open(combined_path, "w", encoding="utf-8") as f_all: |
| | for record in combined_records: |
| | f_all.write(json.dumps(record, ensure_ascii=False) + "\n") |
| |
|
| | summary_obj = { |
| | "input_path": args.input_path, |
| | "api_base": args.api_base, |
| | "models": models, |
| | "max_samples": args.max_samples, |
| | "temperature": args.temperature, |
| | "total_dataset_rows_used": len(parsed_items), |
| | "combined_output_path": combined_path, |
| | "model_stats": model_stats, |
| | } |
| | with open(summary_path, "w", encoding="utf-8") as f_summary: |
| | json.dump(summary_obj, f_summary, ensure_ascii=False, indent=2) |
| |
|
| | print(f"[DONE] Combined output: {combined_path}") |
| | print(f"[DONE] Summary output: {summary_path}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|