diff --git a/code/RL_model/verl/verl_train/reward_func/reward_func/evaluate_scores.py b/code/RL_model/verl/verl_train/reward_func/reward_func/evaluate_scores.py new file mode 100644 index 0000000000000000000000000000000000000000..2259c2602890dd6729babd89cfe8a2250aa880ac --- /dev/null +++ b/code/RL_model/verl/verl_train/reward_func/reward_func/evaluate_scores.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python3 +""" +Standalone evaluation script for computing factuality, hallucination, and +classifier scores on a JSON file. + +Expected input JSON: a list of objects, each with: + - fulltext : str (source document) + - fulltext_subclaims : list[str] (subclaims from fulltext, optional) + - summary_text : str (summary / reference text) + - summary_subclaims : list[str] (subclaims from summary) + - generated_text : str (model-generated text to evaluate) + - label : str (target literacy level, e.g. "low_health_literacy") + +The script reuses the reward functions from reward_new_v6_bn_v4_rmv_src_cov.py: + - factuality_score : fraction of summary subclaims supported by generated_text + - hallucination_score: fraction of gen subclaims NOT supported by fulltext + - classifier_score : whether generated_text matches the target literacy level + +Requires the same vLLM endpoints as the reward file: + - Support checker : VLLM_SUPPORT_CHECK_BN_API_BASE (default http://localhost:8090/v1) + - Classifier : VLLM_CLASSIFIER_BN_API_BASE (default http://localhost:8040/v1) + - Subclaim extractor: VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE (default http://localhost:8050/v1) + +Usage: + python evaluate_scores.py --input data.json [--output results.json] [--target-level low_health_literacy] +""" + +import argparse +import json +import os +import sys +import time +from typing import Any, Dict, List, Optional + +from tqdm import tqdm + +# Import scoring utilities from the reward module (same directory). +from reward_new_v6_bn_v4_rmv_src_cov import ( + _call_support_api, + _compute_classifier_reward, + _extract_subclaims_from_text, + _is_bangla_text, + _nonlinear_grounding, + compute_rewards, +) + + +def evaluate_single( + item: Dict[str, Any], + target_level_override: Optional[str] = None, +) -> Dict[str, Any]: + """ + Evaluate a single item and return detailed scores. + """ + fulltext = item.get("fulltext", "") + summary_text = item.get("summary_text") or item.get("summary", "") + summary_subclaims = item.get("summary_subclaims", []) + generated_text = item.get("generated_text") or item.get("predicted_gen_text", "") + target_level = target_level_override or item.get("label", "") + + result: Dict[str, Any] = { + "doc_id": item.get("doc_id", ""), + "target_level": target_level, + "generated_text_len": len(generated_text.strip()) if generated_text else 0, + "factuality_score": None, + "hallucination_score": None, + "classifier_score": None, + "grounding_score": None, + "factuality_supported": 0, + "total_summary_subclaims": len(summary_subclaims) if summary_subclaims else 0, + "hallucination_supported": 0, + "total_gen_segments": 0, + "skipped": False, + "skip_reason": "", + } + + if not generated_text or len(generated_text.strip()) < 10: + result["skipped"] = True + result["skip_reason"] = "generated_text missing or too short (<10 chars)" + return result + + # -- Factuality & Hallucination via compute_rewards -- + rewards = compute_rewards( + fulltext=fulltext, + generated_text=generated_text, + target_level=target_level, + summary_subclaims=summary_subclaims, + summary_text=summary_text, + ) + + factuality_score = rewards["factuality_score"] + h_score = rewards["hallucination_score"] + + if factuality_score is None: + factuality_score = 0.5 + if h_score is None: + h_score = 0.5 + + grounding_score = _nonlinear_grounding(h_score) + + # -- Classifier -- + input_text = fulltext or "" + class_score = _compute_classifier_reward(target_level, generated_text, input_text) + + result.update({ + "factuality_score": round(factuality_score, 4), + "hallucination_score": round(h_score, 4), + "grounding_score": round(grounding_score, 4), + "classifier_score": round(class_score, 4), + "factuality_supported": rewards.get("factuality_supported", 0), + "total_summary_subclaims": rewards.get("total_summary_subclaims", 0), + "hallucination_supported": rewards.get("hallucination_supported", 0), + "total_gen_segments": rewards.get("total_gen_segments", 0), + }) + + return result + + +def compute_aggregate(results: List[Dict[str, Any]]) -> Dict[str, Any]: + """Compute aggregate statistics over all evaluated items.""" + scored = [r for r in results if not r.get("skipped", False)] + n = len(scored) + total = len(results) + skipped = total - n + + if n == 0: + return { + "total_items": total, + "scored_items": 0, + "skipped_items": skipped, + "avg_factuality_score": None, + "avg_hallucination_score": None, + "avg_grounding_score": None, + "avg_classifier_score": None, + } + + def safe_avg(key): + vals = [r[key] for r in scored if r[key] is not None] + return round(sum(vals) / len(vals), 4) if vals else None + + return { + "total_items": total, + "scored_items": n, + "skipped_items": skipped, + "avg_factuality_score": safe_avg("factuality_score"), + "avg_hallucination_score": safe_avg("hallucination_score"), + "avg_grounding_score": safe_avg("grounding_score"), + "avg_classifier_score": safe_avg("classifier_score"), + } + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate factuality, hallucination, and classifier scores on a JSON file." + ) + parser.add_argument( + "--input", "-i", required=True, + help="Path to input JSON file (list of objects).", + ) + parser.add_argument( + "--output", "-o", default=None, + help="Path to output JSON file with per-item scores. " + "Defaults to _eval_results.json.", + ) + parser.add_argument( + "--target-level", "-t", default=None, + help="Override target literacy level for all items " + "(e.g. low_health_literacy). If not set, uses each item's 'label' field.", + ) + parser.add_argument( + "--support-check-url", default=None, + help="Override VLLM_SUPPORT_CHECK_BN_API_BASE.", + ) + parser.add_argument( + "--classifier-url", default=None, + help="Override VLLM_CLASSIFIER_BN_API_BASE.", + ) + parser.add_argument( + "--subclaim-extractor-url", default=None, + help="Override VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE.", + ) + args = parser.parse_args() + + if args.support_check_url: + os.environ["VLLM_SUPPORT_CHECK_BN_API_BASE"] = args.support_check_url + if args.classifier_url: + os.environ["VLLM_CLASSIFIER_BN_API_BASE"] = args.classifier_url + if args.subclaim_extractor_url: + os.environ["VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE"] = args.subclaim_extractor_url + + # Load input + with open(args.input, "r", encoding="utf-8") as f: + data = json.load(f) + + if not isinstance(data, list): + print(f"Error: Expected a JSON list, got {type(data).__name__}.", file=sys.stderr) + sys.exit(1) + + print(f"Loaded {len(data)} items from {args.input}") + print(f" Support check API : {os.getenv('VLLM_SUPPORT_CHECK_BN_API_BASE', 'http://localhost:8090/v1')}") + print(f" Classifier API : {os.getenv('VLLM_CLASSIFIER_BN_API_BASE', 'http://localhost:8040/v1')}") + print(f" Subclaim extractor: {os.getenv('VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE', 'http://localhost:8050/v1')}") + if args.target_level: + print(f" Target level override: {args.target_level}") + print("-" * 60) + + # Evaluate each item + results = [] + start_time = time.time() + for idx, item in enumerate(tqdm(data, desc="Evaluating")): + r = evaluate_single(item, target_level_override=args.target_level) + r["index"] = idx + results.append(r) + + if (idx + 1) % 10 == 0 or idx == 0: + partial_agg = compute_aggregate(results) + tqdm.write( + f" [{idx+1}/{len(data)}] " + f"fact={partial_agg['avg_factuality_score']} " + f"hallu={partial_agg['avg_hallucination_score']} " + f"cls={partial_agg['avg_classifier_score']}" + ) + + elapsed = time.time() - start_time + + # Aggregate + agg = compute_aggregate(results) + + # Output path + if args.output: + out_path = args.output + else: + stem = os.path.splitext(os.path.basename(args.input))[0] + out_dir = os.path.dirname(args.input) or "." + out_path = os.path.join(out_dir, f"{stem}_eval_results.json") + + output = { + "input_file": os.path.abspath(args.input), + "target_level_override": args.target_level, + "elapsed_seconds": round(elapsed, 2), + "aggregate": agg, + "per_item": results, + } + + with open(out_path, "w", encoding="utf-8") as f: + json.dump(output, f, indent=2, ensure_ascii=False) + + # Print summary + print("\n" + "=" * 60) + print("EVALUATION SUMMARY") + print("=" * 60) + print(f" Total items : {agg['total_items']}") + print(f" Scored items : {agg['scored_items']}") + print(f" Skipped items : {agg['skipped_items']}") + print(f" Elapsed time : {round(elapsed, 1)}s") + print("-" * 60) + print(f" Avg Factuality Score : {agg['avg_factuality_score']}") + print(f" Avg Hallucination Score: {agg['avg_hallucination_score']}") + print(f" Avg Grounding Score : {agg['avg_grounding_score']}") + print(f" Avg Classifier Score : {agg['avg_classifier_score']}") + print("-" * 60) + print(f" Results saved to: {out_path}") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/code/RL_model/verl/verl_train/reward_func/reward_func/reward_new_v6_bn_v4.py b/code/RL_model/verl/verl_train/reward_func/reward_func/reward_new_v6_bn_v4.py index 3b6590fe4bf8439b432d04b0796aaacd6ebdc979..a6954101e4cd32be4ae348b469da4365aaad1779 100644 --- a/code/RL_model/verl/verl_train/reward_func/reward_func/reward_new_v6_bn_v4.py +++ b/code/RL_model/verl/verl_train/reward_func/reward_func/reward_new_v6_bn_v4.py @@ -844,10 +844,10 @@ def compute_score(data_source, solution_str, ground_truth, extra_info=None): the generated text; must stay within a level-specific [min, max] range — too little OR too much is penalised. """ - W_FACTUALITY = 0.20 - W_HALLU = 0.20 - W_SRC_COV = 0.20 - W_CLASSIFIER = 0.25 + W_FACTUALITY = 0.25 + W_HALLU = 0.15 + W_SRC_COV = 0.25 + W_CLASSIFIER = 0.20 W_LENGTH = 0.15 FAIL = { diff --git a/code/RL_model/verl/verl_train/reward_func/reward_func/reward_new_v6_bn_v4_rmv_src_cov.py b/code/RL_model/verl/verl_train/reward_func/reward_func/reward_new_v6_bn_v4_rmv_src_cov.py new file mode 100644 index 0000000000000000000000000000000000000000..089a8bf00932411a21056236f7522b0300cc9a2f --- /dev/null +++ b/code/RL_model/verl/verl_train/reward_func/reward_func/reward_new_v6_bn_v4_rmv_src_cov.py @@ -0,0 +1,835 @@ +import ast +import os +import re +import json +import argparse +from typing import Any, List, Dict, Optional +import warnings +import requests +warnings.filterwarnings("ignore") + +# vLLM endpoints (see script/vllm.sh: support-check port 8090, classifier port 8040) +# Use localhost when reward runs on same machine; set env vars for remote (e.g. http://HOST:8090/v1). +VLLM_SUPPORT_CHECK_BN_API_BASE = os.getenv( + "VLLM_SUPPORT_CHECK_BN_API_BASE", + "http://localhost:8090/v1", +) +# Both support-check and classifier use Bangla prompts. +SUPPORT_CHECK_PROMPT_LANGUAGE = "bn" + +VLLM_CLASSIFIER_BN_API_BASE = os.getenv( + "VLLM_CLASSIFIER_BN_API_BASE", + "http://localhost:8040/v1", +) + +# Subclaim-extractor vLLM endpoint (Bangla medical text → subclaim list) +VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE = os.getenv( + "VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE", + "http://localhost:8050/v1", +) + + +# --------------------------------------------------------------------------- +# Support check via vLLM (Gemma-3 fine-tuned; same prompt as support_check gemma3-finetune) +# --------------------------------------------------------------------------- + +def _build_support_list_user_prompt(context: str, subclaims: List[str]) -> str: + """Build list-of-subclaims support prompt in Bangla (matches support_check model_finetune/gemma3-finetune.py).""" + numbered = "\n".join(f"{idx + 1}. {sc}" for idx, sc in enumerate(subclaims)) + return ( + "আপনাকে একটি মেডিকেল কেস বর্ণনা এবং একাধিক সাবক্লেইমের তালিকা দেওয়া হবে। " + "প্রতিটি সাবক্লেইম টেক্সট দ্বারা সমর্থিত কি না তা নির্ধারণ করুন।\n\n" + f"টেক্সট:\n{context}\n\n" + f"সাবক্লেইমগুলোর তালিকা:\n{numbered}\n\n" + "প্রতিটি সাবক্লেইমের জন্য ক্রমানুসারে লেবেল দিন। " + "নির্দিষ্টভাবে একটি JSON array আকারে উত্তর দিন, যেমন:\n" + '["supported", "not_supported", ...]\n' + "অন্য কিছু লিখবেন না।" + ) + + +def _parse_support_label_array(raw_text: str) -> List[str]: + """Parse model output to list of 'supported' | 'not_supported' (matches support_check gemma3-finetune).""" + text = (raw_text or "").strip() + if not text: + return [] + if "```" in text: + text = text.replace("```json", "").replace("```", "").strip() + start = text.find("[") + end = text.rfind("]") + if start != -1 and end != -1 and end > start: + text = text[start : end + 1] + parsed = None + for parser in (json.loads, ast.literal_eval): + try: + parsed = parser(text) + break + except Exception: + continue + if not isinstance(parsed, list): + return [] + # import ipdb; ipdb.set_trace() + normalized = [] + for item in parsed: + if not isinstance(item, str): + normalized.append("invalid") + continue + + label = item.strip().lower().replace("-", "_").replace(" ", "_") + # Strict keyword check: if not one of these two, it is invalid + if label in {"supported", "not_supported"}: + normalized.append(label) + else: + normalized.append("invalid") + return normalized + + +def _format_gemma3_for_support(user_message: str) -> str: + """Format user message for Gemma-3 chat (vLLM).""" + return ( + f"user\n{user_message}\n" + "model\n" + ) + + +def _call_vllm_support_check(prompt: str, max_tokens: int = 512, timeout: float = 120.0) -> Optional[str]: + """Call vLLM completions API for support-check model. Returns generated text or None.""" + base = VLLM_SUPPORT_CHECK_BN_API_BASE.rstrip("/") + url = f"{base}/completions" + payload = { + "prompt": prompt, + "max_tokens": max_tokens, + "temperature": 0.0, + "stop": ["", "", "\n\n"], + } + try: + resp = requests.post(url, json=payload, timeout=timeout) + resp.raise_for_status() + data = resp.json() + choices = data.get("choices") + # import ipdb; ipdb.set_trace() + if choices and len(choices) > 0: + text = choices[0].get("text", "") + return (text or "").strip() + return None + except requests.exceptions.RequestException: + return None + + +def _call_support_api( + context: str, + subclaims: List[str], + threshold: float = 0.5, + batch_size: int = 128, +) -> Optional[List[str]]: + """ + Call the BN support-check model via vLLM (Gemma-3 fine-tuned, subclaim_list mode). + Same prompt as support_check/model_finetune/gemma3-finetune.py. + + Returns + ------- + List[str] : one label per subclaim — "supported" | "not_supported" | "invalid". + None : on total failure (network or empty/unparseable response). + """ + if not context or not subclaims: + return ["invalid"] * len(subclaims) + + user_prompt = _build_support_list_user_prompt(context, subclaims) + prompt = _format_gemma3_for_support(user_prompt) + raw = _call_vllm_support_check(prompt) + # import ipdb; ipdb.set_trace() + if raw is None: + print("Warning: Support check vLLM call failed (returning None).") + return None + + labels = _parse_support_label_array(raw) + n = len(subclaims) + if len(labels) < n: + labels = labels + ["invalid"] * (n - len(labels)) + elif len(labels) > n: + labels = labels[:n] + # import ipdb; ipdb.set_trace() + return labels + + +# --------------------------------------------------------------------------- +# Subclaim extractor (Bangla, vLLM) + sentence splitter fallback +# --------------------------------------------------------------------------- + +# Minimum character length for a sentence to be considered a real unit. +# Used only as a fallback when subclaim extraction is unavailable. +MIN_SENTENCE_CHARS = 15 + + +def _is_bangla_text(text: str, min_bangla_ratio: float = 0.6) -> bool: + """ + Heuristic check: returns True if the majority of alphabetic characters + in `text` are Bangla (Unicode block \u0980–\u09FF). + """ + if not text: + return False + bangla_chars = 0 + alpha_chars = 0 + for ch in text: + if ch.isalpha(): + alpha_chars += 1 + if "\u0980" <= ch <= "\u09FF": + bangla_chars += 1 + if alpha_chars == 0: + return False + return (bangla_chars / alpha_chars) >= min_bangla_ratio + + +def _split_into_sentences(text: str, min_chars: int = MIN_SENTENCE_CHARS) -> List[str]: + """ + Split text into sentences at [.!?] boundaries. + Segments shorter than `min_chars` characters are dropped to + prevent micro-fragment padding from gaming ratio-based scores. + """ + if not text or not text.strip(): + return [] + parts = re.split(r"(?<=[.!?\u0964])\s+", text.strip()) + return [s.strip() for s in parts if len(s.strip()) >= min_chars] + + +def _build_subclaim_extraction_prompt(medical_text: str) -> str: + """ + Bangla subclaim-extraction prompt (same wording as `extract_bn_subclaims_vllm.py`, + generalized to "medical text" so it works for any generated explanation). + """ + return f""" +You are an expert medical annotator. The following text is in Bangla (Bengali). + +Your task is to extract granular, factual subclaims from the provided medical text. +A subclaim is the smallest standalone factual unit that can be independently verified. + +Instructions: +1. Read the Bangla medical text carefully. +2. Extract factual statements explicitly stated in the text. +3. Each subclaim must: + - Be in Bangla (same language as the input) + - Contain exactly ONE factual assertion + - Come directly from the text (no inference or interpretation) + - Preserve original wording as much as possible + - Include any negation, uncertainty, or qualifier +4. Do NOT: + - Combine multiple facts into one subclaim + - Add new information + - Translate to another language +5. Return ONLY a valid JSON array of strings. +6. Use double quotes and valid JSON formatting only (no markdown, no commentary). + +Medical Text (Bangla): +{medical_text} + +Return format: +[ + "subclaim 1", + "subclaim 2" +] +""".strip() + + +def _strip_markdown_json_block(text: str) -> str: + """Strip optional markdown code fence (e.g. ```json\\n[...]\\n```), if present.""" + text = (text or "").strip() + if not text: + return "" + if text.startswith("```json"): + text = text[7:].lstrip("\n") + elif text.startswith("```"): + text = text[3:].lstrip("\n") + if text.endswith("```"): + text = text[:-3].rstrip("\n") + return text.strip() + + +def _parse_subclaim_list_output(output_text: str) -> List[str]: + """Parse subclaim-extractor model output into a list of Bangla subclaims.""" + output_text = (output_text or "").strip() + if not output_text: + return [] + + if "" in output_text: + output_text = output_text.split("")[-1].strip() + + output_text = _strip_markdown_json_block(output_text) + + start_idx = output_text.find("[") + end_idx = output_text.rfind("]") + 1 + if start_idx != -1 and end_idx > start_idx: + content = output_text[start_idx:end_idx] + parsed = json.loads(content) + if isinstance(parsed, list): + return [str(s).strip() for s in parsed if str(s).strip()] + + raise ValueError("Incomplete or invalid JSON list") + + +def _call_vllm_subclaim_extractor( + text: str, + max_tokens: int = 2048, + temperature: float = 0.2, + timeout: float = 120.0, +) -> Optional[List[str]]: + """ + Call Bangla subclaim-extractor model via vLLM (OpenAI /chat/completions). + + Returns a list of subclaims on success, or None on total failure. + """ + if not text or not text.strip(): + return [] + + base = VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE.rstrip("/") + url = f"{base}/chat/completions" + + prompt = _build_subclaim_extraction_prompt(text) + payload = { + "model": os.getenv("VLLM_SUBCLAIM_EXTRACTOR_MODEL_NAME", "subclaim-extractor"), + "messages": [{"role": "user", "content": prompt}], + "temperature": temperature, + "max_tokens": max_tokens, + } + + try: + resp = requests.post(url, json=payload, timeout=timeout) + resp.raise_for_status() + data = resp.json() + choices = data.get("choices") or [] + if not choices: + return None + content = (choices[0].get("message", {}) or {}).get("content", "") or "" + # import ipdb; ipdb.set_trace() + return _parse_subclaim_list_output(content) + except Exception: + return None + + +def _extract_subclaims_from_text(text: str) -> List[str]: + """ + Extract Bangla subclaims from generated text using the vLLM subclaim-extractor. + + On failure (e.g., server down or parse error), falls back to sentence splitting + so the rest of the reward logic can still operate. + """ + subclaims = _call_vllm_subclaim_extractor(text) + if subclaims is None: + # Fallback: keep system running even if extractor is unavailable. + return _split_into_sentences(text) + return subclaims + + +# --------------------------------------------------------------------------- +# Two reward signals: +# 1. Factuality — summary subclaims vs gen_text (how much summary info is in gen_text) +# 2. Hallucination — gen_segments vs fulltext (how much gen info is NOT in fulltext) +# --------------------------------------------------------------------------- + +def compute_rewards( + fulltext: str, + generated_text: str, + target_level: str, + summary_subclaims: Optional[List[str]] = None, + summary_text: Optional[str] = None, + threshold: float = 0.5, + batch_size: int = 128, +) -> Dict[str, Optional[float]]: + """ + Compute two independent reward signals. + + 1. **Factuality** (summary_subclaims → gen_text): + Use pre-extracted *summary_subclaims*, check how many are supported + by the generated text. Measures "how much of the summary's information + made it into the output". + + 2. **Hallucination** (gen_segments → fulltext): + Extract subclaims from the *generated text* (gen_segments), then check + how many are supported by the source fulltext. The *unsupported* + fraction is the hallucination score (lower is better). + + Returns dict with: + factuality_score : [0,1] fraction of summary subclaims supported by gen_text + factuality_supported : int count + total_summary_subclaims : int + hallucination_score : [0,1] fraction of gen_segments NOT supported by fulltext + hallucination_supported : int count of gen_segments supported by fulltext + total_gen_segments : int + """ + result: Dict[str, Any] = { + "factuality_score": None, + "factuality_supported": 0, + "total_summary_subclaims": 0, + "hallucination_score": None, + "hallucination_supported": 0, + "total_gen_segments": 0, + } + + gen_segments = _extract_subclaims_from_text(generated_text) + + if not gen_segments: + result.update({ + "hallucination_score": 0.0, + "factuality_score": 0.0, + }) + return result + + total_gen = len(gen_segments) + result["total_gen_segments"] = total_gen + + # ===================================================================== + # 1. FACTUALITY — summary subclaims checked against gen_text + # "How much information from the summary exists in the generated text?" + # ===================================================================== + factuality_score = None + if summary_subclaims and len(summary_subclaims) > 0: + result["total_summary_subclaims"] = len(summary_subclaims) + + labels_summary_vs_gen = _call_support_api( + context=generated_text, + subclaims=summary_subclaims, + threshold=threshold, + batch_size=batch_size, + ) + if labels_summary_vs_gen is not None: + valid = [l for l in labels_summary_vs_gen if str(l).strip().lower() != "invalid"] + if valid: + sup = sum(1 for l in valid if str(l).strip().lower() == "supported") + factuality_score = sup / len(summary_subclaims) + result["factuality_supported"] = sup + else: + factuality_score = 0.0 + + result["factuality_score"] = factuality_score + + # ===================================================================== + # 2. HALLUCINATION — gen_segments checked against fulltext + # "How much info in gen_segments is NOT supported by the fulltext?" + # ===================================================================== + hallucination_score = None + if fulltext and fulltext.strip(): + labels_gen_vs_full = _call_support_api( + context=fulltext, + subclaims=gen_segments, + threshold=threshold, + batch_size=batch_size, + ) + if labels_gen_vs_full is not None and len(labels_gen_vs_full) > 0: + sup_full = sum( + 1 for l in labels_gen_vs_full + if str(l).strip().lower() == "supported" + ) + + unsupported_indices = [ + i for i, l in enumerate(labels_gen_vs_full) + if str(l).strip().lower() != "supported" + ] + + if unsupported_indices and summary_text and summary_text.strip(): + unsup_segments = [gen_segments[i] for i in unsupported_indices] + rescue_labels = _call_support_api( + context=summary_text, + subclaims=unsup_segments, + threshold=threshold, + batch_size=batch_size, + ) + if rescue_labels: + rescued = sum( + 1 for l in rescue_labels + if str(l).strip().lower() == "supported" + ) + sup_full += rescued + + hallucination_score = max(0.0, (total_gen - sup_full) / total_gen) + result["hallucination_supported"] = sup_full + else: + hallucination_score = 0.0 + + result["hallucination_score"] = hallucination_score + + return result + + +# --------------------------------------------------------------------------- +# BN health-literacy classifier via vLLM (Gemma-3 fine-tuned model) +# Uses Bangla prompt; model is assumed running in vLLM. +# --------------------------------------------------------------------------- + + +def build_classification_user_prompt(fulltext: str, gen_text: str) -> str: + """Build the classification user prompt in English (matches gemma3-finetune.py). Full text is reference; generated text is what to classify.""" + return ( + "You will be given a medical case description as reference (full text) and a generated text to classify. " + "Determine the patient's health literacy level based only on the generated text.\n\n" + f"Reference (full text):\n{fulltext}\n\n" + f"Generated text (to classify):\n{gen_text}\n\n" + "Reply with exactly one label from this set:\n" + "low_health_literacy, intermediate_health_literacy, proficient_health_literacy" + ) + + +def format_gemma3_prompt(user_message: str) -> str: + """Format user message for Gemma-3 chat (vLLM expects this).""" + return ( + f"user\n{user_message}\n" + "model\n" + ) + + +def _call_vllm_classifier(prompt: str, max_tokens: int = 64, timeout: float = 60.0) -> Optional[str]: + """ + Call vLLM completions API. Returns generated text or None on failure. + """ + url = f"{VLLM_CLASSIFIER_BN_API_BASE.rstrip('/')}/completions" + payload = { + "prompt": prompt, + "max_tokens": max_tokens, + "temperature": 0.0, + "stop": ["", "", "\n\n",""], + } + try: + resp = requests.post(url, json=payload, timeout=timeout) + resp.raise_for_status() + data = resp.json() + choices = data.get("choices") + # import ipdb; ipdb.set_trace() + if choices and len(choices) > 0: + text = choices[0].get("text", "") + return (text or "").strip() + return None + except requests.exceptions.RequestException as exc: + return None + + +def _parse_classifier_output(raw: str) -> str: + """ + Extract health literacy label from model output. Normalize to + low_health_literacy | intermediate_health_literacy | proficient_health_literacy. + Returns empty string if no valid label found. + """ + if not raw: + return "" + raw = raw.strip().lower() + # Take first line and clean + first_line = raw.split("\n")[0].strip() + for label in ["low_health_literacy", "intermediate_health_literacy", "proficient_health_literacy"]: + if label in first_line or label in raw: + # import ipdb; ipdb.set_trace() + return label + return "" + + +_CLASSIFIER_ERROR_LOGGED = False + + +def _predict_label(input_text: str, generated_text: str) -> str: + """ + Run BN health-literacy classifier via vLLM (Gemma-3 fine-tuned). + Uses fulltext=input_text and gen_text=generated_text; returns normalized label or "". + """ + global _CLASSIFIER_ERROR_LOGGED + try: + user_prompt = build_classification_user_prompt(input_text or "", generated_text or "") + prompt = format_gemma3_prompt(user_prompt) + raw = _call_vllm_classifier(prompt) + # import ipdb; ipdb.set_trace() + if raw is None: + if not _CLASSIFIER_ERROR_LOGGED: + print("Warning: BN classifier vLLM call failed, continuing without it.") + _CLASSIFIER_ERROR_LOGGED = True + return "" + return _parse_classifier_output(raw) + except Exception as exc: + if not _CLASSIFIER_ERROR_LOGGED: + print(f"Warning: BN literacy classifier unavailable, continuing without it: {exc}") + _CLASSIFIER_ERROR_LOGGED = True + return "" + + +def _parse_solution_json(solution_str): + if isinstance(solution_str, (dict, list)): + return solution_str + try: + cleaned_str = str(solution_str).strip() + if "```json" in cleaned_str: + cleaned_str = cleaned_str.split("```json")[1].split("```")[0].strip() + elif "```" in cleaned_str: + cleaned_str = cleaned_str.split("```")[1].split("```")[0].strip() + return json.loads(cleaned_str) + except Exception: + return None + + +def _compute_classifier_reward(target_level: str, gen_text: str, input_text: str) -> float: + """ + Soft classifier score in [0, 1] (NOT binary +1/-1). + + 1.0 — predicted label matches target level (correct style) + 0.0 — predicted label does not match (wrong style) + 0.5 — classifier unavailable; neutral / no signal + + Uses BN classifier via vLLM (Gemma-3); needs input_text (fulltext) and gen_text. + """ + result = _predict_label(input_text, gen_text) + if result == "": # unavailable → neutral, no penalty + return 0.5 + if result.strip().lower() == target_level.strip().lower(): + # import ipdb; ipdb.set_trace() + return 1.0 # correct literacy style + return 0.0 # wrong literacy style (penalty-free cliff avoided) + + +# --------------------------------------------------------------------------- +# Copy-paste penalty (prevent trivial copy of input_text) +# --------------------------------------------------------------------------- + +def _approx_copy_ratio(input_text: str, gen_text: str) -> float: + """ + Rough similarity estimate between input and generated text. + + - Detects near-verbatim copy via substring + length ratio. + - Otherwise uses token overlap (gen tokens that also appear in input). + Returns value in [0, 1], where 1 ≈ almost exact copy. + """ + a = (input_text or "").strip() + b = (gen_text or "").strip() + if not a or not b: + return 0.0 + + len_a, len_b = len(a), len(b) + shorter, longer = (a, b) if len_a <= len_b else (b, a) + + # Near-verbatim copy: one string almost fully contained in the other. + if shorter and shorter in longer: + ratio = len(shorter) / max(1, len(longer)) + if ratio >= 0.9: + return 1.0 + + # Fallback: 3-gram (trigram) token overlap to reduce false positives + # from shared medical vocabulary (drug names, symptoms, etc.). + def _tokens(t: str): + return [tok for tok in re.split(r"\s+", t) if tok] + + def _shingles(tokens, n=3): + if len(tokens) < n: + return set() + return {" ".join(tokens[i : i + n]) for i in range(len(tokens) - n + 1)} + + toks_a = _tokens(a) + toks_b = _tokens(b) + if not toks_a or not toks_b: + return 0.0 + + sh_a = _shingles(toks_a, n=3) + sh_b = _shingles(toks_b, n=3) + if not sh_a or not sh_b: + return 0.0 + + overlap = len(sh_a & sh_b) / max(1, len(sh_b)) + return max(0.0, min(1.0, overlap)) + + +def _compute_copy_penalty(input_text: str, gen_text: str) -> float: + """ + Map copy ratio → penalty in [0, 1]. + + - ≤ 0.7 similarity → no penalty + - 0.7–1.0 → linearly ramp penalty up to 1.0 + """ + ratio = _approx_copy_ratio(input_text, gen_text) + if ratio <= 0.7: + return 0.0 + # Scale [0.7, 1.0] → [0, 1] + return max(0.0, min(1.0, (ratio - 0.7) / 0.3)) + + +# --------------------------------------------------------------------------- +# Main scoring function +# --------------------------------------------------------------------------- +def _nonlinear_grounding(h_score: float) -> float: + """ + Sharper penalty for hallucination. + + h_score=0.00 → 1.00 (perfect) + h_score=0.05 → 0.95 (mild) + h_score=0.10 → 0.82 (noticeable) + h_score=0.17 → 0.65 (significant — was 0.83 before!) + h_score=0.30 → 0.36 (harsh) + h_score=0.50 → 0.13 (near zero) + """ + return max(0.0, (1.0 - h_score) ** 2.5) +def compute_score(data_source, solution_str, ground_truth, extra_info=None): + """ + Reward = weighted sum of three components (all in [0, 1]): + + W_FACTUALITY × factuality_score (summary info present in gen_text) + W_HALLU × (1 - hallucination_score) (gen_segments grounded in fulltext) + W_CLASSIFIER × classifier_score (style match) + + 1. Factuality : extract subclaims from *summary*, check how many are + supported by the generated text. + 2. Hallucination: extract subclaims from *generated text*, check how many + are NOT supported by the fulltext. + """ + W_FACTUALITY = 0.40 + W_HALLU = 0.25 + W_CLASSIFIER = 0.35 + + FAIL = { + "score": -1.0, + "factuality_score": 0.0, + "hallucination_score": 0.0, + "classifier_score": 0.0, + "factuality_supported": 0, + "hallucination_supported": 0, + "total_gen_segments": 0, + } + + # 1. Parse & validate + data = _parse_solution_json(solution_str) + if not data: + return FAIL + + target_level = extra_info.get("target_level") if extra_info else None + gen_text = data.get(target_level, "") if target_level else "" + + if not gen_text or len(gen_text.strip()) < 10: + return FAIL + + if not _is_bangla_text(gen_text): + return FAIL + + fulltext = ground_truth.get("fulltext") or ground_truth.get("input_text", "") + input_text = ground_truth.get("input_text", "") + summary_subclaims = ground_truth.get("summary_subclaims") + summary_text = ground_truth.get("summary_text", "") + + # 2. Compute the two core rewards + rewards = compute_rewards( + fulltext=fulltext, + generated_text=gen_text, + target_level=target_level, + summary_subclaims=summary_subclaims, + summary_text=summary_text, + ) + + factuality_score = rewards["factuality_score"] + h_score = rewards["hallucination_score"] + total_gen_units = rewards.get("total_gen_segments", 0) + + if factuality_score is None: + factuality_score = 0.5 + if h_score is None: + h_score = 0.5 + + grounding_score = _nonlinear_grounding(h_score) + + # 3. Classifier (style match) + class_score = _compute_classifier_reward(target_level, gen_text, input_text) + + # 4. Final weighted sum + final_reward = ( + W_FACTUALITY * factuality_score + + W_HALLU * grounding_score + + W_CLASSIFIER * class_score + ) + + # 5. Copy-paste penalty + copy_penalty = _compute_copy_penalty(input_text, gen_text) + if copy_penalty > 0.0: + final_reward = max(0.0, final_reward * (1.0 - copy_penalty)) + + return { + "score": float(final_reward), + "factuality_score": float(factuality_score), + "hallucination_score": float(h_score), + "classifier_score": float(class_score), + "factuality_supported": int(rewards.get("factuality_supported", 0)), + "hallucination_supported": int(rewards.get("hallucination_supported", 0)), + "total_gen_segments": int(total_gen_units), + } + + +# --------------------------------------------------------------------------- +# Test mode +# --------------------------------------------------------------------------- + +test_mode = True +if test_mode: + import time + + def run_actual_api_test(): + # Bangla medical example (support-check and classifier use Bangla prompts) + ground_truth = { + "summary_subclaims": [ + "লিসিনোপ্রিল উচ্চ রক্তচাপের চিকিৎসায় ব্যবহৃত হয়।", + "এটি ACE ইনহিবিটর যা আপনার হৃদযন্ত্রকে ভালো কাজ করতে সাহায্য করে।", + "সাধারণ পার্শ্বপ্রতিক্রিয়ার মধ্যে শুষ্ক কাশি রয়েছে।", + "গর্ভবতী হলে ব্যবহার করবেন না।", + ], + "input_text": ( + "লিসিনোপ্রিল উচ্চ রক্তচাপের চিকিৎসায় ব্যবহৃত হয়। " + "এটি ACE ইনহিবিটর নামক ওষুধ। " + "এটি আপনার হৃদযন্ত্রকে ভালো কাজ করতে সাহায্য করে।" + ), + "summary_text": ( + "লিসিনোপ্রিল উচ্চ রক্তচাপের চিকিৎসায় ব্যবহৃত হয়। " + "এটি ACE ইনহিবিটর যা আপনার হৃদযন্ত্রকে ভালো কাজ করতে সাহায্য করে। " + "সাধারণ পার্শ্বপ্রতিক্রিয়ার মধ্যে শুষ্ক কাশি রয়েছে। " + "গর্ভবতী হলে ব্যবহার করবেন না।" + ), + } + + # LLM output: low_health_literacy style, grounded in summary + generated_response = { + "low_health_literacy": ( + "এই ওষুধ আপনার উচ্চ রক্তচাপের জন্য। " + "এটি ACE ইনহিবিটর ধরনের ওষুধ। " + "এটি আপনার হৃদযন্ত্রকে ভালো কাজ করতে সাহায্য করে। " + "গর্ভবতী হলে এই ওষুধ খাবেন না।" + ) + } + + solution_str = f"```json\n{json.dumps(generated_response)}\n```" + extra_info = {"target_level": "low_health_literacy"} + + print("📡 Running BN reward test (Bangla example)...") + start_time = time.time() + + try: + score = compute_score( + data_source="real_api_test", + solution_str=solution_str, + ground_truth=ground_truth, + extra_info=extra_info, + ) + + final_score = score["score"] if isinstance(score, dict) else score + + duration = time.time() - start_time + print(f"\nAPI Call Successful ({round(duration, 2)}s)") + print("-" * 50) + print(f"Target Level : {extra_info['target_level']}") + print(f"Final Reward : {round(final_score, 4)}") + print(f"factuality_score : {round(score.get('factuality_score', 0), 4)} (summary subclaims in gen_text)") + print(f"hallucination_score : {round(score.get('hallucination_score', 0), 4)} (gen_segments NOT in fulltext)") + print(f"classifier_score : {round(score.get('classifier_score', 0), 4)}") + print(f"factuality_supported : {score.get('factuality_supported', 0)}") + print(f"hallucination_supported: {score.get('hallucination_supported', 0)}") + print(f"total_gen_segments : {score.get('total_gen_segments', 0)}") + print("-" * 50) + print("\nReward definitions:") + print("- factuality_score : fraction of *summary* subclaims supported by gen_text [0,1]") + print("- hallucination_score : fraction of *gen_segments* NOT supported by fulltext [0,1] (lower=better)") + print("- classifier_score : 1.0 match, 0.0 mismatch, 0.5 unavailable") + print("- Weights: factuality=0.35, grounding=0.30, classifier=0.35") + + except Exception as e: + print(f"\n❌ API Call Failed!") + print(f"Error Type: {type(e).__name__}") + print(f"Details: {str(e)}") + print("\nPossible fixes:") + print("1. Check if the support-check vLLM server is running (VLLM_SUPPORT_CHECK_BN_API_BASE).") + print("2. Check if the classifier vLLM server is running (VLLM_CLASSIFIER_BN_API_BASE).") + + if __name__ == "__main__": + run_actual_api_test() \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/script/run_qwen3-8b_v3(2G)_v2.sh b/code/RL_model/verl/verl_train/script/run_qwen3-8b_v3(2G)_v2.sh index b2279f55f8514e1f2c9c5229ee6d7a40b7fb984f..c7ecbb0bcd33df9872c8ca46e6676d7f159cd8a4 100644 --- a/code/RL_model/verl/verl_train/script/run_qwen3-8b_v3(2G)_v2.sh +++ b/code/RL_model/verl/verl_train/script/run_qwen3-8b_v3(2G)_v2.sh @@ -10,7 +10,7 @@ PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=grpo \ data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/bn_dataset/train.parquet \ data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/bn_dataset/test.parquet \ - custom_reward_function.path=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/reward_func/reward_func/reward_new_v6_bn_v4.py \ + custom_reward_function.path="/home/mshahidul/readctrl/code/RL_model/verl/verl_train/reward_func/reward_func/reward_new_v6_bn_v4_rmv_src_cov.py" \ data.train_batch_size=256 \ data.max_prompt_length=6144 \ data.max_response_length=2048 \ diff --git a/code/RL_model/verl/verl_train/script/run_qwen3-8b_v3(2G)_v3.sh b/code/RL_model/verl/verl_train/script/run_qwen3-8b_v3(2G)_v3.sh new file mode 100644 index 0000000000000000000000000000000000000000..5e396b910396d49325af77fed0380e0cb8f99f53 --- /dev/null +++ b/code/RL_model/verl/verl_train/script/run_qwen3-8b_v3(2G)_v3.sh @@ -0,0 +1,56 @@ +set -x + +unset PYTORCH_CUDA_ALLOC_CONF +export EXPERIMENT_NAME=qwen3-4b-instruct-bn +export WAND_PROJECT='readctrl-verl' +export CUDA_DEVICE_ORDER="PCI_BUS_ID" +export CUDA_VISIBLE_DEVICES=1,2 + +PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/bn_dataset/train.parquet \ + data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/bn_dataset/test.parquet \ + custom_reward_function.path="/home/mshahidul/readctrl/code/RL_model/verl/verl_train/reward_func/reward_func/reward_new_v6_bn_v4.py" \ + data.train_batch_size=256 \ + data.max_prompt_length=6144 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507 \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.35 \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.max_model_len=8192 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name=$WAND_PROJECT \ + trainer.experiment_name=$EXPERIMENT_NAME \ + trainer.n_gpus_per_node=2 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=10 \ + trainer.log_val_generations=1 \ + +trainer.remove_previous_ckpt_in_save=true \ + trainer.max_actor_ckpt_to_keep=1 \ + trainer.max_critic_ckpt_to_keep=1 \ + trainer.resume_mode=auto \ + trainer.default_local_dir=/home/mshahidul/readctrl/code/RL_model/models/reward_new_v6_bn_v4_test2 \ + trainer.total_epochs=45 $@ \ + 2>&1 | tee $EXPERIMENT_NAME.log \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/script/run_qwen3-8b_v3(4G)_v2.sh b/code/RL_model/verl/verl_train/script/run_qwen3-8b_v3(4G)_v2.sh index 275cf70da4ca87d03b3a64e9f2fc4ae03e851917..0726500bb289eee6ad95604fa764d6d8ef4c5fcf 100644 --- a/code/RL_model/verl/verl_train/script/run_qwen3-8b_v3(4G)_v2.sh +++ b/code/RL_model/verl/verl_train/script/run_qwen3-8b_v3(4G)_v2.sh @@ -54,6 +54,6 @@ PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \ trainer.max_actor_ckpt_to_keep=1 \ trainer.max_critic_ckpt_to_keep=1 \ trainer.resume_mode=auto \ - trainer.default_local_dir=/home/mshahidul/readctrl/code/RL_model/models/bn_wo_summary \ + trainer.default_local_dir=/home/mshahidul/readctrl/code/RL_model/models/readCtrl_RL_bn_srcCov_v1 \ trainer.total_epochs=45 $@ \ 2>&1 | tee $EXPERIMENT_NAME.log \ No newline at end of file diff --git a/code/fine_tune_sft_dpo/best_of_n_qwen3_vllm_bn.py b/code/fine_tune_sft_dpo/best_of_n_qwen3_vllm_bn.py new file mode 100644 index 0000000000000000000000000000000000000000..5352d7a7d5be8d747e12457ae420d94f46250140 --- /dev/null +++ b/code/fine_tune_sft_dpo/best_of_n_qwen3_vllm_bn.py @@ -0,0 +1,518 @@ +import os +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = "6" + +import argparse +import json +import re +from datetime import datetime +from typing import Any, Dict, List, Tuple + +from vllm import LLM, SamplingParams +from transformers import AutoTokenizer + + +def strip_think_blocks(text: str) -> str: + """Remove ... reasoning blocks from model output.""" + cleaned = re.sub(r".*?", "", text, flags=re.DOTALL).strip() + return cleaned if cleaned else text + + +BASE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo" +FINETUNED_MODEL_DIR = os.path.join(BASE_DIR, "model", "bn") +PROMPT_DIR = os.path.join(BASE_DIR, "prompt_bn") +TEST_JSON = os.path.join(BASE_DIR, "dataset", "bn", "test_bn.json") +RESULTS_DIR = os.path.join(BASE_DIR, "results", "bn") + +SOURCE_LANG = "Bengali" + +LABEL_TO_PROMPT_FILE = { + "low_health_literacy": "prompt_low", + "intermediate_health_literacy": "prompt_intermediate", + "proficient_health_literacy": "prompt_proficient", +} + +LABEL_TO_READABILITY = { + "low_health_literacy": ( + "Low Health Literacy (High Readability): individuals needing the simplest " + "terms for immediate action, using 'living room' language, one idea per " + "sentence, and focusing only on need-to-know information from the Gold Summary." + ), + "intermediate_health_literacy": ( + "Intermediate Health Literacy (Medium Readability): the general public at a " + "news-reading level, with standard vocabulary and some common medical terms, " + "and a balanced level of detail led by the Gold Summary." + ), + "proficient_health_literacy": ( + "Proficient Health Literacy (Low Readability): researchers, clinicians, or " + "highly informed patients, using technical and academic language, high " + "information density, and full clinical nuance and terminology from the " + "Source Text." + ), +} + + +def load_prompts(prompt_dir: str) -> Dict[str, str]: + prompts: Dict[str, str] = {} + for label, filename in LABEL_TO_PROMPT_FILE.items(): + path = os.path.join(prompt_dir, filename) + if os.path.isfile(path): + with open(path, "r", encoding="utf-8") as f: + prompts[label] = f.read() + else: + raise FileNotFoundError(f"Prompt file not found: {path}") + return prompts + + +def build_generation_user_message( + prompt_template: str, + full_text: str, + gold_summary: str, + source_lang: str = SOURCE_LANG, +) -> str: + return ( + prompt_template.replace("{full_text}", full_text) + .replace("{gold_summary}", gold_summary) + .replace("{source_lang}", source_lang) + ) + + +def build_selection_user_message( + full_text: str, + label: str, + candidates: List[str], + source_lang: str = SOURCE_LANG, +) -> str: + readability = LABEL_TO_READABILITY.get(label, label) + numbered = [] + for i, cand in enumerate(candidates, start=1): + numbered.append(f"[{i}]\n{cand.strip()}") + candidates_block = "\n\n".join(numbered) + + return ( + "You are selecting the best patient-friendly summary of a medical case.\n\n" + f"Original text ({source_lang}):\n{full_text}\n\n" + f"Readability requirement: {readability}.\n\n" + f"Here are {len(candidates)} candidate summaries:\n\n" + f"{candidates_block}\n\n" + "Choose the single candidate that best matches the readability " + "requirement and accurately reflects the key clinical information.\n" + "Answer with exactly one line in the form:\n" + '"BEST_INDEX: k"\n' + f"where k is an integer from 1 to {len(candidates)}." + ) + + +def parse_best_index(text: str, num_candidates: int) -> int: + # Look for an integer in the model output; default to 1 if parsing fails. + match = re.search(r"(\d+)", text) + if not match: + return 1 + idx = int(match.group(1)) + if idx < 1 or idx > num_candidates: + return 1 + return idx + + +def build_generation_prompts_for_model( + tokenizer, + test_list: List[Dict[str, Any]], + prompts: Dict[str, str], + source_lang: str = SOURCE_LANG, +) -> Tuple[List[str], List[Dict[str, Any]]]: + batched_prompts: List[str] = [] + meta: List[Dict[str, Any]] = [] + + for idx, item in enumerate(test_list): + label = item.get("label") + doc_id = item.get("doc_id", idx) + fulltext = item.get("fulltext", "") + summary = item.get("summary", "") + gold_gen_text = item.get("gen_text", "") + + if label not in prompts: + meta.append( + { + "doc_id": doc_id, + "label": label, + "gold_gen_text": gold_gen_text, + "fulltext": fulltext, + "summary": summary, + "error": f"Unknown label: {label}", + } + ) + batched_prompts.append(None) # type: ignore[arg-type] + continue + + user_prompt = build_generation_user_message( + prompts[label], + fulltext, + summary, + source_lang=source_lang, + ) + chat = [{"role": "user", "content": user_prompt}] + formatted = tokenizer.apply_chat_template( + chat, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + + batched_prompts.append(formatted) + meta.append( + { + "doc_id": doc_id, + "label": label, + "gold_gen_text": gold_gen_text, + "fulltext": fulltext, + "summary": summary, + "error": None, + } + ) + + return batched_prompts, meta + + +def run_best_of_n_for_model( + model_id: str, + model_key: str, + test_list: List[Dict[str, Any]], + prompts: Dict[str, str], + max_new_tokens: int, + temperature: float, + num_candidates: int, + batch_size: int, + source_lang: str = SOURCE_LANG, +) -> Dict[int, Dict[str, Any]]: + print(f"\n=== Running model {model_key}: {model_id} ===") + + print("Loading tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(model_id) + + print("Preparing prompts...") + batched_prompts, meta = build_generation_prompts_for_model( + tokenizer, test_list, prompts, source_lang=source_lang + ) + + print("Loading vLLM model...") + llm = LLM( + model=model_id, + trust_remote_code=True, + ) + + gen_sampling_params = SamplingParams( + temperature=temperature, + max_tokens=max_new_tokens, + n=num_candidates, + ) + + # Filter out None prompts (unknown labels) for generation + valid_indices = [i for i, p in enumerate(batched_prompts) if p is not None] + valid_prompts = [batched_prompts[i] for i in valid_indices] + + total_valid = len(valid_prompts) + batch_size = max(1, batch_size) + print( + f"Running vLLM generation on {total_valid} samples " + f"in batches of {batch_size} with Best-of-{num_candidates}..." + ) + + candidates_per_idx: Dict[int, List[str]] = {} + + num_batches = (total_valid + batch_size - 1) // batch_size + for batch_idx in range(num_batches): + start = batch_idx * batch_size + end = min(start + batch_size, total_valid) + batch_prompts = valid_prompts[start:end] + batch_indices = valid_indices[start:end] + + print( + f"Generating batch {batch_idx + 1}/{num_batches} " + f"with {len(batch_prompts)} samples..." + ) + outputs = llm.generate(batch_prompts, sampling_params=gen_sampling_params) + + for idx_in_batch, output in enumerate(outputs): + original_idx = batch_indices[idx_in_batch] + cand_texts = [strip_think_blocks(o.text) for o in output.outputs] + candidates_per_idx[original_idx] = cand_texts + + # Now build selection prompts to choose the best candidate for each valid sample. + print("Building selection prompts for Best-of-N choice...") + selection_prompts: List[str] = [] + selection_indices: List[int] = [] + reverse_map: Dict[int, int] = {} + + for original_idx in valid_indices: + info = meta[original_idx] + if info["error"] is not None: + continue + cands = candidates_per_idx.get(original_idx, []) + if not cands: + continue + sel_user = build_selection_user_message( + info["fulltext"], + info["label"], + cands, + source_lang=source_lang, + ) + chat = [{"role": "user", "content": sel_user}] + formatted = tokenizer.apply_chat_template( + chat, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + reverse_map[len(selection_prompts)] = original_idx + selection_prompts.append(formatted) + + select_sampling_params = SamplingParams( + temperature=0.0, + max_tokens=32, + n=1, + ) + + best_index_per_idx: Dict[int, int] = {} + + total_select = len(selection_prompts) + if total_select > 0: + print( + f"Running selection passes on {total_select} samples " + f"in batches of {batch_size}..." + ) + num_sel_batches = (total_select + batch_size - 1) // batch_size + for batch_idx in range(num_sel_batches): + start = batch_idx * batch_size + end = min(start + batch_size, total_select) + batch_prompts = selection_prompts[start:end] + + print( + f"Selecting batch {batch_idx + 1}/{num_sel_batches} " + f"with {len(batch_prompts)} samples..." + ) + outputs = llm.generate( + batch_prompts, sampling_params=select_sampling_params + ) + + for idx_in_batch, output in enumerate(outputs): + global_sel_idx = start + idx_in_batch + original_idx = reverse_map[global_sel_idx] + raw_text = strip_think_blocks(output.outputs[0].text) + best_idx = parse_best_index(raw_text, num_candidates) + best_index_per_idx[original_idx] = best_idx + + # Build structured results per original index. + model_results: Dict[int, Dict[str, Any]] = {} + for idx, info in enumerate(meta): + if info["error"] is not None: + model_results[idx] = { + "error": info["error"], + } + continue + + cands = candidates_per_idx.get(idx, []) + best_idx = best_index_per_idx.get(idx, 1 if cands else None) + best_summary = ( + cands[best_idx - 1] if cands and best_idx is not None and 1 <= best_idx <= len(cands) else "" + ) + + model_results[idx] = { + "candidates": cands, + "best_index": best_idx, + "best_summary": best_summary, + } + + return model_results + + +def parse_args(): + p = argparse.ArgumentParser( + description=( + "Run vLLM inference with Best-of-N for both the finetuned " + "Qwen3 model and the base Qwen/Qwen3-4B-Instruct-2507 model " + "on test_bn.json (Bengali)." + ) + ) + p.add_argument( + "--prompt-dir", + type=str, + default=PROMPT_DIR, + help="Directory containing prompt files (prompt_low, prompt_intermediate, prompt_proficient).", + ) + p.add_argument( + "--finetuned-model-dir", + type=str, + default=FINETUNED_MODEL_DIR, + help="Path to the merged finetuned model directory.", + ) + p.add_argument( + "--test-data", + type=str, + default=TEST_JSON, + help="Path to the test data JSON file.", + ) + p.add_argument( + "--src-lang", + type=str, + default=SOURCE_LANG, + help="Source language of the text (e.g. Bengali, English).", + ) + p.add_argument( + "--base-model-id", + type=str, + default="Qwen/Qwen3-4B-Instruct-2507", + help="Hugging Face model id for the base Qwen3 instruct model.", + ) + p.add_argument( + "--max-new-tokens", + type=int, + default=512, + help="Maximum number of new tokens to generate per candidate.", + ) + p.add_argument( + "--temperature", + type=float, + default=0.7, + help="Sampling temperature for candidate generation.", + ) + p.add_argument( + "--num-candidates", + type=int, + default=5, + help="Number of candidate summaries to generate per example (N in Best-of-N).", + ) + p.add_argument( + "--batch-size", + type=int, + default=16, + help="Batch size for vLLM generation.", + ) + p.add_argument( + "--output-file", + type=str, + default=None, + help=( + "Optional path for the main results JSON file. " + "If not set, a timestamped name in the results directory is used." + ), + ) + p.add_argument( + "--model", + type=str, + choices=["base", "finetuned", "both"], + default="both", + help=( + "Which model(s) to run: 'base' (Qwen3-4B-Instruct), " + "'finetuned' (local SFT model), or 'both' (default)." + ), + ) + return p.parse_args() + + +def main(): + args = parse_args() + + os.makedirs(RESULTS_DIR, exist_ok=True) + + print("Loading prompts from", args.prompt_dir) + prompts = load_prompts(args.prompt_dir) + + print("Loading test data from", args.test_data) + with open(args.test_data, "r", encoding="utf-8") as f: + test_list = json.load(f) + + # Run Best-of-N for the selected model(s), one at a time to save GPU memory. + finetuned_results: Dict[int, Dict[str, Any]] = {} + base_results: Dict[int, Dict[str, Any]] = {} + + if args.model in ("finetuned", "both"): + finetuned_results = run_best_of_n_for_model( + model_id=args.finetuned_model_dir, + model_key="qwen3_finetuned", + test_list=test_list, + prompts=prompts, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + num_candidates=args.num_candidates, + batch_size=args.batch_size, + source_lang=args.src_lang, + ) + + if args.model in ("base", "both"): + base_results = run_best_of_n_for_model( + model_id=args.base_model_id, + model_key="qwen3_base", + test_list=test_list, + prompts=prompts, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + num_candidates=args.num_candidates, + batch_size=args.batch_size, + source_lang=args.src_lang, + ) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + if args.output_file: + out_path = args.output_file + base, ext = os.path.splitext(out_path) + if not ext: + out_path = base + ".json" + base = out_path.rsplit(".", 1)[0] + summary_path = base + "_summary.json" + else: + out_path = os.path.join(RESULTS_DIR, f"test_best_of_n_vllm_{timestamp}.json") + summary_path = os.path.join( + RESULTS_DIR, f"inference_best_of_n_vllm_{timestamp}.json" + ) + + combined_results = [] + for idx, item in enumerate(test_list): + label = item.get("label") + doc_id = item.get("doc_id", idx) + gold_gen_text = item.get("gen_text", "") + + entry: Dict[str, Any] = { + "doc_id": doc_id, + "label": label, + "gold_gen_text": gold_gen_text, + "predicted_label": item.get("predicted_label", ""), + "prediction_correct": item.get("prediction_correct", None), + } + + if args.model in ("finetuned", "both"): + entry["qwen3_finetuned"] = finetuned_results.get(idx, {}) + if args.model in ("base", "both"): + entry["qwen3_base"] = base_results.get(idx, {}) + + combined_results.append(entry) + + with open(out_path, "w", encoding="utf-8") as f: + json.dump(combined_results, f, ensure_ascii=False, indent=2) + + summary_data: Dict[str, Any] = { + "model_run": args.model, + "test_json": args.test_data, + "prompt_dir": args.prompt_dir, + "src_lang": args.src_lang, + "num_test_samples": len(test_list), + "results_file": out_path, + "timestamp": timestamp, + "max_new_tokens": args.max_new_tokens, + "temperature": args.temperature, + "num_candidates": args.num_candidates, + } + if args.model in ("finetuned", "both"): + summary_data["finetuned_model_dir"] = args.finetuned_model_dir + if args.model in ("base", "both"): + summary_data["base_model_id"] = args.base_model_id + + with open(summary_path, "w", encoding="utf-8") as f: + json.dump(summary_data, f, ensure_ascii=False, indent=2) + + print(f"\nResults saved to {out_path}") + print(f"Summary saved to {summary_path}") + + +if __name__ == "__main__": + main() + diff --git a/code/fine_tune_sft_dpo/dataset/bn/old/test_bn_subclaims.json b/code/fine_tune_sft_dpo/dataset/bn/old/test_bn_subclaims.json new file mode 100644 index 0000000000000000000000000000000000000000..d29ae811a35a7b33d8f9e9e753bcb42387758e47 --- /dev/null +++ b/code/fine_tune_sft_dpo/dataset/bn/old/test_bn_subclaims.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f99326a350c42c5012c10e6898e892ac2962230f627f91a1da814fe8a8f79bba +size 2249546 diff --git a/code/fine_tune_sft_dpo/dataset/bn/test_bn_subclaims.json b/code/fine_tune_sft_dpo/dataset/bn/test_bn_subclaims.json new file mode 100644 index 0000000000000000000000000000000000000000..803daab1e332a69a788c29bba3d7a714a7f93566 --- /dev/null +++ b/code/fine_tune_sft_dpo/dataset/bn/test_bn_subclaims.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6f58aca4d71f46ded80cabe51e3f6b96c0774eae5d4680e46eec0cadefe121e9 +size 5741053 diff --git a/code/fine_tune_sft_dpo/eval.sh b/code/fine_tune_sft_dpo/eval.sh index 7d3bb8d8a6b3cdb71b3d8f2b028cd5e8cbf36eee..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/code/fine_tune_sft_dpo/eval.sh +++ b/code/fine_tune_sft_dpo/eval.sh @@ -1,2 +0,0 @@ -python /home/mshahidul/readctrl/code/fine_tune_sft_dpo/test_classifier_with_subclaim_thresholds.py \ ---input-file /home/mshahidul/readctrl/code/fine_tune_sft_dpo/results/en/test_best_of_n_qwen3-4B_sft.json \ No newline at end of file diff --git a/code/fine_tune_sft_dpo/evaluate_scores_bn.py b/code/fine_tune_sft_dpo/evaluate_scores_bn.py new file mode 100644 index 0000000000000000000000000000000000000000..c1e29bffd93449027cedcec11e510bb33fa92577 --- /dev/null +++ b/code/fine_tune_sft_dpo/evaluate_scores_bn.py @@ -0,0 +1,554 @@ +#!/usr/bin/env python3 +""" +Standalone evaluation script for computing factuality, hallucination, and +classifier scores on a JSON file. + +Supports two input formats: + +1. **Standard format** — a list of objects, each with: + - fulltext, summary_text, summary_subclaims, generated_text, label + +2. **Best-of-N (BON) format** — a list of objects, each with: + - doc_id, label, qwen3_base.best_summary (JSON-wrapped generated text) + Requires a separate --subclaims file to supply fulltext, summary, + summary_subclaims, and fulltext_subclaims (keyed by doc_id). + +3. **Inference format** — a list of objects, each with: + - doc_id, label, predicted_gen_text (JSON-wrapped evaluated summary), + optionally gold_gen_text + predicted_gen_text is the summary to evaluate (same JSON key-by-label + format as best_summary). Requires --subclaims for fulltext and subclaims. + +4. **Self-refine format** — a list of objects, each with: + - doc_id, label, final_summary (the generated text to evaluate), + optionally gold_gen_text, gold_summary + final_summary is the summary to evaluate (plain text or JSON-wrapped by + label). Requires --subclaims for fulltext and subclaims. + +The script reuses the reward functions from reward_new_v6_bn_v4_rmv_src_cov.py: + - factuality_score : fraction of summary subclaims supported by generated_text + - hallucination_score: fraction of gen subclaims NOT supported by fulltext + - classifier_score : whether generated_text matches the target literacy level + +Requires the same vLLM endpoints as the reward file: + - Support checker : VLLM_SUPPORT_CHECK_BN_API_BASE (default http://localhost:8090/v1) + - Classifier : VLLM_CLASSIFIER_BN_API_BASE (default http://localhost:8040/v1) + - Subclaim extractor: VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE (default http://localhost:8050/v1) + +Usage: + # Standard format + python evaluate_scores.py --input data.json [--output results.json] + + # BON format with subclaims file + python evaluate_scores.py --input bon_results.json --subclaims subclaims.json --output-dir evaluation/bn/ + + # Inference format (predicted_gen_text as evaluated summary) + python evaluate_scores.py --input test_inference_vllm_qwen3-4B_base.json --subclaims subclaims.json --output results.json + + # Self-refine format (final_summary as evaluated summary) + python evaluate_scores.py --input test_self_refine_vllm_qwen3_4B_base.json --subclaims subclaims.json --output-dir evaluation/bn/ +""" + +import argparse +import json +import os +import re +import sys +import time +from typing import Any, Dict, List, Optional + +from tqdm import tqdm + +# Import scoring utilities from the reward module (same directory). +from reward_new_v6_bn_v4_rmv_src_cov import ( + _call_support_api, + _compute_classifier_reward, + _extract_subclaims_from_text, + _is_bangla_text, + _nonlinear_grounding, + compute_rewards, +) + + +def extract_text_from_best_summary(best_summary: str, label: str) -> str: + """Extract the raw generated text from a BON best_summary string. + + The best_summary is a (possibly truncated) JSON string like: + '{"proficient_health_literacy": "...actual text..."}' + We locate the value after the label key and strip JSON wrapping. + """ + key_pattern = re.compile(re.escape(f'"{label}"') + r'\s*:\s*"') + m = key_pattern.search(best_summary) + if not m: + return best_summary.strip() + text = best_summary[m.end():] + if text.endswith('"\n}'): + text = text[:-3] + elif text.endswith('"}\n'): + text = text[:-3] + elif text.endswith('"}'): + text = text[:-2] + elif text.endswith('"'): + text = text[:-1] + text = text.replace("\\n", "\n").replace('\\"', '"') + return text.strip() + + +def prepare_bon_items( + bon_data: List[Dict[str, Any]], + subclaims_data: List[Dict[str, Any]], + model_key: str = "qwen3_base", +) -> List[Dict[str, Any]]: + """Merge BON results with subclaims data into the standard evaluation format.""" + sc_by_docid = {} + for item in subclaims_data: + sc_by_docid[item["doc_id"]] = item + + prepared = [] + for item in bon_data: + doc_id = item["doc_id"] + label = item["label"] + sc = sc_by_docid.get(doc_id, {}) + + model_data = item.get(model_key, {}) + best_summary = model_data.get("best_summary", "") or model_data.get("predicted_gen_text", "") + generated_text = extract_text_from_best_summary(best_summary, label) + + prepared.append({ + "doc_id": doc_id, + "label": label, + "fulltext": sc.get("fulltext", ""), + "summary_text": sc.get("summary", ""), + "summary_subclaims": sc.get("summary_subclaims", []), + "fulltext_subclaims": sc.get("fulltext_subclaims", []), + "generated_text": generated_text, + "gold_gen_text": item.get("gold_gen_text", ""), + "predicted_label": item.get("predicted_label", ""), + "prediction_correct": item.get("prediction_correct", False), + }) + return prepared + + +def prepare_inference_items( + inference_data: List[Dict[str, Any]], + subclaims_data: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Merge inference-format results (doc_id, label, predicted_gen_text) with + subclaims data into the standard evaluation format. predicted_gen_text is + the JSON-wrapped evaluated summary; the raw text is extracted using the + item's label. + """ + sc_by_docid = {} + for item in subclaims_data: + sc_by_docid[item["doc_id"]] = item + + prepared = [] + for item in inference_data: + doc_id = item["doc_id"] + label = item["label"] + sc = sc_by_docid.get(doc_id, {}) + + raw_pred = item.get("predicted_gen_text", "") or "" + generated_text = extract_text_from_best_summary(raw_pred, label) + + prepared.append({ + "doc_id": doc_id, + "label": label, + "fulltext": sc.get("fulltext", ""), + "summary_text": sc.get("summary", ""), + "summary_subclaims": sc.get("summary_subclaims", []), + "fulltext_subclaims": sc.get("fulltext_subclaims", []), + "generated_text": generated_text, + "gold_gen_text": item.get("gold_gen_text", ""), + }) + return prepared + + +def prepare_self_refine_items( + self_refine_data: List[Dict[str, Any]], + subclaims_data: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Merge self-refine format (doc_id, label, final_summary) with subclaims + data. final_summary is the generated text to evaluate (plain text or + JSON-wrapped by label); it is extracted and used as generated_text. + """ + sc_by_docid = {} + for item in subclaims_data: + sc_by_docid[item["doc_id"]] = item + + prepared = [] + for item in self_refine_data: + doc_id = item["doc_id"] + label = item["label"] + sc = sc_by_docid.get(doc_id, {}) + + raw_final = item.get("final_summary", "") or "" + generated_text = extract_text_from_best_summary(raw_final, label) + + prepared.append({ + "doc_id": doc_id, + "label": label, + "fulltext": sc.get("fulltext", ""), + "summary_text": sc.get("summary", ""), + "summary_subclaims": sc.get("summary_subclaims", []), + "fulltext_subclaims": sc.get("fulltext_subclaims", []), + "generated_text": generated_text, + "gold_gen_text": item.get("gold_gen_text", ""), + }) + return prepared + + +def evaluate_single( + item: Dict[str, Any], + target_level_override: Optional[str] = None, +) -> Dict[str, Any]: + """ + Evaluate a single item and return detailed scores. + """ + fulltext = item.get("fulltext", "") + summary_text = item.get("summary_text") or item.get("summary", "") + summary_subclaims = item.get("summary_subclaims", []) + generated_text = item.get("generated_text") or item.get("predicted_gen_text", "") + target_level = target_level_override or item.get("label", "") + + result: Dict[str, Any] = { + "doc_id": item.get("doc_id", ""), + "target_level": target_level, + "generated_text_len": len(generated_text.strip()) if generated_text else 0, + "factuality_score": None, + "hallucination_score": None, + "classifier_score": None, + "grounding_score": None, + "factuality_supported": 0, + "total_summary_subclaims": len(summary_subclaims) if summary_subclaims else 0, + "hallucination_supported": 0, + "total_gen_segments": 0, + "skipped": False, + "skip_reason": "", + } + + if not generated_text or len(generated_text.strip()) < 10: + result["skipped"] = True + result["skip_reason"] = "generated_text missing or too short (<10 chars)" + return result + + # -- Factuality & Hallucination via compute_rewards -- + rewards = compute_rewards( + fulltext=fulltext, + generated_text=generated_text, + target_level=target_level, + summary_subclaims=summary_subclaims, + summary_text=summary_text, + ) + + factuality_score = rewards["factuality_score"] + h_score = rewards["hallucination_score"] + + if factuality_score is None: + factuality_score = 0.5 + if h_score is None: + h_score = 0.5 + + grounding_score = _nonlinear_grounding(h_score) + + # -- Classifier -- + input_text = fulltext or "" + class_score = _compute_classifier_reward(target_level, generated_text, input_text) + + result.update({ + "factuality_score": round(factuality_score, 4), + "hallucination_score": round(h_score, 4), + "grounding_score": round(grounding_score, 4), + "classifier_score": round(class_score, 4), + "factuality_supported": rewards.get("factuality_supported", 0), + "total_summary_subclaims": rewards.get("total_summary_subclaims", 0), + "hallucination_supported": rewards.get("hallucination_supported", 0), + "total_gen_segments": rewards.get("total_gen_segments", 0), + }) + + return result + + +def compute_aggregate(results: List[Dict[str, Any]]) -> Dict[str, Any]: + """Compute aggregate statistics over all evaluated items.""" + scored = [r for r in results if not r.get("skipped", False)] + n = len(scored) + total = len(results) + skipped = total - n + + if n == 0: + return { + "total_items": total, + "scored_items": 0, + "skipped_items": skipped, + "avg_factuality_score": None, + "avg_hallucination_score": None, + "avg_grounding_score": None, + "avg_classifier_score": None, + } + + def safe_avg(key): + vals = [r[key] for r in scored if r[key] is not None] + return round(sum(vals) / len(vals), 4) if vals else None + + return { + "total_items": total, + "scored_items": n, + "skipped_items": skipped, + "avg_factuality_score": safe_avg("factuality_score"), + "avg_hallucination_score": safe_avg("hallucination_score"), + "avg_grounding_score": safe_avg("grounding_score"), + "avg_classifier_score": safe_avg("classifier_score"), + } + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate factuality, hallucination, and classifier scores on a JSON file." + ) + parser.add_argument( + "--input", "-i", required=True, + help="Path to input JSON file (list of objects).", + ) + parser.add_argument( + "--output", "-o", default=None, + help="Path to output JSON file with per-item scores. " + "Defaults to _eval_results.json.", + ) + parser.add_argument( + "--output-dir", default=None, + help="Directory to save output files. If set, output filename is derived " + "from input filename and placed in this directory.", + ) + parser.add_argument( + "--subclaims", "-s", default=None, + help="Path to subclaims JSON file (for BON format). Provides fulltext, " + "summary, summary_subclaims, and fulltext_subclaims keyed by doc_id.", + ) + parser.add_argument( + "--model-key", default="qwen3_base", + help="Key in the BON data containing candidates/best_summary (default: qwen3_base).", + ) + parser.add_argument( + "--target-level", "-t", default=None, + help="Override target literacy level for all items " + "(e.g. low_health_literacy). If not set, uses each item's 'label' field.", + ) + parser.add_argument( + "--support-check-url", default=None, + help="Override VLLM_SUPPORT_CHECK_BN_API_BASE.", + ) + parser.add_argument( + "--classifier-url", default=None, + help="Override VLLM_CLASSIFIER_BN_API_BASE.", + ) + parser.add_argument( + "--subclaim-extractor-url", default=None, + help="Override VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE.", + ) + args = parser.parse_args() + + if args.support_check_url: + os.environ["VLLM_SUPPORT_CHECK_BN_API_BASE"] = args.support_check_url + if args.classifier_url: + os.environ["VLLM_CLASSIFIER_BN_API_BASE"] = args.classifier_url + if args.subclaim_extractor_url: + os.environ["VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE"] = args.subclaim_extractor_url + + # Load input + with open(args.input, "r", encoding="utf-8") as f: + raw_data = json.load(f) + + if not isinstance(raw_data, list): + print(f"Error: Expected a JSON list, got {type(raw_data).__name__}.", file=sys.stderr) + sys.exit(1) + + # Detect BON format: items have a model key (e.g. qwen3_base) with best_summary + is_bon = ( + len(raw_data) > 0 + and args.model_key in raw_data[0] + and "best_summary" in raw_data[0].get(args.model_key, {}) + ) + + # Detect inference format: top-level doc_id, label, predicted_gen_text; no fulltext/summary_subclaims + is_inference = ( + len(raw_data) > 0 + and "doc_id" in raw_data[0] + and "label" in raw_data[0] + and "predicted_gen_text" in raw_data[0] + and raw_data[0].get("fulltext") is None + and raw_data[0].get("summary_subclaims") is None + ) + + # Detect self-refine format: doc_id, label, final_summary as gen text; no fulltext/summary_subclaims + is_self_refine = ( + len(raw_data) > 0 + and "doc_id" in raw_data[0] + and "label" in raw_data[0] + and "final_summary" in raw_data[0] + and raw_data[0].get("fulltext") is None + and raw_data[0].get("summary_subclaims") is None + ) + + if is_bon: + if not args.subclaims: + print("Error: BON format detected but --subclaims file not provided.", file=sys.stderr) + sys.exit(1) + with open(args.subclaims, "r", encoding="utf-8") as f: + subclaims_data = json.load(f) + print(f"BON format detected (model_key={args.model_key})") + print(f"Loaded {len(raw_data)} BON items from {args.input}") + print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}") + data = prepare_bon_items(raw_data, subclaims_data, model_key=args.model_key) + print(f"Prepared {len(data)} items for evaluation") + elif is_inference: + if not args.subclaims: + print("Error: Inference format detected (predicted_gen_text) but --subclaims file not provided.", file=sys.stderr) + sys.exit(1) + with open(args.subclaims, "r", encoding="utf-8") as f: + subclaims_data = json.load(f) + print("Inference format detected (predicted_gen_text as evaluated summary)") + print(f"Loaded {len(raw_data)} inference items from {args.input}") + print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}") + data = prepare_inference_items(raw_data, subclaims_data) + print(f"Prepared {len(data)} items for evaluation") + elif is_self_refine: + if not args.subclaims: + print("Error: Self-refine format detected (final_summary) but --subclaims file not provided.", file=sys.stderr) + sys.exit(1) + with open(args.subclaims, "r", encoding="utf-8") as f: + subclaims_data = json.load(f) + print("Self-refine format detected (final_summary as evaluated summary)") + print(f"Loaded {len(raw_data)} self-refine items from {args.input}") + print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}") + data = prepare_self_refine_items(raw_data, subclaims_data) + print(f"Prepared {len(data)} items for evaluation") + else: + data = raw_data + print(f"Loaded {len(data)} items from {args.input}") + + print(f" Support check API : {os.getenv('VLLM_SUPPORT_CHECK_BN_API_BASE', 'http://localhost:8090/v1')}") + print(f" Classifier API : {os.getenv('VLLM_CLASSIFIER_BN_API_BASE', 'http://localhost:8040/v1')}") + print(f" Subclaim extractor: {os.getenv('VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE', 'http://localhost:8050/v1')}") + if args.target_level: + print(f" Target level override: {args.target_level}") + print("-" * 60) + + # Evaluate each item + results = [] + start_time = time.time() + for idx, item in enumerate(tqdm(data, desc="Evaluating")): + r = evaluate_single(item, target_level_override=args.target_level) + r["index"] = idx + r["doc_id"] = item.get("doc_id", "") + results.append(r) + + if (idx + 1) % 10 == 0 or idx == 0: + partial_agg = compute_aggregate(results) + tqdm.write( + f" [{idx+1}/{len(data)}] " + f"fact={partial_agg['avg_factuality_score']} " + f"hallu={partial_agg['avg_hallucination_score']} " + f"cls={partial_agg['avg_classifier_score']}" + ) + + elapsed = time.time() - start_time + + # --- Validation: all items must be evaluated with non-null scores --- + expected_count = len(data) + skipped_items = [r for r in results if r.get("skipped", False)] + null_score_items = [] + for r in results: + if r.get("skipped", False): + continue + for key in ("factuality_score", "hallucination_score", "classifier_score", "grounding_score"): + if r.get(key) is None: + null_score_items.append((r.get("index"), r.get("doc_id"), key)) + + has_errors = False + if skipped_items: + has_errors = True + print(f"\nERROR: {len(skipped_items)} out of {expected_count} items were skipped:", file=sys.stderr) + for r in skipped_items: + print(f" index={r.get('index')} doc_id={r.get('doc_id')} reason={r.get('skip_reason')}", file=sys.stderr) + + if null_score_items: + has_errors = True + print(f"\nERROR: {len(null_score_items)} null score(s) found:", file=sys.stderr) + for idx, doc_id, key in null_score_items: + print(f" index={idx} doc_id={doc_id} null_field={key}", file=sys.stderr) + + if len(results) != expected_count: + has_errors = True + print(f"\nERROR: Expected {expected_count} results but got {len(results)}.", file=sys.stderr) + + if has_errors: + print(f"\nAborting: will NOT save results. All {expected_count} items must be fully evaluated with non-null scores.", file=sys.stderr) + sys.exit(1) + + # Aggregate + agg = compute_aggregate(results) + + # Per-label aggregates + label_groups: Dict[str, List[Dict[str, Any]]] = {} + for r in results: + lbl = r.get("target_level", "unknown") + label_groups.setdefault(lbl, []).append(r) + per_label_agg = {lbl: compute_aggregate(items) for lbl, items in sorted(label_groups.items())} + + # Output path + if args.output: + out_path = args.output + elif args.output_dir: + os.makedirs(args.output_dir, exist_ok=True) + stem = os.path.splitext(os.path.basename(args.input))[0] + out_path = os.path.join(args.output_dir, f"{stem}_eval_results.json") + else: + stem = os.path.splitext(os.path.basename(args.input))[0] + out_dir = os.path.dirname(args.input) or "." + out_path = os.path.join(out_dir, f"{stem}_eval_results.json") + + os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True) + + output = { + "input_file": os.path.abspath(args.input), + "subclaims_file": os.path.abspath(args.subclaims) if args.subclaims else None, + "model_key": args.model_key if is_bon else None, + "inference_format": is_inference if not is_bon else False, + "self_refine_format": is_self_refine if not is_bon and not is_inference else False, + "target_level_override": args.target_level, + "elapsed_seconds": round(elapsed, 2), + "aggregate": agg, + "per_label_aggregate": per_label_agg, + "per_item": results, + } + + with open(out_path, "w", encoding="utf-8") as f: + json.dump(output, f, indent=2, ensure_ascii=False) + + # Print summary + print("\n" + "=" * 60) + print("EVALUATION SUMMARY") + print("=" * 60) + print(f" Total items : {agg['total_items']}") + print(f" Scored items : {agg['scored_items']}") + print(f" Skipped items : {agg['skipped_items']}") + print(f" Elapsed time : {round(elapsed, 1)}s") + print("-" * 60) + print(f" Avg Factuality Score : {agg['avg_factuality_score']}") + print(f" Avg Hallucination Score: {agg['avg_hallucination_score']}") + print(f" Avg Grounding Score : {agg['avg_grounding_score']}") + print(f" Avg Classifier Score : {agg['avg_classifier_score']}") + print("-" * 60) + for lbl, la in per_label_agg.items(): + print(f" [{lbl}] items={la['scored_items']}" + f" fact={la['avg_factuality_score']}" + f" hallu={la['avg_hallucination_score']}" + f" cls={la['avg_classifier_score']}") + print("-" * 60) + print(f" Results saved to: {out_path}") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/code/fine_tune_sft_dpo/evaluate_scores_bn_vllm.py b/code/fine_tune_sft_dpo/evaluate_scores_bn_vllm.py new file mode 100644 index 0000000000000000000000000000000000000000..14416cfbae79f99d54f2b5636c98dda95c4f4b9a --- /dev/null +++ b/code/fine_tune_sft_dpo/evaluate_scores_bn_vllm.py @@ -0,0 +1,611 @@ +#!/usr/bin/env python3 +""" +Standalone evaluation script for computing factuality, hallucination, and +classifier scores on a JSON/JSONL file. + +Supports these input formats: + +1. **Standard format** — a list of objects, each with: + - fulltext, summary_text, summary_subclaims, generated_text, label + +2. **Best-of-N (BON) format** — a list of objects, each with: + - doc_id, label, qwen3_base.best_summary (JSON-wrapped generated text) + Requires a separate --subclaims file to supply fulltext, summary, + summary_subclaims, and fulltext_subclaims (keyed by doc_id). + +3. **Inference format** — a list of objects, each with: + - doc_id, label, predicted_gen_text (JSON-wrapped evaluated summary), + optionally gold_gen_text + predicted_gen_text is the summary to evaluate (same JSON key-by-label + format as best_summary). Requires --subclaims for fulltext and subclaims. + +4. **Self-refine format** — a list of objects, each with: + - doc_id, label, final_summary (the generated text to evaluate), + optionally gold_gen_text, gold_summary + final_summary is the summary to evaluate (plain text or JSON-wrapped by + label). Requires --subclaims for fulltext and subclaims. + +5. **RL inference format** (JSONL) — one JSON object per line, each with: + - doc_id, gold_label, input_text, summary_text, subclaims, generated_text + Self-contained: no --subclaims file needed. Field mapping: + gold_label -> label, input_text -> fulltext, subclaims -> summary_subclaims + +The script reuses the reward functions from reward_new_v6_bn_v4_rmv_src_cov.py: + - factuality_score : fraction of summary subclaims supported by generated_text + - hallucination_score: fraction of gen subclaims NOT supported by fulltext + - classifier_score : whether generated_text matches the target literacy level + +Requires the same vLLM endpoints as the reward file: + - Support checker : VLLM_SUPPORT_CHECK_BN_API_BASE (default http://localhost:8090/v1) + - Classifier : VLLM_CLASSIFIER_BN_API_BASE (default http://localhost:8040/v1) + - Subclaim extractor: VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE (default http://localhost:8050/v1) + +Usage: + # Standard format + python evaluate_scores.py --input data.json [--output results.json] + + # BON format with subclaims file + python evaluate_scores.py --input bon_results.json --subclaims subclaims.json --output-dir evaluation/bn/ + + # Inference format (predicted_gen_text as evaluated summary) + python evaluate_scores.py --input test_inference_vllm_qwen3-4B_base.json --subclaims subclaims.json --output results.json + + # Self-refine format (final_summary as evaluated summary) + python evaluate_scores.py --input test_self_refine_vllm_qwen3_4B_base.json --subclaims subclaims.json --output-dir evaluation/bn/ + + # RL inference format (JSONL, self-contained) + python evaluate_scores.py --input bn_200.jsonl --output-dir evaluation/bn/ +""" + +import argparse +import json +import os +import re +import sys +import time +from typing import Any, Dict, List, Optional + +from tqdm import tqdm + +# Import scoring utilities from the reward module (same directory). +from reward_new_v6_bn_v4_rmv_src_cov import ( + _call_support_api, + _compute_classifier_reward, + _extract_subclaims_from_text, + _is_bangla_text, + _nonlinear_grounding, + compute_rewards, +) + + +def extract_text_from_best_summary(best_summary: str, label: str) -> str: + """Extract the raw generated text from a BON best_summary string. + + The best_summary is a (possibly truncated) JSON string like: + '{"proficient_health_literacy": "...actual text..."}' + We locate the value after the label key and strip JSON wrapping. + """ + key_pattern = re.compile(re.escape(f'"{label}"') + r'\s*:\s*"') + m = key_pattern.search(best_summary) + if not m: + return best_summary.strip() + text = best_summary[m.end():] + if text.endswith('"\n}'): + text = text[:-3] + elif text.endswith('"}\n'): + text = text[:-3] + elif text.endswith('"}'): + text = text[:-2] + elif text.endswith('"'): + text = text[:-1] + text = text.replace("\\n", "\n").replace('\\"', '"') + return text.strip() + + +def prepare_bon_items( + bon_data: List[Dict[str, Any]], + subclaims_data: List[Dict[str, Any]], + model_key: str = "qwen3_base", +) -> List[Dict[str, Any]]: + """Merge BON results with subclaims data into the standard evaluation format.""" + sc_by_docid = {} + for item in subclaims_data: + sc_by_docid[item["doc_id"]] = item + + prepared = [] + for item in bon_data: + doc_id = item["doc_id"] + label = item["label"] + sc = sc_by_docid.get(doc_id, {}) + + model_data = item.get(model_key, {}) + best_summary = model_data.get("best_summary", "") or model_data.get("predicted_gen_text", "") + generated_text = extract_text_from_best_summary(best_summary, label) + + prepared.append({ + "doc_id": doc_id, + "label": label, + "fulltext": sc.get("fulltext", ""), + "summary_text": sc.get("summary", ""), + "summary_subclaims": sc.get("summary_subclaims", []), + "fulltext_subclaims": sc.get("fulltext_subclaims", []), + "generated_text": generated_text, + "gold_gen_text": item.get("gold_gen_text", ""), + "predicted_label": item.get("predicted_label", ""), + "prediction_correct": item.get("prediction_correct", False), + }) + return prepared + + +def prepare_inference_items( + inference_data: List[Dict[str, Any]], + subclaims_data: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Merge inference-format results (doc_id, label, predicted_gen_text) with + subclaims data into the standard evaluation format. predicted_gen_text is + the JSON-wrapped evaluated summary; the raw text is extracted using the + item's label. + """ + sc_by_docid = {} + for item in subclaims_data: + sc_by_docid[item["doc_id"]] = item + + prepared = [] + for item in inference_data: + doc_id = item["doc_id"] + label = item["label"] + sc = sc_by_docid.get(doc_id, {}) + + raw_pred = item.get("predicted_gen_text", "") or "" + generated_text = extract_text_from_best_summary(raw_pred, label) + + prepared.append({ + "doc_id": doc_id, + "label": label, + "fulltext": sc.get("fulltext", ""), + "summary_text": sc.get("summary", ""), + "summary_subclaims": sc.get("summary_subclaims", []), + "fulltext_subclaims": sc.get("fulltext_subclaims", []), + "generated_text": generated_text, + "gold_gen_text": item.get("gold_gen_text", ""), + }) + return prepared + + +def prepare_self_refine_items( + self_refine_data: List[Dict[str, Any]], + subclaims_data: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Merge self-refine format (doc_id, label, final_summary) with subclaims + data. final_summary is the generated text to evaluate (plain text or + JSON-wrapped by label); it is extracted and used as generated_text. + """ + sc_by_docid = {} + for item in subclaims_data: + sc_by_docid[item["doc_id"]] = item + + prepared = [] + for item in self_refine_data: + doc_id = item["doc_id"] + label = item["label"] + sc = sc_by_docid.get(doc_id, {}) + + raw_final = item.get("final_summary", "") or "" + generated_text = extract_text_from_best_summary(raw_final, label) + + prepared.append({ + "doc_id": doc_id, + "label": label, + "fulltext": sc.get("fulltext", ""), + "summary_text": sc.get("summary", ""), + "summary_subclaims": sc.get("summary_subclaims", []), + "fulltext_subclaims": sc.get("fulltext_subclaims", []), + "generated_text": generated_text, + "gold_gen_text": item.get("gold_gen_text", ""), + }) + return prepared + + +def prepare_rl_inference_items( + rl_data: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Convert RL inference JSONL items into the standard evaluation format. + + Field mapping: + gold_label -> label + input_text -> fulltext + subclaims -> summary_subclaims + summary_text -> summary_text + generated_text -> generated_text (plain text, used as-is) + """ + prepared = [] + for item in rl_data: + prepared.append({ + "doc_id": item.get("doc_id", ""), + "label": item.get("gold_label", ""), + "fulltext": item.get("input_text", ""), + "summary_text": item.get("summary_text", ""), + "summary_subclaims": item.get("subclaims", []), + "generated_text": item.get("generated_text", ""), + }) + return prepared + + +def evaluate_single( + item: Dict[str, Any], + target_level_override: Optional[str] = None, +) -> Dict[str, Any]: + """ + Evaluate a single item and return detailed scores. + """ + fulltext = item.get("fulltext", "") + summary_text = item.get("summary_text") or item.get("summary", "") + summary_subclaims = item.get("summary_subclaims", []) + generated_text = item.get("generated_text") or item.get("predicted_gen_text", "") + target_level = target_level_override or item.get("label", "") + + result: Dict[str, Any] = { + "doc_id": item.get("doc_id", ""), + "target_level": target_level, + "generated_text_len": len(generated_text.strip()) if generated_text else 0, + "factuality_score": None, + "hallucination_score": None, + "classifier_score": None, + "grounding_score": None, + "factuality_supported": 0, + "total_summary_subclaims": len(summary_subclaims) if summary_subclaims else 0, + "hallucination_supported": 0, + "total_gen_segments": 0, + "skipped": False, + "skip_reason": "", + } + + if not generated_text or len(generated_text.strip()) < 10: + result["skipped"] = True + result["skip_reason"] = "generated_text missing or too short (<10 chars)" + return result + + # -- Factuality & Hallucination via compute_rewards -- + rewards = compute_rewards( + fulltext=fulltext, + generated_text=generated_text, + target_level=target_level, + summary_subclaims=summary_subclaims, + summary_text=summary_text, + ) + + factuality_score = rewards["factuality_score"] + h_score = rewards["hallucination_score"] + + if factuality_score is None: + factuality_score = 0.5 + if h_score is None: + h_score = 0.5 + + grounding_score = _nonlinear_grounding(h_score) + + # -- Classifier -- + input_text = fulltext or "" + class_score = _compute_classifier_reward(target_level, generated_text, input_text) + + result.update({ + "factuality_score": round(factuality_score, 4), + "hallucination_score": round(h_score, 4), + "grounding_score": round(grounding_score, 4), + "classifier_score": round(class_score, 4), + "factuality_supported": rewards.get("factuality_supported", 0), + "total_summary_subclaims": rewards.get("total_summary_subclaims", 0), + "hallucination_supported": rewards.get("hallucination_supported", 0), + "total_gen_segments": rewards.get("total_gen_segments", 0), + }) + + return result + + +def compute_aggregate(results: List[Dict[str, Any]]) -> Dict[str, Any]: + """Compute aggregate statistics over all evaluated items.""" + scored = [r for r in results if not r.get("skipped", False)] + n = len(scored) + total = len(results) + skipped = total - n + + if n == 0: + return { + "total_items": total, + "scored_items": 0, + "skipped_items": skipped, + "avg_factuality_score": None, + "avg_hallucination_score": None, + "avg_grounding_score": None, + "avg_classifier_score": None, + } + + def safe_avg(key): + vals = [r[key] for r in scored if r[key] is not None] + return round(sum(vals) / len(vals), 4) if vals else None + + return { + "total_items": total, + "scored_items": n, + "skipped_items": skipped, + "avg_factuality_score": safe_avg("factuality_score"), + "avg_hallucination_score": safe_avg("hallucination_score"), + "avg_grounding_score": safe_avg("grounding_score"), + "avg_classifier_score": safe_avg("classifier_score"), + } + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate factuality, hallucination, and classifier scores on a JSON file." + ) + parser.add_argument( + "--input", "-i", required=True, + help="Path to input JSON file (list of objects).", + ) + parser.add_argument( + "--output", "-o", default=None, + help="Path to output JSON file with per-item scores. " + "Defaults to _eval_results.json.", + ) + parser.add_argument( + "--output-dir", default=None, + help="Directory to save output files. If set, output filename is derived " + "from input filename and placed in this directory.", + ) + parser.add_argument( + "--subclaims", "-s", default=None, + help="Path to subclaims JSON file (for BON format). Provides fulltext, " + "summary, summary_subclaims, and fulltext_subclaims keyed by doc_id.", + ) + parser.add_argument( + "--model-key", default="qwen3_base", + help="Key in the BON data containing candidates/best_summary (default: qwen3_base).", + ) + parser.add_argument( + "--target-level", "-t", default=None, + help="Override target literacy level for all items " + "(e.g. low_health_literacy). If not set, uses each item's 'label' field.", + ) + parser.add_argument( + "--support-check-url", default=None, + help="Override VLLM_SUPPORT_CHECK_BN_API_BASE.", + ) + parser.add_argument( + "--classifier-url", default=None, + help="Override VLLM_CLASSIFIER_BN_API_BASE.", + ) + parser.add_argument( + "--subclaim-extractor-url", default=None, + help="Override VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE.", + ) + args = parser.parse_args() + + if args.support_check_url: + os.environ["VLLM_SUPPORT_CHECK_BN_API_BASE"] = args.support_check_url + if args.classifier_url: + os.environ["VLLM_CLASSIFIER_BN_API_BASE"] = args.classifier_url + if args.subclaim_extractor_url: + os.environ["VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE"] = args.subclaim_extractor_url + + # Load input (JSON list or JSONL) + if args.input.endswith(".jsonl"): + raw_data = [] + with open(args.input, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + raw_data.append(json.loads(line)) + else: + with open(args.input, "r", encoding="utf-8") as f: + raw_data = json.load(f) + + if not isinstance(raw_data, list): + print(f"Error: Expected a JSON list, got {type(raw_data).__name__}.", file=sys.stderr) + sys.exit(1) + + # Detect BON format: items have a model key (e.g. qwen3_base) with best_summary + is_bon = ( + len(raw_data) > 0 + and args.model_key in raw_data[0] + and "best_summary" in raw_data[0].get(args.model_key, {}) + ) + + # Detect inference format: top-level doc_id, label, predicted_gen_text; no fulltext/summary_subclaims + is_inference = ( + len(raw_data) > 0 + and "doc_id" in raw_data[0] + and "label" in raw_data[0] + and "predicted_gen_text" in raw_data[0] + and raw_data[0].get("fulltext") is None + and raw_data[0].get("summary_subclaims") is None + ) + + # Detect self-refine format: doc_id, label, final_summary as gen text; no fulltext/summary_subclaims + is_self_refine = ( + len(raw_data) > 0 + and "doc_id" in raw_data[0] + and "label" in raw_data[0] + and "final_summary" in raw_data[0] + and raw_data[0].get("fulltext") is None + and raw_data[0].get("summary_subclaims") is None + ) + + # Detect RL inference format: gold_label, input_text, subclaims, generated_text + is_rl_inference = ( + len(raw_data) > 0 + and "doc_id" in raw_data[0] + and "gold_label" in raw_data[0] + and "input_text" in raw_data[0] + and "generated_text" in raw_data[0] + and "subclaims" in raw_data[0] + ) + + if is_rl_inference: + print("RL inference format detected (gold_label, input_text, subclaims, generated_text)") + print(f"Loaded {len(raw_data)} RL inference items from {args.input}") + data = prepare_rl_inference_items(raw_data) + print(f"Prepared {len(data)} items for evaluation") + elif is_bon: + if not args.subclaims: + print("Error: BON format detected but --subclaims file not provided.", file=sys.stderr) + sys.exit(1) + with open(args.subclaims, "r", encoding="utf-8") as f: + subclaims_data = json.load(f) + print(f"BON format detected (model_key={args.model_key})") + print(f"Loaded {len(raw_data)} BON items from {args.input}") + print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}") + data = prepare_bon_items(raw_data, subclaims_data, model_key=args.model_key) + print(f"Prepared {len(data)} items for evaluation") + elif is_inference: + if not args.subclaims: + print("Error: Inference format detected (predicted_gen_text) but --subclaims file not provided.", file=sys.stderr) + sys.exit(1) + with open(args.subclaims, "r", encoding="utf-8") as f: + subclaims_data = json.load(f) + print("Inference format detected (predicted_gen_text as evaluated summary)") + print(f"Loaded {len(raw_data)} inference items from {args.input}") + print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}") + data = prepare_inference_items(raw_data, subclaims_data) + print(f"Prepared {len(data)} items for evaluation") + elif is_self_refine: + if not args.subclaims: + print("Error: Self-refine format detected (final_summary) but --subclaims file not provided.", file=sys.stderr) + sys.exit(1) + with open(args.subclaims, "r", encoding="utf-8") as f: + subclaims_data = json.load(f) + print("Self-refine format detected (final_summary as evaluated summary)") + print(f"Loaded {len(raw_data)} self-refine items from {args.input}") + print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}") + data = prepare_self_refine_items(raw_data, subclaims_data) + print(f"Prepared {len(data)} items for evaluation") + else: + data = raw_data + print(f"Loaded {len(data)} items from {args.input}") + + print(f" Support check API : {os.getenv('VLLM_SUPPORT_CHECK_BN_API_BASE', 'http://localhost:8090/v1')}") + print(f" Classifier API : {os.getenv('VLLM_CLASSIFIER_BN_API_BASE', 'http://localhost:8040/v1')}") + print(f" Subclaim extractor: {os.getenv('VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE', 'http://localhost:8050/v1')}") + if args.target_level: + print(f" Target level override: {args.target_level}") + print("-" * 60) + + # Evaluate each item + results = [] + start_time = time.time() + for idx, item in enumerate(tqdm(data, desc="Evaluating")): + r = evaluate_single(item, target_level_override=args.target_level) + r["index"] = idx + r["doc_id"] = item.get("doc_id", "") + results.append(r) + + if (idx + 1) % 10 == 0 or idx == 0: + partial_agg = compute_aggregate(results) + tqdm.write( + f" [{idx+1}/{len(data)}] " + f"fact={partial_agg['avg_factuality_score']} " + f"hallu={partial_agg['avg_hallucination_score']} " + f"cls={partial_agg['avg_classifier_score']}" + ) + + elapsed = time.time() - start_time + + # --- Validation: all items must be evaluated with non-null scores --- + expected_count = len(data) + skipped_items = [r for r in results if r.get("skipped", False)] + null_score_items = [] + for r in results: + if r.get("skipped", False): + continue + for key in ("factuality_score", "hallucination_score", "classifier_score", "grounding_score"): + if r.get(key) is None: + null_score_items.append((r.get("index"), r.get("doc_id"), key)) + + has_errors = False + if skipped_items: + has_errors = True + print(f"\nERROR: {len(skipped_items)} out of {expected_count} items were skipped:", file=sys.stderr) + for r in skipped_items: + print(f" index={r.get('index')} doc_id={r.get('doc_id')} reason={r.get('skip_reason')}", file=sys.stderr) + + if null_score_items: + has_errors = True + print(f"\nERROR: {len(null_score_items)} null score(s) found:", file=sys.stderr) + for idx, doc_id, key in null_score_items: + print(f" index={idx} doc_id={doc_id} null_field={key}", file=sys.stderr) + + if len(results) != expected_count: + has_errors = True + print(f"\nERROR: Expected {expected_count} results but got {len(results)}.", file=sys.stderr) + + if has_errors: + print(f"\nAborting: will NOT save results. All {expected_count} items must be fully evaluated with non-null scores.", file=sys.stderr) + sys.exit(1) + + # Aggregate + agg = compute_aggregate(results) + + # Per-label aggregates + label_groups: Dict[str, List[Dict[str, Any]]] = {} + for r in results: + lbl = r.get("target_level", "unknown") + label_groups.setdefault(lbl, []).append(r) + per_label_agg = {lbl: compute_aggregate(items) for lbl, items in sorted(label_groups.items())} + + # Output path + if args.output: + out_path = args.output + elif args.output_dir: + os.makedirs(args.output_dir, exist_ok=True) + stem = os.path.splitext(os.path.basename(args.input))[0] + out_path = os.path.join(args.output_dir, f"{stem}_eval_results.json") + else: + stem = os.path.splitext(os.path.basename(args.input))[0] + out_dir = os.path.dirname(args.input) or "." + out_path = os.path.join(out_dir, f"{stem}_eval_results.json") + + os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True) + + output = { + "input_file": os.path.abspath(args.input), + "subclaims_file": os.path.abspath(args.subclaims) if args.subclaims else None, + "model_key": args.model_key if is_bon else None, + "inference_format": is_inference if not is_bon else False, + "self_refine_format": is_self_refine if not is_bon and not is_inference else False, + "rl_inference_format": is_rl_inference, + "target_level_override": args.target_level, + "elapsed_seconds": round(elapsed, 2), + "aggregate": agg, + "per_label_aggregate": per_label_agg, + "per_item": results, + } + + with open(out_path, "w", encoding="utf-8") as f: + json.dump(output, f, indent=2, ensure_ascii=False) + + # Print summary + print("\n" + "=" * 60) + print("EVALUATION SUMMARY") + print("=" * 60) + print(f" Total items : {agg['total_items']}") + print(f" Scored items : {agg['scored_items']}") + print(f" Skipped items : {agg['skipped_items']}") + print(f" Elapsed time : {round(elapsed, 1)}s") + print("-" * 60) + print(f" Avg Factuality Score : {agg['avg_factuality_score']}") + print(f" Avg Hallucination Score: {agg['avg_hallucination_score']}") + print(f" Avg Grounding Score : {agg['avg_grounding_score']}") + print(f" Avg Classifier Score : {agg['avg_classifier_score']}") + print("-" * 60) + for lbl, la in per_label_agg.items(): + print(f" [{lbl}] items={la['scored_items']}" + f" fact={la['avg_factuality_score']}" + f" hallu={la['avg_hallucination_score']}" + f" cls={la['avg_classifier_score']}") + print("-" * 60) + print(f" Results saved to: {out_path}") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/code/fine_tune_sft_dpo/evaluate_scores_bn_vllm_rl.py b/code/fine_tune_sft_dpo/evaluate_scores_bn_vllm_rl.py new file mode 100644 index 0000000000000000000000000000000000000000..183719b1037cb9cb524f1aeb708af9214a6aa6ce --- /dev/null +++ b/code/fine_tune_sft_dpo/evaluate_scores_bn_vllm_rl.py @@ -0,0 +1,567 @@ +#!/usr/bin/env python3 +""" +Standalone evaluation script for computing factuality, hallucination, and +classifier scores on a JSON file. + +Supports two input formats: + +1. **Standard format** — a list of objects, each with: + - fulltext, summary_text, summary_subclaims, generated_text, label + +2. **Best-of-N (BON) format** — a list of objects, each with: + - doc_id, label, qwen3_base.best_summary (JSON-wrapped generated text) + Requires a separate --subclaims file to supply fulltext, summary, + summary_subclaims, and fulltext_subclaims (keyed by doc_id). + +3. **Inference format** — a list of objects, each with: + - doc_id, label, predicted_gen_text (JSON-wrapped evaluated summary), + optionally gold_gen_text + predicted_gen_text is the summary to evaluate (same JSON key-by-label + format as best_summary). Requires --subclaims for fulltext and subclaims. + +4. **Self-refine format** — a list of objects, each with: + - doc_id, label, final_summary (the generated text to evaluate), + optionally gold_gen_text, gold_summary + final_summary is the summary to evaluate (plain text or JSON-wrapped by + label). Requires --subclaims for fulltext and subclaims. + +The script reuses the reward functions from reward_new_v6_bn_v4_rmv_src_cov.py: + - factuality_score : fraction of summary subclaims supported by generated_text + - hallucination_score: fraction of gen subclaims NOT supported by fulltext + - classifier_score : whether generated_text matches the target literacy level + +Requires the same vLLM endpoints as the reward file: + - Support checker : VLLM_SUPPORT_CHECK_BN_API_BASE (default http://localhost:8090/v1) + - Classifier : VLLM_CLASSIFIER_BN_API_BASE (default http://localhost:8040/v1) + - Subclaim extractor: VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE (default http://localhost:8050/v1) + +Usage: + # Standard format + python evaluate_scores.py --input data.json [--output results.json] + + # BON format with subclaims file + python evaluate_scores.py --input bon_results.json --subclaims subclaims.json --output-dir evaluation/bn/ + + # Inference format (predicted_gen_text as evaluated summary) + python evaluate_scores.py --input test_inference_vllm_qwen3-4B_base.json --subclaims subclaims.json --output results.json + + # Self-refine format (final_summary as evaluated summary) + python evaluate_scores.py --input test_self_refine_vllm_qwen3_4B_base.json --subclaims subclaims.json --output-dir evaluation/bn/ +""" + +import argparse +import json +import os +import re +import sys +import time +from typing import Any, Dict, List, Optional + +from tqdm import tqdm + +# Import scoring utilities from the reward module (same directory). +from reward_new_v6_bn_v4_rmv_src_cov import ( + _call_support_api, + _compute_classifier_reward, + _extract_subclaims_from_text, + _is_bangla_text, + _nonlinear_grounding, + compute_rewards, +) + + +def extract_text_from_best_summary(best_summary: str, label: str) -> str: + """Extract the raw generated text from a BON best_summary string. + + The best_summary is a (possibly truncated) JSON string like: + '{"proficient_health_literacy": "...actual text..."}' + We locate the value after the label key and strip JSON wrapping. + """ + key_pattern = re.compile(re.escape(f'"{label}"') + r'\s*:\s*"') + m = key_pattern.search(best_summary) + if not m: + return best_summary.strip() + text = best_summary[m.end():] + if text.endswith('"\n}'): + text = text[:-3] + elif text.endswith('"}\n'): + text = text[:-3] + elif text.endswith('"}'): + text = text[:-2] + elif text.endswith('"'): + text = text[:-1] + text = text.replace("\\n", "\n").replace('\\"', '"') + return text.strip() + + +def prepare_bon_items( + bon_data: List[Dict[str, Any]], + subclaims_data: List[Dict[str, Any]], + model_key: str = "qwen3_base", +) -> List[Dict[str, Any]]: + """Merge BON results with subclaims data into the standard evaluation format.""" + sc_by_docid = {} + for item in subclaims_data: + sc_by_docid[item["doc_id"]] = item + + prepared = [] + for item in bon_data: + doc_id = item["doc_id"] + label = item["label"] + sc = sc_by_docid.get(doc_id, {}) + + model_data = item.get(model_key, {}) + best_summary = model_data.get("best_summary", "") or model_data.get("predicted_gen_text", "") + generated_text = extract_text_from_best_summary(best_summary, label) + + prepared.append({ + "doc_id": doc_id, + "label": label, + "fulltext": sc.get("fulltext", ""), + "summary_text": sc.get("summary", ""), + "summary_subclaims": sc.get("summary_subclaims", []), + "fulltext_subclaims": sc.get("fulltext_subclaims", []), + "generated_text": generated_text, + "gold_gen_text": item.get("gold_gen_text", ""), + "predicted_label": item.get("predicted_label", ""), + "prediction_correct": item.get("prediction_correct", False), + }) + return prepared + + +def prepare_inference_items( + inference_data: List[Dict[str, Any]], + subclaims_data: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Merge inference-format results (doc_id, label, predicted_gen_text) with + subclaims data into the standard evaluation format. predicted_gen_text is + the JSON-wrapped evaluated summary; the raw text is extracted using the + item's label. + """ + sc_by_docid = {} + for item in subclaims_data: + sc_by_docid[item["doc_id"]] = item + + prepared = [] + for item in inference_data: + doc_id = item["doc_id"] + label = item["label"] + sc = sc_by_docid.get(doc_id, {}) + + raw_pred = item.get("predicted_gen_text", "") or "" or item.get("generated_text", "") + generated_text = extract_text_from_best_summary(raw_pred, label) + + prepared.append({ + "doc_id": doc_id, + "label": label, + "fulltext": sc.get("fulltext", ""), + "summary_text": sc.get("summary", ""), + "summary_subclaims": sc.get("summary_subclaims", []), + "fulltext_subclaims": sc.get("fulltext_subclaims", []), + "generated_text": generated_text, + "gold_gen_text": item.get("gold_gen_text", ""), + }) + return prepared + + +def prepare_self_refine_items( + self_refine_data: List[Dict[str, Any]], + subclaims_data: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Merge self-refine format (doc_id, label, final_summary) with subclaims + data. final_summary is the generated text to evaluate (plain text or + JSON-wrapped by label); it is extracted and used as generated_text. + """ + sc_by_docid = {} + for item in subclaims_data: + sc_by_docid[item["doc_id"]] = item + + prepared = [] + for item in self_refine_data: + doc_id = item["doc_id"] + label = item["label"] + sc = sc_by_docid.get(doc_id, {}) + + raw_final = item.get("final_summary", "") or "" + generated_text = extract_text_from_best_summary(raw_final, label) + + prepared.append({ + "doc_id": doc_id, + "label": label, + "fulltext": sc.get("fulltext", ""), + "summary_text": sc.get("summary", ""), + "summary_subclaims": sc.get("summary_subclaims", []), + "fulltext_subclaims": sc.get("fulltext_subclaims", []), + "generated_text": generated_text, + "gold_gen_text": item.get("gold_gen_text", ""), + }) + return prepared + + +def evaluate_single( + item: Dict[str, Any], + target_level_override: Optional[str] = None, +) -> Dict[str, Any]: + """ + Evaluate a single item and return detailed scores. + """ + fulltext = item.get("fulltext", "") + summary_text = item.get("summary_text") or item.get("summary", "") + summary_subclaims = item.get("summary_subclaims", []) + generated_text = item.get("generated_text") or item.get("predicted_gen_text", "") + target_level = target_level_override or item.get("label", "") + + result: Dict[str, Any] = { + "doc_id": item.get("doc_id", ""), + "target_level": target_level, + "generated_text_len": len(generated_text.strip()) if generated_text else 0, + "factuality_score": None, + "hallucination_score": None, + "classifier_score": None, + "grounding_score": None, + "factuality_supported": 0, + "total_summary_subclaims": len(summary_subclaims) if summary_subclaims else 0, + "hallucination_supported": 0, + "total_gen_segments": 0, + "skipped": False, + "skip_reason": "", + } + + if not generated_text or len(generated_text.strip()) < 10: + result["skipped"] = True + result["skip_reason"] = "generated_text missing or too short (<10 chars)" + return result + + # -- Factuality & Hallucination via compute_rewards -- + rewards = compute_rewards( + fulltext=fulltext, + generated_text=generated_text, + target_level=target_level, + summary_subclaims=summary_subclaims, + summary_text=summary_text, + ) + + factuality_score = rewards["factuality_score"] + h_score = rewards["hallucination_score"] + + if factuality_score is None: + factuality_score = 0.5 + if h_score is None: + h_score = 0.5 + + grounding_score = _nonlinear_grounding(h_score) + + # -- Classifier -- + input_text = fulltext or "" + class_score = _compute_classifier_reward(target_level, generated_text, input_text) + + result.update({ + "factuality_score": round(factuality_score, 4), + "hallucination_score": round(h_score, 4), + "grounding_score": round(grounding_score, 4), + "classifier_score": round(class_score, 4), + "factuality_supported": rewards.get("factuality_supported", 0), + "total_summary_subclaims": rewards.get("total_summary_subclaims", 0), + "hallucination_supported": rewards.get("hallucination_supported", 0), + "total_gen_segments": rewards.get("total_gen_segments", 0), + }) + + return result + + +def compute_aggregate(results: List[Dict[str, Any]]) -> Dict[str, Any]: + """Compute aggregate statistics over all evaluated items.""" + scored = [r for r in results if not r.get("skipped", False)] + n = len(scored) + total = len(results) + skipped = total - n + + if n == 0: + return { + "total_items": total, + "scored_items": 0, + "skipped_items": skipped, + "avg_factuality_score": None, + "avg_hallucination_score": None, + "avg_grounding_score": None, + "avg_classifier_score": None, + } + + def safe_avg(key): + vals = [r[key] for r in scored if r[key] is not None] + return round(sum(vals) / len(vals), 4) if vals else None + + return { + "total_items": total, + "scored_items": n, + "skipped_items": skipped, + "avg_factuality_score": safe_avg("factuality_score"), + "avg_hallucination_score": safe_avg("hallucination_score"), + "avg_grounding_score": safe_avg("grounding_score"), + "avg_classifier_score": safe_avg("classifier_score"), + } + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate factuality, hallucination, and classifier scores on a JSON file." + ) + parser.add_argument( + "--input", "-i", required=True, + help="Path to input JSON file (list of objects).", + ) + parser.add_argument( + "--output", "-o", default=None, + help="Path to output JSON file with per-item scores. " + "Defaults to _eval_results.json.", + ) + parser.add_argument( + "--output-dir", default=None, + help="Directory to save output files. If set, output filename is derived " + "from input filename and placed in this directory.", + ) + parser.add_argument( + "--subclaims", "-s", default=None, + help="Path to subclaims JSON file (for BON format). Provides fulltext, " + "summary, summary_subclaims, and fulltext_subclaims keyed by doc_id.", + ) + parser.add_argument( + "--model-key", default="qwen3_base", + help="Key in the BON data containing candidates/best_summary (default: qwen3_base).", + ) + parser.add_argument( + "--target-level", "-t", default=None, + help="Override target literacy level for all items " + "(e.g. low_health_literacy). If not set, uses each item's 'label' field.", + ) + parser.add_argument( + "--support-check-url", default=None, + help="Override VLLM_SUPPORT_CHECK_BN_API_BASE.", + ) + parser.add_argument( + "--classifier-url", default=None, + help="Override VLLM_CLASSIFIER_BN_API_BASE.", + ) + parser.add_argument( + "--subclaim-extractor-url", default=None, + help="Override VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE.", + ) + args = parser.parse_args() + + if args.support_check_url: + os.environ["VLLM_SUPPORT_CHECK_BN_API_BASE"] = args.support_check_url + if args.classifier_url: + os.environ["VLLM_CLASSIFIER_BN_API_BASE"] = args.classifier_url + if args.subclaim_extractor_url: + os.environ["VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE"] = args.subclaim_extractor_url + + # Load input (supports both JSON array and JSONL) + with open(args.input, "r", encoding="utf-8") as f: + content = f.read().strip() + if content.startswith("["): + raw_data = json.loads(content) + else: + raw_data = [json.loads(line) for line in content.splitlines() if line.strip()] + + if not isinstance(raw_data, list): + print(f"Error: Expected a JSON list, got {type(raw_data).__name__}.", file=sys.stderr) + sys.exit(1) + + # Normalise field names from RL-inference JSONL format + for item in raw_data: + if "label" not in item and "gold_label" in item: + item["label"] = item["gold_label"] + if "fulltext" not in item and "input_text" in item: + item["fulltext"] = item["input_text"] + if "summary_subclaims" not in item and "subclaims" in item: + item["summary_subclaims"] = item["subclaims"] + + # Detect BON format: items have a model key (e.g. qwen3_base) with best_summary + is_bon = ( + len(raw_data) > 0 + and args.model_key in raw_data[0] + and "best_summary" in raw_data[0].get(args.model_key, {}) + ) + + # Detect inference format: top-level doc_id, label, predicted_gen_text; no fulltext/summary_subclaims + is_inference = ( + len(raw_data) > 0 + and "doc_id" in raw_data[0] + and "label" in raw_data[0] + and "predicted_gen_text" in raw_data[0] + and raw_data[0].get("fulltext") is None + and raw_data[0].get("summary_subclaims") is None + ) + + # Detect self-refine format: doc_id, label, final_summary as gen text; no fulltext/summary_subclaims + is_self_refine = ( + len(raw_data) > 0 + and "doc_id" in raw_data[0] + and "label" in raw_data[0] + and "final_summary" in raw_data[0] + and raw_data[0].get("fulltext") is None + and raw_data[0].get("summary_subclaims") is None + ) + + if is_bon: + if not args.subclaims: + print("Error: BON format detected but --subclaims file not provided.", file=sys.stderr) + sys.exit(1) + with open(args.subclaims, "r", encoding="utf-8") as f: + subclaims_data = json.load(f) + print(f"BON format detected (model_key={args.model_key})") + print(f"Loaded {len(raw_data)} BON items from {args.input}") + print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}") + data = prepare_bon_items(raw_data, subclaims_data, model_key=args.model_key) + print(f"Prepared {len(data)} items for evaluation") + elif is_inference: + if not args.subclaims: + print("Error: Inference format detected (predicted_gen_text) but --subclaims file not provided.", file=sys.stderr) + sys.exit(1) + with open(args.subclaims, "r", encoding="utf-8") as f: + subclaims_data = json.load(f) + print("Inference format detected (predicted_gen_text as evaluated summary)") + print(f"Loaded {len(raw_data)} inference items from {args.input}") + print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}") + data = prepare_inference_items(raw_data, subclaims_data) + print(f"Prepared {len(data)} items for evaluation") + elif is_self_refine: + if not args.subclaims: + print("Error: Self-refine format detected (final_summary) but --subclaims file not provided.", file=sys.stderr) + sys.exit(1) + with open(args.subclaims, "r", encoding="utf-8") as f: + subclaims_data = json.load(f) + print("Self-refine format detected (final_summary as evaluated summary)") + print(f"Loaded {len(raw_data)} self-refine items from {args.input}") + print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}") + data = prepare_self_refine_items(raw_data, subclaims_data) + print(f"Prepared {len(data)} items for evaluation") + else: + data = raw_data + print(f"Loaded {len(data)} items from {args.input}") + + print(f" Support check API : {os.getenv('VLLM_SUPPORT_CHECK_BN_API_BASE', 'http://localhost:8090/v1')}") + print(f" Classifier API : {os.getenv('VLLM_CLASSIFIER_BN_API_BASE', 'http://localhost:8040/v1')}") + print(f" Subclaim extractor: {os.getenv('VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE', 'http://localhost:8050/v1')}") + if args.target_level: + print(f" Target level override: {args.target_level}") + print("-" * 60) + + # Evaluate each item + results = [] + start_time = time.time() + for idx, item in enumerate(tqdm(data, desc="Evaluating")): + r = evaluate_single(item, target_level_override=args.target_level) + r["index"] = idx + r["doc_id"] = item.get("doc_id", "") + results.append(r) + + if (idx + 1) % 10 == 0 or idx == 0: + partial_agg = compute_aggregate(results) + tqdm.write( + f" [{idx+1}/{len(data)}] " + f"fact={partial_agg['avg_factuality_score']} " + f"hallu={partial_agg['avg_hallucination_score']} " + f"cls={partial_agg['avg_classifier_score']}" + ) + + elapsed = time.time() - start_time + + # --- Validation: all items must be evaluated with non-null scores --- + expected_count = len(data) + skipped_items = [r for r in results if r.get("skipped", False)] + null_score_items = [] + for r in results: + if r.get("skipped", False): + continue + for key in ("factuality_score", "hallucination_score", "classifier_score", "grounding_score"): + if r.get(key) is None: + null_score_items.append((r.get("index"), r.get("doc_id"), key)) + + has_errors = False + if skipped_items: + has_errors = True + print(f"\nERROR: {len(skipped_items)} out of {expected_count} items were skipped:", file=sys.stderr) + for r in skipped_items: + print(f" index={r.get('index')} doc_id={r.get('doc_id')} reason={r.get('skip_reason')}", file=sys.stderr) + + if null_score_items: + has_errors = True + print(f"\nERROR: {len(null_score_items)} null score(s) found:", file=sys.stderr) + for idx, doc_id, key in null_score_items: + print(f" index={idx} doc_id={doc_id} null_field={key}", file=sys.stderr) + + if len(results) != expected_count: + has_errors = True + print(f"\nERROR: Expected {expected_count} results but got {len(results)}.", file=sys.stderr) + + if has_errors: + print(f"\nAborting: will NOT save results. All {expected_count} items must be fully evaluated with non-null scores.", file=sys.stderr) + sys.exit(1) + + # Aggregate + agg = compute_aggregate(results) + + # Per-label aggregates + label_groups: Dict[str, List[Dict[str, Any]]] = {} + for r in results: + lbl = r.get("target_level", "unknown") + label_groups.setdefault(lbl, []).append(r) + per_label_agg = {lbl: compute_aggregate(items) for lbl, items in sorted(label_groups.items())} + + # Output path + if args.output: + out_path = args.output + elif args.output_dir: + os.makedirs(args.output_dir, exist_ok=True) + stem = os.path.splitext(os.path.basename(args.input))[0] + out_path = os.path.join(args.output_dir, f"{stem}_eval_results.json") + else: + stem = os.path.splitext(os.path.basename(args.input))[0] + out_dir = os.path.dirname(args.input) or "." + out_path = os.path.join(out_dir, f"{stem}_eval_results.json") + + os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True) + + output = { + "input_file": os.path.abspath(args.input), + "subclaims_file": os.path.abspath(args.subclaims) if args.subclaims else None, + "model_key": args.model_key if is_bon else None, + "inference_format": is_inference if not is_bon else False, + "self_refine_format": is_self_refine if not is_bon and not is_inference else False, + "target_level_override": args.target_level, + "elapsed_seconds": round(elapsed, 2), + "aggregate": agg, + "per_label_aggregate": per_label_agg, + "per_item": results, + } + + with open(out_path, "w", encoding="utf-8") as f: + json.dump(output, f, indent=2, ensure_ascii=False) + + # Print summary + print("\n" + "=" * 60) + print("EVALUATION SUMMARY") + print("=" * 60) + print(f" Total items : {agg['total_items']}") + print(f" Scored items : {agg['scored_items']}") + print(f" Skipped items : {agg['skipped_items']}") + print(f" Elapsed time : {round(elapsed, 1)}s") + print("-" * 60) + print(f" Avg Factuality Score : {agg['avg_factuality_score']}") + print(f" Avg Hallucination Score: {agg['avg_hallucination_score']}") + print(f" Avg Grounding Score : {agg['avg_grounding_score']}") + print(f" Avg Classifier Score : {agg['avg_classifier_score']}") + print("-" * 60) + for lbl, la in per_label_agg.items(): + print(f" [{lbl}] items={la['scored_items']}" + f" fact={la['avg_factuality_score']}" + f" hallu={la['avg_hallucination_score']}" + f" cls={la['avg_classifier_score']}") + print("-" * 60) + print(f" Results saved to: {out_path}") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/code/fine_tune_sft_dpo/evaluation/bn/bn_200_eval_results.json b/code/fine_tune_sft_dpo/evaluation/bn/bn_200_eval_results.json new file mode 100644 index 0000000000000000000000000000000000000000..14d40bc80db4c8c71e9492d942e193b0b17ef963 --- /dev/null +++ b/code/fine_tune_sft_dpo/evaluation/bn/bn_200_eval_results.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0cb20710f7c4b519b77457d541d9a132ded57b6a9252bc5552788cf358e9e436 +size 96048 diff --git a/code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_mini_eval_results_20260316_071029.json b/code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_mini_eval_results_20260316_071029.json new file mode 100644 index 0000000000000000000000000000000000000000..86088c00107b080c4b8334dee5350e1cca04a276 --- /dev/null +++ b/code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_mini_eval_results_20260316_071029.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:18c09e4d51cd941941cbc3c585699e3f7e2e98b051a5a26396bc242effa5438a +size 97818 diff --git a/code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_mini_prepared_20260316_071029.json b/code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_mini_prepared_20260316_071029.json new file mode 100644 index 0000000000000000000000000000000000000000..0f8b9b265f7b3876a1fa54be7dd79625c1101080 --- /dev/null +++ b/code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_mini_prepared_20260316_071029.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:97ac7c77d004ac04c4b06a848b28c43b9cdcd4edc8d1c97f757ec11769445aae +size 6569807 diff --git a/code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_nano_eval_results_20260316_071029.json b/code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_nano_eval_results_20260316_071029.json new file mode 100644 index 0000000000000000000000000000000000000000..63cba6528d73356916055e06055a187ec8046c33 --- /dev/null +++ b/code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_nano_eval_results_20260316_071029.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4cf921a153e1d69f2671c93cfc4a1d368cbcc3d3db0c5b30b706bd354402cb44 +size 97656 diff --git a/code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_nano_prepared_20260316_071029.json b/code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_nano_prepared_20260316_071029.json new file mode 100644 index 0000000000000000000000000000000000000000..c737e9b9b7cb1dd61ea3f79d3beb307090135f08 --- /dev/null +++ b/code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_nano_prepared_20260316_071029.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f6638c5e9c0feb85cf471fac04021473211fd645cc250b5e42a39483e3a6e1fd +size 6114304 diff --git a/code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_prepared_20260316_071029.json b/code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_prepared_20260316_071029.json new file mode 100644 index 0000000000000000000000000000000000000000..117eb7fff0108d708aa944d6da84bb11ebcddffa --- /dev/null +++ b/code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_prepared_20260316_071029.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:afefe278ca421f9eb698e638b2fcca1c0814525ac5a9a4bf576c0d3294290033 +size 6349574 diff --git a/code/fine_tune_sft_dpo/evaluation/bn/test_best_of_n_qwen3-4B_base_eval_results.json b/code/fine_tune_sft_dpo/evaluation/bn/test_best_of_n_qwen3-4B_base_eval_results.json new file mode 100644 index 0000000000000000000000000000000000000000..8b3c328e8fc1f09aa6b33fd85793bfe5759a54b9 --- /dev/null +++ b/code/fine_tune_sft_dpo/evaluation/bn/test_best_of_n_qwen3-4B_base_eval_results.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:85d23851eb74b15da0c466c3a11ffd43ae371db75523ee9e85a0cec6f1cf6b5e +size 97215 diff --git a/code/fine_tune_sft_dpo/evaluation/bn/test_best_of_n_qwen3-4B_sft_eval_results.json b/code/fine_tune_sft_dpo/evaluation/bn/test_best_of_n_qwen3-4B_sft_eval_results.json new file mode 100644 index 0000000000000000000000000000000000000000..c5d8b243f6fd5bfcd25427d1108644a8713fac12 --- /dev/null +++ b/code/fine_tune_sft_dpo/evaluation/bn/test_best_of_n_qwen3-4B_sft_eval_results.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f89d77f48683f42239d9bcdcbd469810f3f11d3d06d8e06d3442e69e34729535 +size 96710 diff --git a/code/fine_tune_sft_dpo/evaluation/bn/test_inference_vllm_qwen3-4B_base_eval_results.json b/code/fine_tune_sft_dpo/evaluation/bn/test_inference_vllm_qwen3-4B_base_eval_results.json new file mode 100644 index 0000000000000000000000000000000000000000..dc7bb8afb483baca8b3f40c72aa65b5b3bf3b2e0 --- /dev/null +++ b/code/fine_tune_sft_dpo/evaluation/bn/test_inference_vllm_qwen3-4B_base_eval_results.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2eb135210f5dc0f92170e415936ad34927979be1efdfb0bb6f629374c42b79e4 +size 97550 diff --git a/code/fine_tune_sft_dpo/evaluation/bn/test_inference_vllm_qwen3-4B_sft_eval_results.json b/code/fine_tune_sft_dpo/evaluation/bn/test_inference_vllm_qwen3-4B_sft_eval_results.json new file mode 100644 index 0000000000000000000000000000000000000000..0337c9ff1878d9400deb814cbd8ed1a3066667ba --- /dev/null +++ b/code/fine_tune_sft_dpo/evaluation/bn/test_inference_vllm_qwen3-4B_sft_eval_results.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b54519f530b0894294b4ec86b0af317f521bc9baff38e230dc46dccfb46e7d19 +size 97002 diff --git a/code/fine_tune_sft_dpo/evaluation/bn/test_self_refine_vllm_qwen3_4B_base_eval_results.json b/code/fine_tune_sft_dpo/evaluation/bn/test_self_refine_vllm_qwen3_4B_base_eval_results.json new file mode 100644 index 0000000000000000000000000000000000000000..9b7a87869b4984422d93b55e220c8dcbbc2e15a7 --- /dev/null +++ b/code/fine_tune_sft_dpo/evaluation/bn/test_self_refine_vllm_qwen3_4B_base_eval_results.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:346286565d0a6ec98ea218bd463a9322e6b1738bf5afb8235cddc2cfcb22d488 +size 97347 diff --git a/code/fine_tune_sft_dpo/evaluation/bn/test_self_refine_vllm_qwen3_4B_sft_eval_results.json b/code/fine_tune_sft_dpo/evaluation/bn/test_self_refine_vllm_qwen3_4B_sft_eval_results.json new file mode 100644 index 0000000000000000000000000000000000000000..07f32cc2362435c8916a045bc764405c9e79077f --- /dev/null +++ b/code/fine_tune_sft_dpo/evaluation/bn/test_self_refine_vllm_qwen3_4B_sft_eval_results.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9630e15bd5d4b3a74de3c0d97a5dc5fc40ddffb96eb16ea55b3804cf5db5a419 +size 97016 diff --git a/code/fine_tune_sft_dpo/evaluation/bn/v1/test_best_of_n_qwen3-4B_base_eval_results.json b/code/fine_tune_sft_dpo/evaluation/bn/v1/test_best_of_n_qwen3-4B_base_eval_results.json new file mode 100644 index 0000000000000000000000000000000000000000..ee0f09aa43823ffe1b8c535fe2057d24aac8f98a --- /dev/null +++ b/code/fine_tune_sft_dpo/evaluation/bn/v1/test_best_of_n_qwen3-4B_base_eval_results.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c121bcd6eae3ab09fcc37eaa3a6b49edf437b4700573f62febc5db58948d26e8 +size 97136 diff --git a/code/fine_tune_sft_dpo/evaluation/bn/v1/test_best_of_n_qwen3-4B_sft_eval_results.json b/code/fine_tune_sft_dpo/evaluation/bn/v1/test_best_of_n_qwen3-4B_sft_eval_results.json new file mode 100644 index 0000000000000000000000000000000000000000..d03a6a8c1b8ddf167d4bc6ed12c3118a86ebcef9 --- /dev/null +++ b/code/fine_tune_sft_dpo/evaluation/bn/v1/test_best_of_n_qwen3-4B_sft_eval_results.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ccf84db8e83c982492ad99641f4a01b882ff7e9a259159f7fbcb04105745776b +size 96671 diff --git a/code/fine_tune_sft_dpo/evaluation/bn/v1/test_inference_vllm_qwen3-4B_base_eval_results.json b/code/fine_tune_sft_dpo/evaluation/bn/v1/test_inference_vllm_qwen3-4B_base_eval_results.json new file mode 100644 index 0000000000000000000000000000000000000000..0f44eefd55a332e7f1f07de495fff7fbcafce517 --- /dev/null +++ b/code/fine_tune_sft_dpo/evaluation/bn/v1/test_inference_vllm_qwen3-4B_base_eval_results.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de2460d80063dd684264a474babc483a510e293e42d94e9392665b85e67f351c +size 97496 diff --git a/code/fine_tune_sft_dpo/evaluation/bn/v1/test_inference_vllm_qwen3-4B_sft_eval_results.json b/code/fine_tune_sft_dpo/evaluation/bn/v1/test_inference_vllm_qwen3-4B_sft_eval_results.json new file mode 100644 index 0000000000000000000000000000000000000000..194a3c854e131c9fcac05ba0ec4ae6e39d1a59ae --- /dev/null +++ b/code/fine_tune_sft_dpo/evaluation/bn/v1/test_inference_vllm_qwen3-4B_sft_eval_results.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:32bc196d3288b6355e7b16f9dc1bd31a7dbce1c2c837a0059c29e255cca698e1 +size 97006 diff --git a/code/fine_tune_sft_dpo/evaluation/bn/v1/test_self_refine_vllm_qwen3_4B_base_eval_results.json b/code/fine_tune_sft_dpo/evaluation/bn/v1/test_self_refine_vllm_qwen3_4B_base_eval_results.json new file mode 100644 index 0000000000000000000000000000000000000000..3d32171ec180d1cee67444cf51db21966f31f008 --- /dev/null +++ b/code/fine_tune_sft_dpo/evaluation/bn/v1/test_self_refine_vllm_qwen3_4B_base_eval_results.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b163e821453592fbcdb922291a803431114d0dfa828df32366cd58340da949ef +size 97349 diff --git a/code/fine_tune_sft_dpo/evaluation/bn/v1/test_self_refine_vllm_qwen3_4B_sft_eval_results.json b/code/fine_tune_sft_dpo/evaluation/bn/v1/test_self_refine_vllm_qwen3_4B_sft_eval_results.json new file mode 100644 index 0000000000000000000000000000000000000000..ff5739514e065e86647933d3a3303ce0b2501f65 --- /dev/null +++ b/code/fine_tune_sft_dpo/evaluation/bn/v1/test_self_refine_vllm_qwen3_4B_sft_eval_results.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e822e39c927b8f75d411d42b5e556696481987b5589cd9f6453bee8e65d068d7 +size 97013 diff --git a/code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt b/code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt new file mode 100644 index 0000000000000000000000000000000000000000..d2b00f7dd4cf6785afabe1f4be3a7bf88acb97f5 --- /dev/null +++ b/code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt @@ -0,0 +1,58 @@ +**সিস্টেম ভূমিকা:** + +আপনি একজন বিশেষজ্ঞ চিকিৎসা সম্পাদক এবং স্বাস্থ্য-সাক্ষরতা বিশেষজ্ঞ। আপনার কাজ হলো জটিল চিকিৎসাবিষয়ক লেখাকে পাঠকের স্বাস্থ্য-সাক্ষরতার স্তর অনুযায়ী তিনটি ভিন্ন সংস্করণে রূপান্তর করা। আপনাকে ইনপুটের মূল ভাষা অবশ্যই অপরিবর্তিত রাখতে হবে, তবে ভাষার জটিলতা স্তর অনুযায়ী সমন্বয় করতে হবে। সরলীকৃত সংস্করণগুলো যেন সঠিক ও গুরুত্বপূর্ণ তথ্য বজায় রাখে, সে জন্য আপনাকে মূল তথ্য ও বার্তাকে ভিত্তি হিসেবে ব্যবহার করতে হবে। + +**ব্যবহারকারী নির্দেশনা:** + +দয়া করে নিচের চিকিৎসাবিষয়ক সোর্স টেক্সট ব্যবহার করে স্বাস্থ্য‑সাক্ষরতার তিনটি ভিন্ন স্তরের জন্য আলাদা আলাদা সংস্করণ তৈরি করুন। + +### প্রতিটি স্তরের জন্য নির্দেশনা: + +1. স্তর: কম স্বাস্থ্য‑সাক্ষরতা (উচ্চ পাঠযোগ্যতা) + +লক্ষ্য পাঠক: যারা খুব সহজ, দৈনন্দিন ভাষায় দ্রুত বোঝার মতো ব্যাখ্যা চান। + +ভাষাগত লক্ষ্য: একদম ঘরোয়া/দৈনন্দিন কথাবার্তার ভাষা ব্যবহার করুন। সব ধরনের চিকিৎসাবিষয়ক জারগনকে সহজ ব্যাখ্যামূলক ভাষায় রূপান্তর করুন (যেমন, "renal" এর পরিবর্তে "কিডনির সমস্যা" বা "কিডনি" লিখুন)। + +তথ্যের ঘনত্ব: কেবলমাত্র "যা অবশ্যই জানা দরকার" ধরনের মূল তথ্যগুলো রাখুন। + +কৌশল: বেশি মাত্রায় পুনর্লিখন ও উদাহরণ/উপমা ব্যবহার করুন। প্রতি বাক্যে একটি করে মূল ধারণা রাখুন। + +বিশ্বস্ততা: লেখাটি অবশ্যই গোল্ড সামারি‑র সঙ্গে সম্পূর্ণ সামঞ্জস্যপূর্ণ হতে হবে। + +2. স্তর: মাঝারি স্বাস্থ্য‑সাক্ষরতা (মাঝারি পাঠযোগ্যতা) + +লক্ষ্য পাঠক: সাধারণ মানুষ, যারা সাধারণ সংবাদ বা তথ্যভিত্তিক লেখা পড়ে বুঝতে পারেন। + +ভাষাগত লক্ষ্য: মানিকৃত/সাধারণ শব্দভাণ্ডার ব্যবহার করুন। সাধারণভাবে পরিচিত চিকিৎসাবিষয়ক শব্দ ব্যবহার করা যেতে পারে, তবে অতিরিক্ত টেকনিক্যাল "ডাক্তারি ভাষা" এড়িয়ে চলুন বা সহজভাবে ব্যাখ্যা করুন। + +তথ্যের ঘনত্ব: ভারসাম্যপূর্ণ রাখুন। মূল বার্তাকে কেন্দ্র করে কাঠামো তৈরি করুন এবং প্রয়োজন অনুযায়ী সোর্স টেক্সট থেকে প্রাসঙ্গিক অতিরিক্ত প্রেক্ষাপট যোগ করুন। + +কৌশল: মাঝারি মাত্রার পুনর্লিখন করুন। অপ্রয়োজনীয় টেকনিক্যাল খুঁটিনাটি বাদ দিন, যাতে পাঠক অতিরিক্ত তথ্যের চাপে না পড়েন। + +বিশ্বস্ততা: লেখাটি যেন মূল বার্তা ও ধারাবাহিকতা বজায় রাখে। + +3. স্তর: উচ্চ স্বাস্থ্য‑সাক্ষরতা / প্রফিসিয়েন্ট (কম পাঠযোগ্যতা, উচ্চ জটিলতা) + +লক্ষ্য পাঠক: গবেষক, ক্লিনিশিয়ান বা অত্যন্ত সচেতন/অভিজ্ঞ রোগী। + +ভাষাগত লক্ষ্য: প্রয়োজনে টেকনিক্যাল ও একাডেমিক ভাষা ব্যবহার করুন। ক্লিনিক্যাল নির্ভুলতা ও চিকিৎসাবিজ্ঞানভিত্তিক সূক্ষ্ম দিকগুলোকে অগ্রাধিকার দিন। + +তথ্যের ঘনত্ব: বেশি রাখুন। পুরো সোর্স টেক্সট ব্যবহার করে ডেটা, শারীরবৃত্তীয় প্রক্রিয়া, পরিসংখ্যান ইত্যাদি প্রাসঙ্গিক তথ্য অন্তর্ভুক্ত করুন। + +কৌশল: যতটা সম্ভব কম পুনর্লিখন করুন। মূল টেকনিক্যাল পরিভাষা ও বাক্য গঠন অধিকাংশই অক্ষুণ্ণ রাখুন। + +বিশ্বস্ততা: সোর্স টেক্সটের সাথে ঘনিষ্ঠভাবে অনুগত থাকুন; প্রয়োজন হলে বৈজ্ঞানিক প্রেক্ষাপট বাড়াতে সম্পর্কিত উপ‑দাবি বা ব্যাখ্যা যোগ করতে পারেন। + + +আমি আপনাকে নিম্নোক্ত তথ্যগুলো প্রদান করব: + +- ইনপুট ভাষা: <<>> +- সোর্স টেক্সট (বিস্তারিত মূল লেখা): <<>> + +**আউটপুট ফরম্যাট (শুধু JSON):** + {{ + "low_health_literacy": "...", + "intermediate_health_literacy": "...", + "proficient_health_literacy": "..." + }} \ No newline at end of file diff --git a/code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_intermediate b/code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_intermediate new file mode 100644 index 0000000000000000000000000000000000000000..5a93c6fd475cbde28553260fa5203805841aa026 --- /dev/null +++ b/code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_intermediate @@ -0,0 +1,31 @@ +**সিস্টেম ভূমিকা:** + +আপনি একজন বিশেষজ্ঞ চিকিৎসা সম্পাদক এবং স্বাস্থ্য‑সাক্ষরতা বিশেষজ্ঞ। আপনার কাজ হলো জটিল চিকিৎসাবিষয়ক লেখাকে এমন একটি সংস্করণে রূপান্তর করা, যা মাঝারি স্বাস্থ্য‑সাক্ষরতা সম্পন্ন পাঠকদের জন্য উপযোগী। আপনাকে ইনপুটের মূল ভাষা অপরিবর্তিত রেখে ভাষার জটিলতা ও তথ্যের ঘনত্বকে ভারসাম্যপূর্ণ করতে হবে। সরলীকৃত লেখাটি যেন সঠিক ও গুরুত্বপূর্ণ তথ্য বজায় রাখে, সে জন্য আপনাকে মূল তথ্য ও বার্তাকে ভিত্তি হিসেবে ব্যবহার করতে হবে। + +**ব্যবহারকারী নির্দেশনা:** + +দয়া করে নিচের চিকিৎসাবিষয়ক সোর্স টেক্সট ব্যবহার করে **মাঝারি স্বাস্থ্য‑সাক্ষরতা (মাঝারি পাঠযোগ্যতা)** স্তরের জন্য একটি সংস্করণ তৈরি করুন। + +### নির্দেশনা: + +স্তর: মাঝারি স্বাস্থ্য‑সাক্ষরতা (মাঝারি পাঠযোগ্যতা) + +লক্ষ্য পাঠক: সাধারণ জনগণ, যারা সাধারণ সংবাদ বা তথ্যভিত্তিক লেখা পড়ে বুঝতে পারেন। + +ভাষাগত লক্ষ্য: মানিকৃত ও সহজবোধ্য শব্দভাণ্ডার ব্যবহার করুন। পরিচিত চিকিৎসাবিষয়ক শব্দ ব্যবহার করা যেতে পারে, তবে অতিরিক্ত টেকনিক্যাল "ডাক্তারি ভাষা" এলে তা সহজ ব্যাখ্যায় রূপান্তর করুন। + +তথ্যের ঘনত্ব: ভারসাম্যপূর্ণ রাখুন। মূল বার্তাকে সামনে রেখে মূল কাঠামো তৈরি করুন এবং প্রয়োজন হলে সোর্স টেক্সট থেকে প্রাসঙ্গিক অতিরিক্ত তথ্য বা প্রেক্ষাপট যোগ করুন। + +কৌশল: মাঝারি মাত্রার পুনর্লিখন করুন। অতি খুঁটিনাটি টেকনিক্যাল ডিটেইল বাদ দিন, যাতে পাঠক তথ্যের চাপে না পড়ে কিন্তু মূল বিষয়টি স্পষ্টভাবে বুঝতে পারে। + +বিশ্বস্ততা: লেখাটি যেন মূল বার্তা, ক্রম এবং যুক্তি বজায় রাখে। + +আমি আপনাকে নিম্নোক্ত তথ্যগুলো প্রদান করব: + +- ইনপুট ভাষা: {source_lang} +- সোর্স টেক্সট (বিস্তারিত মূল লেখা): {full_text} + +**আউটপুট ফরম্যাট (শুধু JSON):** + {{ + "intermediate_health_literacy": "..." + }} diff --git a/code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_low b/code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_low new file mode 100644 index 0000000000000000000000000000000000000000..d3bb7e616e11d22cd92eeb1d41d8790303db4c65 --- /dev/null +++ b/code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_low @@ -0,0 +1,31 @@ +**সিস্টেম ভূমিকা:** + +আপনি একজন বিশেষজ্ঞ চিকিৎসা সম্পাদক এবং স্বাস্থ্য‑সাক্ষরতা বিশেষজ্ঞ। আপনার কাজ হলো জটিল চিকিৎসাবিষয়ক লেখাকে এমনভাবে রূপান্তর করা, যা কম স্বাস্থ্য‑সাক্ষরতা সম্পন্ন পাঠকদের জন্য সহজে বোঝা যায়। আপনাকে ইনপুটের মূল ভাষা অপরিবর্তিত রাখতে হবে, তবে ভাষার জটিলতা কমিয়ে আনতে হবে। সরলীকৃত লেখাটি যেন সঠিক ও প্রয়োজনীয় থাকে, সে জন্য আপনাকে মূল তথ্য ও বার্তাকে ভিত্তি হিসেবে ব্যবহার করতে হবে। + +**ব্যবহারকারী নির্দেশনা:** + +দয়া করে নিচের চিকিৎসাবিষয়ক সোর্স টেক্সট ব্যবহার করে **কম স্বাস্থ্য‑সাক্ষরতা (উচ্চ পাঠযোগ্যতা)** স্তরের জন্য একটি সংস্করণ তৈরি করুন। + +### নির্দেশনা: + +স্তর: কম স্বাস্থ্য‑সাক্ষরতা (উচ্চ পাঠযোগ্যতা) + +লক্ষ্য পাঠক: এমন ব্যক্তি, যাঁরা খুব সহজ, সরাসরি ভাষায় তথ্য পেতে চান এবং তা থেকে দ্রুত পদক্ষেপ নিতে চান। + +ভাষাগত লক্ষ্য: একদম ঘরোয়া/দৈনন্দিন কথাবার্তার ভাষা ব্যবহার করুন। সব ধরনের চিকিৎসাবিষয়ক জারগনকে সহজ বর্ণনামূলক শব্দে রূপান্তর করুন (যেমন, "renal" এর পরিবর্তে "কিডনির সমস্যা" বা "কিডনি" লিখুন)। + +তথ্যের ঘনত্ব: কেবলমাত্র "অবশ্যই জানা দরকার" ধরনের মূল তথ্যগুলো রাখুন। অপ্রয়োজনীয় ব্যাখ্যা বা অতিরিক্ত ডেটা এড়িয়ে চলুন। + +কৌশল: উচ্চ মাত্রার পুনর্লিখন করুন এবং প্রয়োজন হলে সহজ উপমা বা উদাহরণ ব্যবহার করুন। প্রতিটি বাক্যে একটি করে স্পষ্ট ধারণা রাখুন। + +বিশ্বস্ততা: লেখাটি অবশ্যই গোল্ড সামারি‑র সাথে সম্পূর্ণ সামঞ্জস্যপূর্ণ হতে হবে; নতুন তথ্য যোগ করা যাবে না। + +আমি আপনাকে নিম্নোক্ত তথ্যগুলো প্রদান করব: + +- ইনপুট ভাষা: {source_lang} +- সোর্স টেক্সট (বিস্তারিত মূল লেখা): {full_text} + +**আউটপুট ফরম্যাট (শুধু JSON):** + {{ + "low_health_literacy": "..." + }} diff --git a/code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_proficient b/code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_proficient new file mode 100644 index 0000000000000000000000000000000000000000..3aa185264db5ec672f66a1c929e2644b640440ce --- /dev/null +++ b/code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_proficient @@ -0,0 +1,31 @@ +**সিস্টেম ভূমিকা:** + +আপনি একজন বিশেষজ্ঞ চিকিৎসা সম্পাদক এবং স্বাস্থ্য‑সাক্ষরতা বিশেষজ্ঞ। আপনার কাজ হলো জটিল চিকিৎসাবিষয়ক লেখাকে এমন একটি সংস্করণে রূপান্তর করা, যা উচ্চ/প্রফিসিয়েন্ট স্বাস্থ্য‑সাক্ষরতা সম্পন্ন পাঠকদের জন্য উপযোগী। আপনাকে ইনপুটের মূল ভাষা বজায় রেখে টেকনিক্যাল ও একাডেমিক ভাষার যথাযথ ব্যবহার করতে হবে। আপনি মূল তথ্যকে রেফারেন্স হিসেবে ব্যবহার করবেন, তবে প্রয়োজনে সোর্স টেক্সট থেকে গভীরতর বৈজ্ঞানিক প্রেক্ষাপটও যোগ করতে পারবেন। + +**ব্যবহারকারী নির্দেশনা:** + +দয়া করে নিচের চিকিৎসাবিষয়ক সোর্স টেক্সট ব্যবহার করে **উচ্চ/প্রফিসিয়েন্ট স্বাস্থ্য‑সাক্ষরতা (কম পাঠযোগ্যতা, উচ্চ জটিলতা)** স্তরের জন্য একটি সংস্করণ তৈরি করুন। + +### নির্দেশনা: + +স্তর: প্রফিসিয়েন্ট স্বাস্থ্য‑সাক্ষরতা (কম পাঠযোগ্যতা) + +লক্ষ্য পাঠক: গবেষক, ক্লিনিশিয়ান, বা অত্যন্ত সচেতন/অভিজ্ঞ রোগী। + +ভাষাগত লক্ষ্য: প্রয়োজন অনুযায়ী টেকনিক্যাল ও একাডেমিক ভাষা ব্যবহার করুন। ক্লিনিক্যাল সূক্ষ্মতা, প্যাথোফিজিওলজি, ডায়াগনস্টিক মানদণ্ড ইত্যাদির নির্ভুল উপস্থাপনাকে অগ্রাধিকার দিন। + +তথ্যের ঘনত্ব: উচ্চ রাখুন। সোর্স টেক্সট থেকে ডেটা, পরিসংখ্যান, শারীরবৃত্তীয় প্রক্রিয়া, চিকিৎসাপদ্ধতি এবং গবেষণালব্ধ তথ্য উপযুক্তভাবে অন্তর্ভুক্ত করুন। + +কৌশল: কম মাত্রার পুনর্লিখন করুন। মূল টেকনিক্যাল পরিভাষা, গঠন এবং গুরুত্বপূর্ণ বাক্যগুলো যতটা সম্ভব অক্ষুণ্ণ রাখুন; প্রয়োজনে কেবল ব্যাকরণগত বা শৈলগত সামঞ্জস্যের জন্য পরিবর্তন করুন। + +বিশ্বস্ততা: সোর্স টেক্সটের প্রতি ঘনিষ্ঠভাবে অনুগত থাকুন; প্রয়োজন হলে বৈজ্ঞানিক প্রেক্ষাপট ও ব্যাখ্যা সম্প্রসারণ করতে সম্পর্কিত উপ‑দাবি বা তথ্য যোগ করতে পারেন, তবে ভিত্তিহীন নতুন দাবি যোগ করবেন না। + +আমি আপনাকে নিম্নোক্ত তথ্যগুলো প্রদান করব: + +- ইনপুট ভাষা: {source_lang} +- সোর্স টেক্সট (বিস্তারিত মূল লেখা): {full_text} + +**আউটপুট ফরম্যাট (শুধু JSON):** + {{ + "proficient_health_literacy": "..." + }} diff --git a/code/fine_tune_sft_dpo/qwen3-inference-vllm_bn.py b/code/fine_tune_sft_dpo/qwen3-inference-vllm_bn.py index 6ac53693346bef6f8dd9bcadfe088d9bed88a40c..78907358123da0bb00972a71d682ede0658fd80c 100644 --- a/code/fine_tune_sft_dpo/qwen3-inference-vllm_bn.py +++ b/code/fine_tune_sft_dpo/qwen3-inference-vllm_bn.py @@ -7,16 +7,23 @@ merged model was saved to `/home/mshahidul/readctrl/code/fine_tune_sft_dpo/model import os os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -os.environ["CUDA_VISIBLE_DEVICES"] = "6" +os.environ["CUDA_VISIBLE_DEVICES"] = "5" import argparse import json +import re from datetime import datetime from vllm import LLM, SamplingParams from transformers import AutoTokenizer +def strip_think_blocks(text: str) -> str: + """Remove ... reasoning blocks from model output.""" + cleaned = re.sub(r".*?", "", text, flags=re.DOTALL).strip() + return cleaned if cleaned else text + + # Paths BASE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo" MODEL_DIR = os.path.join(BASE_DIR, "model", "bn") @@ -73,7 +80,7 @@ def parse_args(): p.add_argument( "--temperature", type=float, - default=0.0, + default=0.1, help="Sampling temperature for generation.", ) p.add_argument( @@ -147,7 +154,10 @@ def main(): user_prompt = build_user_message(prompts[label], fulltext, summary) chat = [{"role": "user", "content": user_prompt}] formatted = tokenizer.apply_chat_template( - chat, tokenize=False, add_generation_prompt=True + chat, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, ) batched_prompts.append(formatted) @@ -189,7 +199,7 @@ def main(): # Map generation results for this batch back to global indices for idx_in_batch, output in enumerate(outputs): original_idx = batch_indices[idx_in_batch] - text = output.outputs[0].text.strip() + text = strip_think_blocks(output.outputs[0].text) generated_texts[original_idx] = text timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") diff --git a/code/fine_tune_sft_dpo/qwen3_best_of_n.py b/code/fine_tune_sft_dpo/qwen3_best_of_n.py new file mode 100644 index 0000000000000000000000000000000000000000..a6cf4948ec8fda7cc1540f622aa482e3b529e21e --- /dev/null +++ b/code/fine_tune_sft_dpo/qwen3_best_of_n.py @@ -0,0 +1,238 @@ +""" +Run inference with the finetuned Bangla Qwen3 model on test_bn.json +and save the generation results under results/bn. +""" +import os +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = "5" +import argparse +import json +import os +from datetime import datetime +from typing import Any, Dict, List + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + + +# Paths (keep in sync with qwen3-finetune_bn.py) +BASE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo" +MODEL_SAVE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo/model/bn" +PROMPT_DIR = os.path.join(BASE_DIR, "prompt_bn") +TEST_JSON = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo/dataset/bn/test_bn.json" +RESULTS_DIR = os.path.join(BASE_DIR, "results", "bn") +SOURCE_LANG = "Bangla" + +LABEL_TO_PROMPT_FILE = { + "low_health_literacy": "prompt_low", + "intermediate_health_literacy": "prompt_intermediate", + "proficient_health_literacy": "prompt_proficient", +} + + +def load_prompts() -> Dict[str, str]: + """Load prompt templates from prompt_bn directory.""" + prompts: Dict[str, str] = {} + for label, filename in LABEL_TO_PROMPT_FILE.items(): + path = os.path.join(PROMPT_DIR, filename) + if os.path.isfile(path): + with open(path, "r", encoding="utf-8") as f: + prompts[label] = f.read() + else: + raise FileNotFoundError(f"Prompt file not found: {path}") + return prompts + + +def build_user_message( + prompt_template: str, + full_text: str, + gold_summary: str, + source_lang: str = SOURCE_LANG, +) -> str: + """Fill prompt template with full_text, gold_summary, source_lang.""" + return ( + prompt_template.replace("{full_text}", full_text) + .replace("{gold_summary}", gold_summary) + .replace("{source_lang}", source_lang) + ) + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser( + description="Run inference with finetuned Qwen3-4B Bangla model on test_bn.json." + ) + p.add_argument( + "--max-new-tokens", + type=int, + default=512, + help="Maximum number of new tokens to generate per sample.", + ) + p.add_argument( + "--temperature", + type=float, + default=0.7, + help="Sampling temperature.", + ) + p.add_argument( + "--top-p", + type=float, + default=0.9, + help="Top-p (nucleus) sampling value.", + ) + p.add_argument( + "--output-json", + type=str, + default="test_bn_qwen3-4B_sft_inference.json", + help=( + "Output JSON filename (saved under results/bn). " + "If it already exists, it will be overwritten." + ), + ) + return p.parse_args() + + +def load_model_and_tokenizer(model_dir: str): + """Load the merged finetuned model and tokenizer for inference.""" + if not os.path.isdir(model_dir): + raise FileNotFoundError( + f"Finetuned model directory not found: {model_dir}. " + "Make sure qwen3-finetune_bn.py was run with model saving enabled." + ) + + print(f"Loading tokenizer from {model_dir}") + tokenizer = AutoTokenizer.from_pretrained(model_dir) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + print(f"Loading model from {model_dir}") + if torch.cuda.is_available(): + model = AutoModelForCausalLM.from_pretrained( + model_dir, + torch_dtype=torch.bfloat16, + device_map="auto", + ) + else: + model = AutoModelForCausalLM.from_pretrained(model_dir) + + model.eval() + return model, tokenizer + + +def run_inference( + model, + tokenizer, + test_items: List[Dict[str, Any]], + prompts: Dict[str, str], + max_new_tokens: int, + temperature: float, + top_p: float, +) -> List[Dict[str, Any]]: + """Generate adapted texts for each test item.""" + results: List[Dict[str, Any]] = [] + + device = next(model.parameters()).device + + for idx, item in enumerate(test_items): + label = item.get("label") + fulltext = item.get("fulltext", "") + summary = item.get("summary", "") + + if not fulltext or label not in prompts: + # Keep the original item, but note that generation was skipped. + out_item = dict(item) + out_item["model_gen_text"] = "" + out_item["model_gen_skipped"] = True + results.append(out_item) + continue + + user_msg = build_user_message(prompts[label], fulltext, summary) + messages = [{"role": "user", "content": user_msg}] + + text = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=False, + ) + inputs = tokenizer(text, return_tensors="pt").to(device) + input_len = inputs["input_ids"].shape[-1] + + with torch.no_grad(): + generated_ids = model.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=True, + temperature=temperature, + top_p=top_p, + pad_token_id=tokenizer.eos_token_id, + ) + + gen_ids = generated_ids[0, input_len:] + gen_text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip() + + out_item = dict(item) + out_item["model_gen_text"] = gen_text + out_item["model_name"] = "qwen3-4B_sft_bn" + out_item["model_max_new_tokens"] = max_new_tokens + out_item["model_temperature"] = temperature + out_item["model_top_p"] = top_p + results.append(out_item) + + if (idx + 1) % 10 == 0: + print(f"Processed {idx + 1} / {len(test_items)} samples") + + return results + + +def main(): + args = parse_args() + + os.makedirs(RESULTS_DIR, exist_ok=True) + + print("Loading prompts from", PROMPT_DIR) + prompts = load_prompts() + + print("Loading test data from", TEST_JSON) + with open(TEST_JSON, "r", encoding="utf-8") as f: + test_items = json.load(f) + + print(f"Test samples: {len(test_items)}") + + model, tokenizer = load_model_and_tokenizer(MODEL_SAVE_DIR) + + print( + f"Running inference with max_new_tokens={args.max_new_tokens}, " + f"temperature={args.temperature}, top_p={args.top_p}" + ) + results = run_inference( + model=model, + tokenizer=tokenizer, + test_items=test_items, + prompts=prompts, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + top_p=args.top_p, + ) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_filename = args.output_json + if not output_filename.endswith(".json"): + output_filename += ".json" + + output_path = os.path.join(RESULTS_DIR, output_filename) + + # If the filename already exists, append a timestamp to avoid silent overwrite. + if os.path.exists(output_path): + name, ext = os.path.splitext(output_filename) + output_filename = f"{name}_{timestamp}{ext}" + output_path = os.path.join(RESULTS_DIR, output_filename) + + print("Saving results to", output_path) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(results, f, ensure_ascii=False, indent=2) + + print("Done.") + + +if __name__ == "__main__": + main() + diff --git a/code/fine_tune_sft_dpo/qwen3_infer_bn.py b/code/fine_tune_sft_dpo/qwen3_infer_bn.py new file mode 100644 index 0000000000000000000000000000000000000000..736d4363c5af6e082b5f39728c950f945fb26fba --- /dev/null +++ b/code/fine_tune_sft_dpo/qwen3_infer_bn.py @@ -0,0 +1,247 @@ +""" +Run inference with the finetuned Bangla Qwen3 model on test_bn.json +and save the generation results under results/bn. +""" +import os +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = "5" +import argparse +import json +import os +import re +from datetime import datetime +from typing import Any, Dict, List + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + + +def strip_think_blocks(text: str) -> str: + """Remove ... reasoning blocks from model output.""" + cleaned = re.sub(r".*?", "", text, flags=re.DOTALL).strip() + return cleaned if cleaned else text + + +# Paths (keep in sync with qwen3-finetune_bn.py) +BASE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo" +MODEL_SAVE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo/model/bn" +PROMPT_DIR = os.path.join(BASE_DIR, "prompt_bn") +TEST_JSON = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo/dataset/bn/test_bn.json" +RESULTS_DIR = os.path.join(BASE_DIR, "results", "bn") +SOURCE_LANG = "Bangla" + +LABEL_TO_PROMPT_FILE = { + "low_health_literacy": "prompt_low", + "intermediate_health_literacy": "prompt_intermediate", + "proficient_health_literacy": "prompt_proficient", +} + + +def load_prompts() -> Dict[str, str]: + """Load prompt templates from prompt_bn directory.""" + prompts: Dict[str, str] = {} + for label, filename in LABEL_TO_PROMPT_FILE.items(): + path = os.path.join(PROMPT_DIR, filename) + if os.path.isfile(path): + with open(path, "r", encoding="utf-8") as f: + prompts[label] = f.read() + else: + raise FileNotFoundError(f"Prompt file not found: {path}") + return prompts + + +def build_user_message( + prompt_template: str, + full_text: str, + gold_summary: str, + source_lang: str = SOURCE_LANG, +) -> str: + """Fill prompt template with full_text, gold_summary, source_lang.""" + return ( + prompt_template.replace("{full_text}", full_text) + .replace("{gold_summary}", gold_summary) + .replace("{source_lang}", source_lang) + ) + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser( + description="Run inference with finetuned Qwen3-4B Bangla model on test_bn.json." + ) + p.add_argument( + "--max-new-tokens", + type=int, + default=512, + help="Maximum number of new tokens to generate per sample.", + ) + p.add_argument( + "--temperature", + type=float, + default=0.7, + help="Sampling temperature.", + ) + p.add_argument( + "--top-p", + type=float, + default=0.9, + help="Top-p (nucleus) sampling value.", + ) + p.add_argument( + "--output-json", + type=str, + default="test_bn_qwen3-4B_sft_inference.json", + help=( + "Output JSON filename (saved under results/bn). " + "If it already exists, it will be overwritten." + ), + ) + return p.parse_args() + + +def load_model_and_tokenizer(model_dir: str): + """Load the merged finetuned model and tokenizer for inference.""" + if not os.path.isdir(model_dir): + raise FileNotFoundError( + f"Finetuned model directory not found: {model_dir}. " + "Make sure qwen3-finetune_bn.py was run with model saving enabled." + ) + + print(f"Loading tokenizer from {model_dir}") + tokenizer = AutoTokenizer.from_pretrained(model_dir) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + print(f"Loading model from {model_dir}") + if torch.cuda.is_available(): + model = AutoModelForCausalLM.from_pretrained( + model_dir, + torch_dtype=torch.bfloat16, + device_map="auto", + ) + else: + model = AutoModelForCausalLM.from_pretrained(model_dir) + + model.eval() + return model, tokenizer + + +def run_inference( + model, + tokenizer, + test_items: List[Dict[str, Any]], + prompts: Dict[str, str], + max_new_tokens: int, + temperature: float, + top_p: float, +) -> List[Dict[str, Any]]: + """Generate adapted texts for each test item.""" + results: List[Dict[str, Any]] = [] + + device = next(model.parameters()).device + + for idx, item in enumerate(test_items): + label = item.get("label") + fulltext = item.get("fulltext", "") + summary = item.get("summary", "") + + if not fulltext or label not in prompts: + # Keep the original item, but note that generation was skipped. + out_item = dict(item) + out_item["model_gen_text"] = "" + out_item["model_gen_skipped"] = True + results.append(out_item) + continue + + user_msg = build_user_message(prompts[label], fulltext, summary) + messages = [{"role": "user", "content": user_msg}] + + text = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=False, + enable_thinking=False, + ) + inputs = tokenizer(text, return_tensors="pt").to(device) + input_len = inputs["input_ids"].shape[-1] + + with torch.no_grad(): + generated_ids = model.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=True, + temperature=temperature, + top_p=top_p, + pad_token_id=tokenizer.eos_token_id, + ) + + gen_ids = generated_ids[0, input_len:] + gen_text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip() + gen_text = strip_think_blocks(gen_text) + + out_item = dict(item) + out_item["model_gen_text"] = gen_text + out_item["model_name"] = "qwen3-4B_sft_bn" + out_item["model_max_new_tokens"] = max_new_tokens + out_item["model_temperature"] = temperature + out_item["model_top_p"] = top_p + results.append(out_item) + + if (idx + 1) % 10 == 0: + print(f"Processed {idx + 1} / {len(test_items)} samples") + + return results + + +def main(): + args = parse_args() + + os.makedirs(RESULTS_DIR, exist_ok=True) + + print("Loading prompts from", PROMPT_DIR) + prompts = load_prompts() + + print("Loading test data from", TEST_JSON) + with open(TEST_JSON, "r", encoding="utf-8") as f: + test_items = json.load(f) + + print(f"Test samples: {len(test_items)}") + + model, tokenizer = load_model_and_tokenizer(MODEL_SAVE_DIR) + + print( + f"Running inference with max_new_tokens={args.max_new_tokens}, " + f"temperature={args.temperature}, top_p={args.top_p}" + ) + results = run_inference( + model=model, + tokenizer=tokenizer, + test_items=test_items, + prompts=prompts, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + top_p=args.top_p, + ) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_filename = args.output_json + if not output_filename.endswith(".json"): + output_filename += ".json" + + output_path = os.path.join(RESULTS_DIR, output_filename) + + # If the filename already exists, append a timestamp to avoid silent overwrite. + if os.path.exists(output_path): + name, ext = os.path.splitext(output_filename) + output_filename = f"{name}_{timestamp}{ext}" + output_path = os.path.join(RESULTS_DIR, output_filename) + + print("Saving results to", output_path) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(results, f, ensure_ascii=False, indent=2) + + print("Done.") + + +if __name__ == "__main__": + main() + diff --git a/code/fine_tune_sft_dpo/results/bn/gpt5_inference_all_20260314_110445.jsonl b/code/fine_tune_sft_dpo/results/bn/gpt5_inference_all_20260314_110445.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..076984d51538bf32035d28e35326f415009ad037 --- /dev/null +++ b/code/fine_tune_sft_dpo/results/bn/gpt5_inference_all_20260314_110445.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:72cc6e0f72af25f7b56ac4c623d23be4d57d35e09473e4a012eff27680942e85 +size 16659336 diff --git a/code/fine_tune_sft_dpo/results/bn/gpt5_inference_all_wo_gs_20260314_173736.jsonl b/code/fine_tune_sft_dpo/results/bn/gpt5_inference_all_wo_gs_20260314_173736.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..f58d97e4d7a7dea36c6556bc73d0286891d8959a --- /dev/null +++ b/code/fine_tune_sft_dpo/results/bn/gpt5_inference_all_wo_gs_20260314_173736.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:195bb9580d98d1a43f642f06f2c982397336c42cb0cf68bc641138997b699f9e +size 17385349 diff --git a/code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-mini_20260314_110445.jsonl b/code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-mini_20260314_110445.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..0a8d1b2493fd826f0fd503aff1c37e96d13f2077 --- /dev/null +++ b/code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-mini_20260314_110445.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:20bf9c12039c4bc33421452806e48ddc1ef7dca13403de06b95642cdf27e4334 +size 5984175 diff --git a/code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-mini_20260314_173736.jsonl b/code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-mini_20260314_173736.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..125eeb88a4d31768c3ae0ef00d2dc1ba72cee13c --- /dev/null +++ b/code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-mini_20260314_173736.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:49eef179000c958ac8986ee7a3fe0a79d44714b35e9a2a09ca68402eda9cd37f +size 6234645 diff --git a/code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-nano_20260314_110445.jsonl b/code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-nano_20260314_110445.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..37884f227fa35cc94ca8b3cc7b2690d77c80a1fb --- /dev/null +++ b/code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-nano_20260314_110445.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5a2a2acef3f83a916faee240a310e9595f6b6cb043c7e1081f014fce80c5d836 +size 5237182 diff --git a/code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-nano_20260314_173736.jsonl b/code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-nano_20260314_173736.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..0ac0efb393b26027def0468ba402b05119b6c59a --- /dev/null +++ b/code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-nano_20260314_173736.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b5439fb64537f7f159f61b97f46f4f400238fee1c65fd08f7b514ffc46f9160a +size 5355465 diff --git a/code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5_20260314_110445.jsonl b/code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5_20260314_110445.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..e4ddf0d7dfb84f09a5c6cfbc37cd380b2b13cc21 --- /dev/null +++ b/code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5_20260314_110445.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fbd39a27ea21e26c4d1ba7d4ef85c41964713b655ff03574c5d260ec7b94a5e4 +size 5437979 diff --git a/code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5_20260314_173736.jsonl b/code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5_20260314_173736.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..b342db6aa399ccc57cc1faaf7d5a0eb8b25076a0 --- /dev/null +++ b/code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5_20260314_173736.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f16ff37f4c633d22a245b087dfb14583f14c3e27def8d896811e31a61af45e5 +size 5795239 diff --git a/code/fine_tune_sft_dpo/results/bn/gpt5_inference_summary_20260314_110445.json b/code/fine_tune_sft_dpo/results/bn/gpt5_inference_summary_20260314_110445.json new file mode 100644 index 0000000000000000000000000000000000000000..891268559d71900c96306b092a52c291fad89b0d --- /dev/null +++ b/code/fine_tune_sft_dpo/results/bn/gpt5_inference_summary_20260314_110445.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3a0b61606edceb252580fdb18a15ac237d5d69bd896bea1d430cad5a6113a152 +size 1684 diff --git a/code/fine_tune_sft_dpo/results/bn/gpt5_inference_summary_wo_gs_20260314_173736.json b/code/fine_tune_sft_dpo/results/bn/gpt5_inference_summary_wo_gs_20260314_173736.json new file mode 100644 index 0000000000000000000000000000000000000000..c994a51a4128b4e82f55e6f8670571a5fec66704 --- /dev/null +++ b/code/fine_tune_sft_dpo/results/bn/gpt5_inference_summary_wo_gs_20260314_173736.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb177630eb8e44f478897e4e96af2ead55760791e59605ce091ff49c8bb7f403 +size 1690 diff --git a/code/fine_tune_sft_dpo/results/bn/inference_summary_vllm_20260314_101627.json b/code/fine_tune_sft_dpo/results/bn/inference_summary_vllm_20260314_101627.json new file mode 100644 index 0000000000000000000000000000000000000000..9b5f3e0a2cf7bb5075456c6b011132749f1633c5 --- /dev/null +++ b/code/fine_tune_sft_dpo/results/bn/inference_summary_vllm_20260314_101627.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b5d5cf9a06af93f982c38c5901fbd8478e4bce6ed6bdb3b5480ad1ff586187c2 +size 471 diff --git a/code/fine_tune_sft_dpo/results/bn/inference_summary_vllm_20260311_044629.json b/code/fine_tune_sft_dpo/results/bn/misc/inference_summary_vllm_20260311_044629.json similarity index 100% rename from code/fine_tune_sft_dpo/results/bn/inference_summary_vllm_20260311_044629.json rename to code/fine_tune_sft_dpo/results/bn/misc/inference_summary_vllm_20260311_044629.json diff --git a/code/fine_tune_sft_dpo/results/bn/inference_summary_vllm_20260311_045131.json b/code/fine_tune_sft_dpo/results/bn/misc/inference_summary_vllm_20260311_045131.json similarity index 100% rename from code/fine_tune_sft_dpo/results/bn/inference_summary_vllm_20260311_045131.json rename to code/fine_tune_sft_dpo/results/bn/misc/inference_summary_vllm_20260311_045131.json diff --git a/code/fine_tune_sft_dpo/results/bn/misc/test_best_of_n_qwen3-4B_sft.json b/code/fine_tune_sft_dpo/results/bn/misc/test_best_of_n_qwen3-4B_sft.json new file mode 100644 index 0000000000000000000000000000000000000000..8c163eeb8b32f0c399f2821b94bfab4b9eface92 --- /dev/null +++ b/code/fine_tune_sft_dpo/results/bn/misc/test_best_of_n_qwen3-4B_sft.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:57ceb2553a90adc8336508482293abc2c8c12d23dc1ff2d1488d86a02d943f12 +size 2852232 diff --git a/code/fine_tune_sft_dpo/results/bn/misc/test_best_of_n_qwen3-4B_sft_summary.json b/code/fine_tune_sft_dpo/results/bn/misc/test_best_of_n_qwen3-4B_sft_summary.json new file mode 100644 index 0000000000000000000000000000000000000000..15a2bf653d31480159628da676a1707152c82299 --- /dev/null +++ b/code/fine_tune_sft_dpo/results/bn/misc/test_best_of_n_qwen3-4B_sft_summary.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dde19c0eaf8223164be1b279444fe334917136e86b37ff09cb75311d825e1160 +size 503 diff --git a/code/fine_tune_sft_dpo/results/bn/test_inference_vllm_qwen3-4B_bn_sft.json b/code/fine_tune_sft_dpo/results/bn/misc/test_inference_vllm_qwen3-4B_bn_sft.json similarity index 100% rename from code/fine_tune_sft_dpo/results/bn/test_inference_vllm_qwen3-4B_bn_sft.json rename to code/fine_tune_sft_dpo/results/bn/misc/test_inference_vllm_qwen3-4B_bn_sft.json diff --git a/code/fine_tune_sft_dpo/results/bn/test_best_of_n_qwen3-4B_sft.json b/code/fine_tune_sft_dpo/results/bn/test_best_of_n_qwen3-4B_sft.json index 8c163eeb8b32f0c399f2821b94bfab4b9eface92..404a345c249d96ab0dc92bf64a9f1b47cf444222 100644 --- a/code/fine_tune_sft_dpo/results/bn/test_best_of_n_qwen3-4B_sft.json +++ b/code/fine_tune_sft_dpo/results/bn/test_best_of_n_qwen3-4B_sft.json @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:57ceb2553a90adc8336508482293abc2c8c12d23dc1ff2d1488d86a02d943f12 -size 2852232 +oid sha256:3b228311ab314c4af1c697bc14e9e33b9d0e73dd6a5fcd9feb7f1b676fd2f8d7 +size 2450937 diff --git a/code/fine_tune_sft_dpo/results/bn/test_best_of_n_qwen3-4B_sft_summary.json b/code/fine_tune_sft_dpo/results/bn/test_best_of_n_qwen3-4B_sft_summary.json index 15a2bf653d31480159628da676a1707152c82299..656a1dd4b3f84e2be78a4df163d543b8d497f49e 100644 --- a/code/fine_tune_sft_dpo/results/bn/test_best_of_n_qwen3-4B_sft_summary.json +++ b/code/fine_tune_sft_dpo/results/bn/test_best_of_n_qwen3-4B_sft_summary.json @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:dde19c0eaf8223164be1b279444fe334917136e86b37ff09cb75311d825e1160 +oid sha256:e137a19d0286ae5044045f8b813d55ed9956f4704e26eaa6293009e384b6bdb1 size 503 diff --git a/code/fine_tune_sft_dpo/results/bn/test_inference_vllm_qwen3-4B_sft.json b/code/fine_tune_sft_dpo/results/bn/test_inference_vllm_qwen3-4B_sft.json new file mode 100644 index 0000000000000000000000000000000000000000..011973b4b586077eafa2606c539c12cc494d7277 --- /dev/null +++ b/code/fine_tune_sft_dpo/results/bn/test_inference_vllm_qwen3-4B_sft.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:10031f2e3b0b0c88bba7f77d98fac60a982f393fd1d5a66b5f922e73f44e4359 +size 1344633 diff --git a/code/fine_tune_sft_dpo/results/bn/test_self_refine_vllm_qwen3_4B_sft.json b/code/fine_tune_sft_dpo/results/bn/test_self_refine_vllm_qwen3_4B_sft.json index ac64cd6f3f71576a59a43b40ca1ca6d1ffb16466..383aa7815da4771908a100e62f3ca24931d5a116 100644 --- a/code/fine_tune_sft_dpo/results/bn/test_self_refine_vllm_qwen3_4B_sft.json +++ b/code/fine_tune_sft_dpo/results/bn/test_self_refine_vllm_qwen3_4B_sft.json @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:15e86b8167744f3948bdf65a5c7cbf6e24558d6d775864a6e9b8f41404b13513 -size 9547804 +oid sha256:9019fff90b41cad73548460ee391cbdd3401e7f3069f29e469d352cb6da1ca91 +size 9548534 diff --git a/code/fine_tune_sft_dpo/reward_new_v6_bn_v4_rmv_src_cov.py b/code/fine_tune_sft_dpo/reward_new_v6_bn_v4_rmv_src_cov.py new file mode 100644 index 0000000000000000000000000000000000000000..089a8bf00932411a21056236f7522b0300cc9a2f --- /dev/null +++ b/code/fine_tune_sft_dpo/reward_new_v6_bn_v4_rmv_src_cov.py @@ -0,0 +1,835 @@ +import ast +import os +import re +import json +import argparse +from typing import Any, List, Dict, Optional +import warnings +import requests +warnings.filterwarnings("ignore") + +# vLLM endpoints (see script/vllm.sh: support-check port 8090, classifier port 8040) +# Use localhost when reward runs on same machine; set env vars for remote (e.g. http://HOST:8090/v1). +VLLM_SUPPORT_CHECK_BN_API_BASE = os.getenv( + "VLLM_SUPPORT_CHECK_BN_API_BASE", + "http://localhost:8090/v1", +) +# Both support-check and classifier use Bangla prompts. +SUPPORT_CHECK_PROMPT_LANGUAGE = "bn" + +VLLM_CLASSIFIER_BN_API_BASE = os.getenv( + "VLLM_CLASSIFIER_BN_API_BASE", + "http://localhost:8040/v1", +) + +# Subclaim-extractor vLLM endpoint (Bangla medical text → subclaim list) +VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE = os.getenv( + "VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE", + "http://localhost:8050/v1", +) + + +# --------------------------------------------------------------------------- +# Support check via vLLM (Gemma-3 fine-tuned; same prompt as support_check gemma3-finetune) +# --------------------------------------------------------------------------- + +def _build_support_list_user_prompt(context: str, subclaims: List[str]) -> str: + """Build list-of-subclaims support prompt in Bangla (matches support_check model_finetune/gemma3-finetune.py).""" + numbered = "\n".join(f"{idx + 1}. {sc}" for idx, sc in enumerate(subclaims)) + return ( + "আপনাকে একটি মেডিকেল কেস বর্ণনা এবং একাধিক সাবক্লেইমের তালিকা দেওয়া হবে। " + "প্রতিটি সাবক্লেইম টেক্সট দ্বারা সমর্থিত কি না তা নির্ধারণ করুন।\n\n" + f"টেক্সট:\n{context}\n\n" + f"সাবক্লেইমগুলোর তালিকা:\n{numbered}\n\n" + "প্রতিটি সাবক্লেইমের জন্য ক্রমানুসারে লেবেল দিন। " + "নির্দিষ্টভাবে একটি JSON array আকারে উত্তর দিন, যেমন:\n" + '["supported", "not_supported", ...]\n' + "অন্য কিছু লিখবেন না।" + ) + + +def _parse_support_label_array(raw_text: str) -> List[str]: + """Parse model output to list of 'supported' | 'not_supported' (matches support_check gemma3-finetune).""" + text = (raw_text or "").strip() + if not text: + return [] + if "```" in text: + text = text.replace("```json", "").replace("```", "").strip() + start = text.find("[") + end = text.rfind("]") + if start != -1 and end != -1 and end > start: + text = text[start : end + 1] + parsed = None + for parser in (json.loads, ast.literal_eval): + try: + parsed = parser(text) + break + except Exception: + continue + if not isinstance(parsed, list): + return [] + # import ipdb; ipdb.set_trace() + normalized = [] + for item in parsed: + if not isinstance(item, str): + normalized.append("invalid") + continue + + label = item.strip().lower().replace("-", "_").replace(" ", "_") + # Strict keyword check: if not one of these two, it is invalid + if label in {"supported", "not_supported"}: + normalized.append(label) + else: + normalized.append("invalid") + return normalized + + +def _format_gemma3_for_support(user_message: str) -> str: + """Format user message for Gemma-3 chat (vLLM).""" + return ( + f"user\n{user_message}\n" + "model\n" + ) + + +def _call_vllm_support_check(prompt: str, max_tokens: int = 512, timeout: float = 120.0) -> Optional[str]: + """Call vLLM completions API for support-check model. Returns generated text or None.""" + base = VLLM_SUPPORT_CHECK_BN_API_BASE.rstrip("/") + url = f"{base}/completions" + payload = { + "prompt": prompt, + "max_tokens": max_tokens, + "temperature": 0.0, + "stop": ["", "", "\n\n"], + } + try: + resp = requests.post(url, json=payload, timeout=timeout) + resp.raise_for_status() + data = resp.json() + choices = data.get("choices") + # import ipdb; ipdb.set_trace() + if choices and len(choices) > 0: + text = choices[0].get("text", "") + return (text or "").strip() + return None + except requests.exceptions.RequestException: + return None + + +def _call_support_api( + context: str, + subclaims: List[str], + threshold: float = 0.5, + batch_size: int = 128, +) -> Optional[List[str]]: + """ + Call the BN support-check model via vLLM (Gemma-3 fine-tuned, subclaim_list mode). + Same prompt as support_check/model_finetune/gemma3-finetune.py. + + Returns + ------- + List[str] : one label per subclaim — "supported" | "not_supported" | "invalid". + None : on total failure (network or empty/unparseable response). + """ + if not context or not subclaims: + return ["invalid"] * len(subclaims) + + user_prompt = _build_support_list_user_prompt(context, subclaims) + prompt = _format_gemma3_for_support(user_prompt) + raw = _call_vllm_support_check(prompt) + # import ipdb; ipdb.set_trace() + if raw is None: + print("Warning: Support check vLLM call failed (returning None).") + return None + + labels = _parse_support_label_array(raw) + n = len(subclaims) + if len(labels) < n: + labels = labels + ["invalid"] * (n - len(labels)) + elif len(labels) > n: + labels = labels[:n] + # import ipdb; ipdb.set_trace() + return labels + + +# --------------------------------------------------------------------------- +# Subclaim extractor (Bangla, vLLM) + sentence splitter fallback +# --------------------------------------------------------------------------- + +# Minimum character length for a sentence to be considered a real unit. +# Used only as a fallback when subclaim extraction is unavailable. +MIN_SENTENCE_CHARS = 15 + + +def _is_bangla_text(text: str, min_bangla_ratio: float = 0.6) -> bool: + """ + Heuristic check: returns True if the majority of alphabetic characters + in `text` are Bangla (Unicode block \u0980–\u09FF). + """ + if not text: + return False + bangla_chars = 0 + alpha_chars = 0 + for ch in text: + if ch.isalpha(): + alpha_chars += 1 + if "\u0980" <= ch <= "\u09FF": + bangla_chars += 1 + if alpha_chars == 0: + return False + return (bangla_chars / alpha_chars) >= min_bangla_ratio + + +def _split_into_sentences(text: str, min_chars: int = MIN_SENTENCE_CHARS) -> List[str]: + """ + Split text into sentences at [.!?] boundaries. + Segments shorter than `min_chars` characters are dropped to + prevent micro-fragment padding from gaming ratio-based scores. + """ + if not text or not text.strip(): + return [] + parts = re.split(r"(?<=[.!?\u0964])\s+", text.strip()) + return [s.strip() for s in parts if len(s.strip()) >= min_chars] + + +def _build_subclaim_extraction_prompt(medical_text: str) -> str: + """ + Bangla subclaim-extraction prompt (same wording as `extract_bn_subclaims_vllm.py`, + generalized to "medical text" so it works for any generated explanation). + """ + return f""" +You are an expert medical annotator. The following text is in Bangla (Bengali). + +Your task is to extract granular, factual subclaims from the provided medical text. +A subclaim is the smallest standalone factual unit that can be independently verified. + +Instructions: +1. Read the Bangla medical text carefully. +2. Extract factual statements explicitly stated in the text. +3. Each subclaim must: + - Be in Bangla (same language as the input) + - Contain exactly ONE factual assertion + - Come directly from the text (no inference or interpretation) + - Preserve original wording as much as possible + - Include any negation, uncertainty, or qualifier +4. Do NOT: + - Combine multiple facts into one subclaim + - Add new information + - Translate to another language +5. Return ONLY a valid JSON array of strings. +6. Use double quotes and valid JSON formatting only (no markdown, no commentary). + +Medical Text (Bangla): +{medical_text} + +Return format: +[ + "subclaim 1", + "subclaim 2" +] +""".strip() + + +def _strip_markdown_json_block(text: str) -> str: + """Strip optional markdown code fence (e.g. ```json\\n[...]\\n```), if present.""" + text = (text or "").strip() + if not text: + return "" + if text.startswith("```json"): + text = text[7:].lstrip("\n") + elif text.startswith("```"): + text = text[3:].lstrip("\n") + if text.endswith("```"): + text = text[:-3].rstrip("\n") + return text.strip() + + +def _parse_subclaim_list_output(output_text: str) -> List[str]: + """Parse subclaim-extractor model output into a list of Bangla subclaims.""" + output_text = (output_text or "").strip() + if not output_text: + return [] + + if "" in output_text: + output_text = output_text.split("")[-1].strip() + + output_text = _strip_markdown_json_block(output_text) + + start_idx = output_text.find("[") + end_idx = output_text.rfind("]") + 1 + if start_idx != -1 and end_idx > start_idx: + content = output_text[start_idx:end_idx] + parsed = json.loads(content) + if isinstance(parsed, list): + return [str(s).strip() for s in parsed if str(s).strip()] + + raise ValueError("Incomplete or invalid JSON list") + + +def _call_vllm_subclaim_extractor( + text: str, + max_tokens: int = 2048, + temperature: float = 0.2, + timeout: float = 120.0, +) -> Optional[List[str]]: + """ + Call Bangla subclaim-extractor model via vLLM (OpenAI /chat/completions). + + Returns a list of subclaims on success, or None on total failure. + """ + if not text or not text.strip(): + return [] + + base = VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE.rstrip("/") + url = f"{base}/chat/completions" + + prompt = _build_subclaim_extraction_prompt(text) + payload = { + "model": os.getenv("VLLM_SUBCLAIM_EXTRACTOR_MODEL_NAME", "subclaim-extractor"), + "messages": [{"role": "user", "content": prompt}], + "temperature": temperature, + "max_tokens": max_tokens, + } + + try: + resp = requests.post(url, json=payload, timeout=timeout) + resp.raise_for_status() + data = resp.json() + choices = data.get("choices") or [] + if not choices: + return None + content = (choices[0].get("message", {}) or {}).get("content", "") or "" + # import ipdb; ipdb.set_trace() + return _parse_subclaim_list_output(content) + except Exception: + return None + + +def _extract_subclaims_from_text(text: str) -> List[str]: + """ + Extract Bangla subclaims from generated text using the vLLM subclaim-extractor. + + On failure (e.g., server down or parse error), falls back to sentence splitting + so the rest of the reward logic can still operate. + """ + subclaims = _call_vllm_subclaim_extractor(text) + if subclaims is None: + # Fallback: keep system running even if extractor is unavailable. + return _split_into_sentences(text) + return subclaims + + +# --------------------------------------------------------------------------- +# Two reward signals: +# 1. Factuality — summary subclaims vs gen_text (how much summary info is in gen_text) +# 2. Hallucination — gen_segments vs fulltext (how much gen info is NOT in fulltext) +# --------------------------------------------------------------------------- + +def compute_rewards( + fulltext: str, + generated_text: str, + target_level: str, + summary_subclaims: Optional[List[str]] = None, + summary_text: Optional[str] = None, + threshold: float = 0.5, + batch_size: int = 128, +) -> Dict[str, Optional[float]]: + """ + Compute two independent reward signals. + + 1. **Factuality** (summary_subclaims → gen_text): + Use pre-extracted *summary_subclaims*, check how many are supported + by the generated text. Measures "how much of the summary's information + made it into the output". + + 2. **Hallucination** (gen_segments → fulltext): + Extract subclaims from the *generated text* (gen_segments), then check + how many are supported by the source fulltext. The *unsupported* + fraction is the hallucination score (lower is better). + + Returns dict with: + factuality_score : [0,1] fraction of summary subclaims supported by gen_text + factuality_supported : int count + total_summary_subclaims : int + hallucination_score : [0,1] fraction of gen_segments NOT supported by fulltext + hallucination_supported : int count of gen_segments supported by fulltext + total_gen_segments : int + """ + result: Dict[str, Any] = { + "factuality_score": None, + "factuality_supported": 0, + "total_summary_subclaims": 0, + "hallucination_score": None, + "hallucination_supported": 0, + "total_gen_segments": 0, + } + + gen_segments = _extract_subclaims_from_text(generated_text) + + if not gen_segments: + result.update({ + "hallucination_score": 0.0, + "factuality_score": 0.0, + }) + return result + + total_gen = len(gen_segments) + result["total_gen_segments"] = total_gen + + # ===================================================================== + # 1. FACTUALITY — summary subclaims checked against gen_text + # "How much information from the summary exists in the generated text?" + # ===================================================================== + factuality_score = None + if summary_subclaims and len(summary_subclaims) > 0: + result["total_summary_subclaims"] = len(summary_subclaims) + + labels_summary_vs_gen = _call_support_api( + context=generated_text, + subclaims=summary_subclaims, + threshold=threshold, + batch_size=batch_size, + ) + if labels_summary_vs_gen is not None: + valid = [l for l in labels_summary_vs_gen if str(l).strip().lower() != "invalid"] + if valid: + sup = sum(1 for l in valid if str(l).strip().lower() == "supported") + factuality_score = sup / len(summary_subclaims) + result["factuality_supported"] = sup + else: + factuality_score = 0.0 + + result["factuality_score"] = factuality_score + + # ===================================================================== + # 2. HALLUCINATION — gen_segments checked against fulltext + # "How much info in gen_segments is NOT supported by the fulltext?" + # ===================================================================== + hallucination_score = None + if fulltext and fulltext.strip(): + labels_gen_vs_full = _call_support_api( + context=fulltext, + subclaims=gen_segments, + threshold=threshold, + batch_size=batch_size, + ) + if labels_gen_vs_full is not None and len(labels_gen_vs_full) > 0: + sup_full = sum( + 1 for l in labels_gen_vs_full + if str(l).strip().lower() == "supported" + ) + + unsupported_indices = [ + i for i, l in enumerate(labels_gen_vs_full) + if str(l).strip().lower() != "supported" + ] + + if unsupported_indices and summary_text and summary_text.strip(): + unsup_segments = [gen_segments[i] for i in unsupported_indices] + rescue_labels = _call_support_api( + context=summary_text, + subclaims=unsup_segments, + threshold=threshold, + batch_size=batch_size, + ) + if rescue_labels: + rescued = sum( + 1 for l in rescue_labels + if str(l).strip().lower() == "supported" + ) + sup_full += rescued + + hallucination_score = max(0.0, (total_gen - sup_full) / total_gen) + result["hallucination_supported"] = sup_full + else: + hallucination_score = 0.0 + + result["hallucination_score"] = hallucination_score + + return result + + +# --------------------------------------------------------------------------- +# BN health-literacy classifier via vLLM (Gemma-3 fine-tuned model) +# Uses Bangla prompt; model is assumed running in vLLM. +# --------------------------------------------------------------------------- + + +def build_classification_user_prompt(fulltext: str, gen_text: str) -> str: + """Build the classification user prompt in English (matches gemma3-finetune.py). Full text is reference; generated text is what to classify.""" + return ( + "You will be given a medical case description as reference (full text) and a generated text to classify. " + "Determine the patient's health literacy level based only on the generated text.\n\n" + f"Reference (full text):\n{fulltext}\n\n" + f"Generated text (to classify):\n{gen_text}\n\n" + "Reply with exactly one label from this set:\n" + "low_health_literacy, intermediate_health_literacy, proficient_health_literacy" + ) + + +def format_gemma3_prompt(user_message: str) -> str: + """Format user message for Gemma-3 chat (vLLM expects this).""" + return ( + f"user\n{user_message}\n" + "model\n" + ) + + +def _call_vllm_classifier(prompt: str, max_tokens: int = 64, timeout: float = 60.0) -> Optional[str]: + """ + Call vLLM completions API. Returns generated text or None on failure. + """ + url = f"{VLLM_CLASSIFIER_BN_API_BASE.rstrip('/')}/completions" + payload = { + "prompt": prompt, + "max_tokens": max_tokens, + "temperature": 0.0, + "stop": ["", "", "\n\n",""], + } + try: + resp = requests.post(url, json=payload, timeout=timeout) + resp.raise_for_status() + data = resp.json() + choices = data.get("choices") + # import ipdb; ipdb.set_trace() + if choices and len(choices) > 0: + text = choices[0].get("text", "") + return (text or "").strip() + return None + except requests.exceptions.RequestException as exc: + return None + + +def _parse_classifier_output(raw: str) -> str: + """ + Extract health literacy label from model output. Normalize to + low_health_literacy | intermediate_health_literacy | proficient_health_literacy. + Returns empty string if no valid label found. + """ + if not raw: + return "" + raw = raw.strip().lower() + # Take first line and clean + first_line = raw.split("\n")[0].strip() + for label in ["low_health_literacy", "intermediate_health_literacy", "proficient_health_literacy"]: + if label in first_line or label in raw: + # import ipdb; ipdb.set_trace() + return label + return "" + + +_CLASSIFIER_ERROR_LOGGED = False + + +def _predict_label(input_text: str, generated_text: str) -> str: + """ + Run BN health-literacy classifier via vLLM (Gemma-3 fine-tuned). + Uses fulltext=input_text and gen_text=generated_text; returns normalized label or "". + """ + global _CLASSIFIER_ERROR_LOGGED + try: + user_prompt = build_classification_user_prompt(input_text or "", generated_text or "") + prompt = format_gemma3_prompt(user_prompt) + raw = _call_vllm_classifier(prompt) + # import ipdb; ipdb.set_trace() + if raw is None: + if not _CLASSIFIER_ERROR_LOGGED: + print("Warning: BN classifier vLLM call failed, continuing without it.") + _CLASSIFIER_ERROR_LOGGED = True + return "" + return _parse_classifier_output(raw) + except Exception as exc: + if not _CLASSIFIER_ERROR_LOGGED: + print(f"Warning: BN literacy classifier unavailable, continuing without it: {exc}") + _CLASSIFIER_ERROR_LOGGED = True + return "" + + +def _parse_solution_json(solution_str): + if isinstance(solution_str, (dict, list)): + return solution_str + try: + cleaned_str = str(solution_str).strip() + if "```json" in cleaned_str: + cleaned_str = cleaned_str.split("```json")[1].split("```")[0].strip() + elif "```" in cleaned_str: + cleaned_str = cleaned_str.split("```")[1].split("```")[0].strip() + return json.loads(cleaned_str) + except Exception: + return None + + +def _compute_classifier_reward(target_level: str, gen_text: str, input_text: str) -> float: + """ + Soft classifier score in [0, 1] (NOT binary +1/-1). + + 1.0 — predicted label matches target level (correct style) + 0.0 — predicted label does not match (wrong style) + 0.5 — classifier unavailable; neutral / no signal + + Uses BN classifier via vLLM (Gemma-3); needs input_text (fulltext) and gen_text. + """ + result = _predict_label(input_text, gen_text) + if result == "": # unavailable → neutral, no penalty + return 0.5 + if result.strip().lower() == target_level.strip().lower(): + # import ipdb; ipdb.set_trace() + return 1.0 # correct literacy style + return 0.0 # wrong literacy style (penalty-free cliff avoided) + + +# --------------------------------------------------------------------------- +# Copy-paste penalty (prevent trivial copy of input_text) +# --------------------------------------------------------------------------- + +def _approx_copy_ratio(input_text: str, gen_text: str) -> float: + """ + Rough similarity estimate between input and generated text. + + - Detects near-verbatim copy via substring + length ratio. + - Otherwise uses token overlap (gen tokens that also appear in input). + Returns value in [0, 1], where 1 ≈ almost exact copy. + """ + a = (input_text or "").strip() + b = (gen_text or "").strip() + if not a or not b: + return 0.0 + + len_a, len_b = len(a), len(b) + shorter, longer = (a, b) if len_a <= len_b else (b, a) + + # Near-verbatim copy: one string almost fully contained in the other. + if shorter and shorter in longer: + ratio = len(shorter) / max(1, len(longer)) + if ratio >= 0.9: + return 1.0 + + # Fallback: 3-gram (trigram) token overlap to reduce false positives + # from shared medical vocabulary (drug names, symptoms, etc.). + def _tokens(t: str): + return [tok for tok in re.split(r"\s+", t) if tok] + + def _shingles(tokens, n=3): + if len(tokens) < n: + return set() + return {" ".join(tokens[i : i + n]) for i in range(len(tokens) - n + 1)} + + toks_a = _tokens(a) + toks_b = _tokens(b) + if not toks_a or not toks_b: + return 0.0 + + sh_a = _shingles(toks_a, n=3) + sh_b = _shingles(toks_b, n=3) + if not sh_a or not sh_b: + return 0.0 + + overlap = len(sh_a & sh_b) / max(1, len(sh_b)) + return max(0.0, min(1.0, overlap)) + + +def _compute_copy_penalty(input_text: str, gen_text: str) -> float: + """ + Map copy ratio → penalty in [0, 1]. + + - ≤ 0.7 similarity → no penalty + - 0.7–1.0 → linearly ramp penalty up to 1.0 + """ + ratio = _approx_copy_ratio(input_text, gen_text) + if ratio <= 0.7: + return 0.0 + # Scale [0.7, 1.0] → [0, 1] + return max(0.0, min(1.0, (ratio - 0.7) / 0.3)) + + +# --------------------------------------------------------------------------- +# Main scoring function +# --------------------------------------------------------------------------- +def _nonlinear_grounding(h_score: float) -> float: + """ + Sharper penalty for hallucination. + + h_score=0.00 → 1.00 (perfect) + h_score=0.05 → 0.95 (mild) + h_score=0.10 → 0.82 (noticeable) + h_score=0.17 → 0.65 (significant — was 0.83 before!) + h_score=0.30 → 0.36 (harsh) + h_score=0.50 → 0.13 (near zero) + """ + return max(0.0, (1.0 - h_score) ** 2.5) +def compute_score(data_source, solution_str, ground_truth, extra_info=None): + """ + Reward = weighted sum of three components (all in [0, 1]): + + W_FACTUALITY × factuality_score (summary info present in gen_text) + W_HALLU × (1 - hallucination_score) (gen_segments grounded in fulltext) + W_CLASSIFIER × classifier_score (style match) + + 1. Factuality : extract subclaims from *summary*, check how many are + supported by the generated text. + 2. Hallucination: extract subclaims from *generated text*, check how many + are NOT supported by the fulltext. + """ + W_FACTUALITY = 0.40 + W_HALLU = 0.25 + W_CLASSIFIER = 0.35 + + FAIL = { + "score": -1.0, + "factuality_score": 0.0, + "hallucination_score": 0.0, + "classifier_score": 0.0, + "factuality_supported": 0, + "hallucination_supported": 0, + "total_gen_segments": 0, + } + + # 1. Parse & validate + data = _parse_solution_json(solution_str) + if not data: + return FAIL + + target_level = extra_info.get("target_level") if extra_info else None + gen_text = data.get(target_level, "") if target_level else "" + + if not gen_text or len(gen_text.strip()) < 10: + return FAIL + + if not _is_bangla_text(gen_text): + return FAIL + + fulltext = ground_truth.get("fulltext") or ground_truth.get("input_text", "") + input_text = ground_truth.get("input_text", "") + summary_subclaims = ground_truth.get("summary_subclaims") + summary_text = ground_truth.get("summary_text", "") + + # 2. Compute the two core rewards + rewards = compute_rewards( + fulltext=fulltext, + generated_text=gen_text, + target_level=target_level, + summary_subclaims=summary_subclaims, + summary_text=summary_text, + ) + + factuality_score = rewards["factuality_score"] + h_score = rewards["hallucination_score"] + total_gen_units = rewards.get("total_gen_segments", 0) + + if factuality_score is None: + factuality_score = 0.5 + if h_score is None: + h_score = 0.5 + + grounding_score = _nonlinear_grounding(h_score) + + # 3. Classifier (style match) + class_score = _compute_classifier_reward(target_level, gen_text, input_text) + + # 4. Final weighted sum + final_reward = ( + W_FACTUALITY * factuality_score + + W_HALLU * grounding_score + + W_CLASSIFIER * class_score + ) + + # 5. Copy-paste penalty + copy_penalty = _compute_copy_penalty(input_text, gen_text) + if copy_penalty > 0.0: + final_reward = max(0.0, final_reward * (1.0 - copy_penalty)) + + return { + "score": float(final_reward), + "factuality_score": float(factuality_score), + "hallucination_score": float(h_score), + "classifier_score": float(class_score), + "factuality_supported": int(rewards.get("factuality_supported", 0)), + "hallucination_supported": int(rewards.get("hallucination_supported", 0)), + "total_gen_segments": int(total_gen_units), + } + + +# --------------------------------------------------------------------------- +# Test mode +# --------------------------------------------------------------------------- + +test_mode = True +if test_mode: + import time + + def run_actual_api_test(): + # Bangla medical example (support-check and classifier use Bangla prompts) + ground_truth = { + "summary_subclaims": [ + "লিসিনোপ্রিল উচ্চ রক্তচাপের চিকিৎসায় ব্যবহৃত হয়।", + "এটি ACE ইনহিবিটর যা আপনার হৃদযন্ত্রকে ভালো কাজ করতে সাহায্য করে।", + "সাধারণ পার্শ্বপ্রতিক্রিয়ার মধ্যে শুষ্ক কাশি রয়েছে।", + "গর্ভবতী হলে ব্যবহার করবেন না।", + ], + "input_text": ( + "লিসিনোপ্রিল উচ্চ রক্তচাপের চিকিৎসায় ব্যবহৃত হয়। " + "এটি ACE ইনহিবিটর নামক ওষুধ। " + "এটি আপনার হৃদযন্ত্রকে ভালো কাজ করতে সাহায্য করে।" + ), + "summary_text": ( + "লিসিনোপ্রিল উচ্চ রক্তচাপের চিকিৎসায় ব্যবহৃত হয়। " + "এটি ACE ইনহিবিটর যা আপনার হৃদযন্ত্রকে ভালো কাজ করতে সাহায্য করে। " + "সাধারণ পার্শ্বপ্রতিক্রিয়ার মধ্যে শুষ্ক কাশি রয়েছে। " + "গর্ভবতী হলে ব্যবহার করবেন না।" + ), + } + + # LLM output: low_health_literacy style, grounded in summary + generated_response = { + "low_health_literacy": ( + "এই ওষুধ আপনার উচ্চ রক্তচাপের জন্য। " + "এটি ACE ইনহিবিটর ধরনের ওষুধ। " + "এটি আপনার হৃদযন্ত্রকে ভালো কাজ করতে সাহায্য করে। " + "গর্ভবতী হলে এই ওষুধ খাবেন না।" + ) + } + + solution_str = f"```json\n{json.dumps(generated_response)}\n```" + extra_info = {"target_level": "low_health_literacy"} + + print("📡 Running BN reward test (Bangla example)...") + start_time = time.time() + + try: + score = compute_score( + data_source="real_api_test", + solution_str=solution_str, + ground_truth=ground_truth, + extra_info=extra_info, + ) + + final_score = score["score"] if isinstance(score, dict) else score + + duration = time.time() - start_time + print(f"\nAPI Call Successful ({round(duration, 2)}s)") + print("-" * 50) + print(f"Target Level : {extra_info['target_level']}") + print(f"Final Reward : {round(final_score, 4)}") + print(f"factuality_score : {round(score.get('factuality_score', 0), 4)} (summary subclaims in gen_text)") + print(f"hallucination_score : {round(score.get('hallucination_score', 0), 4)} (gen_segments NOT in fulltext)") + print(f"classifier_score : {round(score.get('classifier_score', 0), 4)}") + print(f"factuality_supported : {score.get('factuality_supported', 0)}") + print(f"hallucination_supported: {score.get('hallucination_supported', 0)}") + print(f"total_gen_segments : {score.get('total_gen_segments', 0)}") + print("-" * 50) + print("\nReward definitions:") + print("- factuality_score : fraction of *summary* subclaims supported by gen_text [0,1]") + print("- hallucination_score : fraction of *gen_segments* NOT supported by fulltext [0,1] (lower=better)") + print("- classifier_score : 1.0 match, 0.0 mismatch, 0.5 unavailable") + print("- Weights: factuality=0.35, grounding=0.30, classifier=0.35") + + except Exception as e: + print(f"\n❌ API Call Failed!") + print(f"Error Type: {type(e).__name__}") + print(f"Details: {str(e)}") + print("\nPossible fixes:") + print("1. Check if the support-check vLLM server is running (VLLM_SUPPORT_CHECK_BN_API_BASE).") + print("2. Check if the classifier vLLM server is running (VLLM_CLASSIFIER_BN_API_BASE).") + + if __name__ == "__main__": + run_actual_api_test() \ No newline at end of file diff --git a/code/fine_tune_sft_dpo/run_eval_gpt5_models.py b/code/fine_tune_sft_dpo/run_eval_gpt5_models.py new file mode 100644 index 0000000000000000000000000000000000000000..c5cdf0d056cc397e4ad4581224e8aeca7371bfa7 --- /dev/null +++ b/code/fine_tune_sft_dpo/run_eval_gpt5_models.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +""" +Split the GPT-5 inference JSONL by model (gpt-5, gpt-5-mini, gpt-5-nano), +merge each model's generated_text with the subclaims data, and run +evaluate_scores_bn.py on each model's data. + +Usage: + python run_eval_gpt5_models.py +""" + +import json +import os +import subprocess +import sys +from collections import defaultdict +from datetime import datetime + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + +JSONL_PATH = os.path.join( + SCRIPT_DIR, + "results/bn/gpt5_inference_all_wo_gs_20260314_173736.jsonl", +) +SUBCLAIMS_PATH = os.path.join(SCRIPT_DIR, "dataset/bn/test_bn_subclaims.json") +OUTPUT_DIR = os.path.join(SCRIPT_DIR, "evaluation/bn/eval_gpt5_models") +EVAL_SCRIPT = os.path.join(SCRIPT_DIR, "evaluate_scores_bn.py") + +MODELS = ["gpt-5", "gpt-5-mini", "gpt-5-nano"] + + +def main(): + os.makedirs(OUTPUT_DIR, exist_ok=True) + + with open(JSONL_PATH, "r", encoding="utf-8") as f: + all_items = [json.loads(line) for line in f if line.strip()] + + with open(SUBCLAIMS_PATH, "r", encoding="utf-8") as f: + subclaims_data = json.load(f) + + sc_by_key = {} + for sc in subclaims_data: + key = (sc["doc_id"], sc["label"]) + sc_by_key[key] = sc + + by_model = defaultdict(list) + for item in all_items: + by_model[item["model"]].append(item) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + for model_name in MODELS: + items = by_model.get(model_name, []) + if not items: + print(f"[WARN] No items for model {model_name}, skipping.") + continue + + safe_name = model_name.replace("-", "_") + + prepared = [] + skipped_no_subclaims = 0 + for item in items: + doc_id = item["doc_id"] + label = item["gold_label"] + key = (doc_id, label) + sc = sc_by_key.get(key) + + if sc is None: + skipped_no_subclaims += 1 + continue + + prepared.append({ + "doc_id": doc_id, + "label": label, + "fulltext": sc.get("fulltext", ""), + "summary_text": sc.get("summary", ""), + "summary_subclaims": sc.get("summary_subclaims", []), + "fulltext_subclaims": sc.get("fulltext_subclaims", []), + "generated_text": item.get("generated_text", ""), + }) + + input_path = os.path.join( + OUTPUT_DIR, f"{safe_name}_prepared_{timestamp}.json" + ) + with open(input_path, "w", encoding="utf-8") as f: + json.dump(prepared, f, indent=2, ensure_ascii=False) + + print(f"\n{'='*60}") + print(f"Model: {model_name}") + print(f" Total JSONL items : {len(items)}") + print(f" Matched w/ subclaims: {len(prepared)}") + print(f" Skipped (no subcl.) : {skipped_no_subclaims}") + print(f" Prepared file : {input_path}") + print(f"{'='*60}") + + output_path = os.path.join( + OUTPUT_DIR, f"{safe_name}_eval_results_{timestamp}.json" + ) + + cmd = [ + sys.executable, + EVAL_SCRIPT, + "--input", input_path, + "--output", output_path, + ] + + print(f" Running: {' '.join(cmd)}") + result = subprocess.run(cmd, cwd=SCRIPT_DIR) + + if result.returncode != 0: + print(f" [ERROR] Evaluation failed for {model_name} (exit code {result.returncode})") + else: + print(f" [OK] Results saved to: {output_path}") + + +if __name__ == "__main__": + main() diff --git a/code/fine_tune_sft_dpo/run_gpt5mini_nano_inference.py b/code/fine_tune_sft_dpo/run_gpt5mini_nano_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..08f5d88debe9db62cadb6c873e3ab9ab2d07930b --- /dev/null +++ b/code/fine_tune_sft_dpo/run_gpt5mini_nano_inference.py @@ -0,0 +1,448 @@ +import argparse +import json +import os +import time +import urllib.error +import urllib.request +from datetime import datetime +from typing import Any, Dict, List, Optional + +from tqdm import tqdm # pyright: ignore[reportMissingModuleSource] + + +api_file = "/home/mshahidul/api_new.json" +with open(api_file, "r", encoding="utf-8") as f: + api_keys = json.load(f) + +DEFAULT_API_BASE = "https://api.openai.com/v1" +DEFAULT_INPUT_PATH = ( + "/home/mshahidul/readctrl/code/fine_tune_sft_dpo/dataset/bn/test_bn.json" +) +DEFAULT_OUTPUT_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo/results/bn" +DEFAULT_PROMPT_LOW_PATH = ( + "/home/mshahidul/readctrl/code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_low" +) +DEFAULT_PROMPT_INTERMEDIATE_PATH = ( + "/home/mshahidul/readctrl/code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_intermediate" +) +DEFAULT_PROMPT_PROFICIENT_PATH = ( + "/home/mshahidul/readctrl/code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_proficient" +) +DEFAULT_MODELS = "gpt-5,gpt-5-mini,gpt-5-nano" +DEFAULT_COST_LIMIT = 50.0 + +PRICING_PER_1M = { + "gpt-5": {"input": 1.25, "cached_input": 0.125, "output": 10.00}, + "gpt-5-mini": {"input": 0.25, "cached_input": 0.025, "output": 2.00}, + "gpt-5-nano": {"input": 0.05, "cached_input": 0.005, "output": 0.40}, +} + +VALID_LABELS = { + "low_health_literacy", + "intermediate_health_literacy", + "proficient_health_literacy", +} + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Generate outputs with gpt-5-mini and gpt-5-nano using " + "verified_combined dataset and literacy-level prompts." + ) + ) + parser.add_argument("--api-base", default=os.environ.get("OPENAI_API_BASE", DEFAULT_API_BASE)) + parser.add_argument( + "--api-key", + default=os.environ.get("OPENAI_API_KEY", api_keys["openai"]), + ) + parser.add_argument("--models", default=DEFAULT_MODELS, help="Comma-separated model list.") + parser.add_argument("--input-path", default=DEFAULT_INPUT_PATH) + parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR) + parser.add_argument("--prompt-low-path", default=DEFAULT_PROMPT_LOW_PATH) + parser.add_argument( + "--prompt-intermediate-path", + default=DEFAULT_PROMPT_INTERMEDIATE_PATH, + ) + parser.add_argument( + "--prompt-proficient-path", + default=DEFAULT_PROMPT_PROFICIENT_PATH, + ) + parser.add_argument( + "--max-samples", + type=int, + default=-1, + help="Use -1 for all rows.", + ) + parser.add_argument("--temperature", type=float, default=0.0) + parser.add_argument("--timeout-seconds", type=int, default=120) + parser.add_argument("--max-retries", type=int, default=2) + parser.add_argument("--retry-wait-seconds", type=float, default=2.0) + parser.add_argument( + "--cost-limit", + type=float, + default=DEFAULT_COST_LIMIT, + help="Stop and save when cumulative API cost exceeds this amount in USD.", + ) + return parser.parse_args() + + +def check_api_base(api_base: str, api_key: str, timeout_seconds: int) -> None: + models_url = api_base.rstrip("/") + "/models" + req = urllib.request.Request(models_url, method="GET") + if api_key: + req.add_header("Authorization", f"Bearer {api_key}") + try: + with urllib.request.urlopen(req, timeout=timeout_seconds) as resp: + if resp.status >= 400: + raise RuntimeError( + f"Endpoint reachable but unhealthy: {models_url} (status={resp.status})" + ) + except urllib.error.URLError as exc: + raise ConnectionError( + "Cannot reach OpenAI-compatible endpoint. " + f"api_base={api_base}. Check network/API base/API key." + ) from exc + + +def load_prompt_templates(args: argparse.Namespace) -> Dict[str, str]: + prompt_path_by_label = { + "low_health_literacy": args.prompt_low_path, + "intermediate_health_literacy": args.prompt_intermediate_path, + "proficient_health_literacy": args.prompt_proficient_path, + } + templates: Dict[str, str] = {} + for label, path in prompt_path_by_label.items(): + if not os.path.exists(path): + raise FileNotFoundError(f"Prompt file not found: {path}") + with open(path, "r", encoding="utf-8") as f: + templates[label] = f.read() + return templates + + +def infer_source_lang(fulltext: str) -> str: + if fulltext and any("\u0980" <= ch <= "\u09FF" for ch in fulltext): + return "Bangla" + if fulltext and any("a" <= ch.lower() <= "z" for ch in fulltext): + return "English" + return "Unknown" + + +def build_prompt(template: str, fulltext: str, summary: str, source_lang: str) -> str: + return ( + template.replace("{source_lang}", source_lang) + .replace("{gold_summary}", summary) + .replace("{full_text}", fulltext) + ) + + +def load_verified_rows(path: str) -> List[Dict[str, Any]]: + if not os.path.exists(path): + raise FileNotFoundError(f"Input file not found: {path}") + with open(path, "r", encoding="utf-8") as f: + parsed = json.load(f) + if not isinstance(parsed, list): + raise ValueError(f"Expected top-level JSON array in {path}") + return [row for row in parsed if isinstance(row, dict)] + + +def parse_models(models_arg: str) -> List[str]: + models = [m.strip() for m in models_arg.split(",") if m.strip()] + if not models: + raise ValueError("No models provided. Example: --models gpt-5-mini,gpt-5-nano") + return models + + +def _clean_json_block(text: str) -> str: + cleaned = text.strip() + if "```json" in cleaned: + cleaned = cleaned.split("```json", 1)[1].split("```", 1)[0].strip() + elif "```" in cleaned: + cleaned = cleaned.split("```", 1)[1].split("```", 1)[0].strip() + return cleaned + + +def extract_generated_text(raw_response: str, expected_label: str) -> str: + cleaned = _clean_json_block(raw_response) + try: + parsed = json.loads(cleaned) + except json.JSONDecodeError: + return raw_response.strip() + + if isinstance(parsed, dict): + value = parsed.get(expected_label) + if isinstance(value, str) and value.strip(): + return value.strip() + return raw_response.strip() + + +def compute_cost(model: str, input_tokens: int, output_tokens: int, + cached_input_tokens: int = 0) -> float: + pricing = PRICING_PER_1M.get(model) + if pricing is None: + return 0.0 + non_cached_input = max(0, input_tokens - cached_input_tokens) + cost = ( + non_cached_input * pricing["input"] / 1_000_000 + + cached_input_tokens * pricing["cached_input"] / 1_000_000 + + output_tokens * pricing["output"] / 1_000_000 + ) + return cost + + +def call_chat_completion( + *, + api_base: str, + api_key: str, + model: str, + prompt: str, + temperature: float, + timeout_seconds: int, + max_retries: int, + retry_wait_seconds: float, +) -> Dict[str, Any]: + url = api_base.rstrip("/") + "/chat/completions" + payload = { + "model": model, + "messages": [{"role": "user", "content": prompt}], + } + data = json.dumps(payload).encode("utf-8") + + last_error: Optional[Exception] = None + for attempt in range(max_retries + 1): + req = urllib.request.Request(url, data=data, method="POST") + req.add_header("Content-Type", "application/json") + if api_key: + req.add_header("Authorization", f"Bearer {api_key}") + try: + with urllib.request.urlopen(req, timeout=timeout_seconds) as resp: + body = resp.read().decode("utf-8") + parsed = json.loads(body) + content = str(parsed["choices"][0]["message"]["content"]).strip() + usage = parsed.get("usage", {}) + return { + "content": content, + "prompt_tokens": usage.get("prompt_tokens", 0), + "completion_tokens": usage.get("completion_tokens", 0), + "cached_tokens": usage.get("prompt_tokens_details", {}).get( + "cached_tokens", 0 + ), + } + except urllib.error.HTTPError as exc: + retriable = exc.code in (408, 409, 429, 500, 502, 503, 504) + last_error = exc + if attempt < max_retries and retriable: + time.sleep(retry_wait_seconds) + continue + raise + except (urllib.error.URLError, KeyError, IndexError, json.JSONDecodeError) as exc: + last_error = exc + if attempt < max_retries: + time.sleep(retry_wait_seconds) + continue + raise + + if last_error: + raise last_error + raise RuntimeError("Unknown error during chat completion call.") + + +def main() -> None: + args = parse_args() + if not args.api_key: + raise ValueError("Missing API key. Set OPENAI_API_KEY or pass --api-key.") + + for path in ( + args.prompt_low_path, + args.prompt_intermediate_path, + args.prompt_proficient_path, + ): + if not os.path.exists(path): + raise FileNotFoundError(f"Prompt file not found: {path}") + + check_api_base(args.api_base, args.api_key, args.timeout_seconds) + models = parse_models(args.models) + templates = load_prompt_templates(args) + rows = load_verified_rows(args.input_path) + + parsed_items: List[Dict[str, Any]] = [] + for idx, row in enumerate(rows): + gold_label = str(row.get("label", "")).strip() + fulltext = str(row.get("fulltext", "")).strip() + summary = str(row.get("summary", "")).strip() + if gold_label not in VALID_LABELS: + continue + if not fulltext or not summary: + continue + source_lang = infer_source_lang(fulltext) + prompt = build_prompt( + template=templates[gold_label], + fulltext=fulltext, + summary=summary, + source_lang=source_lang, + ) + parsed_items.append( + { + "row_index": idx, + "doc_id": row.get("doc_id"), + "gold_label": gold_label, + "source_lang": source_lang, + "prompt": prompt, + } + ) + + if args.max_samples > 0: + parsed_items = parsed_items[: args.max_samples] + if not parsed_items: + raise RuntimeError("No valid rows found in input file.") + + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + os.makedirs(args.output_dir, exist_ok=True) + summary_path = os.path.join(args.output_dir, f"gpt5_inference_summary_wo_gs_{ts}.json") + combined_path = os.path.join(args.output_dir, f"gpt5_inference_all_wo_gs_{ts}.jsonl") + + combined_records: List[Dict[str, Any]] = [] + model_stats: Dict[str, Dict[str, Any]] = {} + total_cost = 0.0 + total_input_tokens = 0 + total_output_tokens = 0 + budget_exceeded = False + + def _save_outputs() -> None: + with open(combined_path, "w", encoding="utf-8") as f_all: + for rec in combined_records: + f_all.write(json.dumps(rec, ensure_ascii=False) + "\n") + summary_obj = { + "input_path": args.input_path, + "api_base": args.api_base, + "models": models, + "max_samples": args.max_samples, + "temperature": args.temperature, + "cost_limit_usd": args.cost_limit, + "total_cost_usd": round(total_cost, 6), + "total_input_tokens": total_input_tokens, + "total_output_tokens": total_output_tokens, + "budget_exceeded": budget_exceeded, + "total_dataset_rows_used": len(parsed_items), + "combined_output_path": combined_path, + "model_stats": model_stats, + } + with open(summary_path, "w", encoding="utf-8") as f_summary: + json.dump(summary_obj, f_summary, ensure_ascii=False, indent=2) + + for model in models: + if budget_exceeded: + break + + model_slug = model.replace("/", "_") + model_output_path = os.path.join( + args.output_dir, f"gpt5_inference_{model_slug}_{ts}.jsonl" + ) + success_count = 0 + error_count = 0 + model_cost = 0.0 + model_input_tokens = 0 + model_output_tokens = 0 + + with open(model_output_path, "w", encoding="utf-8") as f_model: + total = len(parsed_items) + progress_iter = tqdm( + parsed_items, + total=total, + desc=f"{model}", + unit="item", + ) + for item in progress_iter: + + record: Dict[str, Any] = { + "model": model, + "row_index": item["row_index"], + "doc_id": item.get("doc_id"), + "gold_label": item["gold_label"], + "source_lang": item["source_lang"], + "prompt": item["prompt"], + } + try: + result = call_chat_completion( + api_base=args.api_base, + api_key=args.api_key, + model=model, + prompt=item["prompt"], + temperature=args.temperature, + timeout_seconds=args.timeout_seconds, + max_retries=args.max_retries, + retry_wait_seconds=args.retry_wait_seconds, + ) + raw_response = result["content"] + p_tokens = result["prompt_tokens"] + c_tokens = result["completion_tokens"] + cached = result["cached_tokens"] + + call_cost = compute_cost(model, p_tokens, c_tokens, cached) + total_cost += call_cost + model_cost += call_cost + total_input_tokens += p_tokens + total_output_tokens += c_tokens + model_input_tokens += p_tokens + model_output_tokens += c_tokens + + generated_text = extract_generated_text(raw_response, item["gold_label"]) + record["prediction"] = raw_response + record["generated_text"] = generated_text + record["error"] = "" + record["prompt_tokens"] = p_tokens + record["completion_tokens"] = c_tokens + record["call_cost_usd"] = round(call_cost, 6) + success_count += 1 + except Exception as exc: + record["prediction"] = "" + record["generated_text"] = "" + record["error"] = f"{type(exc).__name__}: {exc}" + record["prompt_tokens"] = 0 + record["completion_tokens"] = 0 + record["call_cost_usd"] = 0.0 + error_count += 1 + + f_model.write(json.dumps(record, ensure_ascii=False) + "\n") + combined_records.append(record) + + progress_iter.set_postfix( + cost=f"${total_cost:.4f}", + limit=f"${args.cost_limit:.2f}", + ) + + if total_cost >= args.cost_limit: + budget_exceeded = True + print( + f"\n[BUDGET] Cost ${total_cost:.4f} reached limit " + f"${args.cost_limit:.2f}. Saving data and stopping." + ) + break + + model_stats[model] = { + "output_path": model_output_path, + "total_rows": len(parsed_items), + "rows_processed": success_count + error_count, + "success_count": success_count, + "error_count": error_count, + "model_cost_usd": round(model_cost, 6), + "model_input_tokens": model_input_tokens, + "model_output_tokens": model_output_tokens, + } + print( + f"[DONE] {model} | cost: ${model_cost:.4f} | " + f"output: {model_output_path}" + ) + + _save_outputs() + + print(f"\n[COST] Total API cost: ${total_cost:.4f} / ${args.cost_limit:.2f} limit") + print(f"[COST] Total tokens — input: {total_input_tokens}, output: {total_output_tokens}") + if budget_exceeded: + print("[COST] Budget exceeded — run stopped early. All data saved.") + print(f"[DONE] Combined output: {combined_path}") + print(f"[DONE] Summary output: {summary_path}") + + +if __name__ == "__main__": + main() diff --git a/code/fine_tune_sft_dpo/script.sh b/code/fine_tune_sft_dpo/script.sh index 4ff16ee2152140aa13d318938eea049bfcec34a8..ab8111ceca00ba79f79db922a90342222e0ec275 100644 --- a/code/fine_tune_sft_dpo/script.sh +++ b/code/fine_tune_sft_dpo/script.sh @@ -16,12 +16,17 @@ python best_of_n_qwen3_vllm.py --model base \ --test-data /home/mshahidul/readctrl/code/fine_tune_sft_dpo/dataset/bn/test_bn.json \ --src-lang Bengali -python best_of_n_qwen3_vllm.py --model finetuned \ ---output-file results/en/test_best_of_n_qwen3-4B_sft.json \ ---prompt-dir /home/mshahidul/readctrl/code/fine_tune_sft_dpo/prompt_en \ ---test-data /home/mshahidul/readctrl/code/fine_tune_sft_dpo/dataset/en/test_en.json \ ---src-lang English \ ---finetuned-model-dir /home/mshahidul/readctrl/code/fine_tune_sft_dpo/model/en +python best_of_n_qwen3_vllm_bn.py --model finetuned \ +--output-file results/bn/test_best_of_n_qwen3-4B_sft.json \ +--prompt-dir /home/mshahidul/readctrl/code/fine_tune_sft_dpo/prompt_bn \ +--test-data /home/mshahidul/readctrl/code/fine_tune_sft_dpo/dataset/bn/test_bn.json \ +--src-lang Bengali \ +--finetuned-model-dir /home/mshahidul/readctrl/code/fine_tune_sft_dpo/model/bn + + +python /home/mshahidul/readctrl/code/fine_tune_sft_dpo/qwen3-inference-vllm_bn.py \ + --model-dir /home/mshahidul/readctrl/code/fine_tune_sft_dpo/model/bn \ + --output-file results/bn/test_inference_vllm_qwen3-4B_sft.json python self_refine_qwen3_vllm.py \ --num-iterations 5 \ @@ -46,4 +51,33 @@ python self_refine_qwen3_vllm.py \ --output-file /home/mshahidul/readctrl/code/fine_tune_sft_dpo/results/bn/test_self_refine_vllm_qwen3_4B_sft.json \ --prompt-dir /home/mshahidul/readctrl/code/fine_tune_sft_dpo/prompt_bn \ --test-json /home/mshahidul/readctrl/code/fine_tune_sft_dpo/dataset/bn/test_bn.json \ - --src-lang Bengali \ No newline at end of file + --src-lang Bengali + +cd /home/mshahidul/readctrl/code/fine_tune_sft_dpo + +python evaluate_scores.py \ + --input results/bn/test_self_refine_vllm_qwen3_4B_sft.json \ + --subclaims dataset/bn/test_bn_subclaims.json \ + --output-dir /home/mshahidul/readctrl/code/fine_tune_sft_dpo/evaluation/bn + +python evaluate_scores_bn.py \ + --input /home/mshahidul/readctrl/code/fine_tune_sft_dpo/results/bn/test_best_of_n_qwen3-4B_base.json \ + --subclaims dataset/bn/test_bn_subclaims.json \ + --model-key qwen3_finetuned \ + --output-dir /home/mshahidul/readctrl/code/fine_tune_sft_dpo/evaluation/bn + +python evaluate_scores_bn.py \ + --input /home/mshahidul/readctrl/code/fine_tune_sft_dpo/results/bn/test_best_of_n_qwen3-4B_base.json \ + --subclaims dataset/bn/test_bn_subclaims.json \ + --output-dir /home/mshahidul/readctrl/code/fine_tune_sft_dpo/evaluation/bn + + +python /home/mshahidul/readctrl/code/fine_tune_sft_dpo/evaluate_scores_bn_vllm_rl.py \ + --input /home/mshahidul/readctrl/code/fine_tune_sft_dpo/results/bn/test_best_of_n_qwen3-4B_base.json \ + --subclaims dataset/bn/test_bn_subclaims.json \ + --output-dir /home/mshahidul/readctrl/code/fine_tune_sft_dpo/evaluation/bn + +python evaluate_scores_bn_vllm.py \ + --input /home/mshahidul/readctrl/code/readctrl_rl_inference/vllm_model_result/bn_temp/bn_200.jsonl \ + --output-dir evaluation/bn/ + --subclaims dataset/bn/test_bn_subclaims.json diff --git a/code/fine_tune_sft_dpo/vllm_logs/generator.log b/code/fine_tune_sft_dpo/vllm_logs/generator.log new file mode 100644 index 0000000000000000000000000000000000000000..c0663d3c91ccadc28eb63bf0284dd909f42f585e --- /dev/null +++ b/code/fine_tune_sft_dpo/vllm_logs/generator.log @@ -0,0 +1,133 @@ +Skipping import of cpp extensions due to incompatible torch version 2.9.1+cu128 for torchao version 0.16.0 Please see https://github.com/pytorch/ao/issues/2919 for more info +(APIServer pid=24913) INFO 03-14 18:23:48 [utils.py:287] +(APIServer pid=24913) INFO 03-14 18:23:48 [utils.py:287] █ █ █▄ ▄█ +(APIServer pid=24913) INFO 03-14 18:23:48 [utils.py:287] ▄▄ ▄█ █ █ █ ▀▄▀ █ version 0.16.0 +(APIServer pid=24913) INFO 03-14 18:23:48 [utils.py:287] █▄█▀ █ █ █ █ model huihui-ai/Qwen3-32B-abliterated +(APIServer pid=24913) INFO 03-14 18:23:48 [utils.py:287] ▀▀ ▀▀▀▀▀ ▀▀▀▀▀ ▀ ▀ +(APIServer pid=24913) INFO 03-14 18:23:48 [utils.py:287] +(APIServer pid=24913) INFO 03-14 18:23:48 [utils.py:223] non-default args: {'model_tag': 'huihui-ai/Qwen3-32B-abliterated', 'api_server_count': 1, 'host': '0.0.0.0', 'port': 8066, 'model': 'huihui-ai/Qwen3-32B-abliterated', 'trust_remote_code': True, 'max_model_len': 16384, 'enforce_eager': True, 'served_model_name': ['generator'], 'gpu_memory_utilization': 0.85, 'kv_cache_dtype': 'fp8', 'max_num_seqs': 128} +(APIServer pid=24913) The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored. +(APIServer pid=24913) The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored. +(APIServer pid=24913) INFO 03-14 18:23:48 [model.py:529] Resolved architecture: Qwen3ForCausalLM +(APIServer pid=24913) INFO 03-14 18:23:48 [model.py:1549] Using max model len 16384 +(APIServer pid=24913) INFO 03-14 18:23:48 [cache.py:214] Using fp8 data type to store kv cache. It reduces the GPU memory footprint and boosts the performance. Meanwhile, it may cause accuracy drop without a proper scaling factor. +(APIServer pid=24913) INFO 03-14 18:23:49 [scheduler.py:224] Chunked prefill is enabled with max_num_batched_tokens=8192. +(APIServer pid=24913) INFO 03-14 18:23:49 [vllm.py:689] Asynchronous scheduling is enabled. +(APIServer pid=24913) WARNING 03-14 18:23:49 [vllm.py:727] Enforce eager set, overriding optimization level to -O0 +(APIServer pid=24913) INFO 03-14 18:23:49 [vllm.py:845] Cudagraph is disabled under eager mode +Skipping import of cpp extensions due to incompatible torch version 2.9.1+cu128 for torchao version 0.16.0 Please see https://github.com/pytorch/ao/issues/2919 for more info +(EngineCore_DP0 pid=25731) INFO 03-14 18:23:59 [core.py:97] Initializing a V1 LLM engine (v0.16.0) with config: model='huihui-ai/Qwen3-32B-abliterated', speculative_config=None, tokenizer='huihui-ai/Qwen3-32B-abliterated', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=16384, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, enable_return_routed_experts=False, kv_cache_dtype=fp8, device_config=cuda, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser='', reasoning_parser_plugin='', enable_in_reasoning=False), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None, kv_cache_metrics=False, kv_cache_metrics_sample=0.01, cudagraph_metrics=False, enable_layerwise_nvtx_tracing=False, enable_mfu_metrics=False, enable_mm_processor_stats=False, enable_logging_iteration_details=False), seed=0, served_model_name=generator, enable_prefix_caching=True, enable_chunked_prefill=True, pooler_config=None, compilation_config={'level': None, 'mode': , 'debug_dump_path': None, 'cache_dir': '', 'compile_cache_save_format': 'binary', 'backend': 'inductor', 'custom_ops': ['all'], 'splitting_ops': [], 'compile_mm_encoder': False, 'compile_sizes': [], 'compile_ranges_split_points': [8192], 'inductor_compile_config': {'enable_auto_functionalized_v2': False, 'combo_kernels': True, 'benchmark_combo_kernel': True}, 'inductor_passes': {}, 'cudagraph_mode': , 'cudagraph_num_of_warmups': 0, 'cudagraph_capture_sizes': [], 'cudagraph_copy_inputs': False, 'cudagraph_specialize_lora': True, 'use_inductor_graph_partition': False, 'pass_config': {'fuse_norm_quant': False, 'fuse_act_quant': False, 'fuse_attn_quant': False, 'eliminate_noops': False, 'enable_sp': False, 'fuse_gemm_comms': False, 'fuse_allreduce_rms': False, 'fuse_act_padding': False}, 'max_cudagraph_capture_size': 0, 'dynamic_shapes_config': {'type': , 'evaluate_guards': False, 'assume_32_bit_indexing': False}, 'local_cache_dir': None, 'fast_moe_cold_start': True, 'static_all_moe_layers': []} +(EngineCore_DP0 pid=25731) INFO 03-14 18:24:01 [parallel_state.py:1234] world_size=1 rank=0 local_rank=0 distributed_init_method=tcp://172.16.34.19:56091 backend=nccl +(EngineCore_DP0 pid=25731) INFO 03-14 18:24:01 [parallel_state.py:1445] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, PCP rank 0, TP rank 0, EP rank N/A +(EngineCore_DP0 pid=25731) ERROR 03-14 18:24:01 [core.py:1006] EngineCore failed to start. +(EngineCore_DP0 pid=25731) ERROR 03-14 18:24:01 [core.py:1006] Traceback (most recent call last): +(EngineCore_DP0 pid=25731) ERROR 03-14 18:24:01 [core.py:1006] File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 996, in run_engine_core +(EngineCore_DP0 pid=25731) ERROR 03-14 18:24:01 [core.py:1006] engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs) +(EngineCore_DP0 pid=25731) ERROR 03-14 18:24:01 [core.py:1006] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +(EngineCore_DP0 pid=25731) ERROR 03-14 18:24:01 [core.py:1006] File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 740, in __init__ +(EngineCore_DP0 pid=25731) ERROR 03-14 18:24:01 [core.py:1006] super().__init__( +(EngineCore_DP0 pid=25731) ERROR 03-14 18:24:01 [core.py:1006] File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 106, in __init__ +(EngineCore_DP0 pid=25731) ERROR 03-14 18:24:01 [core.py:1006] self.model_executor = executor_class(vllm_config) +(EngineCore_DP0 pid=25731) ERROR 03-14 18:24:01 [core.py:1006] ^^^^^^^^^^^^^^^^^^^^^^^^^^^ +(EngineCore_DP0 pid=25731) ERROR 03-14 18:24:01 [core.py:1006] File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/v1/executor/abstract.py", line 103, in __init__ +(EngineCore_DP0 pid=25731) ERROR 03-14 18:24:01 [core.py:1006] self._init_executor() +(EngineCore_DP0 pid=25731) ERROR 03-14 18:24:01 [core.py:1006] File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/v1/executor/uniproc_executor.py", line 47, in _init_executor +(EngineCore_DP0 pid=25731) ERROR 03-14 18:24:01 [core.py:1006] self.driver_worker.init_device() +(EngineCore_DP0 pid=25731) ERROR 03-14 18:24:01 [core.py:1006] File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/v1/worker/worker_base.py", line 322, in init_device +(EngineCore_DP0 pid=25731) ERROR 03-14 18:24:01 [core.py:1006] self.worker.init_device() # type: ignore +(EngineCore_DP0 pid=25731) ERROR 03-14 18:24:01 [core.py:1006] ^^^^^^^^^^^^^^^^^^^^^^^^^ +(EngineCore_DP0 pid=25731) ERROR 03-14 18:24:01 [core.py:1006] File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/v1/worker/gpu_worker.py", line 252, in init_device +(EngineCore_DP0 pid=25731) ERROR 03-14 18:24:01 [core.py:1006] self.requested_memory = request_memory(init_snapshot, self.cache_config) +(EngineCore_DP0 pid=25731) ERROR 03-14 18:24:01 [core.py:1006] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +(EngineCore_DP0 pid=25731) ERROR 03-14 18:24:01 [core.py:1006] File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/v1/worker/utils.py", line 102, in request_memory +(EngineCore_DP0 pid=25731) ERROR 03-14 18:24:01 [core.py:1006] raise ValueError( +(EngineCore_DP0 pid=25731) ERROR 03-14 18:24:01 [core.py:1006] ValueError: Free memory on device cuda:0 (113.92/139.8 GiB) on startup is less than desired GPU memory utilization (0.85, 118.83 GiB). Decrease GPU memory utilization or reduce GPU memory used by other processes. +(EngineCore_DP0 pid=25731) Process EngineCore_DP0: +(EngineCore_DP0 pid=25731) Traceback (most recent call last): +(EngineCore_DP0 pid=25731) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap +(EngineCore_DP0 pid=25731) self.run() +(EngineCore_DP0 pid=25731) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/multiprocessing/process.py", line 108, in run +(EngineCore_DP0 pid=25731) self._target(*self._args, **self._kwargs) +(EngineCore_DP0 pid=25731) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 1010, in run_engine_core +(EngineCore_DP0 pid=25731) raise e +(EngineCore_DP0 pid=25731) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 996, in run_engine_core +(EngineCore_DP0 pid=25731) engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs) +(EngineCore_DP0 pid=25731) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +(EngineCore_DP0 pid=25731) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 740, in __init__ +(EngineCore_DP0 pid=25731) super().__init__( +(EngineCore_DP0 pid=25731) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 106, in __init__ +(EngineCore_DP0 pid=25731) self.model_executor = executor_class(vllm_config) +(EngineCore_DP0 pid=25731) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ +(EngineCore_DP0 pid=25731) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/v1/executor/abstract.py", line 103, in __init__ +(EngineCore_DP0 pid=25731) self._init_executor() +(EngineCore_DP0 pid=25731) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/v1/executor/uniproc_executor.py", line 47, in _init_executor +(EngineCore_DP0 pid=25731) self.driver_worker.init_device() +(EngineCore_DP0 pid=25731) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/v1/worker/worker_base.py", line 322, in init_device +(EngineCore_DP0 pid=25731) self.worker.init_device() # type: ignore +(EngineCore_DP0 pid=25731) ^^^^^^^^^^^^^^^^^^^^^^^^^ +(EngineCore_DP0 pid=25731) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/v1/worker/gpu_worker.py", line 252, in init_device +(EngineCore_DP0 pid=25731) self.requested_memory = request_memory(init_snapshot, self.cache_config) +(EngineCore_DP0 pid=25731) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +(EngineCore_DP0 pid=25731) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/v1/worker/utils.py", line 102, in request_memory +(EngineCore_DP0 pid=25731) raise ValueError( +(EngineCore_DP0 pid=25731) ValueError: Free memory on device cuda:0 (113.92/139.8 GiB) on startup is less than desired GPU memory utilization (0.85, 118.83 GiB). Decrease GPU memory utilization or reduce GPU memory used by other processes. +[rank0]:[W314 18:24:02.729451114 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator()) +(APIServer pid=24913) Traceback (most recent call last): +(APIServer pid=24913) File "/home/mshahidul/miniconda3/envs/unsloth/bin/vllm", line 6, in +(APIServer pid=24913) sys.exit(main()) +(APIServer pid=24913) ^^^^^^ +(APIServer pid=24913) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/entrypoints/cli/main.py", line 73, in main +(APIServer pid=24913) args.dispatch_function(args) +(APIServer pid=24913) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/entrypoints/cli/serve.py", line 111, in cmd +(APIServer pid=24913) uvloop.run(run_server(args)) +(APIServer pid=24913) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/uvloop/__init__.py", line 96, in run +(APIServer pid=24913) return __asyncio.run( +(APIServer pid=24913) ^^^^^^^^^^^^^^ +(APIServer pid=24913) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/asyncio/runners.py", line 195, in run +(APIServer pid=24913) return runner.run(main) +(APIServer pid=24913) ^^^^^^^^^^^^^^^^ +(APIServer pid=24913) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/asyncio/runners.py", line 118, in run +(APIServer pid=24913) return self._loop.run_until_complete(task) +(APIServer pid=24913) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +(APIServer pid=24913) File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete +(APIServer pid=24913) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/uvloop/__init__.py", line 48, in wrapper +(APIServer pid=24913) return await main +(APIServer pid=24913) ^^^^^^^^^^ +(APIServer pid=24913) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/entrypoints/openai/api_server.py", line 457, in run_server +(APIServer pid=24913) await run_server_worker(listen_address, sock, args, **uvicorn_kwargs) +(APIServer pid=24913) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/entrypoints/openai/api_server.py", line 476, in run_server_worker +(APIServer pid=24913) async with build_async_engine_client( +(APIServer pid=24913) ^^^^^^^^^^^^^^^^^^^^^^^^^^ +(APIServer pid=24913) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/contextlib.py", line 210, in __aenter__ +(APIServer pid=24913) return await anext(self.gen) +(APIServer pid=24913) ^^^^^^^^^^^^^^^^^^^^^ +(APIServer pid=24913) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/entrypoints/openai/api_server.py", line 96, in build_async_engine_client +(APIServer pid=24913) async with build_async_engine_client_from_engine_args( +(APIServer pid=24913) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +(APIServer pid=24913) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/contextlib.py", line 210, in __aenter__ +(APIServer pid=24913) return await anext(self.gen) +(APIServer pid=24913) ^^^^^^^^^^^^^^^^^^^^^ +(APIServer pid=24913) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/entrypoints/openai/api_server.py", line 137, in build_async_engine_client_from_engine_args +(APIServer pid=24913) async_llm = AsyncLLM.from_vllm_config( +(APIServer pid=24913) ^^^^^^^^^^^^^^^^^^^^^^^^^^ +(APIServer pid=24913) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/v1/engine/async_llm.py", line 222, in from_vllm_config +(APIServer pid=24913) return cls( +(APIServer pid=24913) ^^^^ +(APIServer pid=24913) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/v1/engine/async_llm.py", line 148, in __init__ +(APIServer pid=24913) self.engine_core = EngineCoreClient.make_async_mp_client( +(APIServer pid=24913) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +(APIServer pid=24913) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/v1/engine/core_client.py", line 124, in make_async_mp_client +(APIServer pid=24913) return AsyncMPClient(*client_args) +(APIServer pid=24913) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ +(APIServer pid=24913) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/v1/engine/core_client.py", line 835, in __init__ +(APIServer pid=24913) super().__init__( +(APIServer pid=24913) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/v1/engine/core_client.py", line 490, in __init__ +(APIServer pid=24913) with launch_core_engines(vllm_config, executor_class, log_stats) as ( +(APIServer pid=24913) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +(APIServer pid=24913) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/contextlib.py", line 144, in __exit__ +(APIServer pid=24913) next(self.gen) +(APIServer pid=24913) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/v1/engine/utils.py", line 925, in launch_core_engines +(APIServer pid=24913) wait_for_engine_startup( +(APIServer pid=24913) File "/home/mshahidul/miniconda3/envs/unsloth/lib/python3.12/site-packages/vllm/v1/engine/utils.py", line 984, in wait_for_engine_startup +(APIServer pid=24913) raise RuntimeError( +(APIServer pid=24913) RuntimeError: Engine core initialization failed. See root cause above. Failed core proc(s): {} diff --git a/code/fine_tune_sft_dpo/vllm_logs/guard.log b/code/fine_tune_sft_dpo/vllm_logs/guard.log new file mode 100644 index 0000000000000000000000000000000000000000..6fc62209a88e3dcc2db1a8c07ebf3858d36d4185 --- /dev/null +++ b/code/fine_tune_sft_dpo/vllm_logs/guard.log @@ -0,0 +1,81 @@ +Skipping import of cpp extensions due to incompatible torch version 2.9.1+cu128 for torchao version 0.16.0 Please see https://github.com/pytorch/ao/issues/2919 for more info +(APIServer pid=22412) INFO 03-14 18:23:08 [utils.py:287] +(APIServer pid=22412) INFO 03-14 18:23:08 [utils.py:287] █ █ █▄ ▄█ +(APIServer pid=22412) INFO 03-14 18:23:08 [utils.py:287] ▄▄ ▄█ █ █ █ ▀▄▀ █ version 0.16.0 +(APIServer pid=22412) INFO 03-14 18:23:08 [utils.py:287] █▄█▀ █ █ █ █ model Qwen/Qwen3Guard-Gen-8B +(APIServer pid=22412) INFO 03-14 18:23:08 [utils.py:287] ▀▀ ▀▀▀▀▀ ▀▀▀▀▀ ▀ ▀ +(APIServer pid=22412) INFO 03-14 18:23:08 [utils.py:287] +(APIServer pid=22412) INFO 03-14 18:23:08 [utils.py:223] non-default args: {'model_tag': 'Qwen/Qwen3Guard-Gen-8B', 'api_server_count': 1, 'host': '0.0.0.0', 'port': 8065, 'model': 'Qwen/Qwen3Guard-Gen-8B', 'trust_remote_code': True, 'max_model_len': 4096, 'enforce_eager': True, 'gpu_memory_utilization': 0.18, 'kv_cache_dtype': 'fp8', 'max_num_seqs': 32} +(APIServer pid=22412) The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored. +(APIServer pid=22412) The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored. +(APIServer pid=22412) INFO 03-14 18:23:08 [model.py:529] Resolved architecture: Qwen3ForCausalLM +(APIServer pid=22412) INFO 03-14 18:23:08 [model.py:1549] Using max model len 4096 +(APIServer pid=22412) INFO 03-14 18:23:08 [cache.py:214] Using fp8 data type to store kv cache. It reduces the GPU memory footprint and boosts the performance. Meanwhile, it may cause accuracy drop without a proper scaling factor. +(APIServer pid=22412) INFO 03-14 18:23:08 [scheduler.py:224] Chunked prefill is enabled with max_num_batched_tokens=8192. +(APIServer pid=22412) INFO 03-14 18:23:08 [vllm.py:689] Asynchronous scheduling is enabled. +(APIServer pid=22412) WARNING 03-14 18:23:08 [vllm.py:727] Enforce eager set, overriding optimization level to -O0 +(APIServer pid=22412) INFO 03-14 18:23:08 [vllm.py:845] Cudagraph is disabled under eager mode +Skipping import of cpp extensions due to incompatible torch version 2.9.1+cu128 for torchao version 0.16.0 Please see https://github.com/pytorch/ao/issues/2919 for more info +(EngineCore_DP0 pid=23313) INFO 03-14 18:23:19 [core.py:97] Initializing a V1 LLM engine (v0.16.0) with config: model='Qwen/Qwen3Guard-Gen-8B', speculative_config=None, tokenizer='Qwen/Qwen3Guard-Gen-8B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, enable_return_routed_experts=False, kv_cache_dtype=fp8, device_config=cuda, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser='', reasoning_parser_plugin='', enable_in_reasoning=False), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None, kv_cache_metrics=False, kv_cache_metrics_sample=0.01, cudagraph_metrics=False, enable_layerwise_nvtx_tracing=False, enable_mfu_metrics=False, enable_mm_processor_stats=False, enable_logging_iteration_details=False), seed=0, served_model_name=Qwen/Qwen3Guard-Gen-8B, enable_prefix_caching=True, enable_chunked_prefill=True, pooler_config=None, compilation_config={'level': None, 'mode': , 'debug_dump_path': None, 'cache_dir': '', 'compile_cache_save_format': 'binary', 'backend': 'inductor', 'custom_ops': ['all'], 'splitting_ops': [], 'compile_mm_encoder': False, 'compile_sizes': [], 'compile_ranges_split_points': [8192], 'inductor_compile_config': {'enable_auto_functionalized_v2': False, 'combo_kernels': True, 'benchmark_combo_kernel': True}, 'inductor_passes': {}, 'cudagraph_mode': , 'cudagraph_num_of_warmups': 0, 'cudagraph_capture_sizes': [], 'cudagraph_copy_inputs': False, 'cudagraph_specialize_lora': True, 'use_inductor_graph_partition': False, 'pass_config': {'fuse_norm_quant': False, 'fuse_act_quant': False, 'fuse_attn_quant': False, 'eliminate_noops': False, 'enable_sp': False, 'fuse_gemm_comms': False, 'fuse_allreduce_rms': False, 'fuse_act_padding': False}, 'max_cudagraph_capture_size': 0, 'dynamic_shapes_config': {'type': , 'evaluate_guards': False, 'assume_32_bit_indexing': False}, 'local_cache_dir': None, 'fast_moe_cold_start': True, 'static_all_moe_layers': []} +(EngineCore_DP0 pid=23313) INFO 03-14 18:23:21 [parallel_state.py:1234] world_size=1 rank=0 local_rank=0 distributed_init_method=tcp://172.16.34.19:52353 backend=nccl +(EngineCore_DP0 pid=23313) INFO 03-14 18:23:21 [parallel_state.py:1445] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, PCP rank 0, TP rank 0, EP rank N/A +(EngineCore_DP0 pid=23313) INFO 03-14 18:23:21 [gpu_model_runner.py:4124] Starting to load model Qwen/Qwen3Guard-Gen-8B... +(EngineCore_DP0 pid=23313) INFO 03-14 18:23:22 [cuda.py:367] Using FLASH_ATTN attention backend out of potential backends: ['FLASH_ATTN', 'FLASHINFER', 'TRITON_ATTN']. +(EngineCore_DP0 pid=23313) Loading safetensors checkpoint shards: 0% Completed | 0/5 [00:00 argparse.Namespace: parser.add_argument("--served_model_name", type=str, default=DEFAULT_SERVED_MODEL_NAME, help="Model name exposed by vLLM server.") parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--max_samples", type=int, default=-1, help="Use -1 for full dataset.") - parser.add_argument("--max_tokens", type=int, default=1024) + parser.add_argument("--max_tokens", type=int, default=2048) parser.add_argument("--temperature", type=float, default=0.1) parser.add_argument("--top_p", type=float, default=0.8) parser.add_argument("--api_key", type=str, default="EMPTY") diff --git a/code/readctrl_rl_inference/run_inference_vllm_server_bn_api_wo_gs.py b/code/readctrl_rl_inference/run_inference_vllm_server_bn_api_wo_gs.py new file mode 100644 index 0000000000000000000000000000000000000000..1155c7d3ce01d27ca479750abe1adcd48623c8e9 --- /dev/null +++ b/code/readctrl_rl_inference/run_inference_vllm_server_bn_api_wo_gs.py @@ -0,0 +1,385 @@ +import argparse +import json +import os +import re +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime +from typing import Any, Dict, List, Optional + +import pandas as pd +import requests +from tqdm import tqdm +from transformers import AutoTokenizer + + +DEFAULT_MODEL_PATH = "/home/mshahidul/readctrl/code/RL_model/models/converted_model/bn_40" +DEFAULT_DATASET_PATH = ( + "/home/mshahidul/readctrl/code/fine_tune_sft_dpo/dataset/bn/test_bn.json" +) +DEFAULT_OUTPUT_DIR = "/home/mshahidul/readctrl/code/readctrl_rl_inference/vllm_model_result/bn_temp" +DEFAULT_BASE_URL = "http://127.0.0.1:8021/v1" +DEFAULT_SERVED_MODEL_NAME = "inference" +DEFAULT_PROMPT_LOW_PATH = ( + "/home/mshahidul/readctrl/code/readctrl_rl_inference/prompt/prompt_bn_wo_gs/prompt_low" +) +DEFAULT_PROMPT_INTERMEDIATE_PATH = ( + "/home/mshahidul/readctrl/code/readctrl_rl_inference/prompt/prompt_bn_wo_gs/prompt_intermediate" +) +DEFAULT_PROMPT_PROFICIENT_PATH = ( + "/home/mshahidul/readctrl/code/readctrl_rl_inference/prompt/prompt_bn_wo_gs/prompt_proficient" +) +VALID_LABELS = { + "low_health_literacy", + "intermediate_health_literacy", + "proficient_health_literacy", +} + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run batched inference via vLLM OpenAI-compatible server.") + parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL_PATH, help="Local path for tokenizer/chat template.") + parser.add_argument("--dataset_path", type=str, default=DEFAULT_DATASET_PATH) + parser.add_argument( + "--input_name", + type=str, + default=None, + help=( + "Optional short name for the input file; used in output filenames. " + "If not provided, derived from the basename of --dataset_path." + ), + ) + parser.add_argument( + "--output_name", + type=str, + default=None, + help=( + "Base name (without extension) for output files. " + "If not provided, uses vllm_inference_{model_tag}_{input_name_or_dataset}_{timestamp}." + ), + ) + parser.add_argument("--prompt-low-path", type=str, default=DEFAULT_PROMPT_LOW_PATH) + parser.add_argument("--prompt-intermediate-path", type=str, default=DEFAULT_PROMPT_INTERMEDIATE_PATH) + parser.add_argument("--prompt-proficient-path", type=str, default=DEFAULT_PROMPT_PROFICIENT_PATH) + parser.add_argument("--output_dir", type=str, default=DEFAULT_OUTPUT_DIR) + parser.add_argument("--base_url", type=str, default=DEFAULT_BASE_URL, help="vLLM OpenAI base URL, e.g. http://127.0.0.1:8000/v1") + parser.add_argument("--served_model_name", type=str, default=DEFAULT_SERVED_MODEL_NAME, help="Model name exposed by vLLM server.") + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--max_samples", type=int, default=-1, help="Use -1 for full dataset.") + parser.add_argument("--max_tokens", type=int, default=2048) + parser.add_argument("--temperature", type=float, default=0.1) + parser.add_argument("--top_p", type=float, default=0.8) + parser.add_argument("--api_key", type=str, default="EMPTY") + parser.add_argument("--timeout_sec", type=int, default=300) + parser.add_argument("--num_workers", type=int, default=4, help="Concurrent request threads to keep server pipeline full.") + return parser.parse_args() + + +def load_prompt_templates(args: argparse.Namespace) -> Dict[str, str]: + prompt_path_by_label = { + "low_health_literacy": args.prompt_low_path, + "intermediate_health_literacy": args.prompt_intermediate_path, + "proficient_health_literacy": args.prompt_proficient_path, + } + templates: Dict[str, str] = {} + for label, path in prompt_path_by_label.items(): + if not os.path.exists(path): + raise FileNotFoundError(f"Prompt file not found: {path}") + with open(path, "r", encoding="utf-8") as f: + templates[label] = f.read() + return templates + + +def load_verified_rows(path: str) -> List[Dict[str, Any]]: + if not os.path.exists(path): + raise FileNotFoundError(f"Input file not found: {path}") + with open(path, "r", encoding="utf-8") as f: + parsed = json.load(f) + if not isinstance(parsed, list): + raise ValueError(f"Expected top-level JSON array in {path}") + return [row for row in parsed if isinstance(row, dict)] + + +def infer_source_lang(fulltext: str) -> str: + if fulltext and any("a" <= ch.lower() <= "z" for ch in fulltext): + return "English" + return "Unknown" + + +def build_prompt(template: str, fulltext: str, source_lang: str) -> str: + return ( + template.replace("{source_lang}", source_lang) + .replace("{full_text}", fulltext) + ) + + +def _clean_json_block(text: str) -> str: + cleaned = text.strip() + if "```json" in cleaned: + cleaned = cleaned.split("```json", 1)[1].split("```", 1)[0].strip() + elif "```" in cleaned: + cleaned = cleaned.split("```", 1)[1].split("```", 1)[0].strip() + return cleaned + + +def extract_generated_text(raw_response: str, expected_label: str) -> str: + cleaned = _clean_json_block(raw_response) + try: + parsed = json.loads(cleaned) + except json.JSONDecodeError: + return raw_response.strip() + + if isinstance(parsed, dict): + value = parsed.get(expected_label) + if isinstance(value, str) and value.strip(): + return value.strip() + return raw_response.strip() + + +def _normalize_messages(prompt_obj: Any) -> List[Dict[str, str]]: + if hasattr(prompt_obj, "tolist"): + prompt_obj = prompt_obj.tolist() + + if isinstance(prompt_obj, dict): + if "role" in prompt_obj and "content" in prompt_obj: + return [{"role": str(prompt_obj["role"]), "content": str(prompt_obj["content"])}] + return [{"role": "user", "content": json.dumps(prompt_obj, ensure_ascii=False)}] + + if isinstance(prompt_obj, list): + messages = [] + for item in prompt_obj: + if isinstance(item, dict) and "role" in item and "content" in item: + messages.append({"role": str(item["role"]), "content": str(item["content"])}) + else: + messages.append({"role": "user", "content": str(item)}) + if messages: + return messages + + return [{"role": "user", "content": str(prompt_obj)}] + + +def build_prompt_text(tokenizer: AutoTokenizer, prompt_obj: Any) -> str: + messages = _normalize_messages(prompt_obj) + if tokenizer.chat_template: + return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + return "\n".join(m["content"] for m in messages) + "\n\nAssistant:" + + +def sanitize_model_tag(model_path: str, max_len: int = 80) -> str: + tag = re.sub(r"[^A-Za-z0-9]+", "-", model_path).strip("-").lower() + if not tag: + return "unknown-model" + if len(tag) > max_len: + return tag[:max_len].rstrip("-") + return tag + + +def check_server(base_url: str, headers: Dict[str, str], timeout_sec: int) -> Optional[List[Dict[str, Any]]]: + models_url = f"{base_url.rstrip('/')}/models" + resp = requests.get(models_url, headers=headers, timeout=timeout_sec) + resp.raise_for_status() + payload = resp.json() + return payload.get("data", []) + + +def batched_completion_request( + base_url: str, + headers: Dict[str, str], + model_name: str, + prompts: List[str], + max_tokens: int, + temperature: float, + top_p: float, + timeout_sec: int, +) -> List[str]: + payload = { + "model": model_name, + "prompt": prompts, + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + } + url = f"{base_url.rstrip('/')}/completions" + resp = requests.post(url, headers=headers, json=payload, timeout=timeout_sec) + resp.raise_for_status() + data = resp.json() + choices = data.get("choices", []) + + preds = [""] * len(prompts) + for choice in choices: + idx = choice.get("index", None) + text = str(choice.get("text", "")).strip() + if isinstance(idx, int) and 0 <= idx < len(preds) and not preds[idx]: + preds[idx] = text + + if any(not p for p in preds): + fallback_texts = [str(c.get("text", "")).strip() for c in choices] + for i in range(len(preds)): + if not preds[i]: + preds[i] = fallback_texts[i] if i < len(fallback_texts) else "" + + return preds + + +def main() -> None: + args = parse_args() + os.makedirs(args.output_dir, exist_ok=True) + + run_ts = datetime.now().strftime("%Y%m%d_%H%M%S") + model_tag = sanitize_model_tag(args.model_path) + + input_tag_raw = ( + args.input_name + if args.input_name + else os.path.splitext(os.path.basename(args.dataset_path))[0] + ) + input_tag = sanitize_model_tag(input_tag_raw) + default_base = f"vllm_inference_{model_tag}_{input_tag}_{run_ts}" + base_name = args.output_name if args.output_name else default_base + output_jsonl = os.path.join(args.output_dir, f"{base_name}.jsonl") + meta_path = os.path.join(args.output_dir, f"{base_name}_meta.json") + + headers = { + "Authorization": f"Bearer {args.api_key}", + "Content-Type": "application/json", + } + + print(f"[INFO] Checking vLLM server: {args.base_url}") + models = check_server(args.base_url, headers=headers, timeout_sec=args.timeout_sec) + available_model_ids = [m.get("id", "") for m in models or []] + print(f"[INFO] Server models: {available_model_ids}") + if args.served_model_name not in available_model_ids: + print( + f"[WARN] Served model '{args.served_model_name}' not found in /models. " + "Will still try requests with provided name." + ) + + print(f"[INFO] Loading tokenizer from: {args.model_path}") + tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) + + print(f"[INFO] Reading dataset: {args.dataset_path}") + templates = load_prompt_templates(args) + rows = load_verified_rows(args.dataset_path) + parsed_items: List[Dict[str, Any]] = [] + for idx, row in enumerate(rows): + gold_label = str(row.get("label", "")).strip() + fulltext = str(row.get("fulltext", "")).strip() + if gold_label not in VALID_LABELS: + continue + if not fulltext: + continue + source_lang = "Bengali" + prompt = build_prompt( + template=templates[gold_label], + fulltext=fulltext, + source_lang=source_lang, + ) + # import ipdb; ipdb.set_trace() + parsed_items.append( + { + "row_index": idx, + "doc_id": row.get("doc_id"), + "gold_label": gold_label, + "source_lang": source_lang, + "input_text": fulltext, + "prompt": prompt, + } + ) + + df = pd.DataFrame(parsed_items) + if args.max_samples > 0: + df = df.head(args.max_samples) + print(f"[INFO] Rows to process: {len(df)}") + if df.empty: + raise RuntimeError("No valid rows found in input file.") + + batch_ranges = list(range(0, len(df), args.batch_size)) + total_batches = len(batch_ranges) + num_workers = min(args.num_workers, total_batches) + print(f"[INFO] {total_batches} batches × {args.batch_size} prompts, {num_workers} concurrent workers") + + t0 = time.time() + + def _process_batch(start: int) -> List[Dict[str, Any]]: + batch_df = df.iloc[start : start + args.batch_size] + prompts = [build_prompt_text(tokenizer, row.get("prompt", "")) for _, row in batch_df.iterrows()] + preds = batched_completion_request( + base_url=args.base_url, + headers=headers, + model_name=args.served_model_name, + prompts=prompts, + max_tokens=args.max_tokens, + temperature=args.temperature, + top_p=args.top_p, + timeout_sec=args.timeout_sec, + ) + records = [] + for (row_idx, row), pred in zip(batch_df.iterrows(), preds): + gold_label = str(row.get("gold_label", "")) + records.append( + { + "row_index": int(row.get("row_index", row_idx)), + "doc_id": row.get("doc_id"), + "gold_label": gold_label, + "source_lang": row.get("source_lang"), + "input_text": row.get("input_text", ""), + "prediction": pred, + "generated_text": extract_generated_text(pred, gold_label) + if gold_label + else pred.strip(), + } + ) + return records + + pending_results: Dict[int, List[Dict[str, Any]]] = {} + next_write_idx = 0 + outputs: List[Dict[str, Any]] = [] + + with open(output_jsonl, "w", encoding="utf-8") as f_out: + with ThreadPoolExecutor(max_workers=num_workers) as executor: + future_to_idx = { + executor.submit(_process_batch, batch_ranges[i]): i + for i in range(total_batches) + } + pbar = tqdm(total=total_batches, desc="Batches") + for future in as_completed(future_to_idx): + batch_idx = future_to_idx[future] + records = future.result() + pending_results[batch_idx] = records + pbar.update(1) + + while next_write_idx in pending_results: + for rec in pending_results.pop(next_write_idx): + outputs.append(rec) + f_out.write(json.dumps(rec, ensure_ascii=False) + "\n") + next_write_idx += 1 + pbar.close() + + elapsed = time.time() - t0 + print(f"[INFO] Inference done: {len(outputs)} samples in {elapsed:.1f}s ({len(outputs)/elapsed:.1f} samples/s)") + + with open(meta_path, "w", encoding="utf-8") as f_meta: + json.dump( + { + "model_path_for_tokenizer": args.model_path, + "dataset_path": args.dataset_path, + "input_name": input_tag, + "output_name": base_name, + "base_url": args.base_url, + "served_model_name": args.served_model_name, + "batch_size": args.batch_size, + "num_samples": len(outputs), + "output_jsonl": output_jsonl, + }, + f_meta, + ensure_ascii=False, + indent=2, + ) + + print("[DONE] vLLM batch inference complete.") + print(f"[DONE] JSONL: {output_jsonl}") + print(f"[DONE] Meta: {meta_path}") + + +if __name__ == "__main__": + main() diff --git a/code/readctrl_rl_inference/vllm_model_result/bn_temp/bn_200.jsonl b/code/readctrl_rl_inference/vllm_model_result/bn_temp/bn_200.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..d9f25994df5aa59719be9ebb393a953ad836fabd --- /dev/null +++ b/code/readctrl_rl_inference/vllm_model_result/bn_temp/bn_200.jsonl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:da74dcc0cbc71e0d81e20da4918c36fe705558deeb9720a4409a4d73ec4674d2 +size 3870275 diff --git a/code/readctrl_rl_inference/vllm_model_result/bn_temp/bn_200_meta.json b/code/readctrl_rl_inference/vllm_model_result/bn_temp/bn_200_meta.json new file mode 100644 index 0000000000000000000000000000000000000000..7ea433ebcad7ac3118823e6250782d3add8b72fd --- /dev/null +++ b/code/readctrl_rl_inference/vllm_model_result/bn_temp/bn_200_meta.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f4aed7cfc40d1c5bbf14a128a496520902e770a1c849226371659fa45b0d1977 +size 481 diff --git a/code/subclaim_support_extraction/extract_bn_subclaims_from_test_bn.py b/code/subclaim_support_extraction/extract_bn_subclaims_from_test_bn.py new file mode 100644 index 0000000000000000000000000000000000000000..cf192387be3c2c48f36712eff261d2cf8cf6efec --- /dev/null +++ b/code/subclaim_support_extraction/extract_bn_subclaims_from_test_bn.py @@ -0,0 +1,462 @@ +#!/usr/bin/env python3 +""" +Extract Bangla subclaims for the BN test set using the subclaim-extractor vLLM server. + +Input: + - Single JSON file (list of objects) with at least: + - "fulltext": Bangla full article text + - "summary": Bangla summary text + +Output: + - Single JSON file with the original fields plus: + - "fulltext_subclaims": list[str] + - "summary_subclaims": list[str] + +Defaults are wired to: + - input_file: /home/mshahidul/readctrl/code/fine_tune_sft_dpo/dataset/bn/test_bn.json + - save_path: /home/mshahidul/readctrl/code/fine_tune_sft_dpo/dataset/bn/test_bn_subclaims.json + +cd /home/mshahidul/readctrl/code/subclaim_support_extraction + +python extract_bn_subclaims_from_test_bn.py \ + --input_file /home/mshahidul/readctrl/code/fine_tune_sft_dpo/dataset/bn/test_bn.json \ + --save_path /home/mshahidul/readctrl/code/fine_tune_sft_dpo/dataset/bn/test_bn_subclaims.json \ + --port 8055 \ + --model subclaim-extractor +""" + +import argparse +import json +import os + +from extract_bn_subclaims_vllm_v2_mod import ( + DEFAULT_API_URL, + DEFAULT_MODEL_NAME, + infer_subclaims_api, +) + + +def load_items(input_file: str): + with open(input_file, "r", encoding="utf-8") as f: + data = json.load(f) + if not isinstance(data, list): + data = [data] + return data + + +def main(): + parser = argparse.ArgumentParser( + description="Extract Bangla subclaims for BN test set (fields: fulltext, summary)" + ) + parser.add_argument( + "--input_file", + type=str, + default="/home/mshahidul/readctrl/code/fine_tune_sft_dpo/dataset/bn/test_bn.json", + help="Input JSON file (list of items with 'fulltext' and 'summary')", + ) + parser.add_argument( + "--save_path", + type=str, + default="/home/mshahidul/readctrl/code/fine_tune_sft_dpo/dataset/bn/test_bn_subclaims.json", + help="Output JSON file path", + ) + parser.add_argument( + "--api_url", + type=str, + default=DEFAULT_API_URL, + help="vLLM OpenAI-compatible API base URL (default: http://localhost:8050/v1)", + ) + parser.add_argument( + "--port", + type=int, + default=None, + help="Server port (e.g. 8050). Builds API URL as http://localhost:PORT/v1 (overrides --api_url if set)", + ) + parser.add_argument( + "--model", + type=str, + default=DEFAULT_MODEL_NAME, + help="Served model name (default: subclaim-extractor)", + ) + parser.add_argument("--start", type=int, default=0, help="Start index") + parser.add_argument("--end", type=int, default=None, help="End index (exclusive)") + parser.add_argument( + "--resume", + type=str, + default=None, + help="Existing output JSON to resume from (append by id/doc_id)", + ) + args = parser.parse_args() + + if args.port is not None: + args.api_url = f"http://localhost:{args.port}/v1" + print(f"Using API URL: {args.api_url}") + + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + + all_items = load_items(args.input_file) + end = args.end if args.end is not None else len(all_items) + subset = all_items[args.start : end] + print(f"Loaded {len(all_items)} items from {args.input_file}") + print(f"Processing indices [{args.start}:{end}], total items in this run: {len(subset)}") + + processed_by_id = {} + output_file = args.save_path + if args.resume and os.path.isfile(args.resume): + output_file = args.resume + with open(args.resume, "r", encoding="utf-8") as f: + existing = json.load(f) + for it in existing: + uid = it.get("doc_id") or it.get("id") + if uid is None: + continue + processed_by_id[uid] = it + print(f"Resumed from {args.resume}: {len(processed_by_id)} existing entries") + + try: + import tqdm + + iterator = tqdm.tqdm(subset, desc="Extracting subclaims (test_bn)") + except ImportError: + iterator = subset + + for item in iterator: + uid = item.get("doc_id") or item.get("id") + if uid is None: + # fallback: index-based id to keep uniqueness + uid = f"index_{all_items.index(item)}" + + if uid in processed_by_id: + continue + + fulltext = (item.get("fulltext") or "").strip() + summary = (item.get("summary") or "").strip() + + if not fulltext and not summary: + enriched = dict(item) + enriched["fulltext_subclaims"] = [] + enriched["summary_subclaims"] = [] + processed_by_id[uid] = enriched + continue + + fulltext_subclaims = infer_subclaims_api( + fulltext, + is_summary=False, + max_tokens=4096, + base_url=args.api_url, + model_name=args.model, + ) + summary_subclaims = infer_subclaims_api( + summary, + is_summary=True, + max_tokens=2048, + base_url=args.api_url, + model_name=args.model, + ) + + enriched = dict(item) + enriched["fulltext_subclaims"] = fulltext_subclaims + enriched["summary_subclaims"] = summary_subclaims + processed_by_id[uid] = enriched + + # Checkpoint every 20 items + if len(processed_by_id) % 20 == 0: + with open(output_file, "w", encoding="utf-8") as f: + json.dump( + list(processed_by_id.values()), + f, + indent=2, + ensure_ascii=False, + ) + + with open(output_file, "w", encoding="utf-8") as f: + json.dump( + list(processed_by_id.values()), + f, + indent=2, + ensure_ascii=False, + ) + print(f"Saved {len(processed_by_id)} entries to {output_file}") + + +if __name__ == "__main__": + main() + +#!/usr/bin/env python3 +""" +Extract Bangla subclaims from test_bn.json (fulltext and summary fields) +using the subclaim-extractor vLLM server. + +Uses async batch processing for concurrent requests against the vLLM server. + +Input: test_bn.json with fields: doc_id, label, gen_text, fulltext, summary, ... +Output: JSON with all fields except predicted_label and prediction_correct, + plus fulltext_subclaims and summary_subclaims. +""" + +import os +import json +import asyncio +import argparse +import aiohttp + +DEFAULT_API_URL = "http://localhost:8055/v1" +DEFAULT_MODEL_NAME = "subclaim-extractor" + +MAX_SUBCLAIMS_FULLTEXT = 80 +MAX_SUBCLAIMS_SUMMARY = 40 + +EXCLUDE_KEYS = {"predicted_label", "prediction_correct"} + + +def extraction_prompt( + medical_text: str, + is_summary: bool = False, + max_subclaims: int = None, +) -> str: + source_type = "summary" if is_summary else "full medical text" + limit = max_subclaims if max_subclaims is not None else ( + MAX_SUBCLAIMS_SUMMARY if is_summary else MAX_SUBCLAIMS_FULLTEXT + ) + return f""" +You are an expert medical annotator. The following text is in Bangla (Bengali). + +Your task is to extract granular, factual subclaims from the provided {source_type}. +A subclaim is the smallest standalone factual unit that can be independently verified. + +IMPORTANT: Extract at most {limit} subclaims. Prioritize the most important factual statements. If the text contains more, list only the first {limit} and stop. + +Instructions: +1. Read the Bangla medical text carefully. +2. Extract factual statements explicitly stated in the text (at most {limit}). +3. Each subclaim must: + - Be in Bangla (same language as the input) + - Contain exactly ONE factual assertion + - Come directly from the text (no inference or interpretation) + - Preserve original wording as much as possible + - Include any negation, uncertainty, or qualifier +4. Do NOT: + - Combine multiple facts into one subclaim + - Add new information + - Translate to another language + - Exceed {limit} subclaims +5. Return ONLY a valid JSON array of strings. +6. Use double quotes and valid JSON formatting only (no markdown, no commentary). + +Medical Text (Bangla): +{medical_text} + +Return format: +[ + "subclaim 1", + "subclaim 2" +] +""".strip() + + +def parse_subclaims(output_text: str) -> list: + if "" in output_text: + output_text = output_text.split("")[-1].strip() + + start_idx = output_text.find("[") + end_idx = output_text.rfind("]") + 1 + if start_idx != -1 and end_idx > start_idx: + content = output_text[start_idx:end_idx] + parsed = json.loads(content) + if isinstance(parsed, list): + return [str(s).strip() for s in parsed if s] + + raise ValueError("Incomplete or invalid JSON list") + + +async def infer_subclaims_async( + session: aiohttp.ClientSession, + medical_text: str, + is_summary: bool = False, + temperature: float = 0.2, + max_tokens: int = 4196, + max_subclaims: int = None, + retries: int = 2, + api_url: str = DEFAULT_API_URL, + model_name: str = DEFAULT_MODEL_NAME, +) -> list: + if not medical_text or not medical_text.strip(): + return [] + + prompt = extraction_prompt( + medical_text, is_summary=is_summary, max_subclaims=max_subclaims + ) + url = f"{api_url}/chat/completions" + + for attempt in range(retries + 1): + try: + payload = { + "model": model_name, + "messages": [{"role": "user", "content": prompt}], + "temperature": temperature, + "max_tokens": max_tokens, + } + async with session.post(url, json=payload) as resp: + resp.raise_for_status() + data = await resp.json() + output_text = data["choices"][0]["message"]["content"].strip() + return parse_subclaims(output_text) + except Exception as e: + if attempt < retries: + max_tokens = max_tokens + 1024 + print(f" [Warning] {e}. Retry with max_tokens={max_tokens}") + continue + print(f" [Error] Failed after retries: {e}") + return [] + + return [] + + +async def process_one_item( + sem: asyncio.Semaphore, + session: aiohttp.ClientSession, + item: dict, + api_url: str, + model_name: str, +) -> dict: + async with sem: + fulltext = item.get("fulltext", "") + summary = item.get("summary", "") + + fulltext_task = infer_subclaims_async( + session, fulltext, + is_summary=False, max_tokens=4096, + api_url=api_url, model_name=model_name, + ) + summary_task = infer_subclaims_async( + session, summary, + is_summary=True, max_tokens=2048, + api_url=api_url, model_name=model_name, + ) + + fulltext_subclaims, summary_subclaims = await asyncio.gather( + fulltext_task, summary_task + ) + + result = {k: v for k, v in item.items() if k not in EXCLUDE_KEYS} + result["fulltext_subclaims"] = fulltext_subclaims + result["summary_subclaims"] = summary_subclaims + return result + + +def save_checkpoint(processed_by_id: dict, output_file: str): + with open(output_file, "w", encoding="utf-8") as f: + json.dump( + list(processed_by_id.values()), + f, indent=2, ensure_ascii=False, + ) + + +async def run(args): + if args.port is not None: + args.api_url = f"http://localhost:{args.port}/v1" + print(f"Using API URL: {args.api_url}") + + with open(args.input_file, "r", encoding="utf-8") as f: + all_items = json.load(f) + print(f"Loaded {len(all_items)} items from {args.input_file}") + + end = args.end if args.end is not None else len(all_items) + subset = all_items[args.start:end] + print(f"Processing indices [{args.start}:{end}], total items: {len(subset)}") + + processed_by_id = {} + if args.resume and os.path.isfile(args.output_file): + with open(args.output_file, "r", encoding="utf-8") as f: + existing = json.load(f) + for item in existing: + processed_by_id[item["doc_id"]] = item + print(f"Resumed: {len(processed_by_id)} existing entries from {args.output_file}") + + to_process = [item for item in subset if item["doc_id"] not in processed_by_id] + print(f"Items to process: {len(to_process)} (skipping {len(subset) - len(to_process)} already done)") + + if not to_process: + print("Nothing to process.") + return + + sem = asyncio.Semaphore(args.batch_size) + timeout = aiohttp.ClientTimeout(total=300) + + async with aiohttp.ClientSession( + timeout=timeout, + headers={"Authorization": "Bearer EMPTY"}, + ) as session: + batch_count = 0 + total = len(to_process) + + for batch_start in range(0, total, args.batch_size): + batch = to_process[batch_start : batch_start + args.batch_size] + tasks = [ + process_one_item(sem, session, item, args.api_url, args.model) + for item in batch + ] + results = await asyncio.gather(*tasks, return_exceptions=True) + + for item, result in zip(batch, results): + if isinstance(result, Exception): + print(f" [Error] doc_id={item['doc_id']}: {result}") + continue + processed_by_id[result["doc_id"]] = result + + batch_count += len(batch) + print(f"Progress: {batch_count}/{total} items done") + + if batch_count % args.checkpoint_every == 0 or batch_count == total: + save_checkpoint(processed_by_id, args.output_file) + + save_checkpoint(processed_by_id, args.output_file) + print(f"Saved {len(processed_by_id)} entries to {args.output_file}") + + +def main(): + parser = argparse.ArgumentParser( + description="Extract Bangla subclaims from test_bn.json (fulltext & summary) with batch processing" + ) + parser.add_argument( + "--input_file", type=str, + default="/home/mshahidul/readctrl/code/fine_tune_sft_dpo/dataset/bn/test_bn.json", + help="Input JSON file", + ) + parser.add_argument( + "--output_file", type=str, + default="/home/mshahidul/readctrl/code/fine_tune_sft_dpo/dataset/bn/test_bn_subclaims.json", + help="Output JSON file", + ) + parser.add_argument( + "--api_url", type=str, default=DEFAULT_API_URL, + help="vLLM OpenAI-compatible API base URL", + ) + parser.add_argument( + "--port", type=int, default=None, + help="Server port (builds URL as http://localhost:PORT/v1, overrides --api_url)", + ) + parser.add_argument( + "--model", type=str, default=DEFAULT_MODEL_NAME, + help="Served model name (default: subclaim-extractor)", + ) + parser.add_argument("--start", type=int, default=0, help="Start index") + parser.add_argument("--end", type=int, default=None, help="End index (exclusive)") + parser.add_argument( + "--resume", action="store_true", + help="Resume from existing output file (skip already processed doc_ids)", + ) + parser.add_argument( + "--batch_size", type=int, default=16, + help="Number of concurrent requests (default: 16)", + ) + parser.add_argument( + "--checkpoint_every", type=int, default=20, + help="Save checkpoint every N items (default: 20)", + ) + args = parser.parse_args() + asyncio.run(run(args)) + + +if __name__ == "__main__": + main() diff --git a/code/subclaim_support_extraction/extract_bn_subclaims_vllm_v2_mod.py b/code/subclaim_support_extraction/extract_bn_subclaims_vllm_v2_mod.py new file mode 100644 index 0000000000000000000000000000000000000000..f05c3fe818e8cb8b2c3cc4509b790769efb54159 --- /dev/null +++ b/code/subclaim_support_extraction/extract_bn_subclaims_vllm_v2_mod.py @@ -0,0 +1,328 @@ +#!/usr/bin/env python3 +""" +Extract Bangla subclaims from translated MultiClinSum files using the +subclaim-extractor vLLM server (google/gemma-3-27b-it on port 8050). + +- Input: JSON files in translation_testing_3396 (attrs: translated_fulltext, translated_summary) +- Output: Save to extracting_subclaim/bn without fulltext/summary. +""" + +import os +import json +import glob +import argparse +from openai import OpenAI + +# ----------------------------- +# API CONFIGURATION (subclaim-extractor vLLM server) +# ----------------------------- +DEFAULT_API_URL = "http://localhost:8050/v1" +DEFAULT_MODEL_NAME = "subclaim-extractor" + +client = None + + +def get_client(base_url: str = None, api_key: str = "EMPTY"): + global client + if client is None: + client = OpenAI(base_url=base_url or DEFAULT_API_URL, api_key=api_key) + return client + + +# ----------------------------- +# SUBCLAIM EXTRACTION PROMPT (Bangla) +# ----------------------------- +# Max subclaims to request (keeps output within max_tokens) +MAX_SUBCLAIMS_FULLTEXT = 80 +MAX_SUBCLAIMS_SUMMARY = 40 + + +def extraction_prompt( + medical_text: str, + is_summary: bool = False, + max_subclaims: int = None, +) -> str: + source_type = "summary" if is_summary else "full medical text" + limit = max_subclaims if max_subclaims is not None else ( + MAX_SUBCLAIMS_SUMMARY if is_summary else MAX_SUBCLAIMS_FULLTEXT + ) + return f""" +You are an expert medical annotator. The following text is in Bangla (Bengali). + +Your task is to extract granular, factual subclaims from the provided {source_type}. +A subclaim is the smallest standalone factual unit that can be independently verified. + +IMPORTANT: Extract at most {limit} subclaims. Prioritize the most important factual statements. If the text contains more, list only the first {limit} and stop. + +Instructions: +1. Read the Bangla medical text carefully. +2. Extract factual statements explicitly stated in the text (at most {limit}). +3. Each subclaim must: + - Be in Bangla (same language as the input) + - Contain exactly ONE factual assertion + - Come directly from the text (no inference or interpretation) + - Preserve original wording as much as possible + - Include any negation, uncertainty, or qualifier +4. Do NOT: + - Combine multiple facts into one subclaim + - Add new information + - Translate to another language + - Exceed {limit} subclaims +5. Return ONLY a valid JSON array of strings. +6. Use double quotes and valid JSON formatting only (no markdown, no commentary). + +Medical Text (Bangla): +{medical_text} + +Return format: +[ + "subclaim 1", + "subclaim 2" +] +""".strip() + + +def infer_subclaims_api( + medical_text: str, + is_summary: bool = False, + temperature: float = 0.2, + max_tokens: int = 2048, + max_subclaims: int = None, + retries: int = 2, + base_url: str = None, + model_name: str = None, +) -> list: + if not medical_text or not medical_text.strip(): + return [] + + prompt = extraction_prompt( + medical_text, is_summary=is_summary, max_subclaims=max_subclaims + ) + c = get_client(base_url=base_url) + model = model_name or DEFAULT_MODEL_NAME + + for attempt in range(retries + 1): + try: + response = c.chat.completions.create( + model=model, + messages=[{"role": "user", "content": prompt}], + temperature=temperature, + max_tokens=max_tokens, + ) + output_text = response.choices[0].message.content.strip() + + + if "" in output_text: + output_text = output_text.split("")[-1].strip() + + start_idx = output_text.find("[") + end_idx = output_text.rfind("]") + 1 + if start_idx != -1 and end_idx > start_idx: + content = output_text[start_idx:end_idx] + parsed = json.loads(content) + if isinstance(parsed, list): + return [str(s).strip() for s in parsed if s] + + raise ValueError("Incomplete or invalid JSON list") + except (json.JSONDecodeError, ValueError, Exception) as e: + if attempt < retries: + max_tokens = max_tokens + 1024 + print(f" [Warning] {e}. Retry with max_tokens={max_tokens}") + continue + print(f" [Error] Failed after retries: {e}") + return [] + + return [] + + +def _has_null_translation(item: dict) -> bool: + """True if translated_fulltext or translated_summary is None (ignore such instances).""" + return item.get("translated_fulltext") is None or item.get("translated_summary") is None + + +def load_from_single_file(input_path: str) -> list: + """Load items from a single JSON file (list or single object). Ignore instances with null translations.""" + with open(input_path, "r", encoding="utf-8") as f: + data = json.load(f) + if not isinstance(data, list): + data = [data] + return [item for item in data if not _has_null_translation(item)] + + +def load_all_translation_items(input_dir: str) -> list: + """Load and merge all JSON arrays from translation_testing_3396. Ignore instances with null translations.""" + pattern = os.path.join(input_dir, "*.json") + files = sorted(glob.glob(pattern)) + if not files: + raise FileNotFoundError(f"No JSON files in {input_dir}") + all_items = [] + seen_ids = set() + for path in files: + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + if not isinstance(data, list): + data = [data] + for item in data: + if _has_null_translation(item): + continue + uid = item.get("id") + if uid in seen_ids: + continue + seen_ids.add(uid) + all_items.append(item) + return all_items + + +def main(): + parser = argparse.ArgumentParser(description="Extract Bangla subclaims via subclaim-extractor vLLM") + parser.add_argument( + "--input_dir", + type=str, + default="/home/mshahidul/readctrl/data/translated_data/translation_testing_3396", + help="Directory containing translated JSON files (used when --input_file is not set)", + ) + parser.add_argument( + "--input_file", + type=str, + default=None, + help="Single JSON file to process (overrides --input_dir)", + ) + parser.add_argument( + "--save_dir", + type=str, + default="/home/mshahidul/readctrl/data/extracting_subclaim/bn", + help="Directory to save output JSON files", + ) + parser.add_argument( + "--api_url", + type=str, + default=DEFAULT_API_URL, + help="vLLM OpenAI-compatible API base URL (default: http://localhost:8050/v1)", + ) + parser.add_argument( + "--port", + type=int, + default=None, + help="Server port (e.g. 8050). Builds API URL as http://localhost:PORT/v1 (overrides --api_url if set)", + ) + parser.add_argument( + "--model", + type=str, + default=DEFAULT_MODEL_NAME, + help="Served model name (default: subclaim-extractor)", + ) + parser.add_argument("--start", type=int, default=0, help="Start index") + parser.add_argument("--end", type=int, default=None, help="End index (exclusive)") + parser.add_argument( + "--resume", + type=str, + default=None, + help="Path to existing output JSON to resume (append new items by id)", + ) + args = parser.parse_args() + + if args.port is not None: + args.api_url = f"http://localhost:{args.port}/v1" + print(f"Using API URL: {args.api_url}") + + os.makedirs(args.save_dir, exist_ok=True) + + if args.input_file: + if not os.path.isfile(args.input_file): + raise FileNotFoundError(f"Input file not found: {args.input_file}") + all_items = load_from_single_file(args.input_file) + print(f"Loaded {len(all_items)} items from {args.input_file}") + else: + all_items = load_all_translation_items(args.input_dir) + end = args.end if args.end is not None else len(all_items) + subset = all_items[args.start : end] + print(f"Processing indices [{args.start}:{end}], total items: {len(subset)}") + + # Resume: load existing by id + processed_by_id = {} + if args.resume and os.path.isfile(args.resume): + with open(args.resume, "r", encoding="utf-8") as f: + existing = json.load(f) + for item in existing: + processed_by_id[item["id"]] = item + print(f"Resumed: {len(processed_by_id)} existing entries from {args.resume}") + + # Single output file for this run (resume appends into same structure) + output_file = os.path.join( + args.save_dir, + f"extracted_subclaims_bn_{args.start}_{end if end != len(all_items) else 'end'}.json", + ) + if args.resume: + output_file = args.resume + + try: + import tqdm + iterator = tqdm.tqdm(subset, desc="Extracting subclaims") + except ImportError: + iterator = subset + + for item in iterator: + uid = item.get("id") + if uid in processed_by_id: + continue + + translated_fulltext = item.get("translated_fulltext") or "" + translated_summary = item.get("translated_summary") or "" + + # Skip if both missing + if not translated_fulltext.strip() and not translated_summary.strip(): + processed_by_id[uid] = { + "id": uid, + "translated_fulltext": translated_fulltext, + "translated_summary": translated_summary, + "fulltext_subclaims": [], + "summary_subclaims": [], + } + continue + + fulltext_subclaims = infer_subclaims_api( + translated_fulltext, + is_summary=False, + max_tokens=4096, + base_url=args.api_url, + model_name=args.model, + ) + summary_subclaims = infer_subclaims_api( + translated_summary, + is_summary=True, + max_tokens=2048, + base_url=args.api_url, + model_name=args.model, + ) + + # Save only requested fields; no fulltext, no summary + processed_by_id[uid] = { + "id": uid, + "translated_fulltext": translated_fulltext, + "translated_summary": translated_summary, + "fulltext_subclaims": fulltext_subclaims, + "summary_subclaims": summary_subclaims, + } + + # Checkpoint every 20 items + if len(processed_by_id) % 20 == 0: + with open(output_file, "w", encoding="utf-8") as f: + json.dump( + list(processed_by_id.values()), + f, + indent=2, + ensure_ascii=False, + ) + + with open(output_file, "w", encoding="utf-8") as f: + json.dump( + list(processed_by_id.values()), + f, + indent=2, + ensure_ascii=False, + ) + print(f"Saved {len(processed_by_id)} entries to {output_file}") + + +if __name__ == "__main__": + main() diff --git a/code/subclaim_support_extraction/s.sh b/code/subclaim_support_extraction/s.sh new file mode 100644 index 0000000000000000000000000000000000000000..3ed7a484b639e7eb08e22d6642297ccd8d4ed040 --- /dev/null +++ b/code/subclaim_support_extraction/s.sh @@ -0,0 +1,9 @@ +CUDA_DEVICE_ORDER="PCI_BUS_ID" CUDA_VISIBLE_DEVICES=5 vllm serve google/gemma-3-27b-it \ + --gpu-memory-utilization 0.95 \ + --max-model-len 16384 \ + --enable-prefix-caching \ + --kv-cache-dtype fp8 \ + --max-num-batched-tokens 32768 \ + --trust-remote-code \ + --port 8055 \ + --served-model-name subclaim-extractor \ No newline at end of file diff --git a/push_to_hf.sh b/push_to_hf.sh new file mode 100644 index 0000000000000000000000000000000000000000..efcddf4e5a49b4d7dcdbc69437f577965c8956bf --- /dev/null +++ b/push_to_hf.sh @@ -0,0 +1,32 @@ +#!/usr/bin/env bash + +# Simple helper script to push this repo to Hugging Face. +# Assumes: +# - You are in the repo root (same directory as this script), or +# - `origin` remote is already set to https://huggingface.co/shahidul034/readCtrl_lambda +# - You have configured your Hugging Face token for git authentication. + +set -euo pipefail + +REPO_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$REPO_DIR" + +# Use first argument as commit message, or a default one. +COMMIT_MSG="${1:-\"Update readCtrl repo\"}" + +echo "Adding changes (respecting .gitignore)..." +git add . + +if git diff --cached --quiet; then + echo "No changes to commit. Nothing to push." + exit 0 +fi + +echo "Committing with message: $COMMIT_MSG" +git commit -m "$COMMIT_MSG" + +echo "Pushing to origin main..." +git push origin main + +echo "Done." +