import argparse import json import os import time import urllib.error import urllib.request from datetime import datetime from typing import Any, Dict, List, Optional from tqdm import tqdm # pyright: ignore[reportMissingModuleSource] api_file = "/home/mshahidul/api_new.json" with open(api_file, "r", encoding="utf-8") as f: api_keys = json.load(f) DEFAULT_API_BASE = "https://api.openai.com/v1" DEFAULT_INPUT_PATH = ( "/home/mshahidul/readctrl/data/annotators_validate_data_(20_80)/combine/" "verified_combined_0-80.json" ) DEFAULT_OUTPUT_DIR = "/home/mshahidul/readctrl/code/rl_inference/test_result" DEFAULT_PROMPT_LOW_PATH = ( "/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt_low" ) DEFAULT_PROMPT_INTERMEDIATE_PATH = ( "/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt_intermediate" ) DEFAULT_PROMPT_PROFICIENT_PATH = ( "/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt_proficient" ) DEFAULT_MODELS = "gpt-5-mini,gpt-5-nano" VALID_LABELS = { "low_health_literacy", "intermediate_health_literacy", "proficient_health_literacy", } def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description=( "Generate outputs with gpt-5-mini and gpt-5-nano using " "verified_combined dataset and literacy-level prompts." ) ) parser.add_argument("--api-base", default=os.environ.get("OPENAI_API_BASE", DEFAULT_API_BASE)) parser.add_argument( "--api-key", default=os.environ.get("OPENAI_API_KEY", api_keys["openai"]), ) parser.add_argument("--models", default=DEFAULT_MODELS, help="Comma-separated model list.") parser.add_argument("--input-path", default=DEFAULT_INPUT_PATH) parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR) parser.add_argument("--prompt-low-path", default=DEFAULT_PROMPT_LOW_PATH) parser.add_argument( "--prompt-intermediate-path", default=DEFAULT_PROMPT_INTERMEDIATE_PATH, ) parser.add_argument( "--prompt-proficient-path", default=DEFAULT_PROMPT_PROFICIENT_PATH, ) parser.add_argument( "--max-samples", type=int, default=-1, help="Use -1 for all rows.", ) parser.add_argument("--temperature", type=float, default=0.0) parser.add_argument("--timeout-seconds", type=int, default=120) parser.add_argument("--max-retries", type=int, default=2) parser.add_argument("--retry-wait-seconds", type=float, default=2.0) return parser.parse_args() def check_api_base(api_base: str, api_key: str, timeout_seconds: int) -> None: models_url = api_base.rstrip("/") + "/models" req = urllib.request.Request(models_url, method="GET") if api_key: req.add_header("Authorization", f"Bearer {api_key}") try: with urllib.request.urlopen(req, timeout=timeout_seconds) as resp: if resp.status >= 400: raise RuntimeError( f"Endpoint reachable but unhealthy: {models_url} (status={resp.status})" ) except urllib.error.URLError as exc: raise ConnectionError( "Cannot reach OpenAI-compatible endpoint. " f"api_base={api_base}. Check network/API base/API key." ) from exc def load_prompt_templates(args: argparse.Namespace) -> Dict[str, str]: prompt_path_by_label = { "low_health_literacy": args.prompt_low_path, "intermediate_health_literacy": args.prompt_intermediate_path, "proficient_health_literacy": args.prompt_proficient_path, } templates: Dict[str, str] = {} for label, path in prompt_path_by_label.items(): if not os.path.exists(path): raise FileNotFoundError(f"Prompt file not found: {path}") with open(path, "r", encoding="utf-8") as f: templates[label] = f.read() return templates def infer_source_lang(fulltext: str) -> str: if fulltext and any("a" <= ch.lower() <= "z" for ch in fulltext): return "English" return "Unknown" def build_prompt(template: str, fulltext: str, summary: str, source_lang: str) -> str: return ( template.replace("{source_lang}", source_lang) .replace("{gold_summary}", summary) .replace("{full_text}", fulltext) ) def load_verified_rows(path: str) -> List[Dict[str, Any]]: if not os.path.exists(path): raise FileNotFoundError(f"Input file not found: {path}") with open(path, "r", encoding="utf-8") as f: parsed = json.load(f) if not isinstance(parsed, list): raise ValueError(f"Expected top-level JSON array in {path}") return [row for row in parsed if isinstance(row, dict)] def parse_models(models_arg: str) -> List[str]: models = [m.strip() for m in models_arg.split(",") if m.strip()] if not models: raise ValueError("No models provided. Example: --models gpt-5-mini,gpt-5-nano") return models def _clean_json_block(text: str) -> str: cleaned = text.strip() if "```json" in cleaned: cleaned = cleaned.split("```json", 1)[1].split("```", 1)[0].strip() elif "```" in cleaned: cleaned = cleaned.split("```", 1)[1].split("```", 1)[0].strip() return cleaned def extract_generated_text(raw_response: str, expected_label: str) -> str: cleaned = _clean_json_block(raw_response) try: parsed = json.loads(cleaned) except json.JSONDecodeError: return raw_response.strip() if isinstance(parsed, dict): value = parsed.get(expected_label) if isinstance(value, str) and value.strip(): return value.strip() return raw_response.strip() def call_chat_completion( *, api_base: str, api_key: str, model: str, prompt: str, temperature: float, timeout_seconds: int, max_retries: int, retry_wait_seconds: float, ) -> str: url = api_base.rstrip("/") + "/chat/completions" payload = { "model": model, "messages": [{"role": "user", "content": prompt}], } data = json.dumps(payload).encode("utf-8") last_error: Optional[Exception] = None for attempt in range(max_retries + 1): req = urllib.request.Request(url, data=data, method="POST") req.add_header("Content-Type", "application/json") if api_key: req.add_header("Authorization", f"Bearer {api_key}") try: with urllib.request.urlopen(req, timeout=timeout_seconds) as resp: body = resp.read().decode("utf-8") parsed = json.loads(body) return str(parsed["choices"][0]["message"]["content"]).strip() except urllib.error.HTTPError as exc: retriable = exc.code in (408, 409, 429, 500, 502, 503, 504) last_error = exc if attempt < max_retries and retriable: time.sleep(retry_wait_seconds) continue raise except (urllib.error.URLError, KeyError, IndexError, json.JSONDecodeError) as exc: last_error = exc if attempt < max_retries: time.sleep(retry_wait_seconds) continue raise if last_error: raise last_error raise RuntimeError("Unknown error during chat completion call.") def main() -> None: args = parse_args() if not args.api_key: raise ValueError("Missing API key. Set OPENAI_API_KEY or pass --api-key.") for path in ( args.prompt_low_path, args.prompt_intermediate_path, args.prompt_proficient_path, ): if not os.path.exists(path): raise FileNotFoundError(f"Prompt file not found: {path}") check_api_base(args.api_base, args.api_key, args.timeout_seconds) models = parse_models(args.models) templates = load_prompt_templates(args) rows = load_verified_rows(args.input_path) parsed_items: List[Dict[str, Any]] = [] for idx, row in enumerate(rows): gold_label = str(row.get("label", "")).strip() fulltext = str(row.get("fulltext", "")).strip() summary = str(row.get("summary", "")).strip() if gold_label not in VALID_LABELS: continue if not fulltext or not summary: continue source_lang = infer_source_lang(fulltext) prompt = build_prompt( template=templates[gold_label], fulltext=fulltext, summary=summary, source_lang=source_lang, ) parsed_items.append( { "row_index": idx, "doc_id": row.get("doc_id"), "gold_label": gold_label, "source_lang": source_lang, "prompt": prompt, } ) if args.max_samples > 0: parsed_items = parsed_items[: args.max_samples] if not parsed_items: raise RuntimeError("No valid rows found in input file.") ts = datetime.now().strftime("%Y%m%d_%H%M%S") os.makedirs(args.output_dir, exist_ok=True) summary_path = os.path.join(args.output_dir, f"gpt5_inference_summary_{ts}.json") combined_path = os.path.join(args.output_dir, f"gpt5_inference_all_{ts}.jsonl") combined_records: List[Dict[str, Any]] = [] model_stats: Dict[str, Dict[str, Any]] = {} for model in models: model_slug = model.replace("/", "_") model_output_path = os.path.join( args.output_dir, f"gpt5_inference_{model_slug}_{ts}.jsonl" ) success_count = 0 error_count = 0 with open(model_output_path, "w", encoding="utf-8") as f_model: total = len(parsed_items) progress_iter = tqdm( parsed_items, total=total, desc=f"{model}", unit="item", ) for item in progress_iter: record: Dict[str, Any] = { "model": model, "row_index": item["row_index"], "doc_id": item.get("doc_id"), "gold_label": item["gold_label"], "source_lang": item["source_lang"], "prompt": item["prompt"], } try: raw_response = call_chat_completion( api_base=args.api_base, api_key=args.api_key, model=model, prompt=item["prompt"], temperature=args.temperature, timeout_seconds=args.timeout_seconds, max_retries=args.max_retries, retry_wait_seconds=args.retry_wait_seconds, ) generated_text = extract_generated_text(raw_response, item["gold_label"]) record["prediction"] = raw_response record["generated_text"] = generated_text record["error"] = "" success_count += 1 except Exception as exc: record["prediction"] = "" record["generated_text"] = "" record["error"] = f"{type(exc).__name__}: {exc}" error_count += 1 f_model.write(json.dumps(record, ensure_ascii=False) + "\n") combined_records.append(record) model_stats[model] = { "output_path": model_output_path, "total_rows": len(parsed_items), "success_count": success_count, "error_count": error_count, } print(f"[DONE] {model} output: {model_output_path}") with open(combined_path, "w", encoding="utf-8") as f_all: for record in combined_records: f_all.write(json.dumps(record, ensure_ascii=False) + "\n") summary_obj = { "input_path": args.input_path, "api_base": args.api_base, "models": models, "max_samples": args.max_samples, "temperature": args.temperature, "total_dataset_rows_used": len(parsed_items), "combined_output_path": combined_path, "model_stats": model_stats, } with open(summary_path, "w", encoding="utf-8") as f_summary: json.dump(summary_obj, f_summary, ensure_ascii=False, indent=2) print(f"[DONE] Combined output: {combined_path}") print(f"[DONE] Summary output: {summary_path}") if __name__ == "__main__": main()