| import argparse | |
| import json | |
| import os | |
| import re | |
| from datetime import datetime | |
| from typing import Any, Dict, List, Optional | |
| import pandas as pd | |
| import requests | |
| from tqdm import tqdm | |
| from transformers import AutoTokenizer | |
| DEFAULT_MODEL_PATH = "Qwen/Qwen3-4B-Instruct-2507" | |
| DEFAULT_DATASET_PATH = ( | |
| "/home/mshahidul/readctrl/code/text_classifier/data/verified_combined_0-80_clean200.json" | |
| ) | |
| DEFAULT_OUTPUT_DIR = "/home/mshahidul/readctrl/code/RL_model/inference_data" | |
| DEFAULT_BASE_URL = "http://127.0.0.1:8001/v1" | |
| DEFAULT_SERVED_MODEL_NAME = "inference" | |
| DEFAULT_PROMPT_LOW_PATH = ( | |
| "/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt_low" | |
| ) | |
| DEFAULT_PROMPT_INTERMEDIATE_PATH = ( | |
| "/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt_intermediate" | |
| ) | |
| DEFAULT_PROMPT_PROFICIENT_PATH = ( | |
| "/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt_proficient" | |
| ) | |
| VALID_LABELS = { | |
| "low_health_literacy", | |
| "intermediate_health_literacy", | |
| "proficient_health_literacy", | |
| } | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Run batched inference via vLLM OpenAI-compatible server.") | |
| parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL_PATH, help="Local path for tokenizer/chat template.") | |
| parser.add_argument("--dataset_path", type=str, default=DEFAULT_DATASET_PATH) | |
| parser.add_argument("--prompt-low-path", type=str, default=DEFAULT_PROMPT_LOW_PATH) | |
| parser.add_argument("--prompt-intermediate-path", type=str, default=DEFAULT_PROMPT_INTERMEDIATE_PATH) | |
| parser.add_argument("--prompt-proficient-path", type=str, default=DEFAULT_PROMPT_PROFICIENT_PATH) | |
| parser.add_argument("--output_dir", type=str, default=DEFAULT_OUTPUT_DIR) | |
| parser.add_argument("--base_url", type=str, default=DEFAULT_BASE_URL, help="vLLM OpenAI base URL, e.g. http://127.0.0.1:8000/v1") | |
| parser.add_argument("--served_model_name", type=str, default=DEFAULT_SERVED_MODEL_NAME, help="Model name exposed by vLLM server.") | |
| parser.add_argument("--batch_size", type=int, default=8) | |
| parser.add_argument("--max_samples", type=int, default=-1, help="Use -1 for full dataset.") | |
| parser.add_argument("--max_tokens", type=int, default=1024) | |
| parser.add_argument("--temperature", type=float, default=0.7) | |
| parser.add_argument("--top_p", type=float, default=0.8) | |
| parser.add_argument("--api_key", type=str, default="EMPTY") | |
| parser.add_argument("--timeout_sec", type=int, default=300) | |
| return parser.parse_args() | |
| def load_prompt_templates(args: argparse.Namespace) -> Dict[str, str]: | |
| prompt_path_by_label = { | |
| "low_health_literacy": args.prompt_low_path, | |
| "intermediate_health_literacy": args.prompt_intermediate_path, | |
| "proficient_health_literacy": args.prompt_proficient_path, | |
| } | |
| templates: Dict[str, str] = {} | |
| for label, path in prompt_path_by_label.items(): | |
| if not os.path.exists(path): | |
| raise FileNotFoundError(f"Prompt file not found: {path}") | |
| with open(path, "r", encoding="utf-8") as f: | |
| templates[label] = f.read() | |
| return templates | |
| def load_verified_rows(path: str) -> List[Dict[str, Any]]: | |
| if not os.path.exists(path): | |
| raise FileNotFoundError(f"Input file not found: {path}") | |
| with open(path, "r", encoding="utf-8") as f: | |
| parsed = json.load(f) | |
| if not isinstance(parsed, list): | |
| raise ValueError(f"Expected top-level JSON array in {path}") | |
| return [row for row in parsed if isinstance(row, dict)] | |
| def infer_source_lang(fulltext: str) -> str: | |
| if fulltext and any("a" <= ch.lower() <= "z" for ch in fulltext): | |
| return "English" | |
| return "Unknown" | |
| def build_prompt(template: str, fulltext: str, summary: str, source_lang: str) -> str: | |
| return ( | |
| template.replace("{source_lang}", source_lang) | |
| .replace("{gold_summary}", summary) | |
| .replace("{full_text}", fulltext) | |
| ) | |
| def _clean_json_block(text: str) -> str: | |
| cleaned = text.strip() | |
| if "```json" in cleaned: | |
| cleaned = cleaned.split("```json", 1)[1].split("```", 1)[0].strip() | |
| elif "```" in cleaned: | |
| cleaned = cleaned.split("```", 1)[1].split("```", 1)[0].strip() | |
| return cleaned | |
| def extract_generated_text(raw_response: str, expected_label: str) -> str: | |
| cleaned = _clean_json_block(raw_response) | |
| try: | |
| parsed = json.loads(cleaned) | |
| except json.JSONDecodeError: | |
| return raw_response.strip() | |
| if isinstance(parsed, dict): | |
| value = parsed.get(expected_label) | |
| if isinstance(value, str) and value.strip(): | |
| return value.strip() | |
| return raw_response.strip() | |
| def _normalize_messages(prompt_obj: Any) -> List[Dict[str, str]]: | |
| if hasattr(prompt_obj, "tolist"): | |
| prompt_obj = prompt_obj.tolist() | |
| if isinstance(prompt_obj, dict): | |
| if "role" in prompt_obj and "content" in prompt_obj: | |
| return [{"role": str(prompt_obj["role"]), "content": str(prompt_obj["content"])}] | |
| return [{"role": "user", "content": json.dumps(prompt_obj, ensure_ascii=False)}] | |
| if isinstance(prompt_obj, list): | |
| messages = [] | |
| for item in prompt_obj: | |
| if isinstance(item, dict) and "role" in item and "content" in item: | |
| messages.append({"role": str(item["role"]), "content": str(item["content"])}) | |
| else: | |
| messages.append({"role": "user", "content": str(item)}) | |
| if messages: | |
| return messages | |
| return [{"role": "user", "content": str(prompt_obj)}] | |
| def build_prompt_text(tokenizer: AutoTokenizer, prompt_obj: Any) -> str: | |
| messages = _normalize_messages(prompt_obj) | |
| if tokenizer.chat_template: | |
| return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| return "\n".join(m["content"] for m in messages) + "\n\nAssistant:" | |
| def sanitize_model_tag(model_path: str, max_len: int = 80) -> str: | |
| tag = re.sub(r"[^A-Za-z0-9]+", "-", model_path).strip("-").lower() | |
| if not tag: | |
| return "unknown-model" | |
| if len(tag) > max_len: | |
| return tag[:max_len].rstrip("-") | |
| return tag | |
| def check_server(base_url: str, headers: Dict[str, str], timeout_sec: int) -> Optional[List[Dict[str, Any]]]: | |
| models_url = f"{base_url.rstrip('/')}/models" | |
| resp = requests.get(models_url, headers=headers, timeout=timeout_sec) | |
| resp.raise_for_status() | |
| payload = resp.json() | |
| return payload.get("data", []) | |
| def batched_completion_request( | |
| base_url: str, | |
| headers: Dict[str, str], | |
| model_name: str, | |
| prompts: List[str], | |
| max_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| timeout_sec: int, | |
| ) -> List[str]: | |
| payload = { | |
| "model": model_name, | |
| "prompt": prompts, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| } | |
| url = f"{base_url.rstrip('/')}/completions" | |
| resp = requests.post(url, headers=headers, json=payload, timeout=timeout_sec) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| choices = data.get("choices", []) | |
| preds = [""] * len(prompts) | |
| for choice in choices: | |
| idx = choice.get("index", None) | |
| text = str(choice.get("text", "")).strip() | |
| if isinstance(idx, int) and 0 <= idx < len(preds) and not preds[idx]: | |
| preds[idx] = text | |
| if any(not p for p in preds): | |
| fallback_texts = [str(c.get("text", "")).strip() for c in choices] | |
| for i in range(len(preds)): | |
| if not preds[i]: | |
| preds[i] = fallback_texts[i] if i < len(fallback_texts) else "" | |
| return preds | |
| def main() -> None: | |
| args = parse_args() | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| run_ts = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| model_tag = sanitize_model_tag(args.model_path) | |
| output_jsonl = os.path.join(args.output_dir, f"vllm_inference_{model_tag}_{run_ts}.jsonl") | |
| output_parquet = os.path.join(args.output_dir, f"vllm_inference_{model_tag}_{run_ts}.parquet") | |
| meta_path = os.path.join(args.output_dir, f"vllm_inference_{model_tag}_{run_ts}_meta.json") | |
| headers = { | |
| "Authorization": f"Bearer {args.api_key}", | |
| "Content-Type": "application/json", | |
| } | |
| print(f"[INFO] Checking vLLM server: {args.base_url}") | |
| models = check_server(args.base_url, headers=headers, timeout_sec=args.timeout_sec) | |
| available_model_ids = [m.get("id", "") for m in models or []] | |
| print(f"[INFO] Server models: {available_model_ids}") | |
| if args.served_model_name not in available_model_ids: | |
| print( | |
| f"[WARN] Served model '{args.served_model_name}' not found in /models. " | |
| "Will still try requests with provided name." | |
| ) | |
| print(f"[INFO] Loading tokenizer from: {args.model_path}") | |
| tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) | |
| print(f"[INFO] Reading dataset: {args.dataset_path}") | |
| templates = load_prompt_templates(args) | |
| rows = load_verified_rows(args.dataset_path) | |
| parsed_items: List[Dict[str, Any]] = [] | |
| for idx, row in enumerate(rows): | |
| gold_label = str(row.get("label", "")).strip() | |
| fulltext = str(row.get("fulltext", "")).strip() | |
| summary = str(row.get("summary", "")).strip() | |
| if gold_label not in VALID_LABELS: | |
| continue | |
| if not fulltext or not summary: | |
| continue | |
| source_lang = infer_source_lang(fulltext) | |
| prompt = build_prompt( | |
| template=templates[gold_label], | |
| fulltext=fulltext, | |
| summary=summary, | |
| source_lang=source_lang, | |
| ) | |
| parsed_items.append( | |
| { | |
| "row_index": idx, | |
| "doc_id": row.get("doc_id"), | |
| "gold_label": gold_label, | |
| "source_lang": source_lang, | |
| "prompt": prompt, | |
| } | |
| ) | |
| df = pd.DataFrame(parsed_items) | |
| if args.max_samples > 0: | |
| df = df.head(args.max_samples) | |
| print(f"[INFO] Rows to process: {len(df)}") | |
| if df.empty: | |
| raise RuntimeError("No valid rows found in input file.") | |
| outputs: List[Dict[str, Any]] = [] | |
| with open(output_jsonl, "w", encoding="utf-8") as f_out: | |
| for start in tqdm(range(0, len(df), args.batch_size), desc="Batches"): | |
| batch_df = df.iloc[start : start + args.batch_size] | |
| prompts = [build_prompt_text(tokenizer, row.get("prompt", "")) for _, row in batch_df.iterrows()] | |
| preds = batched_completion_request( | |
| base_url=args.base_url, | |
| headers=headers, | |
| model_name=args.served_model_name, | |
| prompts=prompts, | |
| max_tokens=args.max_tokens, | |
| temperature=args.temperature, | |
| top_p=args.top_p, | |
| timeout_sec=args.timeout_sec, | |
| ) | |
| for (row_idx, row), pred in zip(batch_df.iterrows(), preds): | |
| gold_label = str(row.get("gold_label", "")) | |
| record = { | |
| "row_index": int(row.get("row_index", row_idx)), | |
| "doc_id": row.get("doc_id"), | |
| "gold_label": gold_label, | |
| "source_lang": row.get("source_lang"), | |
| "prediction": pred, | |
| "generated_text": extract_generated_text(pred, gold_label) if gold_label else pred.strip(), | |
| } | |
| outputs.append(record) | |
| f_out.write(json.dumps(record, ensure_ascii=False) + "\n") | |
| pd.DataFrame(outputs).to_parquet(output_parquet, index=False) | |
| with open(meta_path, "w", encoding="utf-8") as f_meta: | |
| json.dump( | |
| { | |
| "model_path_for_tokenizer": args.model_path, | |
| "dataset_path": args.dataset_path, | |
| "base_url": args.base_url, | |
| "served_model_name": args.served_model_name, | |
| "batch_size": args.batch_size, | |
| "num_samples": len(outputs), | |
| "output_jsonl": output_jsonl, | |
| "output_parquet": output_parquet, | |
| }, | |
| f_meta, | |
| ensure_ascii=False, | |
| indent=2, | |
| ) | |
| print("[DONE] vLLM batch inference complete.") | |
| print(f"[DONE] JSONL: {output_jsonl}") | |
| print(f"[DONE] Parquet: {output_parquet}") | |
| print(f"[DONE] Meta: {meta_path}") | |
| if __name__ == "__main__": | |
| main() | |