| import argparse |
| import json |
| import os |
| import time |
| import urllib.error |
| import urllib.request |
| from datetime import datetime |
| from typing import Any, Dict, List, Optional |
|
|
| from tqdm import tqdm |
|
|
|
|
| api_file = "/home/mshahidul/api_new.json" |
| with open(api_file, "r", encoding="utf-8") as f: |
| api_keys = json.load(f) |
|
|
| DEFAULT_API_BASE = "https://api.openai.com/v1" |
| DEFAULT_INPUT_PATH = ( |
| "/home/mshahidul/readctrl/code/fine_tune_sft_dpo/dataset/bn/test_bn.json" |
| ) |
| DEFAULT_OUTPUT_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo/results/bn" |
| DEFAULT_PROMPT_LOW_PATH = ( |
| "/home/mshahidul/readctrl/code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_low" |
| ) |
| DEFAULT_PROMPT_INTERMEDIATE_PATH = ( |
| "/home/mshahidul/readctrl/code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_intermediate" |
| ) |
| DEFAULT_PROMPT_PROFICIENT_PATH = ( |
| "/home/mshahidul/readctrl/code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_proficient" |
| ) |
| DEFAULT_MODELS = "gpt-5,gpt-5-mini,gpt-5-nano" |
| DEFAULT_COST_LIMIT = 50.0 |
|
|
| PRICING_PER_1M = { |
| "gpt-5": {"input": 1.25, "cached_input": 0.125, "output": 10.00}, |
| "gpt-5-mini": {"input": 0.25, "cached_input": 0.025, "output": 2.00}, |
| "gpt-5-nano": {"input": 0.05, "cached_input": 0.005, "output": 0.40}, |
| } |
|
|
| VALID_LABELS = { |
| "low_health_literacy", |
| "intermediate_health_literacy", |
| "proficient_health_literacy", |
| } |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description=( |
| "Generate outputs with gpt-5-mini and gpt-5-nano using " |
| "verified_combined dataset and literacy-level prompts." |
| ) |
| ) |
| parser.add_argument("--api-base", default=os.environ.get("OPENAI_API_BASE", DEFAULT_API_BASE)) |
| parser.add_argument( |
| "--api-key", |
| default=os.environ.get("OPENAI_API_KEY", api_keys["openai"]), |
| ) |
| parser.add_argument("--models", default=DEFAULT_MODELS, help="Comma-separated model list.") |
| parser.add_argument("--input-path", default=DEFAULT_INPUT_PATH) |
| parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR) |
| parser.add_argument("--prompt-low-path", default=DEFAULT_PROMPT_LOW_PATH) |
| parser.add_argument( |
| "--prompt-intermediate-path", |
| default=DEFAULT_PROMPT_INTERMEDIATE_PATH, |
| ) |
| parser.add_argument( |
| "--prompt-proficient-path", |
| default=DEFAULT_PROMPT_PROFICIENT_PATH, |
| ) |
| parser.add_argument( |
| "--max-samples", |
| type=int, |
| default=-1, |
| help="Use -1 for all rows.", |
| ) |
| parser.add_argument("--temperature", type=float, default=0.0) |
| parser.add_argument("--timeout-seconds", type=int, default=120) |
| parser.add_argument("--max-retries", type=int, default=2) |
| parser.add_argument("--retry-wait-seconds", type=float, default=2.0) |
| parser.add_argument( |
| "--cost-limit", |
| type=float, |
| default=DEFAULT_COST_LIMIT, |
| help="Stop and save when cumulative API cost exceeds this amount in USD.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def check_api_base(api_base: str, api_key: str, timeout_seconds: int) -> None: |
| models_url = api_base.rstrip("/") + "/models" |
| req = urllib.request.Request(models_url, method="GET") |
| if api_key: |
| req.add_header("Authorization", f"Bearer {api_key}") |
| try: |
| with urllib.request.urlopen(req, timeout=timeout_seconds) as resp: |
| if resp.status >= 400: |
| raise RuntimeError( |
| f"Endpoint reachable but unhealthy: {models_url} (status={resp.status})" |
| ) |
| except urllib.error.URLError as exc: |
| raise ConnectionError( |
| "Cannot reach OpenAI-compatible endpoint. " |
| f"api_base={api_base}. Check network/API base/API key." |
| ) from exc |
|
|
|
|
| def load_prompt_templates(args: argparse.Namespace) -> Dict[str, str]: |
| prompt_path_by_label = { |
| "low_health_literacy": args.prompt_low_path, |
| "intermediate_health_literacy": args.prompt_intermediate_path, |
| "proficient_health_literacy": args.prompt_proficient_path, |
| } |
| templates: Dict[str, str] = {} |
| for label, path in prompt_path_by_label.items(): |
| if not os.path.exists(path): |
| raise FileNotFoundError(f"Prompt file not found: {path}") |
| with open(path, "r", encoding="utf-8") as f: |
| templates[label] = f.read() |
| return templates |
|
|
|
|
| def infer_source_lang(fulltext: str) -> str: |
| if fulltext and any("\u0980" <= ch <= "\u09FF" for ch in fulltext): |
| return "Bangla" |
| if fulltext and any("a" <= ch.lower() <= "z" for ch in fulltext): |
| return "English" |
| return "Unknown" |
|
|
|
|
| def build_prompt(template: str, fulltext: str, summary: str, source_lang: str) -> str: |
| return ( |
| template.replace("{source_lang}", source_lang) |
| .replace("{gold_summary}", summary) |
| .replace("{full_text}", fulltext) |
| ) |
|
|
|
|
| def load_verified_rows(path: str) -> List[Dict[str, Any]]: |
| if not os.path.exists(path): |
| raise FileNotFoundError(f"Input file not found: {path}") |
| with open(path, "r", encoding="utf-8") as f: |
| parsed = json.load(f) |
| if not isinstance(parsed, list): |
| raise ValueError(f"Expected top-level JSON array in {path}") |
| return [row for row in parsed if isinstance(row, dict)] |
|
|
|
|
| def parse_models(models_arg: str) -> List[str]: |
| models = [m.strip() for m in models_arg.split(",") if m.strip()] |
| if not models: |
| raise ValueError("No models provided. Example: --models gpt-5-mini,gpt-5-nano") |
| return models |
|
|
|
|
| def _clean_json_block(text: str) -> str: |
| cleaned = text.strip() |
| if "```json" in cleaned: |
| cleaned = cleaned.split("```json", 1)[1].split("```", 1)[0].strip() |
| elif "```" in cleaned: |
| cleaned = cleaned.split("```", 1)[1].split("```", 1)[0].strip() |
| return cleaned |
|
|
|
|
| def extract_generated_text(raw_response: str, expected_label: str) -> str: |
| cleaned = _clean_json_block(raw_response) |
| try: |
| parsed = json.loads(cleaned) |
| except json.JSONDecodeError: |
| return raw_response.strip() |
|
|
| if isinstance(parsed, dict): |
| value = parsed.get(expected_label) |
| if isinstance(value, str) and value.strip(): |
| return value.strip() |
| return raw_response.strip() |
|
|
|
|
| def compute_cost(model: str, input_tokens: int, output_tokens: int, |
| cached_input_tokens: int = 0) -> float: |
| pricing = PRICING_PER_1M.get(model) |
| if pricing is None: |
| return 0.0 |
| non_cached_input = max(0, input_tokens - cached_input_tokens) |
| cost = ( |
| non_cached_input * pricing["input"] / 1_000_000 |
| + cached_input_tokens * pricing["cached_input"] / 1_000_000 |
| + output_tokens * pricing["output"] / 1_000_000 |
| ) |
| return cost |
|
|
|
|
| def call_chat_completion( |
| *, |
| api_base: str, |
| api_key: str, |
| model: str, |
| prompt: str, |
| temperature: float, |
| timeout_seconds: int, |
| max_retries: int, |
| retry_wait_seconds: float, |
| ) -> Dict[str, Any]: |
| url = api_base.rstrip("/") + "/chat/completions" |
| payload = { |
| "model": model, |
| "messages": [{"role": "user", "content": prompt}], |
| } |
| data = json.dumps(payload).encode("utf-8") |
|
|
| last_error: Optional[Exception] = None |
| for attempt in range(max_retries + 1): |
| req = urllib.request.Request(url, data=data, method="POST") |
| req.add_header("Content-Type", "application/json") |
| if api_key: |
| req.add_header("Authorization", f"Bearer {api_key}") |
| try: |
| with urllib.request.urlopen(req, timeout=timeout_seconds) as resp: |
| body = resp.read().decode("utf-8") |
| parsed = json.loads(body) |
| content = str(parsed["choices"][0]["message"]["content"]).strip() |
| usage = parsed.get("usage", {}) |
| return { |
| "content": content, |
| "prompt_tokens": usage.get("prompt_tokens", 0), |
| "completion_tokens": usage.get("completion_tokens", 0), |
| "cached_tokens": usage.get("prompt_tokens_details", {}).get( |
| "cached_tokens", 0 |
| ), |
| } |
| except urllib.error.HTTPError as exc: |
| retriable = exc.code in (408, 409, 429, 500, 502, 503, 504) |
| last_error = exc |
| if attempt < max_retries and retriable: |
| time.sleep(retry_wait_seconds) |
| continue |
| raise |
| except (urllib.error.URLError, KeyError, IndexError, json.JSONDecodeError) as exc: |
| last_error = exc |
| if attempt < max_retries: |
| time.sleep(retry_wait_seconds) |
| continue |
| raise |
|
|
| if last_error: |
| raise last_error |
| raise RuntimeError("Unknown error during chat completion call.") |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| if not args.api_key: |
| raise ValueError("Missing API key. Set OPENAI_API_KEY or pass --api-key.") |
|
|
| for path in ( |
| args.prompt_low_path, |
| args.prompt_intermediate_path, |
| args.prompt_proficient_path, |
| ): |
| if not os.path.exists(path): |
| raise FileNotFoundError(f"Prompt file not found: {path}") |
|
|
| check_api_base(args.api_base, args.api_key, args.timeout_seconds) |
| models = parse_models(args.models) |
| templates = load_prompt_templates(args) |
| rows = load_verified_rows(args.input_path) |
|
|
| parsed_items: List[Dict[str, Any]] = [] |
| for idx, row in enumerate(rows): |
| gold_label = str(row.get("label", "")).strip() |
| fulltext = str(row.get("fulltext", "")).strip() |
| summary = str(row.get("summary", "")).strip() |
| if gold_label not in VALID_LABELS: |
| continue |
| if not fulltext or not summary: |
| continue |
| source_lang = infer_source_lang(fulltext) |
| prompt = build_prompt( |
| template=templates[gold_label], |
| fulltext=fulltext, |
| summary=summary, |
| source_lang=source_lang, |
| ) |
| parsed_items.append( |
| { |
| "row_index": idx, |
| "doc_id": row.get("doc_id"), |
| "gold_label": gold_label, |
| "source_lang": source_lang, |
| "prompt": prompt, |
| } |
| ) |
|
|
| if args.max_samples > 0: |
| parsed_items = parsed_items[: args.max_samples] |
| if not parsed_items: |
| raise RuntimeError("No valid rows found in input file.") |
|
|
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") |
| os.makedirs(args.output_dir, exist_ok=True) |
| summary_path = os.path.join(args.output_dir, f"gpt5_inference_summary_wo_gs_{ts}.json") |
| combined_path = os.path.join(args.output_dir, f"gpt5_inference_all_wo_gs_{ts}.jsonl") |
|
|
| combined_records: List[Dict[str, Any]] = [] |
| model_stats: Dict[str, Dict[str, Any]] = {} |
| total_cost = 0.0 |
| total_input_tokens = 0 |
| total_output_tokens = 0 |
| budget_exceeded = False |
|
|
| def _save_outputs() -> None: |
| with open(combined_path, "w", encoding="utf-8") as f_all: |
| for rec in combined_records: |
| f_all.write(json.dumps(rec, ensure_ascii=False) + "\n") |
| summary_obj = { |
| "input_path": args.input_path, |
| "api_base": args.api_base, |
| "models": models, |
| "max_samples": args.max_samples, |
| "temperature": args.temperature, |
| "cost_limit_usd": args.cost_limit, |
| "total_cost_usd": round(total_cost, 6), |
| "total_input_tokens": total_input_tokens, |
| "total_output_tokens": total_output_tokens, |
| "budget_exceeded": budget_exceeded, |
| "total_dataset_rows_used": len(parsed_items), |
| "combined_output_path": combined_path, |
| "model_stats": model_stats, |
| } |
| with open(summary_path, "w", encoding="utf-8") as f_summary: |
| json.dump(summary_obj, f_summary, ensure_ascii=False, indent=2) |
|
|
| for model in models: |
| if budget_exceeded: |
| break |
|
|
| model_slug = model.replace("/", "_") |
| model_output_path = os.path.join( |
| args.output_dir, f"gpt5_inference_{model_slug}_{ts}.jsonl" |
| ) |
| success_count = 0 |
| error_count = 0 |
| model_cost = 0.0 |
| model_input_tokens = 0 |
| model_output_tokens = 0 |
|
|
| with open(model_output_path, "w", encoding="utf-8") as f_model: |
| total = len(parsed_items) |
| progress_iter = tqdm( |
| parsed_items, |
| total=total, |
| desc=f"{model}", |
| unit="item", |
| ) |
| for item in progress_iter: |
|
|
| record: Dict[str, Any] = { |
| "model": model, |
| "row_index": item["row_index"], |
| "doc_id": item.get("doc_id"), |
| "gold_label": item["gold_label"], |
| "source_lang": item["source_lang"], |
| "prompt": item["prompt"], |
| } |
| try: |
| result = call_chat_completion( |
| api_base=args.api_base, |
| api_key=args.api_key, |
| model=model, |
| prompt=item["prompt"], |
| temperature=args.temperature, |
| timeout_seconds=args.timeout_seconds, |
| max_retries=args.max_retries, |
| retry_wait_seconds=args.retry_wait_seconds, |
| ) |
| raw_response = result["content"] |
| p_tokens = result["prompt_tokens"] |
| c_tokens = result["completion_tokens"] |
| cached = result["cached_tokens"] |
|
|
| call_cost = compute_cost(model, p_tokens, c_tokens, cached) |
| total_cost += call_cost |
| model_cost += call_cost |
| total_input_tokens += p_tokens |
| total_output_tokens += c_tokens |
| model_input_tokens += p_tokens |
| model_output_tokens += c_tokens |
|
|
| generated_text = extract_generated_text(raw_response, item["gold_label"]) |
| record["prediction"] = raw_response |
| record["generated_text"] = generated_text |
| record["error"] = "" |
| record["prompt_tokens"] = p_tokens |
| record["completion_tokens"] = c_tokens |
| record["call_cost_usd"] = round(call_cost, 6) |
| success_count += 1 |
| except Exception as exc: |
| record["prediction"] = "" |
| record["generated_text"] = "" |
| record["error"] = f"{type(exc).__name__}: {exc}" |
| record["prompt_tokens"] = 0 |
| record["completion_tokens"] = 0 |
| record["call_cost_usd"] = 0.0 |
| error_count += 1 |
|
|
| f_model.write(json.dumps(record, ensure_ascii=False) + "\n") |
| combined_records.append(record) |
|
|
| progress_iter.set_postfix( |
| cost=f"${total_cost:.4f}", |
| limit=f"${args.cost_limit:.2f}", |
| ) |
|
|
| if total_cost >= args.cost_limit: |
| budget_exceeded = True |
| print( |
| f"\n[BUDGET] Cost ${total_cost:.4f} reached limit " |
| f"${args.cost_limit:.2f}. Saving data and stopping." |
| ) |
| break |
|
|
| model_stats[model] = { |
| "output_path": model_output_path, |
| "total_rows": len(parsed_items), |
| "rows_processed": success_count + error_count, |
| "success_count": success_count, |
| "error_count": error_count, |
| "model_cost_usd": round(model_cost, 6), |
| "model_input_tokens": model_input_tokens, |
| "model_output_tokens": model_output_tokens, |
| } |
| print( |
| f"[DONE] {model} | cost: ${model_cost:.4f} | " |
| f"output: {model_output_path}" |
| ) |
|
|
| _save_outputs() |
|
|
| print(f"\n[COST] Total API cost: ${total_cost:.4f} / ${args.cost_limit:.2f} limit") |
| print(f"[COST] Total tokens — input: {total_input_tokens}, output: {total_output_tokens}") |
| if budget_exceeded: |
| print("[COST] Budget exceeded — run stopped early. All data saved.") |
| print(f"[DONE] Combined output: {combined_path}") |
| print(f"[DONE] Summary output: {summary_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|