shahidul034 commited on
Commit ·
93694bb
1
Parent(s): 030876e
"Update readCtrl repo"
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- code/RL_model/verl/verl_train/reward_func/reward_func/evaluate_scores.py +267 -0
- code/RL_model/verl/verl_train/reward_func/reward_func/reward_new_v6_bn_v4.py +4 -4
- code/RL_model/verl/verl_train/reward_func/reward_func/reward_new_v6_bn_v4_rmv_src_cov.py +835 -0
- code/RL_model/verl/verl_train/script/run_qwen3-8b_v3(2G)_v2.sh +1 -1
- code/RL_model/verl/verl_train/script/run_qwen3-8b_v3(2G)_v3.sh +56 -0
- code/RL_model/verl/verl_train/script/run_qwen3-8b_v3(4G)_v2.sh +1 -1
- code/fine_tune_sft_dpo/best_of_n_qwen3_vllm_bn.py +518 -0
- code/fine_tune_sft_dpo/dataset/bn/old/test_bn_subclaims.json +3 -0
- code/fine_tune_sft_dpo/dataset/bn/test_bn_subclaims.json +3 -0
- code/fine_tune_sft_dpo/eval.sh +0 -2
- code/fine_tune_sft_dpo/evaluate_scores_bn.py +554 -0
- code/fine_tune_sft_dpo/evaluate_scores_bn_vllm.py +611 -0
- code/fine_tune_sft_dpo/evaluate_scores_bn_vllm_rl.py +567 -0
- code/fine_tune_sft_dpo/evaluation/bn/bn_200_eval_results.json +3 -0
- code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_mini_eval_results_20260316_071029.json +3 -0
- code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_mini_prepared_20260316_071029.json +3 -0
- code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_nano_eval_results_20260316_071029.json +3 -0
- code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_nano_prepared_20260316_071029.json +3 -0
- code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_prepared_20260316_071029.json +3 -0
- code/fine_tune_sft_dpo/evaluation/bn/test_best_of_n_qwen3-4B_base_eval_results.json +3 -0
- code/fine_tune_sft_dpo/evaluation/bn/test_best_of_n_qwen3-4B_sft_eval_results.json +3 -0
- code/fine_tune_sft_dpo/evaluation/bn/test_inference_vllm_qwen3-4B_base_eval_results.json +3 -0
- code/fine_tune_sft_dpo/evaluation/bn/test_inference_vllm_qwen3-4B_sft_eval_results.json +3 -0
- code/fine_tune_sft_dpo/evaluation/bn/test_self_refine_vllm_qwen3_4B_base_eval_results.json +3 -0
- code/fine_tune_sft_dpo/evaluation/bn/test_self_refine_vllm_qwen3_4B_sft_eval_results.json +3 -0
- code/fine_tune_sft_dpo/evaluation/bn/v1/test_best_of_n_qwen3-4B_base_eval_results.json +3 -0
- code/fine_tune_sft_dpo/evaluation/bn/v1/test_best_of_n_qwen3-4B_sft_eval_results.json +3 -0
- code/fine_tune_sft_dpo/evaluation/bn/v1/test_inference_vllm_qwen3-4B_base_eval_results.json +3 -0
- code/fine_tune_sft_dpo/evaluation/bn/v1/test_inference_vllm_qwen3-4B_sft_eval_results.json +3 -0
- code/fine_tune_sft_dpo/evaluation/bn/v1/test_self_refine_vllm_qwen3_4B_base_eval_results.json +3 -0
- code/fine_tune_sft_dpo/evaluation/bn/v1/test_self_refine_vllm_qwen3_4B_sft_eval_results.json +3 -0
- code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt +58 -0
- code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_intermediate +31 -0
- code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_low +31 -0
- code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_proficient +31 -0
- code/fine_tune_sft_dpo/qwen3-inference-vllm_bn.py +14 -4
- code/fine_tune_sft_dpo/qwen3_best_of_n.py +238 -0
- code/fine_tune_sft_dpo/qwen3_infer_bn.py +247 -0
- code/fine_tune_sft_dpo/results/bn/gpt5_inference_all_20260314_110445.jsonl +3 -0
- code/fine_tune_sft_dpo/results/bn/gpt5_inference_all_wo_gs_20260314_173736.jsonl +3 -0
- code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-mini_20260314_110445.jsonl +3 -0
- code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-mini_20260314_173736.jsonl +3 -0
- code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-nano_20260314_110445.jsonl +3 -0
- code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-nano_20260314_173736.jsonl +3 -0
- code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5_20260314_110445.jsonl +3 -0
- code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5_20260314_173736.jsonl +3 -0
- code/fine_tune_sft_dpo/results/bn/gpt5_inference_summary_20260314_110445.json +3 -0
- code/fine_tune_sft_dpo/results/bn/gpt5_inference_summary_wo_gs_20260314_173736.json +3 -0
- code/fine_tune_sft_dpo/results/bn/inference_summary_vllm_20260314_101627.json +3 -0
- code/fine_tune_sft_dpo/results/bn/{inference_summary_vllm_20260311_044629.json → misc/inference_summary_vllm_20260311_044629.json} +0 -0
code/RL_model/verl/verl_train/reward_func/reward_func/evaluate_scores.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Standalone evaluation script for computing factuality, hallucination, and
|
| 4 |
+
classifier scores on a JSON file.
|
| 5 |
+
|
| 6 |
+
Expected input JSON: a list of objects, each with:
|
| 7 |
+
- fulltext : str (source document)
|
| 8 |
+
- fulltext_subclaims : list[str] (subclaims from fulltext, optional)
|
| 9 |
+
- summary_text : str (summary / reference text)
|
| 10 |
+
- summary_subclaims : list[str] (subclaims from summary)
|
| 11 |
+
- generated_text : str (model-generated text to evaluate)
|
| 12 |
+
- label : str (target literacy level, e.g. "low_health_literacy")
|
| 13 |
+
|
| 14 |
+
The script reuses the reward functions from reward_new_v6_bn_v4_rmv_src_cov.py:
|
| 15 |
+
- factuality_score : fraction of summary subclaims supported by generated_text
|
| 16 |
+
- hallucination_score: fraction of gen subclaims NOT supported by fulltext
|
| 17 |
+
- classifier_score : whether generated_text matches the target literacy level
|
| 18 |
+
|
| 19 |
+
Requires the same vLLM endpoints as the reward file:
|
| 20 |
+
- Support checker : VLLM_SUPPORT_CHECK_BN_API_BASE (default http://localhost:8090/v1)
|
| 21 |
+
- Classifier : VLLM_CLASSIFIER_BN_API_BASE (default http://localhost:8040/v1)
|
| 22 |
+
- Subclaim extractor: VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE (default http://localhost:8050/v1)
|
| 23 |
+
|
| 24 |
+
Usage:
|
| 25 |
+
python evaluate_scores.py --input data.json [--output results.json] [--target-level low_health_literacy]
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
import argparse
|
| 29 |
+
import json
|
| 30 |
+
import os
|
| 31 |
+
import sys
|
| 32 |
+
import time
|
| 33 |
+
from typing import Any, Dict, List, Optional
|
| 34 |
+
|
| 35 |
+
from tqdm import tqdm
|
| 36 |
+
|
| 37 |
+
# Import scoring utilities from the reward module (same directory).
|
| 38 |
+
from reward_new_v6_bn_v4_rmv_src_cov import (
|
| 39 |
+
_call_support_api,
|
| 40 |
+
_compute_classifier_reward,
|
| 41 |
+
_extract_subclaims_from_text,
|
| 42 |
+
_is_bangla_text,
|
| 43 |
+
_nonlinear_grounding,
|
| 44 |
+
compute_rewards,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def evaluate_single(
|
| 49 |
+
item: Dict[str, Any],
|
| 50 |
+
target_level_override: Optional[str] = None,
|
| 51 |
+
) -> Dict[str, Any]:
|
| 52 |
+
"""
|
| 53 |
+
Evaluate a single item and return detailed scores.
|
| 54 |
+
"""
|
| 55 |
+
fulltext = item.get("fulltext", "")
|
| 56 |
+
summary_text = item.get("summary_text") or item.get("summary", "")
|
| 57 |
+
summary_subclaims = item.get("summary_subclaims", [])
|
| 58 |
+
generated_text = item.get("generated_text") or item.get("predicted_gen_text", "")
|
| 59 |
+
target_level = target_level_override or item.get("label", "")
|
| 60 |
+
|
| 61 |
+
result: Dict[str, Any] = {
|
| 62 |
+
"doc_id": item.get("doc_id", ""),
|
| 63 |
+
"target_level": target_level,
|
| 64 |
+
"generated_text_len": len(generated_text.strip()) if generated_text else 0,
|
| 65 |
+
"factuality_score": None,
|
| 66 |
+
"hallucination_score": None,
|
| 67 |
+
"classifier_score": None,
|
| 68 |
+
"grounding_score": None,
|
| 69 |
+
"factuality_supported": 0,
|
| 70 |
+
"total_summary_subclaims": len(summary_subclaims) if summary_subclaims else 0,
|
| 71 |
+
"hallucination_supported": 0,
|
| 72 |
+
"total_gen_segments": 0,
|
| 73 |
+
"skipped": False,
|
| 74 |
+
"skip_reason": "",
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
if not generated_text or len(generated_text.strip()) < 10:
|
| 78 |
+
result["skipped"] = True
|
| 79 |
+
result["skip_reason"] = "generated_text missing or too short (<10 chars)"
|
| 80 |
+
return result
|
| 81 |
+
|
| 82 |
+
# -- Factuality & Hallucination via compute_rewards --
|
| 83 |
+
rewards = compute_rewards(
|
| 84 |
+
fulltext=fulltext,
|
| 85 |
+
generated_text=generated_text,
|
| 86 |
+
target_level=target_level,
|
| 87 |
+
summary_subclaims=summary_subclaims,
|
| 88 |
+
summary_text=summary_text,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
factuality_score = rewards["factuality_score"]
|
| 92 |
+
h_score = rewards["hallucination_score"]
|
| 93 |
+
|
| 94 |
+
if factuality_score is None:
|
| 95 |
+
factuality_score = 0.5
|
| 96 |
+
if h_score is None:
|
| 97 |
+
h_score = 0.5
|
| 98 |
+
|
| 99 |
+
grounding_score = _nonlinear_grounding(h_score)
|
| 100 |
+
|
| 101 |
+
# -- Classifier --
|
| 102 |
+
input_text = fulltext or ""
|
| 103 |
+
class_score = _compute_classifier_reward(target_level, generated_text, input_text)
|
| 104 |
+
|
| 105 |
+
result.update({
|
| 106 |
+
"factuality_score": round(factuality_score, 4),
|
| 107 |
+
"hallucination_score": round(h_score, 4),
|
| 108 |
+
"grounding_score": round(grounding_score, 4),
|
| 109 |
+
"classifier_score": round(class_score, 4),
|
| 110 |
+
"factuality_supported": rewards.get("factuality_supported", 0),
|
| 111 |
+
"total_summary_subclaims": rewards.get("total_summary_subclaims", 0),
|
| 112 |
+
"hallucination_supported": rewards.get("hallucination_supported", 0),
|
| 113 |
+
"total_gen_segments": rewards.get("total_gen_segments", 0),
|
| 114 |
+
})
|
| 115 |
+
|
| 116 |
+
return result
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def compute_aggregate(results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 120 |
+
"""Compute aggregate statistics over all evaluated items."""
|
| 121 |
+
scored = [r for r in results if not r.get("skipped", False)]
|
| 122 |
+
n = len(scored)
|
| 123 |
+
total = len(results)
|
| 124 |
+
skipped = total - n
|
| 125 |
+
|
| 126 |
+
if n == 0:
|
| 127 |
+
return {
|
| 128 |
+
"total_items": total,
|
| 129 |
+
"scored_items": 0,
|
| 130 |
+
"skipped_items": skipped,
|
| 131 |
+
"avg_factuality_score": None,
|
| 132 |
+
"avg_hallucination_score": None,
|
| 133 |
+
"avg_grounding_score": None,
|
| 134 |
+
"avg_classifier_score": None,
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
def safe_avg(key):
|
| 138 |
+
vals = [r[key] for r in scored if r[key] is not None]
|
| 139 |
+
return round(sum(vals) / len(vals), 4) if vals else None
|
| 140 |
+
|
| 141 |
+
return {
|
| 142 |
+
"total_items": total,
|
| 143 |
+
"scored_items": n,
|
| 144 |
+
"skipped_items": skipped,
|
| 145 |
+
"avg_factuality_score": safe_avg("factuality_score"),
|
| 146 |
+
"avg_hallucination_score": safe_avg("hallucination_score"),
|
| 147 |
+
"avg_grounding_score": safe_avg("grounding_score"),
|
| 148 |
+
"avg_classifier_score": safe_avg("classifier_score"),
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def main():
|
| 153 |
+
parser = argparse.ArgumentParser(
|
| 154 |
+
description="Evaluate factuality, hallucination, and classifier scores on a JSON file."
|
| 155 |
+
)
|
| 156 |
+
parser.add_argument(
|
| 157 |
+
"--input", "-i", required=True,
|
| 158 |
+
help="Path to input JSON file (list of objects).",
|
| 159 |
+
)
|
| 160 |
+
parser.add_argument(
|
| 161 |
+
"--output", "-o", default=None,
|
| 162 |
+
help="Path to output JSON file with per-item scores. "
|
| 163 |
+
"Defaults to <input_stem>_eval_results.json.",
|
| 164 |
+
)
|
| 165 |
+
parser.add_argument(
|
| 166 |
+
"--target-level", "-t", default=None,
|
| 167 |
+
help="Override target literacy level for all items "
|
| 168 |
+
"(e.g. low_health_literacy). If not set, uses each item's 'label' field.",
|
| 169 |
+
)
|
| 170 |
+
parser.add_argument(
|
| 171 |
+
"--support-check-url", default=None,
|
| 172 |
+
help="Override VLLM_SUPPORT_CHECK_BN_API_BASE.",
|
| 173 |
+
)
|
| 174 |
+
parser.add_argument(
|
| 175 |
+
"--classifier-url", default=None,
|
| 176 |
+
help="Override VLLM_CLASSIFIER_BN_API_BASE.",
|
| 177 |
+
)
|
| 178 |
+
parser.add_argument(
|
| 179 |
+
"--subclaim-extractor-url", default=None,
|
| 180 |
+
help="Override VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE.",
|
| 181 |
+
)
|
| 182 |
+
args = parser.parse_args()
|
| 183 |
+
|
| 184 |
+
if args.support_check_url:
|
| 185 |
+
os.environ["VLLM_SUPPORT_CHECK_BN_API_BASE"] = args.support_check_url
|
| 186 |
+
if args.classifier_url:
|
| 187 |
+
os.environ["VLLM_CLASSIFIER_BN_API_BASE"] = args.classifier_url
|
| 188 |
+
if args.subclaim_extractor_url:
|
| 189 |
+
os.environ["VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE"] = args.subclaim_extractor_url
|
| 190 |
+
|
| 191 |
+
# Load input
|
| 192 |
+
with open(args.input, "r", encoding="utf-8") as f:
|
| 193 |
+
data = json.load(f)
|
| 194 |
+
|
| 195 |
+
if not isinstance(data, list):
|
| 196 |
+
print(f"Error: Expected a JSON list, got {type(data).__name__}.", file=sys.stderr)
|
| 197 |
+
sys.exit(1)
|
| 198 |
+
|
| 199 |
+
print(f"Loaded {len(data)} items from {args.input}")
|
| 200 |
+
print(f" Support check API : {os.getenv('VLLM_SUPPORT_CHECK_BN_API_BASE', 'http://localhost:8090/v1')}")
|
| 201 |
+
print(f" Classifier API : {os.getenv('VLLM_CLASSIFIER_BN_API_BASE', 'http://localhost:8040/v1')}")
|
| 202 |
+
print(f" Subclaim extractor: {os.getenv('VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE', 'http://localhost:8050/v1')}")
|
| 203 |
+
if args.target_level:
|
| 204 |
+
print(f" Target level override: {args.target_level}")
|
| 205 |
+
print("-" * 60)
|
| 206 |
+
|
| 207 |
+
# Evaluate each item
|
| 208 |
+
results = []
|
| 209 |
+
start_time = time.time()
|
| 210 |
+
for idx, item in enumerate(tqdm(data, desc="Evaluating")):
|
| 211 |
+
r = evaluate_single(item, target_level_override=args.target_level)
|
| 212 |
+
r["index"] = idx
|
| 213 |
+
results.append(r)
|
| 214 |
+
|
| 215 |
+
if (idx + 1) % 10 == 0 or idx == 0:
|
| 216 |
+
partial_agg = compute_aggregate(results)
|
| 217 |
+
tqdm.write(
|
| 218 |
+
f" [{idx+1}/{len(data)}] "
|
| 219 |
+
f"fact={partial_agg['avg_factuality_score']} "
|
| 220 |
+
f"hallu={partial_agg['avg_hallucination_score']} "
|
| 221 |
+
f"cls={partial_agg['avg_classifier_score']}"
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
elapsed = time.time() - start_time
|
| 225 |
+
|
| 226 |
+
# Aggregate
|
| 227 |
+
agg = compute_aggregate(results)
|
| 228 |
+
|
| 229 |
+
# Output path
|
| 230 |
+
if args.output:
|
| 231 |
+
out_path = args.output
|
| 232 |
+
else:
|
| 233 |
+
stem = os.path.splitext(os.path.basename(args.input))[0]
|
| 234 |
+
out_dir = os.path.dirname(args.input) or "."
|
| 235 |
+
out_path = os.path.join(out_dir, f"{stem}_eval_results.json")
|
| 236 |
+
|
| 237 |
+
output = {
|
| 238 |
+
"input_file": os.path.abspath(args.input),
|
| 239 |
+
"target_level_override": args.target_level,
|
| 240 |
+
"elapsed_seconds": round(elapsed, 2),
|
| 241 |
+
"aggregate": agg,
|
| 242 |
+
"per_item": results,
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
| 246 |
+
json.dump(output, f, indent=2, ensure_ascii=False)
|
| 247 |
+
|
| 248 |
+
# Print summary
|
| 249 |
+
print("\n" + "=" * 60)
|
| 250 |
+
print("EVALUATION SUMMARY")
|
| 251 |
+
print("=" * 60)
|
| 252 |
+
print(f" Total items : {agg['total_items']}")
|
| 253 |
+
print(f" Scored items : {agg['scored_items']}")
|
| 254 |
+
print(f" Skipped items : {agg['skipped_items']}")
|
| 255 |
+
print(f" Elapsed time : {round(elapsed, 1)}s")
|
| 256 |
+
print("-" * 60)
|
| 257 |
+
print(f" Avg Factuality Score : {agg['avg_factuality_score']}")
|
| 258 |
+
print(f" Avg Hallucination Score: {agg['avg_hallucination_score']}")
|
| 259 |
+
print(f" Avg Grounding Score : {agg['avg_grounding_score']}")
|
| 260 |
+
print(f" Avg Classifier Score : {agg['avg_classifier_score']}")
|
| 261 |
+
print("-" * 60)
|
| 262 |
+
print(f" Results saved to: {out_path}")
|
| 263 |
+
print("=" * 60)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
if __name__ == "__main__":
|
| 267 |
+
main()
|
code/RL_model/verl/verl_train/reward_func/reward_func/reward_new_v6_bn_v4.py
CHANGED
|
@@ -844,10 +844,10 @@ def compute_score(data_source, solution_str, ground_truth, extra_info=None):
|
|
| 844 |
the generated text; must stay within a level-specific
|
| 845 |
[min, max] range — too little OR too much is penalised.
|
| 846 |
"""
|
| 847 |
-
W_FACTUALITY = 0.
|
| 848 |
-
W_HALLU = 0.
|
| 849 |
-
W_SRC_COV = 0.
|
| 850 |
-
W_CLASSIFIER = 0.
|
| 851 |
W_LENGTH = 0.15
|
| 852 |
|
| 853 |
FAIL = {
|
|
|
|
| 844 |
the generated text; must stay within a level-specific
|
| 845 |
[min, max] range — too little OR too much is penalised.
|
| 846 |
"""
|
| 847 |
+
W_FACTUALITY = 0.25
|
| 848 |
+
W_HALLU = 0.15
|
| 849 |
+
W_SRC_COV = 0.25
|
| 850 |
+
W_CLASSIFIER = 0.20
|
| 851 |
W_LENGTH = 0.15
|
| 852 |
|
| 853 |
FAIL = {
|
code/RL_model/verl/verl_train/reward_func/reward_func/reward_new_v6_bn_v4_rmv_src_cov.py
ADDED
|
@@ -0,0 +1,835 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ast
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import json
|
| 5 |
+
import argparse
|
| 6 |
+
from typing import Any, List, Dict, Optional
|
| 7 |
+
import warnings
|
| 8 |
+
import requests
|
| 9 |
+
warnings.filterwarnings("ignore")
|
| 10 |
+
|
| 11 |
+
# vLLM endpoints (see script/vllm.sh: support-check port 8090, classifier port 8040)
|
| 12 |
+
# Use localhost when reward runs on same machine; set env vars for remote (e.g. http://HOST:8090/v1).
|
| 13 |
+
VLLM_SUPPORT_CHECK_BN_API_BASE = os.getenv(
|
| 14 |
+
"VLLM_SUPPORT_CHECK_BN_API_BASE",
|
| 15 |
+
"http://localhost:8090/v1",
|
| 16 |
+
)
|
| 17 |
+
# Both support-check and classifier use Bangla prompts.
|
| 18 |
+
SUPPORT_CHECK_PROMPT_LANGUAGE = "bn"
|
| 19 |
+
|
| 20 |
+
VLLM_CLASSIFIER_BN_API_BASE = os.getenv(
|
| 21 |
+
"VLLM_CLASSIFIER_BN_API_BASE",
|
| 22 |
+
"http://localhost:8040/v1",
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# Subclaim-extractor vLLM endpoint (Bangla medical text → subclaim list)
|
| 26 |
+
VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE = os.getenv(
|
| 27 |
+
"VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE",
|
| 28 |
+
"http://localhost:8050/v1",
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
# Support check via vLLM (Gemma-3 fine-tuned; same prompt as support_check gemma3-finetune)
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
|
| 36 |
+
def _build_support_list_user_prompt(context: str, subclaims: List[str]) -> str:
|
| 37 |
+
"""Build list-of-subclaims support prompt in Bangla (matches support_check model_finetune/gemma3-finetune.py)."""
|
| 38 |
+
numbered = "\n".join(f"{idx + 1}. {sc}" for idx, sc in enumerate(subclaims))
|
| 39 |
+
return (
|
| 40 |
+
"আপনাকে একটি মেডিকেল কেস বর্ণনা এবং একাধিক সাবক্লেইমের তালিকা দেওয়া হবে। "
|
| 41 |
+
"প্রতিটি সাবক্লেইম টেক্সট দ্বারা সমর্থিত কি না তা নির্ধারণ করুন।\n\n"
|
| 42 |
+
f"টেক্সট:\n{context}\n\n"
|
| 43 |
+
f"সাবক্লেইমগুলোর তালিকা:\n{numbered}\n\n"
|
| 44 |
+
"প্রতিটি সাবক্লেইমের জন্য ক্রমানুসারে লেবেল দিন। "
|
| 45 |
+
"নির্দিষ্টভাবে একটি JSON array আকারে উত্তর দিন, যেমন:\n"
|
| 46 |
+
'["supported", "not_supported", ...]\n'
|
| 47 |
+
"অন্য কিছু লিখবেন না।"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _parse_support_label_array(raw_text: str) -> List[str]:
|
| 52 |
+
"""Parse model output to list of 'supported' | 'not_supported' (matches support_check gemma3-finetune)."""
|
| 53 |
+
text = (raw_text or "").strip()
|
| 54 |
+
if not text:
|
| 55 |
+
return []
|
| 56 |
+
if "```" in text:
|
| 57 |
+
text = text.replace("```json", "").replace("```", "").strip()
|
| 58 |
+
start = text.find("[")
|
| 59 |
+
end = text.rfind("]")
|
| 60 |
+
if start != -1 and end != -1 and end > start:
|
| 61 |
+
text = text[start : end + 1]
|
| 62 |
+
parsed = None
|
| 63 |
+
for parser in (json.loads, ast.literal_eval):
|
| 64 |
+
try:
|
| 65 |
+
parsed = parser(text)
|
| 66 |
+
break
|
| 67 |
+
except Exception:
|
| 68 |
+
continue
|
| 69 |
+
if not isinstance(parsed, list):
|
| 70 |
+
return []
|
| 71 |
+
# import ipdb; ipdb.set_trace()
|
| 72 |
+
normalized = []
|
| 73 |
+
for item in parsed:
|
| 74 |
+
if not isinstance(item, str):
|
| 75 |
+
normalized.append("invalid")
|
| 76 |
+
continue
|
| 77 |
+
|
| 78 |
+
label = item.strip().lower().replace("-", "_").replace(" ", "_")
|
| 79 |
+
# Strict keyword check: if not one of these two, it is invalid
|
| 80 |
+
if label in {"supported", "not_supported"}:
|
| 81 |
+
normalized.append(label)
|
| 82 |
+
else:
|
| 83 |
+
normalized.append("invalid")
|
| 84 |
+
return normalized
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _format_gemma3_for_support(user_message: str) -> str:
|
| 88 |
+
"""Format user message for Gemma-3 chat (vLLM)."""
|
| 89 |
+
return (
|
| 90 |
+
f"<start_of_turn>user\n{user_message}<end_of_turn>\n"
|
| 91 |
+
"<start_of_turn>model\n"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _call_vllm_support_check(prompt: str, max_tokens: int = 512, timeout: float = 120.0) -> Optional[str]:
|
| 96 |
+
"""Call vLLM completions API for support-check model. Returns generated text or None."""
|
| 97 |
+
base = VLLM_SUPPORT_CHECK_BN_API_BASE.rstrip("/")
|
| 98 |
+
url = f"{base}/completions"
|
| 99 |
+
payload = {
|
| 100 |
+
"prompt": prompt,
|
| 101 |
+
"max_tokens": max_tokens,
|
| 102 |
+
"temperature": 0.0,
|
| 103 |
+
"stop": ["<end_of_turn>", "<start_of_turn>", "\n\n"],
|
| 104 |
+
}
|
| 105 |
+
try:
|
| 106 |
+
resp = requests.post(url, json=payload, timeout=timeout)
|
| 107 |
+
resp.raise_for_status()
|
| 108 |
+
data = resp.json()
|
| 109 |
+
choices = data.get("choices")
|
| 110 |
+
# import ipdb; ipdb.set_trace()
|
| 111 |
+
if choices and len(choices) > 0:
|
| 112 |
+
text = choices[0].get("text", "")
|
| 113 |
+
return (text or "").strip()
|
| 114 |
+
return None
|
| 115 |
+
except requests.exceptions.RequestException:
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _call_support_api(
|
| 120 |
+
context: str,
|
| 121 |
+
subclaims: List[str],
|
| 122 |
+
threshold: float = 0.5,
|
| 123 |
+
batch_size: int = 128,
|
| 124 |
+
) -> Optional[List[str]]:
|
| 125 |
+
"""
|
| 126 |
+
Call the BN support-check model via vLLM (Gemma-3 fine-tuned, subclaim_list mode).
|
| 127 |
+
Same prompt as support_check/model_finetune/gemma3-finetune.py.
|
| 128 |
+
|
| 129 |
+
Returns
|
| 130 |
+
-------
|
| 131 |
+
List[str] : one label per subclaim — "supported" | "not_supported" | "invalid".
|
| 132 |
+
None : on total failure (network or empty/unparseable response).
|
| 133 |
+
"""
|
| 134 |
+
if not context or not subclaims:
|
| 135 |
+
return ["invalid"] * len(subclaims)
|
| 136 |
+
|
| 137 |
+
user_prompt = _build_support_list_user_prompt(context, subclaims)
|
| 138 |
+
prompt = _format_gemma3_for_support(user_prompt)
|
| 139 |
+
raw = _call_vllm_support_check(prompt)
|
| 140 |
+
# import ipdb; ipdb.set_trace()
|
| 141 |
+
if raw is None:
|
| 142 |
+
print("Warning: Support check vLLM call failed (returning None).")
|
| 143 |
+
return None
|
| 144 |
+
|
| 145 |
+
labels = _parse_support_label_array(raw)
|
| 146 |
+
n = len(subclaims)
|
| 147 |
+
if len(labels) < n:
|
| 148 |
+
labels = labels + ["invalid"] * (n - len(labels))
|
| 149 |
+
elif len(labels) > n:
|
| 150 |
+
labels = labels[:n]
|
| 151 |
+
# import ipdb; ipdb.set_trace()
|
| 152 |
+
return labels
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# ---------------------------------------------------------------------------
|
| 156 |
+
# Subclaim extractor (Bangla, vLLM) + sentence splitter fallback
|
| 157 |
+
# ---------------------------------------------------------------------------
|
| 158 |
+
|
| 159 |
+
# Minimum character length for a sentence to be considered a real unit.
|
| 160 |
+
# Used only as a fallback when subclaim extraction is unavailable.
|
| 161 |
+
MIN_SENTENCE_CHARS = 15
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def _is_bangla_text(text: str, min_bangla_ratio: float = 0.6) -> bool:
|
| 165 |
+
"""
|
| 166 |
+
Heuristic check: returns True if the majority of alphabetic characters
|
| 167 |
+
in `text` are Bangla (Unicode block \u0980–\u09FF).
|
| 168 |
+
"""
|
| 169 |
+
if not text:
|
| 170 |
+
return False
|
| 171 |
+
bangla_chars = 0
|
| 172 |
+
alpha_chars = 0
|
| 173 |
+
for ch in text:
|
| 174 |
+
if ch.isalpha():
|
| 175 |
+
alpha_chars += 1
|
| 176 |
+
if "\u0980" <= ch <= "\u09FF":
|
| 177 |
+
bangla_chars += 1
|
| 178 |
+
if alpha_chars == 0:
|
| 179 |
+
return False
|
| 180 |
+
return (bangla_chars / alpha_chars) >= min_bangla_ratio
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _split_into_sentences(text: str, min_chars: int = MIN_SENTENCE_CHARS) -> List[str]:
|
| 184 |
+
"""
|
| 185 |
+
Split text into sentences at [.!?] boundaries.
|
| 186 |
+
Segments shorter than `min_chars` characters are dropped to
|
| 187 |
+
prevent micro-fragment padding from gaming ratio-based scores.
|
| 188 |
+
"""
|
| 189 |
+
if not text or not text.strip():
|
| 190 |
+
return []
|
| 191 |
+
parts = re.split(r"(?<=[.!?\u0964])\s+", text.strip())
|
| 192 |
+
return [s.strip() for s in parts if len(s.strip()) >= min_chars]
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def _build_subclaim_extraction_prompt(medical_text: str) -> str:
|
| 196 |
+
"""
|
| 197 |
+
Bangla subclaim-extraction prompt (same wording as `extract_bn_subclaims_vllm.py`,
|
| 198 |
+
generalized to "medical text" so it works for any generated explanation).
|
| 199 |
+
"""
|
| 200 |
+
return f"""
|
| 201 |
+
You are an expert medical annotator. The following text is in Bangla (Bengali).
|
| 202 |
+
|
| 203 |
+
Your task is to extract granular, factual subclaims from the provided medical text.
|
| 204 |
+
A subclaim is the smallest standalone factual unit that can be independently verified.
|
| 205 |
+
|
| 206 |
+
Instructions:
|
| 207 |
+
1. Read the Bangla medical text carefully.
|
| 208 |
+
2. Extract factual statements explicitly stated in the text.
|
| 209 |
+
3. Each subclaim must:
|
| 210 |
+
- Be in Bangla (same language as the input)
|
| 211 |
+
- Contain exactly ONE factual assertion
|
| 212 |
+
- Come directly from the text (no inference or interpretation)
|
| 213 |
+
- Preserve original wording as much as possible
|
| 214 |
+
- Include any negation, uncertainty, or qualifier
|
| 215 |
+
4. Do NOT:
|
| 216 |
+
- Combine multiple facts into one subclaim
|
| 217 |
+
- Add new information
|
| 218 |
+
- Translate to another language
|
| 219 |
+
5. Return ONLY a valid JSON array of strings.
|
| 220 |
+
6. Use double quotes and valid JSON formatting only (no markdown, no commentary).
|
| 221 |
+
|
| 222 |
+
Medical Text (Bangla):
|
| 223 |
+
{medical_text}
|
| 224 |
+
|
| 225 |
+
Return format:
|
| 226 |
+
[
|
| 227 |
+
"subclaim 1",
|
| 228 |
+
"subclaim 2"
|
| 229 |
+
]
|
| 230 |
+
""".strip()
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def _strip_markdown_json_block(text: str) -> str:
|
| 234 |
+
"""Strip optional markdown code fence (e.g. ```json\\n[...]\\n```), if present."""
|
| 235 |
+
text = (text or "").strip()
|
| 236 |
+
if not text:
|
| 237 |
+
return ""
|
| 238 |
+
if text.startswith("```json"):
|
| 239 |
+
text = text[7:].lstrip("\n")
|
| 240 |
+
elif text.startswith("```"):
|
| 241 |
+
text = text[3:].lstrip("\n")
|
| 242 |
+
if text.endswith("```"):
|
| 243 |
+
text = text[:-3].rstrip("\n")
|
| 244 |
+
return text.strip()
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def _parse_subclaim_list_output(output_text: str) -> List[str]:
|
| 248 |
+
"""Parse subclaim-extractor model output into a list of Bangla subclaims."""
|
| 249 |
+
output_text = (output_text or "").strip()
|
| 250 |
+
if not output_text:
|
| 251 |
+
return []
|
| 252 |
+
|
| 253 |
+
if "</think>" in output_text:
|
| 254 |
+
output_text = output_text.split("</think>")[-1].strip()
|
| 255 |
+
|
| 256 |
+
output_text = _strip_markdown_json_block(output_text)
|
| 257 |
+
|
| 258 |
+
start_idx = output_text.find("[")
|
| 259 |
+
end_idx = output_text.rfind("]") + 1
|
| 260 |
+
if start_idx != -1 and end_idx > start_idx:
|
| 261 |
+
content = output_text[start_idx:end_idx]
|
| 262 |
+
parsed = json.loads(content)
|
| 263 |
+
if isinstance(parsed, list):
|
| 264 |
+
return [str(s).strip() for s in parsed if str(s).strip()]
|
| 265 |
+
|
| 266 |
+
raise ValueError("Incomplete or invalid JSON list")
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def _call_vllm_subclaim_extractor(
|
| 270 |
+
text: str,
|
| 271 |
+
max_tokens: int = 2048,
|
| 272 |
+
temperature: float = 0.2,
|
| 273 |
+
timeout: float = 120.0,
|
| 274 |
+
) -> Optional[List[str]]:
|
| 275 |
+
"""
|
| 276 |
+
Call Bangla subclaim-extractor model via vLLM (OpenAI /chat/completions).
|
| 277 |
+
|
| 278 |
+
Returns a list of subclaims on success, or None on total failure.
|
| 279 |
+
"""
|
| 280 |
+
if not text or not text.strip():
|
| 281 |
+
return []
|
| 282 |
+
|
| 283 |
+
base = VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE.rstrip("/")
|
| 284 |
+
url = f"{base}/chat/completions"
|
| 285 |
+
|
| 286 |
+
prompt = _build_subclaim_extraction_prompt(text)
|
| 287 |
+
payload = {
|
| 288 |
+
"model": os.getenv("VLLM_SUBCLAIM_EXTRACTOR_MODEL_NAME", "subclaim-extractor"),
|
| 289 |
+
"messages": [{"role": "user", "content": prompt}],
|
| 290 |
+
"temperature": temperature,
|
| 291 |
+
"max_tokens": max_tokens,
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
try:
|
| 295 |
+
resp = requests.post(url, json=payload, timeout=timeout)
|
| 296 |
+
resp.raise_for_status()
|
| 297 |
+
data = resp.json()
|
| 298 |
+
choices = data.get("choices") or []
|
| 299 |
+
if not choices:
|
| 300 |
+
return None
|
| 301 |
+
content = (choices[0].get("message", {}) or {}).get("content", "") or ""
|
| 302 |
+
# import ipdb; ipdb.set_trace()
|
| 303 |
+
return _parse_subclaim_list_output(content)
|
| 304 |
+
except Exception:
|
| 305 |
+
return None
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def _extract_subclaims_from_text(text: str) -> List[str]:
|
| 309 |
+
"""
|
| 310 |
+
Extract Bangla subclaims from generated text using the vLLM subclaim-extractor.
|
| 311 |
+
|
| 312 |
+
On failure (e.g., server down or parse error), falls back to sentence splitting
|
| 313 |
+
so the rest of the reward logic can still operate.
|
| 314 |
+
"""
|
| 315 |
+
subclaims = _call_vllm_subclaim_extractor(text)
|
| 316 |
+
if subclaims is None:
|
| 317 |
+
# Fallback: keep system running even if extractor is unavailable.
|
| 318 |
+
return _split_into_sentences(text)
|
| 319 |
+
return subclaims
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
# ---------------------------------------------------------------------------
|
| 323 |
+
# Two reward signals:
|
| 324 |
+
# 1. Factuality — summary subclaims vs gen_text (how much summary info is in gen_text)
|
| 325 |
+
# 2. Hallucination — gen_segments vs fulltext (how much gen info is NOT in fulltext)
|
| 326 |
+
# ---------------------------------------------------------------------------
|
| 327 |
+
|
| 328 |
+
def compute_rewards(
|
| 329 |
+
fulltext: str,
|
| 330 |
+
generated_text: str,
|
| 331 |
+
target_level: str,
|
| 332 |
+
summary_subclaims: Optional[List[str]] = None,
|
| 333 |
+
summary_text: Optional[str] = None,
|
| 334 |
+
threshold: float = 0.5,
|
| 335 |
+
batch_size: int = 128,
|
| 336 |
+
) -> Dict[str, Optional[float]]:
|
| 337 |
+
"""
|
| 338 |
+
Compute two independent reward signals.
|
| 339 |
+
|
| 340 |
+
1. **Factuality** (summary_subclaims → gen_text):
|
| 341 |
+
Use pre-extracted *summary_subclaims*, check how many are supported
|
| 342 |
+
by the generated text. Measures "how much of the summary's information
|
| 343 |
+
made it into the output".
|
| 344 |
+
|
| 345 |
+
2. **Hallucination** (gen_segments → fulltext):
|
| 346 |
+
Extract subclaims from the *generated text* (gen_segments), then check
|
| 347 |
+
how many are supported by the source fulltext. The *unsupported*
|
| 348 |
+
fraction is the hallucination score (lower is better).
|
| 349 |
+
|
| 350 |
+
Returns dict with:
|
| 351 |
+
factuality_score : [0,1] fraction of summary subclaims supported by gen_text
|
| 352 |
+
factuality_supported : int count
|
| 353 |
+
total_summary_subclaims : int
|
| 354 |
+
hallucination_score : [0,1] fraction of gen_segments NOT supported by fulltext
|
| 355 |
+
hallucination_supported : int count of gen_segments supported by fulltext
|
| 356 |
+
total_gen_segments : int
|
| 357 |
+
"""
|
| 358 |
+
result: Dict[str, Any] = {
|
| 359 |
+
"factuality_score": None,
|
| 360 |
+
"factuality_supported": 0,
|
| 361 |
+
"total_summary_subclaims": 0,
|
| 362 |
+
"hallucination_score": None,
|
| 363 |
+
"hallucination_supported": 0,
|
| 364 |
+
"total_gen_segments": 0,
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
gen_segments = _extract_subclaims_from_text(generated_text)
|
| 368 |
+
|
| 369 |
+
if not gen_segments:
|
| 370 |
+
result.update({
|
| 371 |
+
"hallucination_score": 0.0,
|
| 372 |
+
"factuality_score": 0.0,
|
| 373 |
+
})
|
| 374 |
+
return result
|
| 375 |
+
|
| 376 |
+
total_gen = len(gen_segments)
|
| 377 |
+
result["total_gen_segments"] = total_gen
|
| 378 |
+
|
| 379 |
+
# =====================================================================
|
| 380 |
+
# 1. FACTUALITY — summary subclaims checked against gen_text
|
| 381 |
+
# "How much information from the summary exists in the generated text?"
|
| 382 |
+
# =====================================================================
|
| 383 |
+
factuality_score = None
|
| 384 |
+
if summary_subclaims and len(summary_subclaims) > 0:
|
| 385 |
+
result["total_summary_subclaims"] = len(summary_subclaims)
|
| 386 |
+
|
| 387 |
+
labels_summary_vs_gen = _call_support_api(
|
| 388 |
+
context=generated_text,
|
| 389 |
+
subclaims=summary_subclaims,
|
| 390 |
+
threshold=threshold,
|
| 391 |
+
batch_size=batch_size,
|
| 392 |
+
)
|
| 393 |
+
if labels_summary_vs_gen is not None:
|
| 394 |
+
valid = [l for l in labels_summary_vs_gen if str(l).strip().lower() != "invalid"]
|
| 395 |
+
if valid:
|
| 396 |
+
sup = sum(1 for l in valid if str(l).strip().lower() == "supported")
|
| 397 |
+
factuality_score = sup / len(summary_subclaims)
|
| 398 |
+
result["factuality_supported"] = sup
|
| 399 |
+
else:
|
| 400 |
+
factuality_score = 0.0
|
| 401 |
+
|
| 402 |
+
result["factuality_score"] = factuality_score
|
| 403 |
+
|
| 404 |
+
# =====================================================================
|
| 405 |
+
# 2. HALLUCINATION — gen_segments checked against fulltext
|
| 406 |
+
# "How much info in gen_segments is NOT supported by the fulltext?"
|
| 407 |
+
# =====================================================================
|
| 408 |
+
hallucination_score = None
|
| 409 |
+
if fulltext and fulltext.strip():
|
| 410 |
+
labels_gen_vs_full = _call_support_api(
|
| 411 |
+
context=fulltext,
|
| 412 |
+
subclaims=gen_segments,
|
| 413 |
+
threshold=threshold,
|
| 414 |
+
batch_size=batch_size,
|
| 415 |
+
)
|
| 416 |
+
if labels_gen_vs_full is not None and len(labels_gen_vs_full) > 0:
|
| 417 |
+
sup_full = sum(
|
| 418 |
+
1 for l in labels_gen_vs_full
|
| 419 |
+
if str(l).strip().lower() == "supported"
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
unsupported_indices = [
|
| 423 |
+
i for i, l in enumerate(labels_gen_vs_full)
|
| 424 |
+
if str(l).strip().lower() != "supported"
|
| 425 |
+
]
|
| 426 |
+
|
| 427 |
+
if unsupported_indices and summary_text and summary_text.strip():
|
| 428 |
+
unsup_segments = [gen_segments[i] for i in unsupported_indices]
|
| 429 |
+
rescue_labels = _call_support_api(
|
| 430 |
+
context=summary_text,
|
| 431 |
+
subclaims=unsup_segments,
|
| 432 |
+
threshold=threshold,
|
| 433 |
+
batch_size=batch_size,
|
| 434 |
+
)
|
| 435 |
+
if rescue_labels:
|
| 436 |
+
rescued = sum(
|
| 437 |
+
1 for l in rescue_labels
|
| 438 |
+
if str(l).strip().lower() == "supported"
|
| 439 |
+
)
|
| 440 |
+
sup_full += rescued
|
| 441 |
+
|
| 442 |
+
hallucination_score = max(0.0, (total_gen - sup_full) / total_gen)
|
| 443 |
+
result["hallucination_supported"] = sup_full
|
| 444 |
+
else:
|
| 445 |
+
hallucination_score = 0.0
|
| 446 |
+
|
| 447 |
+
result["hallucination_score"] = hallucination_score
|
| 448 |
+
|
| 449 |
+
return result
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
# ---------------------------------------------------------------------------
|
| 453 |
+
# BN health-literacy classifier via vLLM (Gemma-3 fine-tuned model)
|
| 454 |
+
# Uses Bangla prompt; model is assumed running in vLLM.
|
| 455 |
+
# ---------------------------------------------------------------------------
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def build_classification_user_prompt(fulltext: str, gen_text: str) -> str:
|
| 459 |
+
"""Build the classification user prompt in English (matches gemma3-finetune.py). Full text is reference; generated text is what to classify."""
|
| 460 |
+
return (
|
| 461 |
+
"You will be given a medical case description as reference (full text) and a generated text to classify. "
|
| 462 |
+
"Determine the patient's health literacy level based only on the generated text.\n\n"
|
| 463 |
+
f"Reference (full text):\n{fulltext}\n\n"
|
| 464 |
+
f"Generated text (to classify):\n{gen_text}\n\n"
|
| 465 |
+
"Reply with exactly one label from this set:\n"
|
| 466 |
+
"low_health_literacy, intermediate_health_literacy, proficient_health_literacy"
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def format_gemma3_prompt(user_message: str) -> str:
|
| 471 |
+
"""Format user message for Gemma-3 chat (vLLM expects this)."""
|
| 472 |
+
return (
|
| 473 |
+
f"<start_of_turn>user\n{user_message}<end_of_turn>\n"
|
| 474 |
+
"<start_of_turn>model\n"
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def _call_vllm_classifier(prompt: str, max_tokens: int = 64, timeout: float = 60.0) -> Optional[str]:
|
| 479 |
+
"""
|
| 480 |
+
Call vLLM completions API. Returns generated text or None on failure.
|
| 481 |
+
"""
|
| 482 |
+
url = f"{VLLM_CLASSIFIER_BN_API_BASE.rstrip('/')}/completions"
|
| 483 |
+
payload = {
|
| 484 |
+
"prompt": prompt,
|
| 485 |
+
"max_tokens": max_tokens,
|
| 486 |
+
"temperature": 0.0,
|
| 487 |
+
"stop": ["<end_of_turn>", "<start_of_turn>", "\n\n","<eos>"],
|
| 488 |
+
}
|
| 489 |
+
try:
|
| 490 |
+
resp = requests.post(url, json=payload, timeout=timeout)
|
| 491 |
+
resp.raise_for_status()
|
| 492 |
+
data = resp.json()
|
| 493 |
+
choices = data.get("choices")
|
| 494 |
+
# import ipdb; ipdb.set_trace()
|
| 495 |
+
if choices and len(choices) > 0:
|
| 496 |
+
text = choices[0].get("text", "")
|
| 497 |
+
return (text or "").strip()
|
| 498 |
+
return None
|
| 499 |
+
except requests.exceptions.RequestException as exc:
|
| 500 |
+
return None
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def _parse_classifier_output(raw: str) -> str:
|
| 504 |
+
"""
|
| 505 |
+
Extract health literacy label from model output. Normalize to
|
| 506 |
+
low_health_literacy | intermediate_health_literacy | proficient_health_literacy.
|
| 507 |
+
Returns empty string if no valid label found.
|
| 508 |
+
"""
|
| 509 |
+
if not raw:
|
| 510 |
+
return ""
|
| 511 |
+
raw = raw.strip().lower()
|
| 512 |
+
# Take first line and clean
|
| 513 |
+
first_line = raw.split("\n")[0].strip()
|
| 514 |
+
for label in ["low_health_literacy", "intermediate_health_literacy", "proficient_health_literacy"]:
|
| 515 |
+
if label in first_line or label in raw:
|
| 516 |
+
# import ipdb; ipdb.set_trace()
|
| 517 |
+
return label
|
| 518 |
+
return ""
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
_CLASSIFIER_ERROR_LOGGED = False
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def _predict_label(input_text: str, generated_text: str) -> str:
|
| 525 |
+
"""
|
| 526 |
+
Run BN health-literacy classifier via vLLM (Gemma-3 fine-tuned).
|
| 527 |
+
Uses fulltext=input_text and gen_text=generated_text; returns normalized label or "".
|
| 528 |
+
"""
|
| 529 |
+
global _CLASSIFIER_ERROR_LOGGED
|
| 530 |
+
try:
|
| 531 |
+
user_prompt = build_classification_user_prompt(input_text or "", generated_text or "")
|
| 532 |
+
prompt = format_gemma3_prompt(user_prompt)
|
| 533 |
+
raw = _call_vllm_classifier(prompt)
|
| 534 |
+
# import ipdb; ipdb.set_trace()
|
| 535 |
+
if raw is None:
|
| 536 |
+
if not _CLASSIFIER_ERROR_LOGGED:
|
| 537 |
+
print("Warning: BN classifier vLLM call failed, continuing without it.")
|
| 538 |
+
_CLASSIFIER_ERROR_LOGGED = True
|
| 539 |
+
return ""
|
| 540 |
+
return _parse_classifier_output(raw)
|
| 541 |
+
except Exception as exc:
|
| 542 |
+
if not _CLASSIFIER_ERROR_LOGGED:
|
| 543 |
+
print(f"Warning: BN literacy classifier unavailable, continuing without it: {exc}")
|
| 544 |
+
_CLASSIFIER_ERROR_LOGGED = True
|
| 545 |
+
return ""
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
def _parse_solution_json(solution_str):
|
| 549 |
+
if isinstance(solution_str, (dict, list)):
|
| 550 |
+
return solution_str
|
| 551 |
+
try:
|
| 552 |
+
cleaned_str = str(solution_str).strip()
|
| 553 |
+
if "```json" in cleaned_str:
|
| 554 |
+
cleaned_str = cleaned_str.split("```json")[1].split("```")[0].strip()
|
| 555 |
+
elif "```" in cleaned_str:
|
| 556 |
+
cleaned_str = cleaned_str.split("```")[1].split("```")[0].strip()
|
| 557 |
+
return json.loads(cleaned_str)
|
| 558 |
+
except Exception:
|
| 559 |
+
return None
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
def _compute_classifier_reward(target_level: str, gen_text: str, input_text: str) -> float:
|
| 563 |
+
"""
|
| 564 |
+
Soft classifier score in [0, 1] (NOT binary +1/-1).
|
| 565 |
+
|
| 566 |
+
1.0 — predicted label matches target level (correct style)
|
| 567 |
+
0.0 — predicted label does not match (wrong style)
|
| 568 |
+
0.5 — classifier unavailable; neutral / no signal
|
| 569 |
+
|
| 570 |
+
Uses BN classifier via vLLM (Gemma-3); needs input_text (fulltext) and gen_text.
|
| 571 |
+
"""
|
| 572 |
+
result = _predict_label(input_text, gen_text)
|
| 573 |
+
if result == "": # unavailable → neutral, no penalty
|
| 574 |
+
return 0.5
|
| 575 |
+
if result.strip().lower() == target_level.strip().lower():
|
| 576 |
+
# import ipdb; ipdb.set_trace()
|
| 577 |
+
return 1.0 # correct literacy style
|
| 578 |
+
return 0.0 # wrong literacy style (penalty-free cliff avoided)
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
# ---------------------------------------------------------------------------
|
| 582 |
+
# Copy-paste penalty (prevent trivial copy of input_text)
|
| 583 |
+
# ---------------------------------------------------------------------------
|
| 584 |
+
|
| 585 |
+
def _approx_copy_ratio(input_text: str, gen_text: str) -> float:
|
| 586 |
+
"""
|
| 587 |
+
Rough similarity estimate between input and generated text.
|
| 588 |
+
|
| 589 |
+
- Detects near-verbatim copy via substring + length ratio.
|
| 590 |
+
- Otherwise uses token overlap (gen tokens that also appear in input).
|
| 591 |
+
Returns value in [0, 1], where 1 ≈ almost exact copy.
|
| 592 |
+
"""
|
| 593 |
+
a = (input_text or "").strip()
|
| 594 |
+
b = (gen_text or "").strip()
|
| 595 |
+
if not a or not b:
|
| 596 |
+
return 0.0
|
| 597 |
+
|
| 598 |
+
len_a, len_b = len(a), len(b)
|
| 599 |
+
shorter, longer = (a, b) if len_a <= len_b else (b, a)
|
| 600 |
+
|
| 601 |
+
# Near-verbatim copy: one string almost fully contained in the other.
|
| 602 |
+
if shorter and shorter in longer:
|
| 603 |
+
ratio = len(shorter) / max(1, len(longer))
|
| 604 |
+
if ratio >= 0.9:
|
| 605 |
+
return 1.0
|
| 606 |
+
|
| 607 |
+
# Fallback: 3-gram (trigram) token overlap to reduce false positives
|
| 608 |
+
# from shared medical vocabulary (drug names, symptoms, etc.).
|
| 609 |
+
def _tokens(t: str):
|
| 610 |
+
return [tok for tok in re.split(r"\s+", t) if tok]
|
| 611 |
+
|
| 612 |
+
def _shingles(tokens, n=3):
|
| 613 |
+
if len(tokens) < n:
|
| 614 |
+
return set()
|
| 615 |
+
return {" ".join(tokens[i : i + n]) for i in range(len(tokens) - n + 1)}
|
| 616 |
+
|
| 617 |
+
toks_a = _tokens(a)
|
| 618 |
+
toks_b = _tokens(b)
|
| 619 |
+
if not toks_a or not toks_b:
|
| 620 |
+
return 0.0
|
| 621 |
+
|
| 622 |
+
sh_a = _shingles(toks_a, n=3)
|
| 623 |
+
sh_b = _shingles(toks_b, n=3)
|
| 624 |
+
if not sh_a or not sh_b:
|
| 625 |
+
return 0.0
|
| 626 |
+
|
| 627 |
+
overlap = len(sh_a & sh_b) / max(1, len(sh_b))
|
| 628 |
+
return max(0.0, min(1.0, overlap))
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
def _compute_copy_penalty(input_text: str, gen_text: str) -> float:
|
| 632 |
+
"""
|
| 633 |
+
Map copy ratio → penalty in [0, 1].
|
| 634 |
+
|
| 635 |
+
- ≤ 0.7 similarity → no penalty
|
| 636 |
+
- 0.7–1.0 → linearly ramp penalty up to 1.0
|
| 637 |
+
"""
|
| 638 |
+
ratio = _approx_copy_ratio(input_text, gen_text)
|
| 639 |
+
if ratio <= 0.7:
|
| 640 |
+
return 0.0
|
| 641 |
+
# Scale [0.7, 1.0] → [0, 1]
|
| 642 |
+
return max(0.0, min(1.0, (ratio - 0.7) / 0.3))
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
# ---------------------------------------------------------------------------
|
| 646 |
+
# Main scoring function
|
| 647 |
+
# ---------------------------------------------------------------------------
|
| 648 |
+
def _nonlinear_grounding(h_score: float) -> float:
|
| 649 |
+
"""
|
| 650 |
+
Sharper penalty for hallucination.
|
| 651 |
+
|
| 652 |
+
h_score=0.00 → 1.00 (perfect)
|
| 653 |
+
h_score=0.05 → 0.95 (mild)
|
| 654 |
+
h_score=0.10 → 0.82 (noticeable)
|
| 655 |
+
h_score=0.17 → 0.65 (significant — was 0.83 before!)
|
| 656 |
+
h_score=0.30 → 0.36 (harsh)
|
| 657 |
+
h_score=0.50 → 0.13 (near zero)
|
| 658 |
+
"""
|
| 659 |
+
return max(0.0, (1.0 - h_score) ** 2.5)
|
| 660 |
+
def compute_score(data_source, solution_str, ground_truth, extra_info=None):
|
| 661 |
+
"""
|
| 662 |
+
Reward = weighted sum of three components (all in [0, 1]):
|
| 663 |
+
|
| 664 |
+
W_FACTUALITY × factuality_score (summary info present in gen_text)
|
| 665 |
+
W_HALLU × (1 - hallucination_score) (gen_segments grounded in fulltext)
|
| 666 |
+
W_CLASSIFIER × classifier_score (style match)
|
| 667 |
+
|
| 668 |
+
1. Factuality : extract subclaims from *summary*, check how many are
|
| 669 |
+
supported by the generated text.
|
| 670 |
+
2. Hallucination: extract subclaims from *generated text*, check how many
|
| 671 |
+
are NOT supported by the fulltext.
|
| 672 |
+
"""
|
| 673 |
+
W_FACTUALITY = 0.40
|
| 674 |
+
W_HALLU = 0.25
|
| 675 |
+
W_CLASSIFIER = 0.35
|
| 676 |
+
|
| 677 |
+
FAIL = {
|
| 678 |
+
"score": -1.0,
|
| 679 |
+
"factuality_score": 0.0,
|
| 680 |
+
"hallucination_score": 0.0,
|
| 681 |
+
"classifier_score": 0.0,
|
| 682 |
+
"factuality_supported": 0,
|
| 683 |
+
"hallucination_supported": 0,
|
| 684 |
+
"total_gen_segments": 0,
|
| 685 |
+
}
|
| 686 |
+
|
| 687 |
+
# 1. Parse & validate
|
| 688 |
+
data = _parse_solution_json(solution_str)
|
| 689 |
+
if not data:
|
| 690 |
+
return FAIL
|
| 691 |
+
|
| 692 |
+
target_level = extra_info.get("target_level") if extra_info else None
|
| 693 |
+
gen_text = data.get(target_level, "") if target_level else ""
|
| 694 |
+
|
| 695 |
+
if not gen_text or len(gen_text.strip()) < 10:
|
| 696 |
+
return FAIL
|
| 697 |
+
|
| 698 |
+
if not _is_bangla_text(gen_text):
|
| 699 |
+
return FAIL
|
| 700 |
+
|
| 701 |
+
fulltext = ground_truth.get("fulltext") or ground_truth.get("input_text", "")
|
| 702 |
+
input_text = ground_truth.get("input_text", "")
|
| 703 |
+
summary_subclaims = ground_truth.get("summary_subclaims")
|
| 704 |
+
summary_text = ground_truth.get("summary_text", "")
|
| 705 |
+
|
| 706 |
+
# 2. Compute the two core rewards
|
| 707 |
+
rewards = compute_rewards(
|
| 708 |
+
fulltext=fulltext,
|
| 709 |
+
generated_text=gen_text,
|
| 710 |
+
target_level=target_level,
|
| 711 |
+
summary_subclaims=summary_subclaims,
|
| 712 |
+
summary_text=summary_text,
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
factuality_score = rewards["factuality_score"]
|
| 716 |
+
h_score = rewards["hallucination_score"]
|
| 717 |
+
total_gen_units = rewards.get("total_gen_segments", 0)
|
| 718 |
+
|
| 719 |
+
if factuality_score is None:
|
| 720 |
+
factuality_score = 0.5
|
| 721 |
+
if h_score is None:
|
| 722 |
+
h_score = 0.5
|
| 723 |
+
|
| 724 |
+
grounding_score = _nonlinear_grounding(h_score)
|
| 725 |
+
|
| 726 |
+
# 3. Classifier (style match)
|
| 727 |
+
class_score = _compute_classifier_reward(target_level, gen_text, input_text)
|
| 728 |
+
|
| 729 |
+
# 4. Final weighted sum
|
| 730 |
+
final_reward = (
|
| 731 |
+
W_FACTUALITY * factuality_score +
|
| 732 |
+
W_HALLU * grounding_score +
|
| 733 |
+
W_CLASSIFIER * class_score
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
# 5. Copy-paste penalty
|
| 737 |
+
copy_penalty = _compute_copy_penalty(input_text, gen_text)
|
| 738 |
+
if copy_penalty > 0.0:
|
| 739 |
+
final_reward = max(0.0, final_reward * (1.0 - copy_penalty))
|
| 740 |
+
|
| 741 |
+
return {
|
| 742 |
+
"score": float(final_reward),
|
| 743 |
+
"factuality_score": float(factuality_score),
|
| 744 |
+
"hallucination_score": float(h_score),
|
| 745 |
+
"classifier_score": float(class_score),
|
| 746 |
+
"factuality_supported": int(rewards.get("factuality_supported", 0)),
|
| 747 |
+
"hallucination_supported": int(rewards.get("hallucination_supported", 0)),
|
| 748 |
+
"total_gen_segments": int(total_gen_units),
|
| 749 |
+
}
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
# ---------------------------------------------------------------------------
|
| 753 |
+
# Test mode
|
| 754 |
+
# ---------------------------------------------------------------------------
|
| 755 |
+
|
| 756 |
+
test_mode = True
|
| 757 |
+
if test_mode:
|
| 758 |
+
import time
|
| 759 |
+
|
| 760 |
+
def run_actual_api_test():
|
| 761 |
+
# Bangla medical example (support-check and classifier use Bangla prompts)
|
| 762 |
+
ground_truth = {
|
| 763 |
+
"summary_subclaims": [
|
| 764 |
+
"লিসিনোপ্রিল উচ্চ রক্তচাপের চিকিৎসায় ব্যবহৃত হয়।",
|
| 765 |
+
"এটি ACE ইনহিবিটর যা আপনার হৃদযন্ত্রকে ভালো কাজ করতে সাহায্য করে।",
|
| 766 |
+
"সাধারণ পার্শ্বপ্রতিক্রিয়ার মধ্যে শুষ্ক কাশি রয়েছে।",
|
| 767 |
+
"গর্ভবতী হলে ব্যবহার করবেন না।",
|
| 768 |
+
],
|
| 769 |
+
"input_text": (
|
| 770 |
+
"লিসিনোপ্রিল উচ্চ রক্তচাপের চিকিৎসায় ব্যবহৃত হয়। "
|
| 771 |
+
"এটি ACE ইনহিবিটর নামক ওষুধ। "
|
| 772 |
+
"এটি আপনার হৃদযন্ত্রকে ভালো কাজ করতে সাহায্য করে।"
|
| 773 |
+
),
|
| 774 |
+
"summary_text": (
|
| 775 |
+
"লিসিনোপ্রিল উচ্চ রক্তচাপের চিকিৎসায় ব্যবহৃত হয়। "
|
| 776 |
+
"এটি ACE ইনহিবিটর যা আপনার হৃদযন্ত্রকে ভালো কাজ করতে সাহায্য করে। "
|
| 777 |
+
"সাধারণ পার্শ্বপ্রতিক্রিয়ার মধ্যে শুষ্ক কাশি রয়েছে। "
|
| 778 |
+
"গর্ভবতী হলে ব্যবহার করবেন না।"
|
| 779 |
+
),
|
| 780 |
+
}
|
| 781 |
+
|
| 782 |
+
# LLM output: low_health_literacy style, grounded in summary
|
| 783 |
+
generated_response = {
|
| 784 |
+
"low_health_literacy": (
|
| 785 |
+
"এই ওষুধ আপনার উচ্চ রক্তচাপের জন্য। "
|
| 786 |
+
"এটি ACE ইনহিবিটর ধরনের ওষুধ। "
|
| 787 |
+
"এটি আপনার ��ৃদযন্ত্রকে ভালো কাজ করতে সাহায্য করে। "
|
| 788 |
+
"গর্ভবতী হলে এই ওষুধ খাবেন না।"
|
| 789 |
+
)
|
| 790 |
+
}
|
| 791 |
+
|
| 792 |
+
solution_str = f"```json\n{json.dumps(generated_response)}\n```"
|
| 793 |
+
extra_info = {"target_level": "low_health_literacy"}
|
| 794 |
+
|
| 795 |
+
print("📡 Running BN reward test (Bangla example)...")
|
| 796 |
+
start_time = time.time()
|
| 797 |
+
|
| 798 |
+
try:
|
| 799 |
+
score = compute_score(
|
| 800 |
+
data_source="real_api_test",
|
| 801 |
+
solution_str=solution_str,
|
| 802 |
+
ground_truth=ground_truth,
|
| 803 |
+
extra_info=extra_info,
|
| 804 |
+
)
|
| 805 |
+
|
| 806 |
+
final_score = score["score"] if isinstance(score, dict) else score
|
| 807 |
+
|
| 808 |
+
duration = time.time() - start_time
|
| 809 |
+
print(f"\nAPI Call Successful ({round(duration, 2)}s)")
|
| 810 |
+
print("-" * 50)
|
| 811 |
+
print(f"Target Level : {extra_info['target_level']}")
|
| 812 |
+
print(f"Final Reward : {round(final_score, 4)}")
|
| 813 |
+
print(f"factuality_score : {round(score.get('factuality_score', 0), 4)} (summary subclaims in gen_text)")
|
| 814 |
+
print(f"hallucination_score : {round(score.get('hallucination_score', 0), 4)} (gen_segments NOT in fulltext)")
|
| 815 |
+
print(f"classifier_score : {round(score.get('classifier_score', 0), 4)}")
|
| 816 |
+
print(f"factuality_supported : {score.get('factuality_supported', 0)}")
|
| 817 |
+
print(f"hallucination_supported: {score.get('hallucination_supported', 0)}")
|
| 818 |
+
print(f"total_gen_segments : {score.get('total_gen_segments', 0)}")
|
| 819 |
+
print("-" * 50)
|
| 820 |
+
print("\nReward definitions:")
|
| 821 |
+
print("- factuality_score : fraction of *summary* subclaims supported by gen_text [0,1]")
|
| 822 |
+
print("- hallucination_score : fraction of *gen_segments* NOT supported by fulltext [0,1] (lower=better)")
|
| 823 |
+
print("- classifier_score : 1.0 match, 0.0 mismatch, 0.5 unavailable")
|
| 824 |
+
print("- Weights: factuality=0.35, grounding=0.30, classifier=0.35")
|
| 825 |
+
|
| 826 |
+
except Exception as e:
|
| 827 |
+
print(f"\n❌ API Call Failed!")
|
| 828 |
+
print(f"Error Type: {type(e).__name__}")
|
| 829 |
+
print(f"Details: {str(e)}")
|
| 830 |
+
print("\nPossible fixes:")
|
| 831 |
+
print("1. Check if the support-check vLLM server is running (VLLM_SUPPORT_CHECK_BN_API_BASE).")
|
| 832 |
+
print("2. Check if the classifier vLLM server is running (VLLM_CLASSIFIER_BN_API_BASE).")
|
| 833 |
+
|
| 834 |
+
if __name__ == "__main__":
|
| 835 |
+
run_actual_api_test()
|
code/RL_model/verl/verl_train/script/run_qwen3-8b_v3(2G)_v2.sh
CHANGED
|
@@ -10,7 +10,7 @@ PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
|
|
| 10 |
algorithm.adv_estimator=grpo \
|
| 11 |
data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/bn_dataset/train.parquet \
|
| 12 |
data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/bn_dataset/test.parquet \
|
| 13 |
-
custom_reward_function.path=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/reward_func/reward_func/
|
| 14 |
data.train_batch_size=256 \
|
| 15 |
data.max_prompt_length=6144 \
|
| 16 |
data.max_response_length=2048 \
|
|
|
|
| 10 |
algorithm.adv_estimator=grpo \
|
| 11 |
data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/bn_dataset/train.parquet \
|
| 12 |
data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/bn_dataset/test.parquet \
|
| 13 |
+
custom_reward_function.path="/home/mshahidul/readctrl/code/RL_model/verl/verl_train/reward_func/reward_func/reward_new_v6_bn_v4_rmv_src_cov.py" \
|
| 14 |
data.train_batch_size=256 \
|
| 15 |
data.max_prompt_length=6144 \
|
| 16 |
data.max_response_length=2048 \
|
code/RL_model/verl/verl_train/script/run_qwen3-8b_v3(2G)_v3.sh
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
set -x
|
| 2 |
+
|
| 3 |
+
unset PYTORCH_CUDA_ALLOC_CONF
|
| 4 |
+
export EXPERIMENT_NAME=qwen3-4b-instruct-bn
|
| 5 |
+
export WAND_PROJECT='readctrl-verl'
|
| 6 |
+
export CUDA_DEVICE_ORDER="PCI_BUS_ID"
|
| 7 |
+
export CUDA_VISIBLE_DEVICES=1,2
|
| 8 |
+
|
| 9 |
+
PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
|
| 10 |
+
algorithm.adv_estimator=grpo \
|
| 11 |
+
data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/bn_dataset/train.parquet \
|
| 12 |
+
data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/bn_dataset/test.parquet \
|
| 13 |
+
custom_reward_function.path="/home/mshahidul/readctrl/code/RL_model/verl/verl_train/reward_func/reward_func/reward_new_v6_bn_v4.py" \
|
| 14 |
+
data.train_batch_size=256 \
|
| 15 |
+
data.max_prompt_length=6144 \
|
| 16 |
+
data.max_response_length=2048 \
|
| 17 |
+
data.filter_overlong_prompts=True \
|
| 18 |
+
data.truncation='error' \
|
| 19 |
+
actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507 \
|
| 20 |
+
actor_rollout_ref.actor.optim.lr=1e-6 \
|
| 21 |
+
actor_rollout_ref.model.use_remove_padding=True \
|
| 22 |
+
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
|
| 23 |
+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
|
| 24 |
+
actor_rollout_ref.actor.use_kl_loss=True \
|
| 25 |
+
actor_rollout_ref.actor.kl_loss_coef=0.001 \
|
| 26 |
+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
|
| 27 |
+
actor_rollout_ref.actor.entropy_coeff=0 \
|
| 28 |
+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
| 29 |
+
actor_rollout_ref.actor.fsdp_config.param_offload=False \
|
| 30 |
+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
|
| 31 |
+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
|
| 32 |
+
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
|
| 33 |
+
actor_rollout_ref.rollout.name=vllm \
|
| 34 |
+
actor_rollout_ref.rollout.gpu_memory_utilization=0.35 \
|
| 35 |
+
actor_rollout_ref.rollout.enforce_eager=True \
|
| 36 |
+
actor_rollout_ref.rollout.max_model_len=8192 \
|
| 37 |
+
actor_rollout_ref.rollout.n=5 \
|
| 38 |
+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \
|
| 39 |
+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
|
| 40 |
+
algorithm.use_kl_in_reward=False \
|
| 41 |
+
trainer.critic_warmup=0 \
|
| 42 |
+
trainer.logger='["console","wandb"]' \
|
| 43 |
+
trainer.project_name=$WAND_PROJECT \
|
| 44 |
+
trainer.experiment_name=$EXPERIMENT_NAME \
|
| 45 |
+
trainer.n_gpus_per_node=2 \
|
| 46 |
+
trainer.nnodes=1 \
|
| 47 |
+
trainer.save_freq=20 \
|
| 48 |
+
trainer.test_freq=10 \
|
| 49 |
+
trainer.log_val_generations=1 \
|
| 50 |
+
+trainer.remove_previous_ckpt_in_save=true \
|
| 51 |
+
trainer.max_actor_ckpt_to_keep=1 \
|
| 52 |
+
trainer.max_critic_ckpt_to_keep=1 \
|
| 53 |
+
trainer.resume_mode=auto \
|
| 54 |
+
trainer.default_local_dir=/home/mshahidul/readctrl/code/RL_model/models/reward_new_v6_bn_v4_test2 \
|
| 55 |
+
trainer.total_epochs=45 $@ \
|
| 56 |
+
2>&1 | tee $EXPERIMENT_NAME.log
|
code/RL_model/verl/verl_train/script/run_qwen3-8b_v3(4G)_v2.sh
CHANGED
|
@@ -54,6 +54,6 @@ PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
|
|
| 54 |
trainer.max_actor_ckpt_to_keep=1 \
|
| 55 |
trainer.max_critic_ckpt_to_keep=1 \
|
| 56 |
trainer.resume_mode=auto \
|
| 57 |
-
trainer.default_local_dir=/home/mshahidul/readctrl/code/RL_model/models/
|
| 58 |
trainer.total_epochs=45 $@ \
|
| 59 |
2>&1 | tee $EXPERIMENT_NAME.log
|
|
|
|
| 54 |
trainer.max_actor_ckpt_to_keep=1 \
|
| 55 |
trainer.max_critic_ckpt_to_keep=1 \
|
| 56 |
trainer.resume_mode=auto \
|
| 57 |
+
trainer.default_local_dir=/home/mshahidul/readctrl/code/RL_model/models/readCtrl_RL_bn_srcCov_v1 \
|
| 58 |
trainer.total_epochs=45 $@ \
|
| 59 |
2>&1 | tee $EXPERIMENT_NAME.log
|
code/fine_tune_sft_dpo/best_of_n_qwen3_vllm_bn.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 3 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
import re
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from typing import Any, Dict, List, Tuple
|
| 10 |
+
|
| 11 |
+
from vllm import LLM, SamplingParams
|
| 12 |
+
from transformers import AutoTokenizer
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def strip_think_blocks(text: str) -> str:
|
| 16 |
+
"""Remove <think>...</think> reasoning blocks from model output."""
|
| 17 |
+
cleaned = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
|
| 18 |
+
return cleaned if cleaned else text
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
BASE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo"
|
| 22 |
+
FINETUNED_MODEL_DIR = os.path.join(BASE_DIR, "model", "bn")
|
| 23 |
+
PROMPT_DIR = os.path.join(BASE_DIR, "prompt_bn")
|
| 24 |
+
TEST_JSON = os.path.join(BASE_DIR, "dataset", "bn", "test_bn.json")
|
| 25 |
+
RESULTS_DIR = os.path.join(BASE_DIR, "results", "bn")
|
| 26 |
+
|
| 27 |
+
SOURCE_LANG = "Bengali"
|
| 28 |
+
|
| 29 |
+
LABEL_TO_PROMPT_FILE = {
|
| 30 |
+
"low_health_literacy": "prompt_low",
|
| 31 |
+
"intermediate_health_literacy": "prompt_intermediate",
|
| 32 |
+
"proficient_health_literacy": "prompt_proficient",
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
LABEL_TO_READABILITY = {
|
| 36 |
+
"low_health_literacy": (
|
| 37 |
+
"Low Health Literacy (High Readability): individuals needing the simplest "
|
| 38 |
+
"terms for immediate action, using 'living room' language, one idea per "
|
| 39 |
+
"sentence, and focusing only on need-to-know information from the Gold Summary."
|
| 40 |
+
),
|
| 41 |
+
"intermediate_health_literacy": (
|
| 42 |
+
"Intermediate Health Literacy (Medium Readability): the general public at a "
|
| 43 |
+
"news-reading level, with standard vocabulary and some common medical terms, "
|
| 44 |
+
"and a balanced level of detail led by the Gold Summary."
|
| 45 |
+
),
|
| 46 |
+
"proficient_health_literacy": (
|
| 47 |
+
"Proficient Health Literacy (Low Readability): researchers, clinicians, or "
|
| 48 |
+
"highly informed patients, using technical and academic language, high "
|
| 49 |
+
"information density, and full clinical nuance and terminology from the "
|
| 50 |
+
"Source Text."
|
| 51 |
+
),
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def load_prompts(prompt_dir: str) -> Dict[str, str]:
|
| 56 |
+
prompts: Dict[str, str] = {}
|
| 57 |
+
for label, filename in LABEL_TO_PROMPT_FILE.items():
|
| 58 |
+
path = os.path.join(prompt_dir, filename)
|
| 59 |
+
if os.path.isfile(path):
|
| 60 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 61 |
+
prompts[label] = f.read()
|
| 62 |
+
else:
|
| 63 |
+
raise FileNotFoundError(f"Prompt file not found: {path}")
|
| 64 |
+
return prompts
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def build_generation_user_message(
|
| 68 |
+
prompt_template: str,
|
| 69 |
+
full_text: str,
|
| 70 |
+
gold_summary: str,
|
| 71 |
+
source_lang: str = SOURCE_LANG,
|
| 72 |
+
) -> str:
|
| 73 |
+
return (
|
| 74 |
+
prompt_template.replace("{full_text}", full_text)
|
| 75 |
+
.replace("{gold_summary}", gold_summary)
|
| 76 |
+
.replace("{source_lang}", source_lang)
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def build_selection_user_message(
|
| 81 |
+
full_text: str,
|
| 82 |
+
label: str,
|
| 83 |
+
candidates: List[str],
|
| 84 |
+
source_lang: str = SOURCE_LANG,
|
| 85 |
+
) -> str:
|
| 86 |
+
readability = LABEL_TO_READABILITY.get(label, label)
|
| 87 |
+
numbered = []
|
| 88 |
+
for i, cand in enumerate(candidates, start=1):
|
| 89 |
+
numbered.append(f"[{i}]\n{cand.strip()}")
|
| 90 |
+
candidates_block = "\n\n".join(numbered)
|
| 91 |
+
|
| 92 |
+
return (
|
| 93 |
+
"You are selecting the best patient-friendly summary of a medical case.\n\n"
|
| 94 |
+
f"Original text ({source_lang}):\n{full_text}\n\n"
|
| 95 |
+
f"Readability requirement: {readability}.\n\n"
|
| 96 |
+
f"Here are {len(candidates)} candidate summaries:\n\n"
|
| 97 |
+
f"{candidates_block}\n\n"
|
| 98 |
+
"Choose the single candidate that best matches the readability "
|
| 99 |
+
"requirement and accurately reflects the key clinical information.\n"
|
| 100 |
+
"Answer with exactly one line in the form:\n"
|
| 101 |
+
'"BEST_INDEX: k"\n'
|
| 102 |
+
f"where k is an integer from 1 to {len(candidates)}."
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def parse_best_index(text: str, num_candidates: int) -> int:
|
| 107 |
+
# Look for an integer in the model output; default to 1 if parsing fails.
|
| 108 |
+
match = re.search(r"(\d+)", text)
|
| 109 |
+
if not match:
|
| 110 |
+
return 1
|
| 111 |
+
idx = int(match.group(1))
|
| 112 |
+
if idx < 1 or idx > num_candidates:
|
| 113 |
+
return 1
|
| 114 |
+
return idx
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def build_generation_prompts_for_model(
|
| 118 |
+
tokenizer,
|
| 119 |
+
test_list: List[Dict[str, Any]],
|
| 120 |
+
prompts: Dict[str, str],
|
| 121 |
+
source_lang: str = SOURCE_LANG,
|
| 122 |
+
) -> Tuple[List[str], List[Dict[str, Any]]]:
|
| 123 |
+
batched_prompts: List[str] = []
|
| 124 |
+
meta: List[Dict[str, Any]] = []
|
| 125 |
+
|
| 126 |
+
for idx, item in enumerate(test_list):
|
| 127 |
+
label = item.get("label")
|
| 128 |
+
doc_id = item.get("doc_id", idx)
|
| 129 |
+
fulltext = item.get("fulltext", "")
|
| 130 |
+
summary = item.get("summary", "")
|
| 131 |
+
gold_gen_text = item.get("gen_text", "")
|
| 132 |
+
|
| 133 |
+
if label not in prompts:
|
| 134 |
+
meta.append(
|
| 135 |
+
{
|
| 136 |
+
"doc_id": doc_id,
|
| 137 |
+
"label": label,
|
| 138 |
+
"gold_gen_text": gold_gen_text,
|
| 139 |
+
"fulltext": fulltext,
|
| 140 |
+
"summary": summary,
|
| 141 |
+
"error": f"Unknown label: {label}",
|
| 142 |
+
}
|
| 143 |
+
)
|
| 144 |
+
batched_prompts.append(None) # type: ignore[arg-type]
|
| 145 |
+
continue
|
| 146 |
+
|
| 147 |
+
user_prompt = build_generation_user_message(
|
| 148 |
+
prompts[label],
|
| 149 |
+
fulltext,
|
| 150 |
+
summary,
|
| 151 |
+
source_lang=source_lang,
|
| 152 |
+
)
|
| 153 |
+
chat = [{"role": "user", "content": user_prompt}]
|
| 154 |
+
formatted = tokenizer.apply_chat_template(
|
| 155 |
+
chat,
|
| 156 |
+
tokenize=False,
|
| 157 |
+
add_generation_prompt=True,
|
| 158 |
+
enable_thinking=False,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
batched_prompts.append(formatted)
|
| 162 |
+
meta.append(
|
| 163 |
+
{
|
| 164 |
+
"doc_id": doc_id,
|
| 165 |
+
"label": label,
|
| 166 |
+
"gold_gen_text": gold_gen_text,
|
| 167 |
+
"fulltext": fulltext,
|
| 168 |
+
"summary": summary,
|
| 169 |
+
"error": None,
|
| 170 |
+
}
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
return batched_prompts, meta
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def run_best_of_n_for_model(
|
| 177 |
+
model_id: str,
|
| 178 |
+
model_key: str,
|
| 179 |
+
test_list: List[Dict[str, Any]],
|
| 180 |
+
prompts: Dict[str, str],
|
| 181 |
+
max_new_tokens: int,
|
| 182 |
+
temperature: float,
|
| 183 |
+
num_candidates: int,
|
| 184 |
+
batch_size: int,
|
| 185 |
+
source_lang: str = SOURCE_LANG,
|
| 186 |
+
) -> Dict[int, Dict[str, Any]]:
|
| 187 |
+
print(f"\n=== Running model {model_key}: {model_id} ===")
|
| 188 |
+
|
| 189 |
+
print("Loading tokenizer...")
|
| 190 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 191 |
+
|
| 192 |
+
print("Preparing prompts...")
|
| 193 |
+
batched_prompts, meta = build_generation_prompts_for_model(
|
| 194 |
+
tokenizer, test_list, prompts, source_lang=source_lang
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
print("Loading vLLM model...")
|
| 198 |
+
llm = LLM(
|
| 199 |
+
model=model_id,
|
| 200 |
+
trust_remote_code=True,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
gen_sampling_params = SamplingParams(
|
| 204 |
+
temperature=temperature,
|
| 205 |
+
max_tokens=max_new_tokens,
|
| 206 |
+
n=num_candidates,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# Filter out None prompts (unknown labels) for generation
|
| 210 |
+
valid_indices = [i for i, p in enumerate(batched_prompts) if p is not None]
|
| 211 |
+
valid_prompts = [batched_prompts[i] for i in valid_indices]
|
| 212 |
+
|
| 213 |
+
total_valid = len(valid_prompts)
|
| 214 |
+
batch_size = max(1, batch_size)
|
| 215 |
+
print(
|
| 216 |
+
f"Running vLLM generation on {total_valid} samples "
|
| 217 |
+
f"in batches of {batch_size} with Best-of-{num_candidates}..."
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
candidates_per_idx: Dict[int, List[str]] = {}
|
| 221 |
+
|
| 222 |
+
num_batches = (total_valid + batch_size - 1) // batch_size
|
| 223 |
+
for batch_idx in range(num_batches):
|
| 224 |
+
start = batch_idx * batch_size
|
| 225 |
+
end = min(start + batch_size, total_valid)
|
| 226 |
+
batch_prompts = valid_prompts[start:end]
|
| 227 |
+
batch_indices = valid_indices[start:end]
|
| 228 |
+
|
| 229 |
+
print(
|
| 230 |
+
f"Generating batch {batch_idx + 1}/{num_batches} "
|
| 231 |
+
f"with {len(batch_prompts)} samples..."
|
| 232 |
+
)
|
| 233 |
+
outputs = llm.generate(batch_prompts, sampling_params=gen_sampling_params)
|
| 234 |
+
|
| 235 |
+
for idx_in_batch, output in enumerate(outputs):
|
| 236 |
+
original_idx = batch_indices[idx_in_batch]
|
| 237 |
+
cand_texts = [strip_think_blocks(o.text) for o in output.outputs]
|
| 238 |
+
candidates_per_idx[original_idx] = cand_texts
|
| 239 |
+
|
| 240 |
+
# Now build selection prompts to choose the best candidate for each valid sample.
|
| 241 |
+
print("Building selection prompts for Best-of-N choice...")
|
| 242 |
+
selection_prompts: List[str] = []
|
| 243 |
+
selection_indices: List[int] = []
|
| 244 |
+
reverse_map: Dict[int, int] = {}
|
| 245 |
+
|
| 246 |
+
for original_idx in valid_indices:
|
| 247 |
+
info = meta[original_idx]
|
| 248 |
+
if info["error"] is not None:
|
| 249 |
+
continue
|
| 250 |
+
cands = candidates_per_idx.get(original_idx, [])
|
| 251 |
+
if not cands:
|
| 252 |
+
continue
|
| 253 |
+
sel_user = build_selection_user_message(
|
| 254 |
+
info["fulltext"],
|
| 255 |
+
info["label"],
|
| 256 |
+
cands,
|
| 257 |
+
source_lang=source_lang,
|
| 258 |
+
)
|
| 259 |
+
chat = [{"role": "user", "content": sel_user}]
|
| 260 |
+
formatted = tokenizer.apply_chat_template(
|
| 261 |
+
chat,
|
| 262 |
+
tokenize=False,
|
| 263 |
+
add_generation_prompt=True,
|
| 264 |
+
enable_thinking=False,
|
| 265 |
+
)
|
| 266 |
+
reverse_map[len(selection_prompts)] = original_idx
|
| 267 |
+
selection_prompts.append(formatted)
|
| 268 |
+
|
| 269 |
+
select_sampling_params = SamplingParams(
|
| 270 |
+
temperature=0.0,
|
| 271 |
+
max_tokens=32,
|
| 272 |
+
n=1,
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
best_index_per_idx: Dict[int, int] = {}
|
| 276 |
+
|
| 277 |
+
total_select = len(selection_prompts)
|
| 278 |
+
if total_select > 0:
|
| 279 |
+
print(
|
| 280 |
+
f"Running selection passes on {total_select} samples "
|
| 281 |
+
f"in batches of {batch_size}..."
|
| 282 |
+
)
|
| 283 |
+
num_sel_batches = (total_select + batch_size - 1) // batch_size
|
| 284 |
+
for batch_idx in range(num_sel_batches):
|
| 285 |
+
start = batch_idx * batch_size
|
| 286 |
+
end = min(start + batch_size, total_select)
|
| 287 |
+
batch_prompts = selection_prompts[start:end]
|
| 288 |
+
|
| 289 |
+
print(
|
| 290 |
+
f"Selecting batch {batch_idx + 1}/{num_sel_batches} "
|
| 291 |
+
f"with {len(batch_prompts)} samples..."
|
| 292 |
+
)
|
| 293 |
+
outputs = llm.generate(
|
| 294 |
+
batch_prompts, sampling_params=select_sampling_params
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
for idx_in_batch, output in enumerate(outputs):
|
| 298 |
+
global_sel_idx = start + idx_in_batch
|
| 299 |
+
original_idx = reverse_map[global_sel_idx]
|
| 300 |
+
raw_text = strip_think_blocks(output.outputs[0].text)
|
| 301 |
+
best_idx = parse_best_index(raw_text, num_candidates)
|
| 302 |
+
best_index_per_idx[original_idx] = best_idx
|
| 303 |
+
|
| 304 |
+
# Build structured results per original index.
|
| 305 |
+
model_results: Dict[int, Dict[str, Any]] = {}
|
| 306 |
+
for idx, info in enumerate(meta):
|
| 307 |
+
if info["error"] is not None:
|
| 308 |
+
model_results[idx] = {
|
| 309 |
+
"error": info["error"],
|
| 310 |
+
}
|
| 311 |
+
continue
|
| 312 |
+
|
| 313 |
+
cands = candidates_per_idx.get(idx, [])
|
| 314 |
+
best_idx = best_index_per_idx.get(idx, 1 if cands else None)
|
| 315 |
+
best_summary = (
|
| 316 |
+
cands[best_idx - 1] if cands and best_idx is not None and 1 <= best_idx <= len(cands) else ""
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
model_results[idx] = {
|
| 320 |
+
"candidates": cands,
|
| 321 |
+
"best_index": best_idx,
|
| 322 |
+
"best_summary": best_summary,
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
return model_results
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def parse_args():
|
| 329 |
+
p = argparse.ArgumentParser(
|
| 330 |
+
description=(
|
| 331 |
+
"Run vLLM inference with Best-of-N for both the finetuned "
|
| 332 |
+
"Qwen3 model and the base Qwen/Qwen3-4B-Instruct-2507 model "
|
| 333 |
+
"on test_bn.json (Bengali)."
|
| 334 |
+
)
|
| 335 |
+
)
|
| 336 |
+
p.add_argument(
|
| 337 |
+
"--prompt-dir",
|
| 338 |
+
type=str,
|
| 339 |
+
default=PROMPT_DIR,
|
| 340 |
+
help="Directory containing prompt files (prompt_low, prompt_intermediate, prompt_proficient).",
|
| 341 |
+
)
|
| 342 |
+
p.add_argument(
|
| 343 |
+
"--finetuned-model-dir",
|
| 344 |
+
type=str,
|
| 345 |
+
default=FINETUNED_MODEL_DIR,
|
| 346 |
+
help="Path to the merged finetuned model directory.",
|
| 347 |
+
)
|
| 348 |
+
p.add_argument(
|
| 349 |
+
"--test-data",
|
| 350 |
+
type=str,
|
| 351 |
+
default=TEST_JSON,
|
| 352 |
+
help="Path to the test data JSON file.",
|
| 353 |
+
)
|
| 354 |
+
p.add_argument(
|
| 355 |
+
"--src-lang",
|
| 356 |
+
type=str,
|
| 357 |
+
default=SOURCE_LANG,
|
| 358 |
+
help="Source language of the text (e.g. Bengali, English).",
|
| 359 |
+
)
|
| 360 |
+
p.add_argument(
|
| 361 |
+
"--base-model-id",
|
| 362 |
+
type=str,
|
| 363 |
+
default="Qwen/Qwen3-4B-Instruct-2507",
|
| 364 |
+
help="Hugging Face model id for the base Qwen3 instruct model.",
|
| 365 |
+
)
|
| 366 |
+
p.add_argument(
|
| 367 |
+
"--max-new-tokens",
|
| 368 |
+
type=int,
|
| 369 |
+
default=512,
|
| 370 |
+
help="Maximum number of new tokens to generate per candidate.",
|
| 371 |
+
)
|
| 372 |
+
p.add_argument(
|
| 373 |
+
"--temperature",
|
| 374 |
+
type=float,
|
| 375 |
+
default=0.7,
|
| 376 |
+
help="Sampling temperature for candidate generation.",
|
| 377 |
+
)
|
| 378 |
+
p.add_argument(
|
| 379 |
+
"--num-candidates",
|
| 380 |
+
type=int,
|
| 381 |
+
default=5,
|
| 382 |
+
help="Number of candidate summaries to generate per example (N in Best-of-N).",
|
| 383 |
+
)
|
| 384 |
+
p.add_argument(
|
| 385 |
+
"--batch-size",
|
| 386 |
+
type=int,
|
| 387 |
+
default=16,
|
| 388 |
+
help="Batch size for vLLM generation.",
|
| 389 |
+
)
|
| 390 |
+
p.add_argument(
|
| 391 |
+
"--output-file",
|
| 392 |
+
type=str,
|
| 393 |
+
default=None,
|
| 394 |
+
help=(
|
| 395 |
+
"Optional path for the main results JSON file. "
|
| 396 |
+
"If not set, a timestamped name in the results directory is used."
|
| 397 |
+
),
|
| 398 |
+
)
|
| 399 |
+
p.add_argument(
|
| 400 |
+
"--model",
|
| 401 |
+
type=str,
|
| 402 |
+
choices=["base", "finetuned", "both"],
|
| 403 |
+
default="both",
|
| 404 |
+
help=(
|
| 405 |
+
"Which model(s) to run: 'base' (Qwen3-4B-Instruct), "
|
| 406 |
+
"'finetuned' (local SFT model), or 'both' (default)."
|
| 407 |
+
),
|
| 408 |
+
)
|
| 409 |
+
return p.parse_args()
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def main():
|
| 413 |
+
args = parse_args()
|
| 414 |
+
|
| 415 |
+
os.makedirs(RESULTS_DIR, exist_ok=True)
|
| 416 |
+
|
| 417 |
+
print("Loading prompts from", args.prompt_dir)
|
| 418 |
+
prompts = load_prompts(args.prompt_dir)
|
| 419 |
+
|
| 420 |
+
print("Loading test data from", args.test_data)
|
| 421 |
+
with open(args.test_data, "r", encoding="utf-8") as f:
|
| 422 |
+
test_list = json.load(f)
|
| 423 |
+
|
| 424 |
+
# Run Best-of-N for the selected model(s), one at a time to save GPU memory.
|
| 425 |
+
finetuned_results: Dict[int, Dict[str, Any]] = {}
|
| 426 |
+
base_results: Dict[int, Dict[str, Any]] = {}
|
| 427 |
+
|
| 428 |
+
if args.model in ("finetuned", "both"):
|
| 429 |
+
finetuned_results = run_best_of_n_for_model(
|
| 430 |
+
model_id=args.finetuned_model_dir,
|
| 431 |
+
model_key="qwen3_finetuned",
|
| 432 |
+
test_list=test_list,
|
| 433 |
+
prompts=prompts,
|
| 434 |
+
max_new_tokens=args.max_new_tokens,
|
| 435 |
+
temperature=args.temperature,
|
| 436 |
+
num_candidates=args.num_candidates,
|
| 437 |
+
batch_size=args.batch_size,
|
| 438 |
+
source_lang=args.src_lang,
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
if args.model in ("base", "both"):
|
| 442 |
+
base_results = run_best_of_n_for_model(
|
| 443 |
+
model_id=args.base_model_id,
|
| 444 |
+
model_key="qwen3_base",
|
| 445 |
+
test_list=test_list,
|
| 446 |
+
prompts=prompts,
|
| 447 |
+
max_new_tokens=args.max_new_tokens,
|
| 448 |
+
temperature=args.temperature,
|
| 449 |
+
num_candidates=args.num_candidates,
|
| 450 |
+
batch_size=args.batch_size,
|
| 451 |
+
source_lang=args.src_lang,
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 455 |
+
if args.output_file:
|
| 456 |
+
out_path = args.output_file
|
| 457 |
+
base, ext = os.path.splitext(out_path)
|
| 458 |
+
if not ext:
|
| 459 |
+
out_path = base + ".json"
|
| 460 |
+
base = out_path.rsplit(".", 1)[0]
|
| 461 |
+
summary_path = base + "_summary.json"
|
| 462 |
+
else:
|
| 463 |
+
out_path = os.path.join(RESULTS_DIR, f"test_best_of_n_vllm_{timestamp}.json")
|
| 464 |
+
summary_path = os.path.join(
|
| 465 |
+
RESULTS_DIR, f"inference_best_of_n_vllm_{timestamp}.json"
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
combined_results = []
|
| 469 |
+
for idx, item in enumerate(test_list):
|
| 470 |
+
label = item.get("label")
|
| 471 |
+
doc_id = item.get("doc_id", idx)
|
| 472 |
+
gold_gen_text = item.get("gen_text", "")
|
| 473 |
+
|
| 474 |
+
entry: Dict[str, Any] = {
|
| 475 |
+
"doc_id": doc_id,
|
| 476 |
+
"label": label,
|
| 477 |
+
"gold_gen_text": gold_gen_text,
|
| 478 |
+
"predicted_label": item.get("predicted_label", ""),
|
| 479 |
+
"prediction_correct": item.get("prediction_correct", None),
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
if args.model in ("finetuned", "both"):
|
| 483 |
+
entry["qwen3_finetuned"] = finetuned_results.get(idx, {})
|
| 484 |
+
if args.model in ("base", "both"):
|
| 485 |
+
entry["qwen3_base"] = base_results.get(idx, {})
|
| 486 |
+
|
| 487 |
+
combined_results.append(entry)
|
| 488 |
+
|
| 489 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
| 490 |
+
json.dump(combined_results, f, ensure_ascii=False, indent=2)
|
| 491 |
+
|
| 492 |
+
summary_data: Dict[str, Any] = {
|
| 493 |
+
"model_run": args.model,
|
| 494 |
+
"test_json": args.test_data,
|
| 495 |
+
"prompt_dir": args.prompt_dir,
|
| 496 |
+
"src_lang": args.src_lang,
|
| 497 |
+
"num_test_samples": len(test_list),
|
| 498 |
+
"results_file": out_path,
|
| 499 |
+
"timestamp": timestamp,
|
| 500 |
+
"max_new_tokens": args.max_new_tokens,
|
| 501 |
+
"temperature": args.temperature,
|
| 502 |
+
"num_candidates": args.num_candidates,
|
| 503 |
+
}
|
| 504 |
+
if args.model in ("finetuned", "both"):
|
| 505 |
+
summary_data["finetuned_model_dir"] = args.finetuned_model_dir
|
| 506 |
+
if args.model in ("base", "both"):
|
| 507 |
+
summary_data["base_model_id"] = args.base_model_id
|
| 508 |
+
|
| 509 |
+
with open(summary_path, "w", encoding="utf-8") as f:
|
| 510 |
+
json.dump(summary_data, f, ensure_ascii=False, indent=2)
|
| 511 |
+
|
| 512 |
+
print(f"\nResults saved to {out_path}")
|
| 513 |
+
print(f"Summary saved to {summary_path}")
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
if __name__ == "__main__":
|
| 517 |
+
main()
|
| 518 |
+
|
code/fine_tune_sft_dpo/dataset/bn/old/test_bn_subclaims.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f99326a350c42c5012c10e6898e892ac2962230f627f91a1da814fe8a8f79bba
|
| 3 |
+
size 2249546
|
code/fine_tune_sft_dpo/dataset/bn/test_bn_subclaims.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6f58aca4d71f46ded80cabe51e3f6b96c0774eae5d4680e46eec0cadefe121e9
|
| 3 |
+
size 5741053
|
code/fine_tune_sft_dpo/eval.sh
CHANGED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
python /home/mshahidul/readctrl/code/fine_tune_sft_dpo/test_classifier_with_subclaim_thresholds.py \
|
| 2 |
-
--input-file /home/mshahidul/readctrl/code/fine_tune_sft_dpo/results/en/test_best_of_n_qwen3-4B_sft.json
|
|
|
|
|
|
|
|
|
code/fine_tune_sft_dpo/evaluate_scores_bn.py
ADDED
|
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Standalone evaluation script for computing factuality, hallucination, and
|
| 4 |
+
classifier scores on a JSON file.
|
| 5 |
+
|
| 6 |
+
Supports two input formats:
|
| 7 |
+
|
| 8 |
+
1. **Standard format** — a list of objects, each with:
|
| 9 |
+
- fulltext, summary_text, summary_subclaims, generated_text, label
|
| 10 |
+
|
| 11 |
+
2. **Best-of-N (BON) format** — a list of objects, each with:
|
| 12 |
+
- doc_id, label, qwen3_base.best_summary (JSON-wrapped generated text)
|
| 13 |
+
Requires a separate --subclaims file to supply fulltext, summary,
|
| 14 |
+
summary_subclaims, and fulltext_subclaims (keyed by doc_id).
|
| 15 |
+
|
| 16 |
+
3. **Inference format** — a list of objects, each with:
|
| 17 |
+
- doc_id, label, predicted_gen_text (JSON-wrapped evaluated summary),
|
| 18 |
+
optionally gold_gen_text
|
| 19 |
+
predicted_gen_text is the summary to evaluate (same JSON key-by-label
|
| 20 |
+
format as best_summary). Requires --subclaims for fulltext and subclaims.
|
| 21 |
+
|
| 22 |
+
4. **Self-refine format** — a list of objects, each with:
|
| 23 |
+
- doc_id, label, final_summary (the generated text to evaluate),
|
| 24 |
+
optionally gold_gen_text, gold_summary
|
| 25 |
+
final_summary is the summary to evaluate (plain text or JSON-wrapped by
|
| 26 |
+
label). Requires --subclaims for fulltext and subclaims.
|
| 27 |
+
|
| 28 |
+
The script reuses the reward functions from reward_new_v6_bn_v4_rmv_src_cov.py:
|
| 29 |
+
- factuality_score : fraction of summary subclaims supported by generated_text
|
| 30 |
+
- hallucination_score: fraction of gen subclaims NOT supported by fulltext
|
| 31 |
+
- classifier_score : whether generated_text matches the target literacy level
|
| 32 |
+
|
| 33 |
+
Requires the same vLLM endpoints as the reward file:
|
| 34 |
+
- Support checker : VLLM_SUPPORT_CHECK_BN_API_BASE (default http://localhost:8090/v1)
|
| 35 |
+
- Classifier : VLLM_CLASSIFIER_BN_API_BASE (default http://localhost:8040/v1)
|
| 36 |
+
- Subclaim extractor: VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE (default http://localhost:8050/v1)
|
| 37 |
+
|
| 38 |
+
Usage:
|
| 39 |
+
# Standard format
|
| 40 |
+
python evaluate_scores.py --input data.json [--output results.json]
|
| 41 |
+
|
| 42 |
+
# BON format with subclaims file
|
| 43 |
+
python evaluate_scores.py --input bon_results.json --subclaims subclaims.json --output-dir evaluation/bn/
|
| 44 |
+
|
| 45 |
+
# Inference format (predicted_gen_text as evaluated summary)
|
| 46 |
+
python evaluate_scores.py --input test_inference_vllm_qwen3-4B_base.json --subclaims subclaims.json --output results.json
|
| 47 |
+
|
| 48 |
+
# Self-refine format (final_summary as evaluated summary)
|
| 49 |
+
python evaluate_scores.py --input test_self_refine_vllm_qwen3_4B_base.json --subclaims subclaims.json --output-dir evaluation/bn/
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
import argparse
|
| 53 |
+
import json
|
| 54 |
+
import os
|
| 55 |
+
import re
|
| 56 |
+
import sys
|
| 57 |
+
import time
|
| 58 |
+
from typing import Any, Dict, List, Optional
|
| 59 |
+
|
| 60 |
+
from tqdm import tqdm
|
| 61 |
+
|
| 62 |
+
# Import scoring utilities from the reward module (same directory).
|
| 63 |
+
from reward_new_v6_bn_v4_rmv_src_cov import (
|
| 64 |
+
_call_support_api,
|
| 65 |
+
_compute_classifier_reward,
|
| 66 |
+
_extract_subclaims_from_text,
|
| 67 |
+
_is_bangla_text,
|
| 68 |
+
_nonlinear_grounding,
|
| 69 |
+
compute_rewards,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def extract_text_from_best_summary(best_summary: str, label: str) -> str:
|
| 74 |
+
"""Extract the raw generated text from a BON best_summary string.
|
| 75 |
+
|
| 76 |
+
The best_summary is a (possibly truncated) JSON string like:
|
| 77 |
+
'{"proficient_health_literacy": "...actual text..."}'
|
| 78 |
+
We locate the value after the label key and strip JSON wrapping.
|
| 79 |
+
"""
|
| 80 |
+
key_pattern = re.compile(re.escape(f'"{label}"') + r'\s*:\s*"')
|
| 81 |
+
m = key_pattern.search(best_summary)
|
| 82 |
+
if not m:
|
| 83 |
+
return best_summary.strip()
|
| 84 |
+
text = best_summary[m.end():]
|
| 85 |
+
if text.endswith('"\n}'):
|
| 86 |
+
text = text[:-3]
|
| 87 |
+
elif text.endswith('"}\n'):
|
| 88 |
+
text = text[:-3]
|
| 89 |
+
elif text.endswith('"}'):
|
| 90 |
+
text = text[:-2]
|
| 91 |
+
elif text.endswith('"'):
|
| 92 |
+
text = text[:-1]
|
| 93 |
+
text = text.replace("\\n", "\n").replace('\\"', '"')
|
| 94 |
+
return text.strip()
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def prepare_bon_items(
|
| 98 |
+
bon_data: List[Dict[str, Any]],
|
| 99 |
+
subclaims_data: List[Dict[str, Any]],
|
| 100 |
+
model_key: str = "qwen3_base",
|
| 101 |
+
) -> List[Dict[str, Any]]:
|
| 102 |
+
"""Merge BON results with subclaims data into the standard evaluation format."""
|
| 103 |
+
sc_by_docid = {}
|
| 104 |
+
for item in subclaims_data:
|
| 105 |
+
sc_by_docid[item["doc_id"]] = item
|
| 106 |
+
|
| 107 |
+
prepared = []
|
| 108 |
+
for item in bon_data:
|
| 109 |
+
doc_id = item["doc_id"]
|
| 110 |
+
label = item["label"]
|
| 111 |
+
sc = sc_by_docid.get(doc_id, {})
|
| 112 |
+
|
| 113 |
+
model_data = item.get(model_key, {})
|
| 114 |
+
best_summary = model_data.get("best_summary", "") or model_data.get("predicted_gen_text", "")
|
| 115 |
+
generated_text = extract_text_from_best_summary(best_summary, label)
|
| 116 |
+
|
| 117 |
+
prepared.append({
|
| 118 |
+
"doc_id": doc_id,
|
| 119 |
+
"label": label,
|
| 120 |
+
"fulltext": sc.get("fulltext", ""),
|
| 121 |
+
"summary_text": sc.get("summary", ""),
|
| 122 |
+
"summary_subclaims": sc.get("summary_subclaims", []),
|
| 123 |
+
"fulltext_subclaims": sc.get("fulltext_subclaims", []),
|
| 124 |
+
"generated_text": generated_text,
|
| 125 |
+
"gold_gen_text": item.get("gold_gen_text", ""),
|
| 126 |
+
"predicted_label": item.get("predicted_label", ""),
|
| 127 |
+
"prediction_correct": item.get("prediction_correct", False),
|
| 128 |
+
})
|
| 129 |
+
return prepared
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def prepare_inference_items(
|
| 133 |
+
inference_data: List[Dict[str, Any]],
|
| 134 |
+
subclaims_data: List[Dict[str, Any]],
|
| 135 |
+
) -> List[Dict[str, Any]]:
|
| 136 |
+
"""Merge inference-format results (doc_id, label, predicted_gen_text) with
|
| 137 |
+
subclaims data into the standard evaluation format. predicted_gen_text is
|
| 138 |
+
the JSON-wrapped evaluated summary; the raw text is extracted using the
|
| 139 |
+
item's label.
|
| 140 |
+
"""
|
| 141 |
+
sc_by_docid = {}
|
| 142 |
+
for item in subclaims_data:
|
| 143 |
+
sc_by_docid[item["doc_id"]] = item
|
| 144 |
+
|
| 145 |
+
prepared = []
|
| 146 |
+
for item in inference_data:
|
| 147 |
+
doc_id = item["doc_id"]
|
| 148 |
+
label = item["label"]
|
| 149 |
+
sc = sc_by_docid.get(doc_id, {})
|
| 150 |
+
|
| 151 |
+
raw_pred = item.get("predicted_gen_text", "") or ""
|
| 152 |
+
generated_text = extract_text_from_best_summary(raw_pred, label)
|
| 153 |
+
|
| 154 |
+
prepared.append({
|
| 155 |
+
"doc_id": doc_id,
|
| 156 |
+
"label": label,
|
| 157 |
+
"fulltext": sc.get("fulltext", ""),
|
| 158 |
+
"summary_text": sc.get("summary", ""),
|
| 159 |
+
"summary_subclaims": sc.get("summary_subclaims", []),
|
| 160 |
+
"fulltext_subclaims": sc.get("fulltext_subclaims", []),
|
| 161 |
+
"generated_text": generated_text,
|
| 162 |
+
"gold_gen_text": item.get("gold_gen_text", ""),
|
| 163 |
+
})
|
| 164 |
+
return prepared
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def prepare_self_refine_items(
|
| 168 |
+
self_refine_data: List[Dict[str, Any]],
|
| 169 |
+
subclaims_data: List[Dict[str, Any]],
|
| 170 |
+
) -> List[Dict[str, Any]]:
|
| 171 |
+
"""Merge self-refine format (doc_id, label, final_summary) with subclaims
|
| 172 |
+
data. final_summary is the generated text to evaluate (plain text or
|
| 173 |
+
JSON-wrapped by label); it is extracted and used as generated_text.
|
| 174 |
+
"""
|
| 175 |
+
sc_by_docid = {}
|
| 176 |
+
for item in subclaims_data:
|
| 177 |
+
sc_by_docid[item["doc_id"]] = item
|
| 178 |
+
|
| 179 |
+
prepared = []
|
| 180 |
+
for item in self_refine_data:
|
| 181 |
+
doc_id = item["doc_id"]
|
| 182 |
+
label = item["label"]
|
| 183 |
+
sc = sc_by_docid.get(doc_id, {})
|
| 184 |
+
|
| 185 |
+
raw_final = item.get("final_summary", "") or ""
|
| 186 |
+
generated_text = extract_text_from_best_summary(raw_final, label)
|
| 187 |
+
|
| 188 |
+
prepared.append({
|
| 189 |
+
"doc_id": doc_id,
|
| 190 |
+
"label": label,
|
| 191 |
+
"fulltext": sc.get("fulltext", ""),
|
| 192 |
+
"summary_text": sc.get("summary", ""),
|
| 193 |
+
"summary_subclaims": sc.get("summary_subclaims", []),
|
| 194 |
+
"fulltext_subclaims": sc.get("fulltext_subclaims", []),
|
| 195 |
+
"generated_text": generated_text,
|
| 196 |
+
"gold_gen_text": item.get("gold_gen_text", ""),
|
| 197 |
+
})
|
| 198 |
+
return prepared
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def evaluate_single(
|
| 202 |
+
item: Dict[str, Any],
|
| 203 |
+
target_level_override: Optional[str] = None,
|
| 204 |
+
) -> Dict[str, Any]:
|
| 205 |
+
"""
|
| 206 |
+
Evaluate a single item and return detailed scores.
|
| 207 |
+
"""
|
| 208 |
+
fulltext = item.get("fulltext", "")
|
| 209 |
+
summary_text = item.get("summary_text") or item.get("summary", "")
|
| 210 |
+
summary_subclaims = item.get("summary_subclaims", [])
|
| 211 |
+
generated_text = item.get("generated_text") or item.get("predicted_gen_text", "")
|
| 212 |
+
target_level = target_level_override or item.get("label", "")
|
| 213 |
+
|
| 214 |
+
result: Dict[str, Any] = {
|
| 215 |
+
"doc_id": item.get("doc_id", ""),
|
| 216 |
+
"target_level": target_level,
|
| 217 |
+
"generated_text_len": len(generated_text.strip()) if generated_text else 0,
|
| 218 |
+
"factuality_score": None,
|
| 219 |
+
"hallucination_score": None,
|
| 220 |
+
"classifier_score": None,
|
| 221 |
+
"grounding_score": None,
|
| 222 |
+
"factuality_supported": 0,
|
| 223 |
+
"total_summary_subclaims": len(summary_subclaims) if summary_subclaims else 0,
|
| 224 |
+
"hallucination_supported": 0,
|
| 225 |
+
"total_gen_segments": 0,
|
| 226 |
+
"skipped": False,
|
| 227 |
+
"skip_reason": "",
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
if not generated_text or len(generated_text.strip()) < 10:
|
| 231 |
+
result["skipped"] = True
|
| 232 |
+
result["skip_reason"] = "generated_text missing or too short (<10 chars)"
|
| 233 |
+
return result
|
| 234 |
+
|
| 235 |
+
# -- Factuality & Hallucination via compute_rewards --
|
| 236 |
+
rewards = compute_rewards(
|
| 237 |
+
fulltext=fulltext,
|
| 238 |
+
generated_text=generated_text,
|
| 239 |
+
target_level=target_level,
|
| 240 |
+
summary_subclaims=summary_subclaims,
|
| 241 |
+
summary_text=summary_text,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
factuality_score = rewards["factuality_score"]
|
| 245 |
+
h_score = rewards["hallucination_score"]
|
| 246 |
+
|
| 247 |
+
if factuality_score is None:
|
| 248 |
+
factuality_score = 0.5
|
| 249 |
+
if h_score is None:
|
| 250 |
+
h_score = 0.5
|
| 251 |
+
|
| 252 |
+
grounding_score = _nonlinear_grounding(h_score)
|
| 253 |
+
|
| 254 |
+
# -- Classifier --
|
| 255 |
+
input_text = fulltext or ""
|
| 256 |
+
class_score = _compute_classifier_reward(target_level, generated_text, input_text)
|
| 257 |
+
|
| 258 |
+
result.update({
|
| 259 |
+
"factuality_score": round(factuality_score, 4),
|
| 260 |
+
"hallucination_score": round(h_score, 4),
|
| 261 |
+
"grounding_score": round(grounding_score, 4),
|
| 262 |
+
"classifier_score": round(class_score, 4),
|
| 263 |
+
"factuality_supported": rewards.get("factuality_supported", 0),
|
| 264 |
+
"total_summary_subclaims": rewards.get("total_summary_subclaims", 0),
|
| 265 |
+
"hallucination_supported": rewards.get("hallucination_supported", 0),
|
| 266 |
+
"total_gen_segments": rewards.get("total_gen_segments", 0),
|
| 267 |
+
})
|
| 268 |
+
|
| 269 |
+
return result
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def compute_aggregate(results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 273 |
+
"""Compute aggregate statistics over all evaluated items."""
|
| 274 |
+
scored = [r for r in results if not r.get("skipped", False)]
|
| 275 |
+
n = len(scored)
|
| 276 |
+
total = len(results)
|
| 277 |
+
skipped = total - n
|
| 278 |
+
|
| 279 |
+
if n == 0:
|
| 280 |
+
return {
|
| 281 |
+
"total_items": total,
|
| 282 |
+
"scored_items": 0,
|
| 283 |
+
"skipped_items": skipped,
|
| 284 |
+
"avg_factuality_score": None,
|
| 285 |
+
"avg_hallucination_score": None,
|
| 286 |
+
"avg_grounding_score": None,
|
| 287 |
+
"avg_classifier_score": None,
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
def safe_avg(key):
|
| 291 |
+
vals = [r[key] for r in scored if r[key] is not None]
|
| 292 |
+
return round(sum(vals) / len(vals), 4) if vals else None
|
| 293 |
+
|
| 294 |
+
return {
|
| 295 |
+
"total_items": total,
|
| 296 |
+
"scored_items": n,
|
| 297 |
+
"skipped_items": skipped,
|
| 298 |
+
"avg_factuality_score": safe_avg("factuality_score"),
|
| 299 |
+
"avg_hallucination_score": safe_avg("hallucination_score"),
|
| 300 |
+
"avg_grounding_score": safe_avg("grounding_score"),
|
| 301 |
+
"avg_classifier_score": safe_avg("classifier_score"),
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def main():
|
| 306 |
+
parser = argparse.ArgumentParser(
|
| 307 |
+
description="Evaluate factuality, hallucination, and classifier scores on a JSON file."
|
| 308 |
+
)
|
| 309 |
+
parser.add_argument(
|
| 310 |
+
"--input", "-i", required=True,
|
| 311 |
+
help="Path to input JSON file (list of objects).",
|
| 312 |
+
)
|
| 313 |
+
parser.add_argument(
|
| 314 |
+
"--output", "-o", default=None,
|
| 315 |
+
help="Path to output JSON file with per-item scores. "
|
| 316 |
+
"Defaults to <input_stem>_eval_results.json.",
|
| 317 |
+
)
|
| 318 |
+
parser.add_argument(
|
| 319 |
+
"--output-dir", default=None,
|
| 320 |
+
help="Directory to save output files. If set, output filename is derived "
|
| 321 |
+
"from input filename and placed in this directory.",
|
| 322 |
+
)
|
| 323 |
+
parser.add_argument(
|
| 324 |
+
"--subclaims", "-s", default=None,
|
| 325 |
+
help="Path to subclaims JSON file (for BON format). Provides fulltext, "
|
| 326 |
+
"summary, summary_subclaims, and fulltext_subclaims keyed by doc_id.",
|
| 327 |
+
)
|
| 328 |
+
parser.add_argument(
|
| 329 |
+
"--model-key", default="qwen3_base",
|
| 330 |
+
help="Key in the BON data containing candidates/best_summary (default: qwen3_base).",
|
| 331 |
+
)
|
| 332 |
+
parser.add_argument(
|
| 333 |
+
"--target-level", "-t", default=None,
|
| 334 |
+
help="Override target literacy level for all items "
|
| 335 |
+
"(e.g. low_health_literacy). If not set, uses each item's 'label' field.",
|
| 336 |
+
)
|
| 337 |
+
parser.add_argument(
|
| 338 |
+
"--support-check-url", default=None,
|
| 339 |
+
help="Override VLLM_SUPPORT_CHECK_BN_API_BASE.",
|
| 340 |
+
)
|
| 341 |
+
parser.add_argument(
|
| 342 |
+
"--classifier-url", default=None,
|
| 343 |
+
help="Override VLLM_CLASSIFIER_BN_API_BASE.",
|
| 344 |
+
)
|
| 345 |
+
parser.add_argument(
|
| 346 |
+
"--subclaim-extractor-url", default=None,
|
| 347 |
+
help="Override VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE.",
|
| 348 |
+
)
|
| 349 |
+
args = parser.parse_args()
|
| 350 |
+
|
| 351 |
+
if args.support_check_url:
|
| 352 |
+
os.environ["VLLM_SUPPORT_CHECK_BN_API_BASE"] = args.support_check_url
|
| 353 |
+
if args.classifier_url:
|
| 354 |
+
os.environ["VLLM_CLASSIFIER_BN_API_BASE"] = args.classifier_url
|
| 355 |
+
if args.subclaim_extractor_url:
|
| 356 |
+
os.environ["VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE"] = args.subclaim_extractor_url
|
| 357 |
+
|
| 358 |
+
# Load input
|
| 359 |
+
with open(args.input, "r", encoding="utf-8") as f:
|
| 360 |
+
raw_data = json.load(f)
|
| 361 |
+
|
| 362 |
+
if not isinstance(raw_data, list):
|
| 363 |
+
print(f"Error: Expected a JSON list, got {type(raw_data).__name__}.", file=sys.stderr)
|
| 364 |
+
sys.exit(1)
|
| 365 |
+
|
| 366 |
+
# Detect BON format: items have a model key (e.g. qwen3_base) with best_summary
|
| 367 |
+
is_bon = (
|
| 368 |
+
len(raw_data) > 0
|
| 369 |
+
and args.model_key in raw_data[0]
|
| 370 |
+
and "best_summary" in raw_data[0].get(args.model_key, {})
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
# Detect inference format: top-level doc_id, label, predicted_gen_text; no fulltext/summary_subclaims
|
| 374 |
+
is_inference = (
|
| 375 |
+
len(raw_data) > 0
|
| 376 |
+
and "doc_id" in raw_data[0]
|
| 377 |
+
and "label" in raw_data[0]
|
| 378 |
+
and "predicted_gen_text" in raw_data[0]
|
| 379 |
+
and raw_data[0].get("fulltext") is None
|
| 380 |
+
and raw_data[0].get("summary_subclaims") is None
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
# Detect self-refine format: doc_id, label, final_summary as gen text; no fulltext/summary_subclaims
|
| 384 |
+
is_self_refine = (
|
| 385 |
+
len(raw_data) > 0
|
| 386 |
+
and "doc_id" in raw_data[0]
|
| 387 |
+
and "label" in raw_data[0]
|
| 388 |
+
and "final_summary" in raw_data[0]
|
| 389 |
+
and raw_data[0].get("fulltext") is None
|
| 390 |
+
and raw_data[0].get("summary_subclaims") is None
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
if is_bon:
|
| 394 |
+
if not args.subclaims:
|
| 395 |
+
print("Error: BON format detected but --subclaims file not provided.", file=sys.stderr)
|
| 396 |
+
sys.exit(1)
|
| 397 |
+
with open(args.subclaims, "r", encoding="utf-8") as f:
|
| 398 |
+
subclaims_data = json.load(f)
|
| 399 |
+
print(f"BON format detected (model_key={args.model_key})")
|
| 400 |
+
print(f"Loaded {len(raw_data)} BON items from {args.input}")
|
| 401 |
+
print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}")
|
| 402 |
+
data = prepare_bon_items(raw_data, subclaims_data, model_key=args.model_key)
|
| 403 |
+
print(f"Prepared {len(data)} items for evaluation")
|
| 404 |
+
elif is_inference:
|
| 405 |
+
if not args.subclaims:
|
| 406 |
+
print("Error: Inference format detected (predicted_gen_text) but --subclaims file not provided.", file=sys.stderr)
|
| 407 |
+
sys.exit(1)
|
| 408 |
+
with open(args.subclaims, "r", encoding="utf-8") as f:
|
| 409 |
+
subclaims_data = json.load(f)
|
| 410 |
+
print("Inference format detected (predicted_gen_text as evaluated summary)")
|
| 411 |
+
print(f"Loaded {len(raw_data)} inference items from {args.input}")
|
| 412 |
+
print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}")
|
| 413 |
+
data = prepare_inference_items(raw_data, subclaims_data)
|
| 414 |
+
print(f"Prepared {len(data)} items for evaluation")
|
| 415 |
+
elif is_self_refine:
|
| 416 |
+
if not args.subclaims:
|
| 417 |
+
print("Error: Self-refine format detected (final_summary) but --subclaims file not provided.", file=sys.stderr)
|
| 418 |
+
sys.exit(1)
|
| 419 |
+
with open(args.subclaims, "r", encoding="utf-8") as f:
|
| 420 |
+
subclaims_data = json.load(f)
|
| 421 |
+
print("Self-refine format detected (final_summary as evaluated summary)")
|
| 422 |
+
print(f"Loaded {len(raw_data)} self-refine items from {args.input}")
|
| 423 |
+
print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}")
|
| 424 |
+
data = prepare_self_refine_items(raw_data, subclaims_data)
|
| 425 |
+
print(f"Prepared {len(data)} items for evaluation")
|
| 426 |
+
else:
|
| 427 |
+
data = raw_data
|
| 428 |
+
print(f"Loaded {len(data)} items from {args.input}")
|
| 429 |
+
|
| 430 |
+
print(f" Support check API : {os.getenv('VLLM_SUPPORT_CHECK_BN_API_BASE', 'http://localhost:8090/v1')}")
|
| 431 |
+
print(f" Classifier API : {os.getenv('VLLM_CLASSIFIER_BN_API_BASE', 'http://localhost:8040/v1')}")
|
| 432 |
+
print(f" Subclaim extractor: {os.getenv('VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE', 'http://localhost:8050/v1')}")
|
| 433 |
+
if args.target_level:
|
| 434 |
+
print(f" Target level override: {args.target_level}")
|
| 435 |
+
print("-" * 60)
|
| 436 |
+
|
| 437 |
+
# Evaluate each item
|
| 438 |
+
results = []
|
| 439 |
+
start_time = time.time()
|
| 440 |
+
for idx, item in enumerate(tqdm(data, desc="Evaluating")):
|
| 441 |
+
r = evaluate_single(item, target_level_override=args.target_level)
|
| 442 |
+
r["index"] = idx
|
| 443 |
+
r["doc_id"] = item.get("doc_id", "")
|
| 444 |
+
results.append(r)
|
| 445 |
+
|
| 446 |
+
if (idx + 1) % 10 == 0 or idx == 0:
|
| 447 |
+
partial_agg = compute_aggregate(results)
|
| 448 |
+
tqdm.write(
|
| 449 |
+
f" [{idx+1}/{len(data)}] "
|
| 450 |
+
f"fact={partial_agg['avg_factuality_score']} "
|
| 451 |
+
f"hallu={partial_agg['avg_hallucination_score']} "
|
| 452 |
+
f"cls={partial_agg['avg_classifier_score']}"
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
elapsed = time.time() - start_time
|
| 456 |
+
|
| 457 |
+
# --- Validation: all items must be evaluated with non-null scores ---
|
| 458 |
+
expected_count = len(data)
|
| 459 |
+
skipped_items = [r for r in results if r.get("skipped", False)]
|
| 460 |
+
null_score_items = []
|
| 461 |
+
for r in results:
|
| 462 |
+
if r.get("skipped", False):
|
| 463 |
+
continue
|
| 464 |
+
for key in ("factuality_score", "hallucination_score", "classifier_score", "grounding_score"):
|
| 465 |
+
if r.get(key) is None:
|
| 466 |
+
null_score_items.append((r.get("index"), r.get("doc_id"), key))
|
| 467 |
+
|
| 468 |
+
has_errors = False
|
| 469 |
+
if skipped_items:
|
| 470 |
+
has_errors = True
|
| 471 |
+
print(f"\nERROR: {len(skipped_items)} out of {expected_count} items were skipped:", file=sys.stderr)
|
| 472 |
+
for r in skipped_items:
|
| 473 |
+
print(f" index={r.get('index')} doc_id={r.get('doc_id')} reason={r.get('skip_reason')}", file=sys.stderr)
|
| 474 |
+
|
| 475 |
+
if null_score_items:
|
| 476 |
+
has_errors = True
|
| 477 |
+
print(f"\nERROR: {len(null_score_items)} null score(s) found:", file=sys.stderr)
|
| 478 |
+
for idx, doc_id, key in null_score_items:
|
| 479 |
+
print(f" index={idx} doc_id={doc_id} null_field={key}", file=sys.stderr)
|
| 480 |
+
|
| 481 |
+
if len(results) != expected_count:
|
| 482 |
+
has_errors = True
|
| 483 |
+
print(f"\nERROR: Expected {expected_count} results but got {len(results)}.", file=sys.stderr)
|
| 484 |
+
|
| 485 |
+
if has_errors:
|
| 486 |
+
print(f"\nAborting: will NOT save results. All {expected_count} items must be fully evaluated with non-null scores.", file=sys.stderr)
|
| 487 |
+
sys.exit(1)
|
| 488 |
+
|
| 489 |
+
# Aggregate
|
| 490 |
+
agg = compute_aggregate(results)
|
| 491 |
+
|
| 492 |
+
# Per-label aggregates
|
| 493 |
+
label_groups: Dict[str, List[Dict[str, Any]]] = {}
|
| 494 |
+
for r in results:
|
| 495 |
+
lbl = r.get("target_level", "unknown")
|
| 496 |
+
label_groups.setdefault(lbl, []).append(r)
|
| 497 |
+
per_label_agg = {lbl: compute_aggregate(items) for lbl, items in sorted(label_groups.items())}
|
| 498 |
+
|
| 499 |
+
# Output path
|
| 500 |
+
if args.output:
|
| 501 |
+
out_path = args.output
|
| 502 |
+
elif args.output_dir:
|
| 503 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 504 |
+
stem = os.path.splitext(os.path.basename(args.input))[0]
|
| 505 |
+
out_path = os.path.join(args.output_dir, f"{stem}_eval_results.json")
|
| 506 |
+
else:
|
| 507 |
+
stem = os.path.splitext(os.path.basename(args.input))[0]
|
| 508 |
+
out_dir = os.path.dirname(args.input) or "."
|
| 509 |
+
out_path = os.path.join(out_dir, f"{stem}_eval_results.json")
|
| 510 |
+
|
| 511 |
+
os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
|
| 512 |
+
|
| 513 |
+
output = {
|
| 514 |
+
"input_file": os.path.abspath(args.input),
|
| 515 |
+
"subclaims_file": os.path.abspath(args.subclaims) if args.subclaims else None,
|
| 516 |
+
"model_key": args.model_key if is_bon else None,
|
| 517 |
+
"inference_format": is_inference if not is_bon else False,
|
| 518 |
+
"self_refine_format": is_self_refine if not is_bon and not is_inference else False,
|
| 519 |
+
"target_level_override": args.target_level,
|
| 520 |
+
"elapsed_seconds": round(elapsed, 2),
|
| 521 |
+
"aggregate": agg,
|
| 522 |
+
"per_label_aggregate": per_label_agg,
|
| 523 |
+
"per_item": results,
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
| 527 |
+
json.dump(output, f, indent=2, ensure_ascii=False)
|
| 528 |
+
|
| 529 |
+
# Print summary
|
| 530 |
+
print("\n" + "=" * 60)
|
| 531 |
+
print("EVALUATION SUMMARY")
|
| 532 |
+
print("=" * 60)
|
| 533 |
+
print(f" Total items : {agg['total_items']}")
|
| 534 |
+
print(f" Scored items : {agg['scored_items']}")
|
| 535 |
+
print(f" Skipped items : {agg['skipped_items']}")
|
| 536 |
+
print(f" Elapsed time : {round(elapsed, 1)}s")
|
| 537 |
+
print("-" * 60)
|
| 538 |
+
print(f" Avg Factuality Score : {agg['avg_factuality_score']}")
|
| 539 |
+
print(f" Avg Hallucination Score: {agg['avg_hallucination_score']}")
|
| 540 |
+
print(f" Avg Grounding Score : {agg['avg_grounding_score']}")
|
| 541 |
+
print(f" Avg Classifier Score : {agg['avg_classifier_score']}")
|
| 542 |
+
print("-" * 60)
|
| 543 |
+
for lbl, la in per_label_agg.items():
|
| 544 |
+
print(f" [{lbl}] items={la['scored_items']}"
|
| 545 |
+
f" fact={la['avg_factuality_score']}"
|
| 546 |
+
f" hallu={la['avg_hallucination_score']}"
|
| 547 |
+
f" cls={la['avg_classifier_score']}")
|
| 548 |
+
print("-" * 60)
|
| 549 |
+
print(f" Results saved to: {out_path}")
|
| 550 |
+
print("=" * 60)
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
if __name__ == "__main__":
|
| 554 |
+
main()
|
code/fine_tune_sft_dpo/evaluate_scores_bn_vllm.py
ADDED
|
@@ -0,0 +1,611 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Standalone evaluation script for computing factuality, hallucination, and
|
| 4 |
+
classifier scores on a JSON/JSONL file.
|
| 5 |
+
|
| 6 |
+
Supports these input formats:
|
| 7 |
+
|
| 8 |
+
1. **Standard format** — a list of objects, each with:
|
| 9 |
+
- fulltext, summary_text, summary_subclaims, generated_text, label
|
| 10 |
+
|
| 11 |
+
2. **Best-of-N (BON) format** — a list of objects, each with:
|
| 12 |
+
- doc_id, label, qwen3_base.best_summary (JSON-wrapped generated text)
|
| 13 |
+
Requires a separate --subclaims file to supply fulltext, summary,
|
| 14 |
+
summary_subclaims, and fulltext_subclaims (keyed by doc_id).
|
| 15 |
+
|
| 16 |
+
3. **Inference format** — a list of objects, each with:
|
| 17 |
+
- doc_id, label, predicted_gen_text (JSON-wrapped evaluated summary),
|
| 18 |
+
optionally gold_gen_text
|
| 19 |
+
predicted_gen_text is the summary to evaluate (same JSON key-by-label
|
| 20 |
+
format as best_summary). Requires --subclaims for fulltext and subclaims.
|
| 21 |
+
|
| 22 |
+
4. **Self-refine format** — a list of objects, each with:
|
| 23 |
+
- doc_id, label, final_summary (the generated text to evaluate),
|
| 24 |
+
optionally gold_gen_text, gold_summary
|
| 25 |
+
final_summary is the summary to evaluate (plain text or JSON-wrapped by
|
| 26 |
+
label). Requires --subclaims for fulltext and subclaims.
|
| 27 |
+
|
| 28 |
+
5. **RL inference format** (JSONL) — one JSON object per line, each with:
|
| 29 |
+
- doc_id, gold_label, input_text, summary_text, subclaims, generated_text
|
| 30 |
+
Self-contained: no --subclaims file needed. Field mapping:
|
| 31 |
+
gold_label -> label, input_text -> fulltext, subclaims -> summary_subclaims
|
| 32 |
+
|
| 33 |
+
The script reuses the reward functions from reward_new_v6_bn_v4_rmv_src_cov.py:
|
| 34 |
+
- factuality_score : fraction of summary subclaims supported by generated_text
|
| 35 |
+
- hallucination_score: fraction of gen subclaims NOT supported by fulltext
|
| 36 |
+
- classifier_score : whether generated_text matches the target literacy level
|
| 37 |
+
|
| 38 |
+
Requires the same vLLM endpoints as the reward file:
|
| 39 |
+
- Support checker : VLLM_SUPPORT_CHECK_BN_API_BASE (default http://localhost:8090/v1)
|
| 40 |
+
- Classifier : VLLM_CLASSIFIER_BN_API_BASE (default http://localhost:8040/v1)
|
| 41 |
+
- Subclaim extractor: VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE (default http://localhost:8050/v1)
|
| 42 |
+
|
| 43 |
+
Usage:
|
| 44 |
+
# Standard format
|
| 45 |
+
python evaluate_scores.py --input data.json [--output results.json]
|
| 46 |
+
|
| 47 |
+
# BON format with subclaims file
|
| 48 |
+
python evaluate_scores.py --input bon_results.json --subclaims subclaims.json --output-dir evaluation/bn/
|
| 49 |
+
|
| 50 |
+
# Inference format (predicted_gen_text as evaluated summary)
|
| 51 |
+
python evaluate_scores.py --input test_inference_vllm_qwen3-4B_base.json --subclaims subclaims.json --output results.json
|
| 52 |
+
|
| 53 |
+
# Self-refine format (final_summary as evaluated summary)
|
| 54 |
+
python evaluate_scores.py --input test_self_refine_vllm_qwen3_4B_base.json --subclaims subclaims.json --output-dir evaluation/bn/
|
| 55 |
+
|
| 56 |
+
# RL inference format (JSONL, self-contained)
|
| 57 |
+
python evaluate_scores.py --input bn_200.jsonl --output-dir evaluation/bn/
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
import argparse
|
| 61 |
+
import json
|
| 62 |
+
import os
|
| 63 |
+
import re
|
| 64 |
+
import sys
|
| 65 |
+
import time
|
| 66 |
+
from typing import Any, Dict, List, Optional
|
| 67 |
+
|
| 68 |
+
from tqdm import tqdm
|
| 69 |
+
|
| 70 |
+
# Import scoring utilities from the reward module (same directory).
|
| 71 |
+
from reward_new_v6_bn_v4_rmv_src_cov import (
|
| 72 |
+
_call_support_api,
|
| 73 |
+
_compute_classifier_reward,
|
| 74 |
+
_extract_subclaims_from_text,
|
| 75 |
+
_is_bangla_text,
|
| 76 |
+
_nonlinear_grounding,
|
| 77 |
+
compute_rewards,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def extract_text_from_best_summary(best_summary: str, label: str) -> str:
|
| 82 |
+
"""Extract the raw generated text from a BON best_summary string.
|
| 83 |
+
|
| 84 |
+
The best_summary is a (possibly truncated) JSON string like:
|
| 85 |
+
'{"proficient_health_literacy": "...actual text..."}'
|
| 86 |
+
We locate the value after the label key and strip JSON wrapping.
|
| 87 |
+
"""
|
| 88 |
+
key_pattern = re.compile(re.escape(f'"{label}"') + r'\s*:\s*"')
|
| 89 |
+
m = key_pattern.search(best_summary)
|
| 90 |
+
if not m:
|
| 91 |
+
return best_summary.strip()
|
| 92 |
+
text = best_summary[m.end():]
|
| 93 |
+
if text.endswith('"\n}'):
|
| 94 |
+
text = text[:-3]
|
| 95 |
+
elif text.endswith('"}\n'):
|
| 96 |
+
text = text[:-3]
|
| 97 |
+
elif text.endswith('"}'):
|
| 98 |
+
text = text[:-2]
|
| 99 |
+
elif text.endswith('"'):
|
| 100 |
+
text = text[:-1]
|
| 101 |
+
text = text.replace("\\n", "\n").replace('\\"', '"')
|
| 102 |
+
return text.strip()
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def prepare_bon_items(
|
| 106 |
+
bon_data: List[Dict[str, Any]],
|
| 107 |
+
subclaims_data: List[Dict[str, Any]],
|
| 108 |
+
model_key: str = "qwen3_base",
|
| 109 |
+
) -> List[Dict[str, Any]]:
|
| 110 |
+
"""Merge BON results with subclaims data into the standard evaluation format."""
|
| 111 |
+
sc_by_docid = {}
|
| 112 |
+
for item in subclaims_data:
|
| 113 |
+
sc_by_docid[item["doc_id"]] = item
|
| 114 |
+
|
| 115 |
+
prepared = []
|
| 116 |
+
for item in bon_data:
|
| 117 |
+
doc_id = item["doc_id"]
|
| 118 |
+
label = item["label"]
|
| 119 |
+
sc = sc_by_docid.get(doc_id, {})
|
| 120 |
+
|
| 121 |
+
model_data = item.get(model_key, {})
|
| 122 |
+
best_summary = model_data.get("best_summary", "") or model_data.get("predicted_gen_text", "")
|
| 123 |
+
generated_text = extract_text_from_best_summary(best_summary, label)
|
| 124 |
+
|
| 125 |
+
prepared.append({
|
| 126 |
+
"doc_id": doc_id,
|
| 127 |
+
"label": label,
|
| 128 |
+
"fulltext": sc.get("fulltext", ""),
|
| 129 |
+
"summary_text": sc.get("summary", ""),
|
| 130 |
+
"summary_subclaims": sc.get("summary_subclaims", []),
|
| 131 |
+
"fulltext_subclaims": sc.get("fulltext_subclaims", []),
|
| 132 |
+
"generated_text": generated_text,
|
| 133 |
+
"gold_gen_text": item.get("gold_gen_text", ""),
|
| 134 |
+
"predicted_label": item.get("predicted_label", ""),
|
| 135 |
+
"prediction_correct": item.get("prediction_correct", False),
|
| 136 |
+
})
|
| 137 |
+
return prepared
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def prepare_inference_items(
|
| 141 |
+
inference_data: List[Dict[str, Any]],
|
| 142 |
+
subclaims_data: List[Dict[str, Any]],
|
| 143 |
+
) -> List[Dict[str, Any]]:
|
| 144 |
+
"""Merge inference-format results (doc_id, label, predicted_gen_text) with
|
| 145 |
+
subclaims data into the standard evaluation format. predicted_gen_text is
|
| 146 |
+
the JSON-wrapped evaluated summary; the raw text is extracted using the
|
| 147 |
+
item's label.
|
| 148 |
+
"""
|
| 149 |
+
sc_by_docid = {}
|
| 150 |
+
for item in subclaims_data:
|
| 151 |
+
sc_by_docid[item["doc_id"]] = item
|
| 152 |
+
|
| 153 |
+
prepared = []
|
| 154 |
+
for item in inference_data:
|
| 155 |
+
doc_id = item["doc_id"]
|
| 156 |
+
label = item["label"]
|
| 157 |
+
sc = sc_by_docid.get(doc_id, {})
|
| 158 |
+
|
| 159 |
+
raw_pred = item.get("predicted_gen_text", "") or ""
|
| 160 |
+
generated_text = extract_text_from_best_summary(raw_pred, label)
|
| 161 |
+
|
| 162 |
+
prepared.append({
|
| 163 |
+
"doc_id": doc_id,
|
| 164 |
+
"label": label,
|
| 165 |
+
"fulltext": sc.get("fulltext", ""),
|
| 166 |
+
"summary_text": sc.get("summary", ""),
|
| 167 |
+
"summary_subclaims": sc.get("summary_subclaims", []),
|
| 168 |
+
"fulltext_subclaims": sc.get("fulltext_subclaims", []),
|
| 169 |
+
"generated_text": generated_text,
|
| 170 |
+
"gold_gen_text": item.get("gold_gen_text", ""),
|
| 171 |
+
})
|
| 172 |
+
return prepared
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def prepare_self_refine_items(
|
| 176 |
+
self_refine_data: List[Dict[str, Any]],
|
| 177 |
+
subclaims_data: List[Dict[str, Any]],
|
| 178 |
+
) -> List[Dict[str, Any]]:
|
| 179 |
+
"""Merge self-refine format (doc_id, label, final_summary) with subclaims
|
| 180 |
+
data. final_summary is the generated text to evaluate (plain text or
|
| 181 |
+
JSON-wrapped by label); it is extracted and used as generated_text.
|
| 182 |
+
"""
|
| 183 |
+
sc_by_docid = {}
|
| 184 |
+
for item in subclaims_data:
|
| 185 |
+
sc_by_docid[item["doc_id"]] = item
|
| 186 |
+
|
| 187 |
+
prepared = []
|
| 188 |
+
for item in self_refine_data:
|
| 189 |
+
doc_id = item["doc_id"]
|
| 190 |
+
label = item["label"]
|
| 191 |
+
sc = sc_by_docid.get(doc_id, {})
|
| 192 |
+
|
| 193 |
+
raw_final = item.get("final_summary", "") or ""
|
| 194 |
+
generated_text = extract_text_from_best_summary(raw_final, label)
|
| 195 |
+
|
| 196 |
+
prepared.append({
|
| 197 |
+
"doc_id": doc_id,
|
| 198 |
+
"label": label,
|
| 199 |
+
"fulltext": sc.get("fulltext", ""),
|
| 200 |
+
"summary_text": sc.get("summary", ""),
|
| 201 |
+
"summary_subclaims": sc.get("summary_subclaims", []),
|
| 202 |
+
"fulltext_subclaims": sc.get("fulltext_subclaims", []),
|
| 203 |
+
"generated_text": generated_text,
|
| 204 |
+
"gold_gen_text": item.get("gold_gen_text", ""),
|
| 205 |
+
})
|
| 206 |
+
return prepared
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def prepare_rl_inference_items(
|
| 210 |
+
rl_data: List[Dict[str, Any]],
|
| 211 |
+
) -> List[Dict[str, Any]]:
|
| 212 |
+
"""Convert RL inference JSONL items into the standard evaluation format.
|
| 213 |
+
|
| 214 |
+
Field mapping:
|
| 215 |
+
gold_label -> label
|
| 216 |
+
input_text -> fulltext
|
| 217 |
+
subclaims -> summary_subclaims
|
| 218 |
+
summary_text -> summary_text
|
| 219 |
+
generated_text -> generated_text (plain text, used as-is)
|
| 220 |
+
"""
|
| 221 |
+
prepared = []
|
| 222 |
+
for item in rl_data:
|
| 223 |
+
prepared.append({
|
| 224 |
+
"doc_id": item.get("doc_id", ""),
|
| 225 |
+
"label": item.get("gold_label", ""),
|
| 226 |
+
"fulltext": item.get("input_text", ""),
|
| 227 |
+
"summary_text": item.get("summary_text", ""),
|
| 228 |
+
"summary_subclaims": item.get("subclaims", []),
|
| 229 |
+
"generated_text": item.get("generated_text", ""),
|
| 230 |
+
})
|
| 231 |
+
return prepared
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def evaluate_single(
|
| 235 |
+
item: Dict[str, Any],
|
| 236 |
+
target_level_override: Optional[str] = None,
|
| 237 |
+
) -> Dict[str, Any]:
|
| 238 |
+
"""
|
| 239 |
+
Evaluate a single item and return detailed scores.
|
| 240 |
+
"""
|
| 241 |
+
fulltext = item.get("fulltext", "")
|
| 242 |
+
summary_text = item.get("summary_text") or item.get("summary", "")
|
| 243 |
+
summary_subclaims = item.get("summary_subclaims", [])
|
| 244 |
+
generated_text = item.get("generated_text") or item.get("predicted_gen_text", "")
|
| 245 |
+
target_level = target_level_override or item.get("label", "")
|
| 246 |
+
|
| 247 |
+
result: Dict[str, Any] = {
|
| 248 |
+
"doc_id": item.get("doc_id", ""),
|
| 249 |
+
"target_level": target_level,
|
| 250 |
+
"generated_text_len": len(generated_text.strip()) if generated_text else 0,
|
| 251 |
+
"factuality_score": None,
|
| 252 |
+
"hallucination_score": None,
|
| 253 |
+
"classifier_score": None,
|
| 254 |
+
"grounding_score": None,
|
| 255 |
+
"factuality_supported": 0,
|
| 256 |
+
"total_summary_subclaims": len(summary_subclaims) if summary_subclaims else 0,
|
| 257 |
+
"hallucination_supported": 0,
|
| 258 |
+
"total_gen_segments": 0,
|
| 259 |
+
"skipped": False,
|
| 260 |
+
"skip_reason": "",
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
if not generated_text or len(generated_text.strip()) < 10:
|
| 264 |
+
result["skipped"] = True
|
| 265 |
+
result["skip_reason"] = "generated_text missing or too short (<10 chars)"
|
| 266 |
+
return result
|
| 267 |
+
|
| 268 |
+
# -- Factuality & Hallucination via compute_rewards --
|
| 269 |
+
rewards = compute_rewards(
|
| 270 |
+
fulltext=fulltext,
|
| 271 |
+
generated_text=generated_text,
|
| 272 |
+
target_level=target_level,
|
| 273 |
+
summary_subclaims=summary_subclaims,
|
| 274 |
+
summary_text=summary_text,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
factuality_score = rewards["factuality_score"]
|
| 278 |
+
h_score = rewards["hallucination_score"]
|
| 279 |
+
|
| 280 |
+
if factuality_score is None:
|
| 281 |
+
factuality_score = 0.5
|
| 282 |
+
if h_score is None:
|
| 283 |
+
h_score = 0.5
|
| 284 |
+
|
| 285 |
+
grounding_score = _nonlinear_grounding(h_score)
|
| 286 |
+
|
| 287 |
+
# -- Classifier --
|
| 288 |
+
input_text = fulltext or ""
|
| 289 |
+
class_score = _compute_classifier_reward(target_level, generated_text, input_text)
|
| 290 |
+
|
| 291 |
+
result.update({
|
| 292 |
+
"factuality_score": round(factuality_score, 4),
|
| 293 |
+
"hallucination_score": round(h_score, 4),
|
| 294 |
+
"grounding_score": round(grounding_score, 4),
|
| 295 |
+
"classifier_score": round(class_score, 4),
|
| 296 |
+
"factuality_supported": rewards.get("factuality_supported", 0),
|
| 297 |
+
"total_summary_subclaims": rewards.get("total_summary_subclaims", 0),
|
| 298 |
+
"hallucination_supported": rewards.get("hallucination_supported", 0),
|
| 299 |
+
"total_gen_segments": rewards.get("total_gen_segments", 0),
|
| 300 |
+
})
|
| 301 |
+
|
| 302 |
+
return result
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def compute_aggregate(results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 306 |
+
"""Compute aggregate statistics over all evaluated items."""
|
| 307 |
+
scored = [r for r in results if not r.get("skipped", False)]
|
| 308 |
+
n = len(scored)
|
| 309 |
+
total = len(results)
|
| 310 |
+
skipped = total - n
|
| 311 |
+
|
| 312 |
+
if n == 0:
|
| 313 |
+
return {
|
| 314 |
+
"total_items": total,
|
| 315 |
+
"scored_items": 0,
|
| 316 |
+
"skipped_items": skipped,
|
| 317 |
+
"avg_factuality_score": None,
|
| 318 |
+
"avg_hallucination_score": None,
|
| 319 |
+
"avg_grounding_score": None,
|
| 320 |
+
"avg_classifier_score": None,
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
def safe_avg(key):
|
| 324 |
+
vals = [r[key] for r in scored if r[key] is not None]
|
| 325 |
+
return round(sum(vals) / len(vals), 4) if vals else None
|
| 326 |
+
|
| 327 |
+
return {
|
| 328 |
+
"total_items": total,
|
| 329 |
+
"scored_items": n,
|
| 330 |
+
"skipped_items": skipped,
|
| 331 |
+
"avg_factuality_score": safe_avg("factuality_score"),
|
| 332 |
+
"avg_hallucination_score": safe_avg("hallucination_score"),
|
| 333 |
+
"avg_grounding_score": safe_avg("grounding_score"),
|
| 334 |
+
"avg_classifier_score": safe_avg("classifier_score"),
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def main():
|
| 339 |
+
parser = argparse.ArgumentParser(
|
| 340 |
+
description="Evaluate factuality, hallucination, and classifier scores on a JSON file."
|
| 341 |
+
)
|
| 342 |
+
parser.add_argument(
|
| 343 |
+
"--input", "-i", required=True,
|
| 344 |
+
help="Path to input JSON file (list of objects).",
|
| 345 |
+
)
|
| 346 |
+
parser.add_argument(
|
| 347 |
+
"--output", "-o", default=None,
|
| 348 |
+
help="Path to output JSON file with per-item scores. "
|
| 349 |
+
"Defaults to <input_stem>_eval_results.json.",
|
| 350 |
+
)
|
| 351 |
+
parser.add_argument(
|
| 352 |
+
"--output-dir", default=None,
|
| 353 |
+
help="Directory to save output files. If set, output filename is derived "
|
| 354 |
+
"from input filename and placed in this directory.",
|
| 355 |
+
)
|
| 356 |
+
parser.add_argument(
|
| 357 |
+
"--subclaims", "-s", default=None,
|
| 358 |
+
help="Path to subclaims JSON file (for BON format). Provides fulltext, "
|
| 359 |
+
"summary, summary_subclaims, and fulltext_subclaims keyed by doc_id.",
|
| 360 |
+
)
|
| 361 |
+
parser.add_argument(
|
| 362 |
+
"--model-key", default="qwen3_base",
|
| 363 |
+
help="Key in the BON data containing candidates/best_summary (default: qwen3_base).",
|
| 364 |
+
)
|
| 365 |
+
parser.add_argument(
|
| 366 |
+
"--target-level", "-t", default=None,
|
| 367 |
+
help="Override target literacy level for all items "
|
| 368 |
+
"(e.g. low_health_literacy). If not set, uses each item's 'label' field.",
|
| 369 |
+
)
|
| 370 |
+
parser.add_argument(
|
| 371 |
+
"--support-check-url", default=None,
|
| 372 |
+
help="Override VLLM_SUPPORT_CHECK_BN_API_BASE.",
|
| 373 |
+
)
|
| 374 |
+
parser.add_argument(
|
| 375 |
+
"--classifier-url", default=None,
|
| 376 |
+
help="Override VLLM_CLASSIFIER_BN_API_BASE.",
|
| 377 |
+
)
|
| 378 |
+
parser.add_argument(
|
| 379 |
+
"--subclaim-extractor-url", default=None,
|
| 380 |
+
help="Override VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE.",
|
| 381 |
+
)
|
| 382 |
+
args = parser.parse_args()
|
| 383 |
+
|
| 384 |
+
if args.support_check_url:
|
| 385 |
+
os.environ["VLLM_SUPPORT_CHECK_BN_API_BASE"] = args.support_check_url
|
| 386 |
+
if args.classifier_url:
|
| 387 |
+
os.environ["VLLM_CLASSIFIER_BN_API_BASE"] = args.classifier_url
|
| 388 |
+
if args.subclaim_extractor_url:
|
| 389 |
+
os.environ["VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE"] = args.subclaim_extractor_url
|
| 390 |
+
|
| 391 |
+
# Load input (JSON list or JSONL)
|
| 392 |
+
if args.input.endswith(".jsonl"):
|
| 393 |
+
raw_data = []
|
| 394 |
+
with open(args.input, "r", encoding="utf-8") as f:
|
| 395 |
+
for line in f:
|
| 396 |
+
line = line.strip()
|
| 397 |
+
if line:
|
| 398 |
+
raw_data.append(json.loads(line))
|
| 399 |
+
else:
|
| 400 |
+
with open(args.input, "r", encoding="utf-8") as f:
|
| 401 |
+
raw_data = json.load(f)
|
| 402 |
+
|
| 403 |
+
if not isinstance(raw_data, list):
|
| 404 |
+
print(f"Error: Expected a JSON list, got {type(raw_data).__name__}.", file=sys.stderr)
|
| 405 |
+
sys.exit(1)
|
| 406 |
+
|
| 407 |
+
# Detect BON format: items have a model key (e.g. qwen3_base) with best_summary
|
| 408 |
+
is_bon = (
|
| 409 |
+
len(raw_data) > 0
|
| 410 |
+
and args.model_key in raw_data[0]
|
| 411 |
+
and "best_summary" in raw_data[0].get(args.model_key, {})
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
# Detect inference format: top-level doc_id, label, predicted_gen_text; no fulltext/summary_subclaims
|
| 415 |
+
is_inference = (
|
| 416 |
+
len(raw_data) > 0
|
| 417 |
+
and "doc_id" in raw_data[0]
|
| 418 |
+
and "label" in raw_data[0]
|
| 419 |
+
and "predicted_gen_text" in raw_data[0]
|
| 420 |
+
and raw_data[0].get("fulltext") is None
|
| 421 |
+
and raw_data[0].get("summary_subclaims") is None
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
# Detect self-refine format: doc_id, label, final_summary as gen text; no fulltext/summary_subclaims
|
| 425 |
+
is_self_refine = (
|
| 426 |
+
len(raw_data) > 0
|
| 427 |
+
and "doc_id" in raw_data[0]
|
| 428 |
+
and "label" in raw_data[0]
|
| 429 |
+
and "final_summary" in raw_data[0]
|
| 430 |
+
and raw_data[0].get("fulltext") is None
|
| 431 |
+
and raw_data[0].get("summary_subclaims") is None
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
# Detect RL inference format: gold_label, input_text, subclaims, generated_text
|
| 435 |
+
is_rl_inference = (
|
| 436 |
+
len(raw_data) > 0
|
| 437 |
+
and "doc_id" in raw_data[0]
|
| 438 |
+
and "gold_label" in raw_data[0]
|
| 439 |
+
and "input_text" in raw_data[0]
|
| 440 |
+
and "generated_text" in raw_data[0]
|
| 441 |
+
and "subclaims" in raw_data[0]
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
if is_rl_inference:
|
| 445 |
+
print("RL inference format detected (gold_label, input_text, subclaims, generated_text)")
|
| 446 |
+
print(f"Loaded {len(raw_data)} RL inference items from {args.input}")
|
| 447 |
+
data = prepare_rl_inference_items(raw_data)
|
| 448 |
+
print(f"Prepared {len(data)} items for evaluation")
|
| 449 |
+
elif is_bon:
|
| 450 |
+
if not args.subclaims:
|
| 451 |
+
print("Error: BON format detected but --subclaims file not provided.", file=sys.stderr)
|
| 452 |
+
sys.exit(1)
|
| 453 |
+
with open(args.subclaims, "r", encoding="utf-8") as f:
|
| 454 |
+
subclaims_data = json.load(f)
|
| 455 |
+
print(f"BON format detected (model_key={args.model_key})")
|
| 456 |
+
print(f"Loaded {len(raw_data)} BON items from {args.input}")
|
| 457 |
+
print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}")
|
| 458 |
+
data = prepare_bon_items(raw_data, subclaims_data, model_key=args.model_key)
|
| 459 |
+
print(f"Prepared {len(data)} items for evaluation")
|
| 460 |
+
elif is_inference:
|
| 461 |
+
if not args.subclaims:
|
| 462 |
+
print("Error: Inference format detected (predicted_gen_text) but --subclaims file not provided.", file=sys.stderr)
|
| 463 |
+
sys.exit(1)
|
| 464 |
+
with open(args.subclaims, "r", encoding="utf-8") as f:
|
| 465 |
+
subclaims_data = json.load(f)
|
| 466 |
+
print("Inference format detected (predicted_gen_text as evaluated summary)")
|
| 467 |
+
print(f"Loaded {len(raw_data)} inference items from {args.input}")
|
| 468 |
+
print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}")
|
| 469 |
+
data = prepare_inference_items(raw_data, subclaims_data)
|
| 470 |
+
print(f"Prepared {len(data)} items for evaluation")
|
| 471 |
+
elif is_self_refine:
|
| 472 |
+
if not args.subclaims:
|
| 473 |
+
print("Error: Self-refine format detected (final_summary) but --subclaims file not provided.", file=sys.stderr)
|
| 474 |
+
sys.exit(1)
|
| 475 |
+
with open(args.subclaims, "r", encoding="utf-8") as f:
|
| 476 |
+
subclaims_data = json.load(f)
|
| 477 |
+
print("Self-refine format detected (final_summary as evaluated summary)")
|
| 478 |
+
print(f"Loaded {len(raw_data)} self-refine items from {args.input}")
|
| 479 |
+
print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}")
|
| 480 |
+
data = prepare_self_refine_items(raw_data, subclaims_data)
|
| 481 |
+
print(f"Prepared {len(data)} items for evaluation")
|
| 482 |
+
else:
|
| 483 |
+
data = raw_data
|
| 484 |
+
print(f"Loaded {len(data)} items from {args.input}")
|
| 485 |
+
|
| 486 |
+
print(f" Support check API : {os.getenv('VLLM_SUPPORT_CHECK_BN_API_BASE', 'http://localhost:8090/v1')}")
|
| 487 |
+
print(f" Classifier API : {os.getenv('VLLM_CLASSIFIER_BN_API_BASE', 'http://localhost:8040/v1')}")
|
| 488 |
+
print(f" Subclaim extractor: {os.getenv('VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE', 'http://localhost:8050/v1')}")
|
| 489 |
+
if args.target_level:
|
| 490 |
+
print(f" Target level override: {args.target_level}")
|
| 491 |
+
print("-" * 60)
|
| 492 |
+
|
| 493 |
+
# Evaluate each item
|
| 494 |
+
results = []
|
| 495 |
+
start_time = time.time()
|
| 496 |
+
for idx, item in enumerate(tqdm(data, desc="Evaluating")):
|
| 497 |
+
r = evaluate_single(item, target_level_override=args.target_level)
|
| 498 |
+
r["index"] = idx
|
| 499 |
+
r["doc_id"] = item.get("doc_id", "")
|
| 500 |
+
results.append(r)
|
| 501 |
+
|
| 502 |
+
if (idx + 1) % 10 == 0 or idx == 0:
|
| 503 |
+
partial_agg = compute_aggregate(results)
|
| 504 |
+
tqdm.write(
|
| 505 |
+
f" [{idx+1}/{len(data)}] "
|
| 506 |
+
f"fact={partial_agg['avg_factuality_score']} "
|
| 507 |
+
f"hallu={partial_agg['avg_hallucination_score']} "
|
| 508 |
+
f"cls={partial_agg['avg_classifier_score']}"
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
elapsed = time.time() - start_time
|
| 512 |
+
|
| 513 |
+
# --- Validation: all items must be evaluated with non-null scores ---
|
| 514 |
+
expected_count = len(data)
|
| 515 |
+
skipped_items = [r for r in results if r.get("skipped", False)]
|
| 516 |
+
null_score_items = []
|
| 517 |
+
for r in results:
|
| 518 |
+
if r.get("skipped", False):
|
| 519 |
+
continue
|
| 520 |
+
for key in ("factuality_score", "hallucination_score", "classifier_score", "grounding_score"):
|
| 521 |
+
if r.get(key) is None:
|
| 522 |
+
null_score_items.append((r.get("index"), r.get("doc_id"), key))
|
| 523 |
+
|
| 524 |
+
has_errors = False
|
| 525 |
+
if skipped_items:
|
| 526 |
+
has_errors = True
|
| 527 |
+
print(f"\nERROR: {len(skipped_items)} out of {expected_count} items were skipped:", file=sys.stderr)
|
| 528 |
+
for r in skipped_items:
|
| 529 |
+
print(f" index={r.get('index')} doc_id={r.get('doc_id')} reason={r.get('skip_reason')}", file=sys.stderr)
|
| 530 |
+
|
| 531 |
+
if null_score_items:
|
| 532 |
+
has_errors = True
|
| 533 |
+
print(f"\nERROR: {len(null_score_items)} null score(s) found:", file=sys.stderr)
|
| 534 |
+
for idx, doc_id, key in null_score_items:
|
| 535 |
+
print(f" index={idx} doc_id={doc_id} null_field={key}", file=sys.stderr)
|
| 536 |
+
|
| 537 |
+
if len(results) != expected_count:
|
| 538 |
+
has_errors = True
|
| 539 |
+
print(f"\nERROR: Expected {expected_count} results but got {len(results)}.", file=sys.stderr)
|
| 540 |
+
|
| 541 |
+
if has_errors:
|
| 542 |
+
print(f"\nAborting: will NOT save results. All {expected_count} items must be fully evaluated with non-null scores.", file=sys.stderr)
|
| 543 |
+
sys.exit(1)
|
| 544 |
+
|
| 545 |
+
# Aggregate
|
| 546 |
+
agg = compute_aggregate(results)
|
| 547 |
+
|
| 548 |
+
# Per-label aggregates
|
| 549 |
+
label_groups: Dict[str, List[Dict[str, Any]]] = {}
|
| 550 |
+
for r in results:
|
| 551 |
+
lbl = r.get("target_level", "unknown")
|
| 552 |
+
label_groups.setdefault(lbl, []).append(r)
|
| 553 |
+
per_label_agg = {lbl: compute_aggregate(items) for lbl, items in sorted(label_groups.items())}
|
| 554 |
+
|
| 555 |
+
# Output path
|
| 556 |
+
if args.output:
|
| 557 |
+
out_path = args.output
|
| 558 |
+
elif args.output_dir:
|
| 559 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 560 |
+
stem = os.path.splitext(os.path.basename(args.input))[0]
|
| 561 |
+
out_path = os.path.join(args.output_dir, f"{stem}_eval_results.json")
|
| 562 |
+
else:
|
| 563 |
+
stem = os.path.splitext(os.path.basename(args.input))[0]
|
| 564 |
+
out_dir = os.path.dirname(args.input) or "."
|
| 565 |
+
out_path = os.path.join(out_dir, f"{stem}_eval_results.json")
|
| 566 |
+
|
| 567 |
+
os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
|
| 568 |
+
|
| 569 |
+
output = {
|
| 570 |
+
"input_file": os.path.abspath(args.input),
|
| 571 |
+
"subclaims_file": os.path.abspath(args.subclaims) if args.subclaims else None,
|
| 572 |
+
"model_key": args.model_key if is_bon else None,
|
| 573 |
+
"inference_format": is_inference if not is_bon else False,
|
| 574 |
+
"self_refine_format": is_self_refine if not is_bon and not is_inference else False,
|
| 575 |
+
"rl_inference_format": is_rl_inference,
|
| 576 |
+
"target_level_override": args.target_level,
|
| 577 |
+
"elapsed_seconds": round(elapsed, 2),
|
| 578 |
+
"aggregate": agg,
|
| 579 |
+
"per_label_aggregate": per_label_agg,
|
| 580 |
+
"per_item": results,
|
| 581 |
+
}
|
| 582 |
+
|
| 583 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
| 584 |
+
json.dump(output, f, indent=2, ensure_ascii=False)
|
| 585 |
+
|
| 586 |
+
# Print summary
|
| 587 |
+
print("\n" + "=" * 60)
|
| 588 |
+
print("EVALUATION SUMMARY")
|
| 589 |
+
print("=" * 60)
|
| 590 |
+
print(f" Total items : {agg['total_items']}")
|
| 591 |
+
print(f" Scored items : {agg['scored_items']}")
|
| 592 |
+
print(f" Skipped items : {agg['skipped_items']}")
|
| 593 |
+
print(f" Elapsed time : {round(elapsed, 1)}s")
|
| 594 |
+
print("-" * 60)
|
| 595 |
+
print(f" Avg Factuality Score : {agg['avg_factuality_score']}")
|
| 596 |
+
print(f" Avg Hallucination Score: {agg['avg_hallucination_score']}")
|
| 597 |
+
print(f" Avg Grounding Score : {agg['avg_grounding_score']}")
|
| 598 |
+
print(f" Avg Classifier Score : {agg['avg_classifier_score']}")
|
| 599 |
+
print("-" * 60)
|
| 600 |
+
for lbl, la in per_label_agg.items():
|
| 601 |
+
print(f" [{lbl}] items={la['scored_items']}"
|
| 602 |
+
f" fact={la['avg_factuality_score']}"
|
| 603 |
+
f" hallu={la['avg_hallucination_score']}"
|
| 604 |
+
f" cls={la['avg_classifier_score']}")
|
| 605 |
+
print("-" * 60)
|
| 606 |
+
print(f" Results saved to: {out_path}")
|
| 607 |
+
print("=" * 60)
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
if __name__ == "__main__":
|
| 611 |
+
main()
|
code/fine_tune_sft_dpo/evaluate_scores_bn_vllm_rl.py
ADDED
|
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Standalone evaluation script for computing factuality, hallucination, and
|
| 4 |
+
classifier scores on a JSON file.
|
| 5 |
+
|
| 6 |
+
Supports two input formats:
|
| 7 |
+
|
| 8 |
+
1. **Standard format** — a list of objects, each with:
|
| 9 |
+
- fulltext, summary_text, summary_subclaims, generated_text, label
|
| 10 |
+
|
| 11 |
+
2. **Best-of-N (BON) format** — a list of objects, each with:
|
| 12 |
+
- doc_id, label, qwen3_base.best_summary (JSON-wrapped generated text)
|
| 13 |
+
Requires a separate --subclaims file to supply fulltext, summary,
|
| 14 |
+
summary_subclaims, and fulltext_subclaims (keyed by doc_id).
|
| 15 |
+
|
| 16 |
+
3. **Inference format** — a list of objects, each with:
|
| 17 |
+
- doc_id, label, predicted_gen_text (JSON-wrapped evaluated summary),
|
| 18 |
+
optionally gold_gen_text
|
| 19 |
+
predicted_gen_text is the summary to evaluate (same JSON key-by-label
|
| 20 |
+
format as best_summary). Requires --subclaims for fulltext and subclaims.
|
| 21 |
+
|
| 22 |
+
4. **Self-refine format** — a list of objects, each with:
|
| 23 |
+
- doc_id, label, final_summary (the generated text to evaluate),
|
| 24 |
+
optionally gold_gen_text, gold_summary
|
| 25 |
+
final_summary is the summary to evaluate (plain text or JSON-wrapped by
|
| 26 |
+
label). Requires --subclaims for fulltext and subclaims.
|
| 27 |
+
|
| 28 |
+
The script reuses the reward functions from reward_new_v6_bn_v4_rmv_src_cov.py:
|
| 29 |
+
- factuality_score : fraction of summary subclaims supported by generated_text
|
| 30 |
+
- hallucination_score: fraction of gen subclaims NOT supported by fulltext
|
| 31 |
+
- classifier_score : whether generated_text matches the target literacy level
|
| 32 |
+
|
| 33 |
+
Requires the same vLLM endpoints as the reward file:
|
| 34 |
+
- Support checker : VLLM_SUPPORT_CHECK_BN_API_BASE (default http://localhost:8090/v1)
|
| 35 |
+
- Classifier : VLLM_CLASSIFIER_BN_API_BASE (default http://localhost:8040/v1)
|
| 36 |
+
- Subclaim extractor: VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE (default http://localhost:8050/v1)
|
| 37 |
+
|
| 38 |
+
Usage:
|
| 39 |
+
# Standard format
|
| 40 |
+
python evaluate_scores.py --input data.json [--output results.json]
|
| 41 |
+
|
| 42 |
+
# BON format with subclaims file
|
| 43 |
+
python evaluate_scores.py --input bon_results.json --subclaims subclaims.json --output-dir evaluation/bn/
|
| 44 |
+
|
| 45 |
+
# Inference format (predicted_gen_text as evaluated summary)
|
| 46 |
+
python evaluate_scores.py --input test_inference_vllm_qwen3-4B_base.json --subclaims subclaims.json --output results.json
|
| 47 |
+
|
| 48 |
+
# Self-refine format (final_summary as evaluated summary)
|
| 49 |
+
python evaluate_scores.py --input test_self_refine_vllm_qwen3_4B_base.json --subclaims subclaims.json --output-dir evaluation/bn/
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
import argparse
|
| 53 |
+
import json
|
| 54 |
+
import os
|
| 55 |
+
import re
|
| 56 |
+
import sys
|
| 57 |
+
import time
|
| 58 |
+
from typing import Any, Dict, List, Optional
|
| 59 |
+
|
| 60 |
+
from tqdm import tqdm
|
| 61 |
+
|
| 62 |
+
# Import scoring utilities from the reward module (same directory).
|
| 63 |
+
from reward_new_v6_bn_v4_rmv_src_cov import (
|
| 64 |
+
_call_support_api,
|
| 65 |
+
_compute_classifier_reward,
|
| 66 |
+
_extract_subclaims_from_text,
|
| 67 |
+
_is_bangla_text,
|
| 68 |
+
_nonlinear_grounding,
|
| 69 |
+
compute_rewards,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def extract_text_from_best_summary(best_summary: str, label: str) -> str:
|
| 74 |
+
"""Extract the raw generated text from a BON best_summary string.
|
| 75 |
+
|
| 76 |
+
The best_summary is a (possibly truncated) JSON string like:
|
| 77 |
+
'{"proficient_health_literacy": "...actual text..."}'
|
| 78 |
+
We locate the value after the label key and strip JSON wrapping.
|
| 79 |
+
"""
|
| 80 |
+
key_pattern = re.compile(re.escape(f'"{label}"') + r'\s*:\s*"')
|
| 81 |
+
m = key_pattern.search(best_summary)
|
| 82 |
+
if not m:
|
| 83 |
+
return best_summary.strip()
|
| 84 |
+
text = best_summary[m.end():]
|
| 85 |
+
if text.endswith('"\n}'):
|
| 86 |
+
text = text[:-3]
|
| 87 |
+
elif text.endswith('"}\n'):
|
| 88 |
+
text = text[:-3]
|
| 89 |
+
elif text.endswith('"}'):
|
| 90 |
+
text = text[:-2]
|
| 91 |
+
elif text.endswith('"'):
|
| 92 |
+
text = text[:-1]
|
| 93 |
+
text = text.replace("\\n", "\n").replace('\\"', '"')
|
| 94 |
+
return text.strip()
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def prepare_bon_items(
|
| 98 |
+
bon_data: List[Dict[str, Any]],
|
| 99 |
+
subclaims_data: List[Dict[str, Any]],
|
| 100 |
+
model_key: str = "qwen3_base",
|
| 101 |
+
) -> List[Dict[str, Any]]:
|
| 102 |
+
"""Merge BON results with subclaims data into the standard evaluation format."""
|
| 103 |
+
sc_by_docid = {}
|
| 104 |
+
for item in subclaims_data:
|
| 105 |
+
sc_by_docid[item["doc_id"]] = item
|
| 106 |
+
|
| 107 |
+
prepared = []
|
| 108 |
+
for item in bon_data:
|
| 109 |
+
doc_id = item["doc_id"]
|
| 110 |
+
label = item["label"]
|
| 111 |
+
sc = sc_by_docid.get(doc_id, {})
|
| 112 |
+
|
| 113 |
+
model_data = item.get(model_key, {})
|
| 114 |
+
best_summary = model_data.get("best_summary", "") or model_data.get("predicted_gen_text", "")
|
| 115 |
+
generated_text = extract_text_from_best_summary(best_summary, label)
|
| 116 |
+
|
| 117 |
+
prepared.append({
|
| 118 |
+
"doc_id": doc_id,
|
| 119 |
+
"label": label,
|
| 120 |
+
"fulltext": sc.get("fulltext", ""),
|
| 121 |
+
"summary_text": sc.get("summary", ""),
|
| 122 |
+
"summary_subclaims": sc.get("summary_subclaims", []),
|
| 123 |
+
"fulltext_subclaims": sc.get("fulltext_subclaims", []),
|
| 124 |
+
"generated_text": generated_text,
|
| 125 |
+
"gold_gen_text": item.get("gold_gen_text", ""),
|
| 126 |
+
"predicted_label": item.get("predicted_label", ""),
|
| 127 |
+
"prediction_correct": item.get("prediction_correct", False),
|
| 128 |
+
})
|
| 129 |
+
return prepared
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def prepare_inference_items(
|
| 133 |
+
inference_data: List[Dict[str, Any]],
|
| 134 |
+
subclaims_data: List[Dict[str, Any]],
|
| 135 |
+
) -> List[Dict[str, Any]]:
|
| 136 |
+
"""Merge inference-format results (doc_id, label, predicted_gen_text) with
|
| 137 |
+
subclaims data into the standard evaluation format. predicted_gen_text is
|
| 138 |
+
the JSON-wrapped evaluated summary; the raw text is extracted using the
|
| 139 |
+
item's label.
|
| 140 |
+
"""
|
| 141 |
+
sc_by_docid = {}
|
| 142 |
+
for item in subclaims_data:
|
| 143 |
+
sc_by_docid[item["doc_id"]] = item
|
| 144 |
+
|
| 145 |
+
prepared = []
|
| 146 |
+
for item in inference_data:
|
| 147 |
+
doc_id = item["doc_id"]
|
| 148 |
+
label = item["label"]
|
| 149 |
+
sc = sc_by_docid.get(doc_id, {})
|
| 150 |
+
|
| 151 |
+
raw_pred = item.get("predicted_gen_text", "") or "" or item.get("generated_text", "")
|
| 152 |
+
generated_text = extract_text_from_best_summary(raw_pred, label)
|
| 153 |
+
|
| 154 |
+
prepared.append({
|
| 155 |
+
"doc_id": doc_id,
|
| 156 |
+
"label": label,
|
| 157 |
+
"fulltext": sc.get("fulltext", ""),
|
| 158 |
+
"summary_text": sc.get("summary", ""),
|
| 159 |
+
"summary_subclaims": sc.get("summary_subclaims", []),
|
| 160 |
+
"fulltext_subclaims": sc.get("fulltext_subclaims", []),
|
| 161 |
+
"generated_text": generated_text,
|
| 162 |
+
"gold_gen_text": item.get("gold_gen_text", ""),
|
| 163 |
+
})
|
| 164 |
+
return prepared
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def prepare_self_refine_items(
|
| 168 |
+
self_refine_data: List[Dict[str, Any]],
|
| 169 |
+
subclaims_data: List[Dict[str, Any]],
|
| 170 |
+
) -> List[Dict[str, Any]]:
|
| 171 |
+
"""Merge self-refine format (doc_id, label, final_summary) with subclaims
|
| 172 |
+
data. final_summary is the generated text to evaluate (plain text or
|
| 173 |
+
JSON-wrapped by label); it is extracted and used as generated_text.
|
| 174 |
+
"""
|
| 175 |
+
sc_by_docid = {}
|
| 176 |
+
for item in subclaims_data:
|
| 177 |
+
sc_by_docid[item["doc_id"]] = item
|
| 178 |
+
|
| 179 |
+
prepared = []
|
| 180 |
+
for item in self_refine_data:
|
| 181 |
+
doc_id = item["doc_id"]
|
| 182 |
+
label = item["label"]
|
| 183 |
+
sc = sc_by_docid.get(doc_id, {})
|
| 184 |
+
|
| 185 |
+
raw_final = item.get("final_summary", "") or ""
|
| 186 |
+
generated_text = extract_text_from_best_summary(raw_final, label)
|
| 187 |
+
|
| 188 |
+
prepared.append({
|
| 189 |
+
"doc_id": doc_id,
|
| 190 |
+
"label": label,
|
| 191 |
+
"fulltext": sc.get("fulltext", ""),
|
| 192 |
+
"summary_text": sc.get("summary", ""),
|
| 193 |
+
"summary_subclaims": sc.get("summary_subclaims", []),
|
| 194 |
+
"fulltext_subclaims": sc.get("fulltext_subclaims", []),
|
| 195 |
+
"generated_text": generated_text,
|
| 196 |
+
"gold_gen_text": item.get("gold_gen_text", ""),
|
| 197 |
+
})
|
| 198 |
+
return prepared
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def evaluate_single(
|
| 202 |
+
item: Dict[str, Any],
|
| 203 |
+
target_level_override: Optional[str] = None,
|
| 204 |
+
) -> Dict[str, Any]:
|
| 205 |
+
"""
|
| 206 |
+
Evaluate a single item and return detailed scores.
|
| 207 |
+
"""
|
| 208 |
+
fulltext = item.get("fulltext", "")
|
| 209 |
+
summary_text = item.get("summary_text") or item.get("summary", "")
|
| 210 |
+
summary_subclaims = item.get("summary_subclaims", [])
|
| 211 |
+
generated_text = item.get("generated_text") or item.get("predicted_gen_text", "")
|
| 212 |
+
target_level = target_level_override or item.get("label", "")
|
| 213 |
+
|
| 214 |
+
result: Dict[str, Any] = {
|
| 215 |
+
"doc_id": item.get("doc_id", ""),
|
| 216 |
+
"target_level": target_level,
|
| 217 |
+
"generated_text_len": len(generated_text.strip()) if generated_text else 0,
|
| 218 |
+
"factuality_score": None,
|
| 219 |
+
"hallucination_score": None,
|
| 220 |
+
"classifier_score": None,
|
| 221 |
+
"grounding_score": None,
|
| 222 |
+
"factuality_supported": 0,
|
| 223 |
+
"total_summary_subclaims": len(summary_subclaims) if summary_subclaims else 0,
|
| 224 |
+
"hallucination_supported": 0,
|
| 225 |
+
"total_gen_segments": 0,
|
| 226 |
+
"skipped": False,
|
| 227 |
+
"skip_reason": "",
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
if not generated_text or len(generated_text.strip()) < 10:
|
| 231 |
+
result["skipped"] = True
|
| 232 |
+
result["skip_reason"] = "generated_text missing or too short (<10 chars)"
|
| 233 |
+
return result
|
| 234 |
+
|
| 235 |
+
# -- Factuality & Hallucination via compute_rewards --
|
| 236 |
+
rewards = compute_rewards(
|
| 237 |
+
fulltext=fulltext,
|
| 238 |
+
generated_text=generated_text,
|
| 239 |
+
target_level=target_level,
|
| 240 |
+
summary_subclaims=summary_subclaims,
|
| 241 |
+
summary_text=summary_text,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
factuality_score = rewards["factuality_score"]
|
| 245 |
+
h_score = rewards["hallucination_score"]
|
| 246 |
+
|
| 247 |
+
if factuality_score is None:
|
| 248 |
+
factuality_score = 0.5
|
| 249 |
+
if h_score is None:
|
| 250 |
+
h_score = 0.5
|
| 251 |
+
|
| 252 |
+
grounding_score = _nonlinear_grounding(h_score)
|
| 253 |
+
|
| 254 |
+
# -- Classifier --
|
| 255 |
+
input_text = fulltext or ""
|
| 256 |
+
class_score = _compute_classifier_reward(target_level, generated_text, input_text)
|
| 257 |
+
|
| 258 |
+
result.update({
|
| 259 |
+
"factuality_score": round(factuality_score, 4),
|
| 260 |
+
"hallucination_score": round(h_score, 4),
|
| 261 |
+
"grounding_score": round(grounding_score, 4),
|
| 262 |
+
"classifier_score": round(class_score, 4),
|
| 263 |
+
"factuality_supported": rewards.get("factuality_supported", 0),
|
| 264 |
+
"total_summary_subclaims": rewards.get("total_summary_subclaims", 0),
|
| 265 |
+
"hallucination_supported": rewards.get("hallucination_supported", 0),
|
| 266 |
+
"total_gen_segments": rewards.get("total_gen_segments", 0),
|
| 267 |
+
})
|
| 268 |
+
|
| 269 |
+
return result
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def compute_aggregate(results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 273 |
+
"""Compute aggregate statistics over all evaluated items."""
|
| 274 |
+
scored = [r for r in results if not r.get("skipped", False)]
|
| 275 |
+
n = len(scored)
|
| 276 |
+
total = len(results)
|
| 277 |
+
skipped = total - n
|
| 278 |
+
|
| 279 |
+
if n == 0:
|
| 280 |
+
return {
|
| 281 |
+
"total_items": total,
|
| 282 |
+
"scored_items": 0,
|
| 283 |
+
"skipped_items": skipped,
|
| 284 |
+
"avg_factuality_score": None,
|
| 285 |
+
"avg_hallucination_score": None,
|
| 286 |
+
"avg_grounding_score": None,
|
| 287 |
+
"avg_classifier_score": None,
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
def safe_avg(key):
|
| 291 |
+
vals = [r[key] for r in scored if r[key] is not None]
|
| 292 |
+
return round(sum(vals) / len(vals), 4) if vals else None
|
| 293 |
+
|
| 294 |
+
return {
|
| 295 |
+
"total_items": total,
|
| 296 |
+
"scored_items": n,
|
| 297 |
+
"skipped_items": skipped,
|
| 298 |
+
"avg_factuality_score": safe_avg("factuality_score"),
|
| 299 |
+
"avg_hallucination_score": safe_avg("hallucination_score"),
|
| 300 |
+
"avg_grounding_score": safe_avg("grounding_score"),
|
| 301 |
+
"avg_classifier_score": safe_avg("classifier_score"),
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def main():
|
| 306 |
+
parser = argparse.ArgumentParser(
|
| 307 |
+
description="Evaluate factuality, hallucination, and classifier scores on a JSON file."
|
| 308 |
+
)
|
| 309 |
+
parser.add_argument(
|
| 310 |
+
"--input", "-i", required=True,
|
| 311 |
+
help="Path to input JSON file (list of objects).",
|
| 312 |
+
)
|
| 313 |
+
parser.add_argument(
|
| 314 |
+
"--output", "-o", default=None,
|
| 315 |
+
help="Path to output JSON file with per-item scores. "
|
| 316 |
+
"Defaults to <input_stem>_eval_results.json.",
|
| 317 |
+
)
|
| 318 |
+
parser.add_argument(
|
| 319 |
+
"--output-dir", default=None,
|
| 320 |
+
help="Directory to save output files. If set, output filename is derived "
|
| 321 |
+
"from input filename and placed in this directory.",
|
| 322 |
+
)
|
| 323 |
+
parser.add_argument(
|
| 324 |
+
"--subclaims", "-s", default=None,
|
| 325 |
+
help="Path to subclaims JSON file (for BON format). Provides fulltext, "
|
| 326 |
+
"summary, summary_subclaims, and fulltext_subclaims keyed by doc_id.",
|
| 327 |
+
)
|
| 328 |
+
parser.add_argument(
|
| 329 |
+
"--model-key", default="qwen3_base",
|
| 330 |
+
help="Key in the BON data containing candidates/best_summary (default: qwen3_base).",
|
| 331 |
+
)
|
| 332 |
+
parser.add_argument(
|
| 333 |
+
"--target-level", "-t", default=None,
|
| 334 |
+
help="Override target literacy level for all items "
|
| 335 |
+
"(e.g. low_health_literacy). If not set, uses each item's 'label' field.",
|
| 336 |
+
)
|
| 337 |
+
parser.add_argument(
|
| 338 |
+
"--support-check-url", default=None,
|
| 339 |
+
help="Override VLLM_SUPPORT_CHECK_BN_API_BASE.",
|
| 340 |
+
)
|
| 341 |
+
parser.add_argument(
|
| 342 |
+
"--classifier-url", default=None,
|
| 343 |
+
help="Override VLLM_CLASSIFIER_BN_API_BASE.",
|
| 344 |
+
)
|
| 345 |
+
parser.add_argument(
|
| 346 |
+
"--subclaim-extractor-url", default=None,
|
| 347 |
+
help="Override VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE.",
|
| 348 |
+
)
|
| 349 |
+
args = parser.parse_args()
|
| 350 |
+
|
| 351 |
+
if args.support_check_url:
|
| 352 |
+
os.environ["VLLM_SUPPORT_CHECK_BN_API_BASE"] = args.support_check_url
|
| 353 |
+
if args.classifier_url:
|
| 354 |
+
os.environ["VLLM_CLASSIFIER_BN_API_BASE"] = args.classifier_url
|
| 355 |
+
if args.subclaim_extractor_url:
|
| 356 |
+
os.environ["VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE"] = args.subclaim_extractor_url
|
| 357 |
+
|
| 358 |
+
# Load input (supports both JSON array and JSONL)
|
| 359 |
+
with open(args.input, "r", encoding="utf-8") as f:
|
| 360 |
+
content = f.read().strip()
|
| 361 |
+
if content.startswith("["):
|
| 362 |
+
raw_data = json.loads(content)
|
| 363 |
+
else:
|
| 364 |
+
raw_data = [json.loads(line) for line in content.splitlines() if line.strip()]
|
| 365 |
+
|
| 366 |
+
if not isinstance(raw_data, list):
|
| 367 |
+
print(f"Error: Expected a JSON list, got {type(raw_data).__name__}.", file=sys.stderr)
|
| 368 |
+
sys.exit(1)
|
| 369 |
+
|
| 370 |
+
# Normalise field names from RL-inference JSONL format
|
| 371 |
+
for item in raw_data:
|
| 372 |
+
if "label" not in item and "gold_label" in item:
|
| 373 |
+
item["label"] = item["gold_label"]
|
| 374 |
+
if "fulltext" not in item and "input_text" in item:
|
| 375 |
+
item["fulltext"] = item["input_text"]
|
| 376 |
+
if "summary_subclaims" not in item and "subclaims" in item:
|
| 377 |
+
item["summary_subclaims"] = item["subclaims"]
|
| 378 |
+
|
| 379 |
+
# Detect BON format: items have a model key (e.g. qwen3_base) with best_summary
|
| 380 |
+
is_bon = (
|
| 381 |
+
len(raw_data) > 0
|
| 382 |
+
and args.model_key in raw_data[0]
|
| 383 |
+
and "best_summary" in raw_data[0].get(args.model_key, {})
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
# Detect inference format: top-level doc_id, label, predicted_gen_text; no fulltext/summary_subclaims
|
| 387 |
+
is_inference = (
|
| 388 |
+
len(raw_data) > 0
|
| 389 |
+
and "doc_id" in raw_data[0]
|
| 390 |
+
and "label" in raw_data[0]
|
| 391 |
+
and "predicted_gen_text" in raw_data[0]
|
| 392 |
+
and raw_data[0].get("fulltext") is None
|
| 393 |
+
and raw_data[0].get("summary_subclaims") is None
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
# Detect self-refine format: doc_id, label, final_summary as gen text; no fulltext/summary_subclaims
|
| 397 |
+
is_self_refine = (
|
| 398 |
+
len(raw_data) > 0
|
| 399 |
+
and "doc_id" in raw_data[0]
|
| 400 |
+
and "label" in raw_data[0]
|
| 401 |
+
and "final_summary" in raw_data[0]
|
| 402 |
+
and raw_data[0].get("fulltext") is None
|
| 403 |
+
and raw_data[0].get("summary_subclaims") is None
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
if is_bon:
|
| 407 |
+
if not args.subclaims:
|
| 408 |
+
print("Error: BON format detected but --subclaims file not provided.", file=sys.stderr)
|
| 409 |
+
sys.exit(1)
|
| 410 |
+
with open(args.subclaims, "r", encoding="utf-8") as f:
|
| 411 |
+
subclaims_data = json.load(f)
|
| 412 |
+
print(f"BON format detected (model_key={args.model_key})")
|
| 413 |
+
print(f"Loaded {len(raw_data)} BON items from {args.input}")
|
| 414 |
+
print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}")
|
| 415 |
+
data = prepare_bon_items(raw_data, subclaims_data, model_key=args.model_key)
|
| 416 |
+
print(f"Prepared {len(data)} items for evaluation")
|
| 417 |
+
elif is_inference:
|
| 418 |
+
if not args.subclaims:
|
| 419 |
+
print("Error: Inference format detected (predicted_gen_text) but --subclaims file not provided.", file=sys.stderr)
|
| 420 |
+
sys.exit(1)
|
| 421 |
+
with open(args.subclaims, "r", encoding="utf-8") as f:
|
| 422 |
+
subclaims_data = json.load(f)
|
| 423 |
+
print("Inference format detected (predicted_gen_text as evaluated summary)")
|
| 424 |
+
print(f"Loaded {len(raw_data)} inference items from {args.input}")
|
| 425 |
+
print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}")
|
| 426 |
+
data = prepare_inference_items(raw_data, subclaims_data)
|
| 427 |
+
print(f"Prepared {len(data)} items for evaluation")
|
| 428 |
+
elif is_self_refine:
|
| 429 |
+
if not args.subclaims:
|
| 430 |
+
print("Error: Self-refine format detected (final_summary) but --subclaims file not provided.", file=sys.stderr)
|
| 431 |
+
sys.exit(1)
|
| 432 |
+
with open(args.subclaims, "r", encoding="utf-8") as f:
|
| 433 |
+
subclaims_data = json.load(f)
|
| 434 |
+
print("Self-refine format detected (final_summary as evaluated summary)")
|
| 435 |
+
print(f"Loaded {len(raw_data)} self-refine items from {args.input}")
|
| 436 |
+
print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}")
|
| 437 |
+
data = prepare_self_refine_items(raw_data, subclaims_data)
|
| 438 |
+
print(f"Prepared {len(data)} items for evaluation")
|
| 439 |
+
else:
|
| 440 |
+
data = raw_data
|
| 441 |
+
print(f"Loaded {len(data)} items from {args.input}")
|
| 442 |
+
|
| 443 |
+
print(f" Support check API : {os.getenv('VLLM_SUPPORT_CHECK_BN_API_BASE', 'http://localhost:8090/v1')}")
|
| 444 |
+
print(f" Classifier API : {os.getenv('VLLM_CLASSIFIER_BN_API_BASE', 'http://localhost:8040/v1')}")
|
| 445 |
+
print(f" Subclaim extractor: {os.getenv('VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE', 'http://localhost:8050/v1')}")
|
| 446 |
+
if args.target_level:
|
| 447 |
+
print(f" Target level override: {args.target_level}")
|
| 448 |
+
print("-" * 60)
|
| 449 |
+
|
| 450 |
+
# Evaluate each item
|
| 451 |
+
results = []
|
| 452 |
+
start_time = time.time()
|
| 453 |
+
for idx, item in enumerate(tqdm(data, desc="Evaluating")):
|
| 454 |
+
r = evaluate_single(item, target_level_override=args.target_level)
|
| 455 |
+
r["index"] = idx
|
| 456 |
+
r["doc_id"] = item.get("doc_id", "")
|
| 457 |
+
results.append(r)
|
| 458 |
+
|
| 459 |
+
if (idx + 1) % 10 == 0 or idx == 0:
|
| 460 |
+
partial_agg = compute_aggregate(results)
|
| 461 |
+
tqdm.write(
|
| 462 |
+
f" [{idx+1}/{len(data)}] "
|
| 463 |
+
f"fact={partial_agg['avg_factuality_score']} "
|
| 464 |
+
f"hallu={partial_agg['avg_hallucination_score']} "
|
| 465 |
+
f"cls={partial_agg['avg_classifier_score']}"
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
elapsed = time.time() - start_time
|
| 469 |
+
|
| 470 |
+
# --- Validation: all items must be evaluated with non-null scores ---
|
| 471 |
+
expected_count = len(data)
|
| 472 |
+
skipped_items = [r for r in results if r.get("skipped", False)]
|
| 473 |
+
null_score_items = []
|
| 474 |
+
for r in results:
|
| 475 |
+
if r.get("skipped", False):
|
| 476 |
+
continue
|
| 477 |
+
for key in ("factuality_score", "hallucination_score", "classifier_score", "grounding_score"):
|
| 478 |
+
if r.get(key) is None:
|
| 479 |
+
null_score_items.append((r.get("index"), r.get("doc_id"), key))
|
| 480 |
+
|
| 481 |
+
has_errors = False
|
| 482 |
+
if skipped_items:
|
| 483 |
+
has_errors = True
|
| 484 |
+
print(f"\nERROR: {len(skipped_items)} out of {expected_count} items were skipped:", file=sys.stderr)
|
| 485 |
+
for r in skipped_items:
|
| 486 |
+
print(f" index={r.get('index')} doc_id={r.get('doc_id')} reason={r.get('skip_reason')}", file=sys.stderr)
|
| 487 |
+
|
| 488 |
+
if null_score_items:
|
| 489 |
+
has_errors = True
|
| 490 |
+
print(f"\nERROR: {len(null_score_items)} null score(s) found:", file=sys.stderr)
|
| 491 |
+
for idx, doc_id, key in null_score_items:
|
| 492 |
+
print(f" index={idx} doc_id={doc_id} null_field={key}", file=sys.stderr)
|
| 493 |
+
|
| 494 |
+
if len(results) != expected_count:
|
| 495 |
+
has_errors = True
|
| 496 |
+
print(f"\nERROR: Expected {expected_count} results but got {len(results)}.", file=sys.stderr)
|
| 497 |
+
|
| 498 |
+
if has_errors:
|
| 499 |
+
print(f"\nAborting: will NOT save results. All {expected_count} items must be fully evaluated with non-null scores.", file=sys.stderr)
|
| 500 |
+
sys.exit(1)
|
| 501 |
+
|
| 502 |
+
# Aggregate
|
| 503 |
+
agg = compute_aggregate(results)
|
| 504 |
+
|
| 505 |
+
# Per-label aggregates
|
| 506 |
+
label_groups: Dict[str, List[Dict[str, Any]]] = {}
|
| 507 |
+
for r in results:
|
| 508 |
+
lbl = r.get("target_level", "unknown")
|
| 509 |
+
label_groups.setdefault(lbl, []).append(r)
|
| 510 |
+
per_label_agg = {lbl: compute_aggregate(items) for lbl, items in sorted(label_groups.items())}
|
| 511 |
+
|
| 512 |
+
# Output path
|
| 513 |
+
if args.output:
|
| 514 |
+
out_path = args.output
|
| 515 |
+
elif args.output_dir:
|
| 516 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 517 |
+
stem = os.path.splitext(os.path.basename(args.input))[0]
|
| 518 |
+
out_path = os.path.join(args.output_dir, f"{stem}_eval_results.json")
|
| 519 |
+
else:
|
| 520 |
+
stem = os.path.splitext(os.path.basename(args.input))[0]
|
| 521 |
+
out_dir = os.path.dirname(args.input) or "."
|
| 522 |
+
out_path = os.path.join(out_dir, f"{stem}_eval_results.json")
|
| 523 |
+
|
| 524 |
+
os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
|
| 525 |
+
|
| 526 |
+
output = {
|
| 527 |
+
"input_file": os.path.abspath(args.input),
|
| 528 |
+
"subclaims_file": os.path.abspath(args.subclaims) if args.subclaims else None,
|
| 529 |
+
"model_key": args.model_key if is_bon else None,
|
| 530 |
+
"inference_format": is_inference if not is_bon else False,
|
| 531 |
+
"self_refine_format": is_self_refine if not is_bon and not is_inference else False,
|
| 532 |
+
"target_level_override": args.target_level,
|
| 533 |
+
"elapsed_seconds": round(elapsed, 2),
|
| 534 |
+
"aggregate": agg,
|
| 535 |
+
"per_label_aggregate": per_label_agg,
|
| 536 |
+
"per_item": results,
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
| 540 |
+
json.dump(output, f, indent=2, ensure_ascii=False)
|
| 541 |
+
|
| 542 |
+
# Print summary
|
| 543 |
+
print("\n" + "=" * 60)
|
| 544 |
+
print("EVALUATION SUMMARY")
|
| 545 |
+
print("=" * 60)
|
| 546 |
+
print(f" Total items : {agg['total_items']}")
|
| 547 |
+
print(f" Scored items : {agg['scored_items']}")
|
| 548 |
+
print(f" Skipped items : {agg['skipped_items']}")
|
| 549 |
+
print(f" Elapsed time : {round(elapsed, 1)}s")
|
| 550 |
+
print("-" * 60)
|
| 551 |
+
print(f" Avg Factuality Score : {agg['avg_factuality_score']}")
|
| 552 |
+
print(f" Avg Hallucination Score: {agg['avg_hallucination_score']}")
|
| 553 |
+
print(f" Avg Grounding Score : {agg['avg_grounding_score']}")
|
| 554 |
+
print(f" Avg Classifier Score : {agg['avg_classifier_score']}")
|
| 555 |
+
print("-" * 60)
|
| 556 |
+
for lbl, la in per_label_agg.items():
|
| 557 |
+
print(f" [{lbl}] items={la['scored_items']}"
|
| 558 |
+
f" fact={la['avg_factuality_score']}"
|
| 559 |
+
f" hallu={la['avg_hallucination_score']}"
|
| 560 |
+
f" cls={la['avg_classifier_score']}")
|
| 561 |
+
print("-" * 60)
|
| 562 |
+
print(f" Results saved to: {out_path}")
|
| 563 |
+
print("=" * 60)
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
if __name__ == "__main__":
|
| 567 |
+
main()
|
code/fine_tune_sft_dpo/evaluation/bn/bn_200_eval_results.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0cb20710f7c4b519b77457d541d9a132ded57b6a9252bc5552788cf358e9e436
|
| 3 |
+
size 96048
|
code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_mini_eval_results_20260316_071029.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:18c09e4d51cd941941cbc3c585699e3f7e2e98b051a5a26396bc242effa5438a
|
| 3 |
+
size 97818
|
code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_mini_prepared_20260316_071029.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:97ac7c77d004ac04c4b06a848b28c43b9cdcd4edc8d1c97f757ec11769445aae
|
| 3 |
+
size 6569807
|
code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_nano_eval_results_20260316_071029.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4cf921a153e1d69f2671c93cfc4a1d368cbcc3d3db0c5b30b706bd354402cb44
|
| 3 |
+
size 97656
|
code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_nano_prepared_20260316_071029.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f6638c5e9c0feb85cf471fac04021473211fd645cc250b5e42a39483e3a6e1fd
|
| 3 |
+
size 6114304
|
code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_prepared_20260316_071029.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:afefe278ca421f9eb698e638b2fcca1c0814525ac5a9a4bf576c0d3294290033
|
| 3 |
+
size 6349574
|
code/fine_tune_sft_dpo/evaluation/bn/test_best_of_n_qwen3-4B_base_eval_results.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:85d23851eb74b15da0c466c3a11ffd43ae371db75523ee9e85a0cec6f1cf6b5e
|
| 3 |
+
size 97215
|
code/fine_tune_sft_dpo/evaluation/bn/test_best_of_n_qwen3-4B_sft_eval_results.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f89d77f48683f42239d9bcdcbd469810f3f11d3d06d8e06d3442e69e34729535
|
| 3 |
+
size 96710
|
code/fine_tune_sft_dpo/evaluation/bn/test_inference_vllm_qwen3-4B_base_eval_results.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2eb135210f5dc0f92170e415936ad34927979be1efdfb0bb6f629374c42b79e4
|
| 3 |
+
size 97550
|
code/fine_tune_sft_dpo/evaluation/bn/test_inference_vllm_qwen3-4B_sft_eval_results.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b54519f530b0894294b4ec86b0af317f521bc9baff38e230dc46dccfb46e7d19
|
| 3 |
+
size 97002
|
code/fine_tune_sft_dpo/evaluation/bn/test_self_refine_vllm_qwen3_4B_base_eval_results.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:346286565d0a6ec98ea218bd463a9322e6b1738bf5afb8235cddc2cfcb22d488
|
| 3 |
+
size 97347
|
code/fine_tune_sft_dpo/evaluation/bn/test_self_refine_vllm_qwen3_4B_sft_eval_results.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9630e15bd5d4b3a74de3c0d97a5dc5fc40ddffb96eb16ea55b3804cf5db5a419
|
| 3 |
+
size 97016
|
code/fine_tune_sft_dpo/evaluation/bn/v1/test_best_of_n_qwen3-4B_base_eval_results.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c121bcd6eae3ab09fcc37eaa3a6b49edf437b4700573f62febc5db58948d26e8
|
| 3 |
+
size 97136
|
code/fine_tune_sft_dpo/evaluation/bn/v1/test_best_of_n_qwen3-4B_sft_eval_results.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ccf84db8e83c982492ad99641f4a01b882ff7e9a259159f7fbcb04105745776b
|
| 3 |
+
size 96671
|
code/fine_tune_sft_dpo/evaluation/bn/v1/test_inference_vllm_qwen3-4B_base_eval_results.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:de2460d80063dd684264a474babc483a510e293e42d94e9392665b85e67f351c
|
| 3 |
+
size 97496
|
code/fine_tune_sft_dpo/evaluation/bn/v1/test_inference_vllm_qwen3-4B_sft_eval_results.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:32bc196d3288b6355e7b16f9dc1bd31a7dbce1c2c837a0059c29e255cca698e1
|
| 3 |
+
size 97006
|
code/fine_tune_sft_dpo/evaluation/bn/v1/test_self_refine_vllm_qwen3_4B_base_eval_results.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b163e821453592fbcdb922291a803431114d0dfa828df32366cd58340da949ef
|
| 3 |
+
size 97349
|
code/fine_tune_sft_dpo/evaluation/bn/v1/test_self_refine_vllm_qwen3_4B_sft_eval_results.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e822e39c927b8f75d411d42b5e556696481987b5589cd9f6453bee8e65d068d7
|
| 3 |
+
size 97013
|
code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**সিস্টেম ভূমিকা:**
|
| 2 |
+
|
| 3 |
+
আপনি একজন বিশেষজ্ঞ চিকিৎসা সম্পাদক এবং স্বাস্থ্য-সাক্ষরতা বিশেষজ্ঞ। আপনার কাজ হলো জটিল চিকিৎসাবিষয়ক লেখাকে পাঠকের স্বাস্থ্য-সাক্ষরতার স্তর অনুযায়ী তিনটি ভিন্ন সংস্করণে রূপান্তর করা। আপনাকে ইনপুটের মূল ভাষা অবশ্যই অপরিবর্তিত রাখতে হবে, তবে ভাষার জটিলতা স্তর অনুযায়ী সমন্বয় করতে হবে। সরলীকৃত সংস্করণগুলো যেন সঠিক ও গুরুত্বপূর্ণ তথ্য বজায় রাখে, সে জন্য আপনাকে মূল তথ্য ও বার্তাকে ভিত্তি হিসেবে ব্যবহার করতে হবে।
|
| 4 |
+
|
| 5 |
+
**ব্যবহারকারী নির্দেশনা:**
|
| 6 |
+
|
| 7 |
+
দয়া করে নিচের চিকিৎসাবিষয়ক সোর্স টেক্সট ব্যবহার করে স্বাস্থ্য‑সাক্ষরতার তিনটি ভিন্ন স্তরের জন্য আলাদা আলাদা সংস্করণ তৈরি করুন।
|
| 8 |
+
|
| 9 |
+
### প্রতিটি স্তরের জন্য নির্দেশনা:
|
| 10 |
+
|
| 11 |
+
1. স্তর: কম স্বাস্থ্য‑সাক্ষরতা (উচ্চ পাঠযোগ্যতা)
|
| 12 |
+
|
| 13 |
+
লক্ষ্য পাঠক: যারা খুব সহজ, দৈনন্দিন ভাষায় দ্রুত বোঝার মতো ব্যাখ্যা চান।
|
| 14 |
+
|
| 15 |
+
ভাষাগত লক্ষ্য: একদম ঘরোয়া/দৈনন্দিন কথাবার্তার ভাষা ব্যবহার করুন। সব ধরনের চিকিৎসাবিষয়ক জারগনকে সহজ ব্যাখ্যামূলক ভাষায় রূপান্তর করুন (যেমন, "renal" এর পরিবর্তে "কিডনির সমস্যা" বা "কিডনি" লিখুন)।
|
| 16 |
+
|
| 17 |
+
তথ্যের ঘনত্ব: কেবলমাত্র "যা অবশ্যই জানা দরকার" ধরনের মূল তথ্যগুলো রাখুন।
|
| 18 |
+
|
| 19 |
+
কৌশল: বেশি মাত্রায় পুনর্লিখন ও উদাহরণ/উপমা ব্যবহার করুন। প্রতি বাক্যে একটি করে মূল ধারণা রাখুন।
|
| 20 |
+
|
| 21 |
+
বিশ্বস্ততা: লেখাটি অবশ্যই গোল্ড সামারি‑র সঙ্গে সম্পূর্ণ সামঞ্জস্যপূর্ণ হতে হবে।
|
| 22 |
+
|
| 23 |
+
2. স্তর: মাঝারি স্বাস্থ্য‑সাক্ষরতা (মাঝারি পাঠযোগ্যতা)
|
| 24 |
+
|
| 25 |
+
লক্ষ্য পাঠক: সাধারণ মানুষ, যারা সাধারণ সংবাদ বা তথ্যভিত্তিক লেখা পড়ে বুঝতে পারেন।
|
| 26 |
+
|
| 27 |
+
ভাষাগত লক্ষ্য: মানিকৃত/সাধারণ শব্দভাণ্ডার ব্যবহার করুন। সাধারণভাবে পরিচিত চিকিৎসাবিষয়ক শব্দ ব্যবহার করা যেতে পারে, তবে অতিরিক্ত টেকনিক্যাল "ডাক্তারি ভাষা" এড়িয়ে চলুন বা সহজভাবে ব্যাখ্যা করুন।
|
| 28 |
+
|
| 29 |
+
তথ্যের ঘনত্ব: ভারসাম্যপূর্ণ রাখুন। মূল বার্তাকে কেন্দ্র করে কাঠামো তৈরি করুন এবং প্রয়োজন অনুযায়ী সোর্স টেক্সট থেকে প্রাসঙ্গিক অতিরিক্ত প্রেক্ষাপট যোগ করুন।
|
| 30 |
+
|
| 31 |
+
কৌশল: মাঝারি মাত্রার পুনর্লিখন করুন। অপ্রয়োজনীয় টেকনিক্যাল খুঁটিনাটি বাদ দিন, যাতে পাঠক অতিরিক্ত তথ্যের চাপে না পড়েন।
|
| 32 |
+
|
| 33 |
+
বিশ্বস্ততা: লেখাটি যেন মূল বার্তা ও ধারাবাহিকতা বজায় রাখে।
|
| 34 |
+
|
| 35 |
+
3. স্তর: উচ্চ স্বাস্থ্য‑সাক্ষরতা / প্র���িসিয়েন্ট (কম পাঠযোগ্যতা, উচ্চ জটিলতা)
|
| 36 |
+
|
| 37 |
+
লক্ষ্য পাঠক: গবেষক, ক্লিনিশিয়ান বা অত্যন্ত সচেতন/অভিজ্ঞ রোগী।
|
| 38 |
+
|
| 39 |
+
ভাষাগত লক্ষ্য: প্রয়োজনে টেকনিক্যাল ও একাডেমিক ভাষা ব্যবহার করুন। ক্লিনিক্যাল নির্ভুলতা ও চিকিৎসাবিজ্ঞানভিত্তিক সূক্ষ্ম দিকগুলোকে অগ্রাধিকার দিন।
|
| 40 |
+
|
| 41 |
+
তথ্যের ঘনত্ব: বেশি রাখুন। পুরো সোর্স টেক্সট ব্যবহার করে ডেটা, শারীরবৃত্তীয় প্রক্রিয়া, পরিসংখ্যান ইত্যাদি প্রাসঙ্গিক তথ্য অন্তর্ভুক্ত করুন।
|
| 42 |
+
|
| 43 |
+
কৌশল: যতটা সম্ভব কম পুনর্লিখন করুন। মূল টেকনিক্যাল পরিভাষা ও বাক্য গঠন অধিকাংশই অক্ষুণ্ণ রাখুন।
|
| 44 |
+
|
| 45 |
+
বিশ্বস্ততা: সোর্স টেক্সটের সাথে ঘনিষ্ঠভাবে অনুগত থাকুন; প্রয়োজন হলে বৈজ্ঞানিক প্রেক্ষাপট বাড়াতে সম্পর্কিত উপ‑দাবি বা ব্যাখ্যা যোগ করতে পারেন।
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
আমি আপনাকে নিম্নোক্ত তথ্যগুলো প্রদান করব:
|
| 49 |
+
|
| 50 |
+
- ইনপুট ভাষা: <<<SOURCE_LANGUAGE>>>
|
| 51 |
+
- সোর্স টেক্সট (বিস্তারিত মূল লেখা): <<<FULL_TEXT>>>
|
| 52 |
+
|
| 53 |
+
**আউটপুট ফরম্যাট (শুধু JSON):**
|
| 54 |
+
{{
|
| 55 |
+
"low_health_literacy": "...",
|
| 56 |
+
"intermediate_health_literacy": "...",
|
| 57 |
+
"proficient_health_literacy": "..."
|
| 58 |
+
}}
|
code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_intermediate
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**সিস্টেম ভূমিকা:**
|
| 2 |
+
|
| 3 |
+
আপনি একজন বিশেষজ্ঞ চিকিৎসা সম্পাদক এবং স্বাস্থ্য‑সাক্ষরতা বিশেষজ্ঞ। আপনার কাজ হলো জটিল চিকিৎসাবিষয়ক লেখাকে এমন একটি সংস্করণে রূপান্তর করা, যা মাঝারি স্বাস্থ্য‑সাক্ষরতা সম্পন্ন পাঠকদের জন্য উপযোগী। আপনাকে ইনপুটের মূল ভাষা অপরিবর্তিত রেখে ভাষার জটিলতা ও তথ্যের ঘনত্বকে ভারসাম্যপূর্ণ করতে হবে। সরলীকৃত লেখাটি যেন সঠিক ও গুরুত্বপূর্ণ তথ্য বজায় রাখে, সে জন্য আপনাকে মূল তথ্য ও বার্তাকে ভিত্তি হিসেবে ব্যবহার করতে হবে।
|
| 4 |
+
|
| 5 |
+
**ব্যবহারকারী নির্দেশনা:**
|
| 6 |
+
|
| 7 |
+
দয়া করে নিচের চিকিৎসাবিষয়ক সোর্স টেক্সট ব্যবহার করে **মাঝারি স্বাস্থ্য‑সাক্ষরতা (মাঝারি পাঠযোগ্যতা)** স্তরের জন্য একটি সংস্করণ তৈরি করুন।
|
| 8 |
+
|
| 9 |
+
### নির্দেশনা:
|
| 10 |
+
|
| 11 |
+
স্তর: মাঝারি স্বাস্থ্য‑সাক্ষরতা (মাঝারি পাঠযোগ্যতা)
|
| 12 |
+
|
| 13 |
+
লক্ষ্য পাঠক: সাধারণ জনগণ, যারা সাধারণ সংবাদ বা তথ্যভিত্তিক লেখা পড়ে বুঝতে পারেন।
|
| 14 |
+
|
| 15 |
+
ভাষাগত লক্ষ্য: মানিকৃত ও সহজবোধ্য শব্দভাণ্ডার ব্যবহার করুন। পরিচিত চিকিৎসাবিষয়ক শব্দ ব্যবহার করা যেতে পারে, তবে অতিরিক্ত টেকনিক্যাল "ডাক্তারি ভাষা" এলে তা সহজ ব্যাখ্যায় রূপান্তর করুন।
|
| 16 |
+
|
| 17 |
+
তথ্যের ঘনত্ব: ভারসাম্যপূর্ণ রাখুন। মূল বার্তাকে সামনে রেখে মূল কাঠামো তৈরি করুন এবং প্রয়োজন হলে সোর্স টেক্সট থেকে প্রাসঙ্গিক অতিরিক্ত তথ্য বা প্রেক্ষাপট যোগ করুন।
|
| 18 |
+
|
| 19 |
+
কৌশল: মাঝারি মাত্রার পুনর্লিখন করুন। অতি খুঁটিনাটি টেকনিক্যাল ডিটেইল বাদ দিন, যাতে পাঠক তথ্যের চাপে না পড়ে কিন্তু মূল বিষয়টি স্পষ্টভাবে বুঝতে পারে।
|
| 20 |
+
|
| 21 |
+
বিশ্বস্ততা: লেখাটি যেন মূল বার্তা, ক্রম এবং যুক্তি বজায় রাখে।
|
| 22 |
+
|
| 23 |
+
আমি আপনাকে নিম্নোক্ত তথ্যগুলো প্রদান করব:
|
| 24 |
+
|
| 25 |
+
- ইনপুট ভাষা: {source_lang}
|
| 26 |
+
- সোর্স টেক্সট (বিস্তারিত মূল লেখা): {full_text}
|
| 27 |
+
|
| 28 |
+
**আউটপুট ফরম্যাট (শুধু JSON):**
|
| 29 |
+
{{
|
| 30 |
+
"intermediate_health_literacy": "..."
|
| 31 |
+
}}
|
code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_low
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**সিস্টেম ভূমিকা:**
|
| 2 |
+
|
| 3 |
+
আপনি একজন বিশেষজ্ঞ চিকিৎসা সম্পাদক এবং স্বাস্থ্য‑সাক্ষরতা বিশেষজ্ঞ। আপনার কাজ হলো জটিল চিকিৎসাবিষয়ক লেখাকে এমনভাবে রূপান্তর করা, যা কম স্বাস্থ্য‑সাক্ষরতা সম্পন্ন পাঠকদের জন্য সহজে বোঝা যায়। আপনাকে ইনপুটের মূল ভাষা অপরিবর্তিত রাখতে হবে, তবে ভাষার জটিলতা কমিয়ে আনতে হবে। সরলীকৃত লেখাটি যেন সঠিক ও প্রয়োজনীয় থাকে, সে জন্য আপনাকে মূল তথ্য ও বার্তাকে ভিত্তি হিসেবে ব্যবহার করতে হবে।
|
| 4 |
+
|
| 5 |
+
**ব্যবহারকারী নির্দেশনা:**
|
| 6 |
+
|
| 7 |
+
দয়া করে নিচের চিকিৎসাবিষয়ক সোর্স টেক্সট ব্যবহার করে **কম স্বাস্থ্য‑সাক্ষরতা (উচ্চ পাঠযোগ্যতা)** স্তরের জন্য একটি সংস্করণ তৈরি করুন।
|
| 8 |
+
|
| 9 |
+
### নির্দেশনা:
|
| 10 |
+
|
| 11 |
+
স্তর: কম স্বাস্থ্য‑সাক্ষরতা (উচ্চ পাঠযোগ্যতা)
|
| 12 |
+
|
| 13 |
+
লক্ষ্য পাঠক: এমন ব্যক্তি, যাঁরা খুব সহজ, সরাসরি ভাষায় তথ্য পেতে চান এবং তা থেকে দ্রুত পদক্ষেপ নিতে চান।
|
| 14 |
+
|
| 15 |
+
ভাষাগত লক্ষ্য: একদম ঘরোয়া/দৈনন্দিন কথাবার্তার ভাষা ব্যবহার করুন। সব ধরনের চিকিৎসাবিষয়ক জারগনকে সহজ বর্ণনামূলক শব্দে রূপান্তর করুন (যেমন, "renal" এর পরিবর্তে "কিডনির সমস্যা" বা "কিডনি" লিখুন)।
|
| 16 |
+
|
| 17 |
+
তথ্যের ঘনত্ব: কেবলমাত্র "অবশ্যই জানা দরকার" ধরনের মূল তথ্যগুলো রাখুন। অপ্রয়োজনীয় ব্যাখ্যা বা অতিরিক্ত ডেটা এড়িয়ে চলুন।
|
| 18 |
+
|
| 19 |
+
কৌশল: উচ্চ মাত্রার পুনর্লিখন করুন এবং প্রয়োজন হলে সহজ উপমা বা উদাহরণ ব্যবহার করুন। প্রতিটি বাক্যে একটি করে স্পষ্ট ধারণা রাখুন।
|
| 20 |
+
|
| 21 |
+
বিশ্বস্ততা: লেখাটি অবশ্যই গোল্ড সামারি‑র সাথে সম্পূর্ণ সামঞ্জস্যপূর্ণ হতে হবে; নতুন তথ্য যোগ করা যাবে না।
|
| 22 |
+
|
| 23 |
+
আমি আপনাকে নিম্নোক্ত তথ্যগুলো প্রদান করব:
|
| 24 |
+
|
| 25 |
+
- ইনপুট ভাষা: {source_lang}
|
| 26 |
+
- সোর্স টেক্সট (বিস্তারিত মূল লেখা): {full_text}
|
| 27 |
+
|
| 28 |
+
**আউটপুট ফরম্যাট (শুধু JSON):**
|
| 29 |
+
{{
|
| 30 |
+
"low_health_literacy": "..."
|
| 31 |
+
}}
|
code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_proficient
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**সিস্টেম ভূমিকা:**
|
| 2 |
+
|
| 3 |
+
আপনি একজন বিশেষজ্ঞ চিকিৎসা সম্পাদক এবং স্বাস্থ্য‑সাক্ষরতা বিশেষজ্ঞ। আপনার কাজ হলো জটিল চিকিৎসাবিষয়ক লেখাকে এমন একটি সংস্করণে রূপান্তর করা, যা উচ্চ/প্রফিসিয়েন্ট স্বাস্থ্য‑সাক্ষরতা সম্পন্ন পাঠকদের জন্য উপযোগী। আপনাকে ইনপুটের মূল ভাষা বজায় রেখে টেকনিক্যাল ও একাডেমিক ভাষার যথাযথ ব্যবহার করতে হবে। আপনি মূল তথ্যকে রেফারেন্স হিসেবে ব্যবহার করবেন, তবে প্রয়োজনে সোর্স টেক্সট থেকে গভীরতর বৈজ্ঞানিক প্রেক্ষাপটও যোগ করতে পারবেন।
|
| 4 |
+
|
| 5 |
+
**ব্যবহারকারী নির্দেশনা:**
|
| 6 |
+
|
| 7 |
+
দয়া করে নিচের চিকিৎসাবিষয়ক সোর্স টেক্সট ব্যবহার করে **উচ্চ/প্রফিসিয়েন্ট স্বাস্থ্য‑সাক্ষরতা (কম পাঠযোগ্যতা, উচ্চ জটিলতা)** স্তরের জন্য একটি সংস্করণ তৈরি করুন।
|
| 8 |
+
|
| 9 |
+
### নির্দেশনা:
|
| 10 |
+
|
| 11 |
+
স্তর: প্রফিসিয়েন্ট স্বাস্থ্য‑সাক্ষরতা (কম পাঠযোগ্যতা)
|
| 12 |
+
|
| 13 |
+
লক্ষ্য পাঠক: গবেষক, ক্লিনিশিয়ান, বা অত্যন্ত সচেতন/অভিজ্ঞ রোগী।
|
| 14 |
+
|
| 15 |
+
ভাষাগত লক্ষ্য: প্রয়োজন অনুযায়ী টেকনিক্যাল ও একাডেমিক ভাষা ব্যবহার করুন। ক্লিনিক্যাল সূক্ষ্মতা, প্যাথোফিজিওলজি, ডায়াগনস্টিক মানদণ্ড ইত্যাদির নির্ভুল উপস্থাপনাকে অগ্রাধিকার দিন।
|
| 16 |
+
|
| 17 |
+
তথ্যের ঘনত্ব: উচ্চ রাখুন। সোর্স টেক্সট থেকে ডেটা, পরিসংখ্যান, শারীরবৃত্তীয় প্রক্রিয়া, চিকিৎসাপদ্ধতি এবং গবেষণালব্ধ তথ্য উপযুক্তভাবে অন্তর্ভুক্ত করুন।
|
| 18 |
+
|
| 19 |
+
কৌশল: কম মাত্রার পুনর্লিখন করুন। মূল টেকনিক্যাল পরিভাষা, গঠন এবং গুরুত্বপূর্ণ বাক্যগুলো যতটা সম্ভব অক্ষুণ্ণ রাখুন; প্রয়োজনে কেবল ব্যাকরণগত বা শৈলগত সামঞ্জস্যের জন্য পরিবর্তন করুন।
|
| 20 |
+
|
| 21 |
+
বিশ্বস্ততা: সোর্স টেক্সটের প্রতি ঘনিষ্ঠভাবে অনুগত থাকুন; প্রয়োজন হলে বৈজ্ঞানিক প্রেক্ষাপট ও ব্যাখ্যা সম্প্রসারণ করতে সম্পর্কিত উপ‑দাবি বা তথ্য যোগ করতে পারেন, তবে ভিত্তিহীন নতুন দাবি যোগ করবেন না।
|
| 22 |
+
|
| 23 |
+
আমি আপনাকে নিম্নোক্ত তথ্যগুলো প্রদান করব:
|
| 24 |
+
|
| 25 |
+
- ইনপুট ভাষা: {source_lang}
|
| 26 |
+
- সোর্স টেক্সট (বিস্তারিত মূল লেখা): {full_text}
|
| 27 |
+
|
| 28 |
+
**আউটপুট ফরম্যাট (শুধু JSON):**
|
| 29 |
+
{{
|
| 30 |
+
"proficient_health_literacy": "..."
|
| 31 |
+
}}
|
code/fine_tune_sft_dpo/qwen3-inference-vllm_bn.py
CHANGED
|
@@ -7,16 +7,23 @@ merged model was saved to `/home/mshahidul/readctrl/code/fine_tune_sft_dpo/model
|
|
| 7 |
|
| 8 |
import os
|
| 9 |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 10 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = "
|
| 11 |
|
| 12 |
import argparse
|
| 13 |
import json
|
|
|
|
| 14 |
from datetime import datetime
|
| 15 |
|
| 16 |
from vllm import LLM, SamplingParams
|
| 17 |
from transformers import AutoTokenizer
|
| 18 |
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
# Paths
|
| 21 |
BASE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo"
|
| 22 |
MODEL_DIR = os.path.join(BASE_DIR, "model", "bn")
|
|
@@ -73,7 +80,7 @@ def parse_args():
|
|
| 73 |
p.add_argument(
|
| 74 |
"--temperature",
|
| 75 |
type=float,
|
| 76 |
-
default=0.
|
| 77 |
help="Sampling temperature for generation.",
|
| 78 |
)
|
| 79 |
p.add_argument(
|
|
@@ -147,7 +154,10 @@ def main():
|
|
| 147 |
user_prompt = build_user_message(prompts[label], fulltext, summary)
|
| 148 |
chat = [{"role": "user", "content": user_prompt}]
|
| 149 |
formatted = tokenizer.apply_chat_template(
|
| 150 |
-
chat,
|
|
|
|
|
|
|
|
|
|
| 151 |
)
|
| 152 |
|
| 153 |
batched_prompts.append(formatted)
|
|
@@ -189,7 +199,7 @@ def main():
|
|
| 189 |
# Map generation results for this batch back to global indices
|
| 190 |
for idx_in_batch, output in enumerate(outputs):
|
| 191 |
original_idx = batch_indices[idx_in_batch]
|
| 192 |
-
text = output.outputs[0].text
|
| 193 |
generated_texts[original_idx] = text
|
| 194 |
|
| 195 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
|
| 7 |
|
| 8 |
import os
|
| 9 |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 10 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
|
| 11 |
|
| 12 |
import argparse
|
| 13 |
import json
|
| 14 |
+
import re
|
| 15 |
from datetime import datetime
|
| 16 |
|
| 17 |
from vllm import LLM, SamplingParams
|
| 18 |
from transformers import AutoTokenizer
|
| 19 |
|
| 20 |
|
| 21 |
+
def strip_think_blocks(text: str) -> str:
|
| 22 |
+
"""Remove <think>...</think> reasoning blocks from model output."""
|
| 23 |
+
cleaned = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
|
| 24 |
+
return cleaned if cleaned else text
|
| 25 |
+
|
| 26 |
+
|
| 27 |
# Paths
|
| 28 |
BASE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo"
|
| 29 |
MODEL_DIR = os.path.join(BASE_DIR, "model", "bn")
|
|
|
|
| 80 |
p.add_argument(
|
| 81 |
"--temperature",
|
| 82 |
type=float,
|
| 83 |
+
default=0.1,
|
| 84 |
help="Sampling temperature for generation.",
|
| 85 |
)
|
| 86 |
p.add_argument(
|
|
|
|
| 154 |
user_prompt = build_user_message(prompts[label], fulltext, summary)
|
| 155 |
chat = [{"role": "user", "content": user_prompt}]
|
| 156 |
formatted = tokenizer.apply_chat_template(
|
| 157 |
+
chat,
|
| 158 |
+
tokenize=False,
|
| 159 |
+
add_generation_prompt=True,
|
| 160 |
+
enable_thinking=False,
|
| 161 |
)
|
| 162 |
|
| 163 |
batched_prompts.append(formatted)
|
|
|
|
| 199 |
# Map generation results for this batch back to global indices
|
| 200 |
for idx_in_batch, output in enumerate(outputs):
|
| 201 |
original_idx = batch_indices[idx_in_batch]
|
| 202 |
+
text = strip_think_blocks(output.outputs[0].text)
|
| 203 |
generated_texts[original_idx] = text
|
| 204 |
|
| 205 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
code/fine_tune_sft_dpo/qwen3_best_of_n.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Run inference with the finetuned Bangla Qwen3 model on test_bn.json
|
| 3 |
+
and save the generation results under results/bn.
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 7 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
from typing import Any, Dict, List
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Paths (keep in sync with qwen3-finetune_bn.py)
|
| 19 |
+
BASE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo"
|
| 20 |
+
MODEL_SAVE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo/model/bn"
|
| 21 |
+
PROMPT_DIR = os.path.join(BASE_DIR, "prompt_bn")
|
| 22 |
+
TEST_JSON = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo/dataset/bn/test_bn.json"
|
| 23 |
+
RESULTS_DIR = os.path.join(BASE_DIR, "results", "bn")
|
| 24 |
+
SOURCE_LANG = "Bangla"
|
| 25 |
+
|
| 26 |
+
LABEL_TO_PROMPT_FILE = {
|
| 27 |
+
"low_health_literacy": "prompt_low",
|
| 28 |
+
"intermediate_health_literacy": "prompt_intermediate",
|
| 29 |
+
"proficient_health_literacy": "prompt_proficient",
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_prompts() -> Dict[str, str]:
|
| 34 |
+
"""Load prompt templates from prompt_bn directory."""
|
| 35 |
+
prompts: Dict[str, str] = {}
|
| 36 |
+
for label, filename in LABEL_TO_PROMPT_FILE.items():
|
| 37 |
+
path = os.path.join(PROMPT_DIR, filename)
|
| 38 |
+
if os.path.isfile(path):
|
| 39 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 40 |
+
prompts[label] = f.read()
|
| 41 |
+
else:
|
| 42 |
+
raise FileNotFoundError(f"Prompt file not found: {path}")
|
| 43 |
+
return prompts
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def build_user_message(
|
| 47 |
+
prompt_template: str,
|
| 48 |
+
full_text: str,
|
| 49 |
+
gold_summary: str,
|
| 50 |
+
source_lang: str = SOURCE_LANG,
|
| 51 |
+
) -> str:
|
| 52 |
+
"""Fill prompt template with full_text, gold_summary, source_lang."""
|
| 53 |
+
return (
|
| 54 |
+
prompt_template.replace("{full_text}", full_text)
|
| 55 |
+
.replace("{gold_summary}", gold_summary)
|
| 56 |
+
.replace("{source_lang}", source_lang)
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def parse_args() -> argparse.Namespace:
|
| 61 |
+
p = argparse.ArgumentParser(
|
| 62 |
+
description="Run inference with finetuned Qwen3-4B Bangla model on test_bn.json."
|
| 63 |
+
)
|
| 64 |
+
p.add_argument(
|
| 65 |
+
"--max-new-tokens",
|
| 66 |
+
type=int,
|
| 67 |
+
default=512,
|
| 68 |
+
help="Maximum number of new tokens to generate per sample.",
|
| 69 |
+
)
|
| 70 |
+
p.add_argument(
|
| 71 |
+
"--temperature",
|
| 72 |
+
type=float,
|
| 73 |
+
default=0.7,
|
| 74 |
+
help="Sampling temperature.",
|
| 75 |
+
)
|
| 76 |
+
p.add_argument(
|
| 77 |
+
"--top-p",
|
| 78 |
+
type=float,
|
| 79 |
+
default=0.9,
|
| 80 |
+
help="Top-p (nucleus) sampling value.",
|
| 81 |
+
)
|
| 82 |
+
p.add_argument(
|
| 83 |
+
"--output-json",
|
| 84 |
+
type=str,
|
| 85 |
+
default="test_bn_qwen3-4B_sft_inference.json",
|
| 86 |
+
help=(
|
| 87 |
+
"Output JSON filename (saved under results/bn). "
|
| 88 |
+
"If it already exists, it will be overwritten."
|
| 89 |
+
),
|
| 90 |
+
)
|
| 91 |
+
return p.parse_args()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def load_model_and_tokenizer(model_dir: str):
|
| 95 |
+
"""Load the merged finetuned model and tokenizer for inference."""
|
| 96 |
+
if not os.path.isdir(model_dir):
|
| 97 |
+
raise FileNotFoundError(
|
| 98 |
+
f"Finetuned model directory not found: {model_dir}. "
|
| 99 |
+
"Make sure qwen3-finetune_bn.py was run with model saving enabled."
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
print(f"Loading tokenizer from {model_dir}")
|
| 103 |
+
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
| 104 |
+
if tokenizer.pad_token is None:
|
| 105 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 106 |
+
|
| 107 |
+
print(f"Loading model from {model_dir}")
|
| 108 |
+
if torch.cuda.is_available():
|
| 109 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 110 |
+
model_dir,
|
| 111 |
+
torch_dtype=torch.bfloat16,
|
| 112 |
+
device_map="auto",
|
| 113 |
+
)
|
| 114 |
+
else:
|
| 115 |
+
model = AutoModelForCausalLM.from_pretrained(model_dir)
|
| 116 |
+
|
| 117 |
+
model.eval()
|
| 118 |
+
return model, tokenizer
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def run_inference(
|
| 122 |
+
model,
|
| 123 |
+
tokenizer,
|
| 124 |
+
test_items: List[Dict[str, Any]],
|
| 125 |
+
prompts: Dict[str, str],
|
| 126 |
+
max_new_tokens: int,
|
| 127 |
+
temperature: float,
|
| 128 |
+
top_p: float,
|
| 129 |
+
) -> List[Dict[str, Any]]:
|
| 130 |
+
"""Generate adapted texts for each test item."""
|
| 131 |
+
results: List[Dict[str, Any]] = []
|
| 132 |
+
|
| 133 |
+
device = next(model.parameters()).device
|
| 134 |
+
|
| 135 |
+
for idx, item in enumerate(test_items):
|
| 136 |
+
label = item.get("label")
|
| 137 |
+
fulltext = item.get("fulltext", "")
|
| 138 |
+
summary = item.get("summary", "")
|
| 139 |
+
|
| 140 |
+
if not fulltext or label not in prompts:
|
| 141 |
+
# Keep the original item, but note that generation was skipped.
|
| 142 |
+
out_item = dict(item)
|
| 143 |
+
out_item["model_gen_text"] = ""
|
| 144 |
+
out_item["model_gen_skipped"] = True
|
| 145 |
+
results.append(out_item)
|
| 146 |
+
continue
|
| 147 |
+
|
| 148 |
+
user_msg = build_user_message(prompts[label], fulltext, summary)
|
| 149 |
+
messages = [{"role": "user", "content": user_msg}]
|
| 150 |
+
|
| 151 |
+
text = tokenizer.apply_chat_template(
|
| 152 |
+
messages,
|
| 153 |
+
add_generation_prompt=True,
|
| 154 |
+
tokenize=False,
|
| 155 |
+
)
|
| 156 |
+
inputs = tokenizer(text, return_tensors="pt").to(device)
|
| 157 |
+
input_len = inputs["input_ids"].shape[-1]
|
| 158 |
+
|
| 159 |
+
with torch.no_grad():
|
| 160 |
+
generated_ids = model.generate(
|
| 161 |
+
**inputs,
|
| 162 |
+
max_new_tokens=max_new_tokens,
|
| 163 |
+
do_sample=True,
|
| 164 |
+
temperature=temperature,
|
| 165 |
+
top_p=top_p,
|
| 166 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
gen_ids = generated_ids[0, input_len:]
|
| 170 |
+
gen_text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
|
| 171 |
+
|
| 172 |
+
out_item = dict(item)
|
| 173 |
+
out_item["model_gen_text"] = gen_text
|
| 174 |
+
out_item["model_name"] = "qwen3-4B_sft_bn"
|
| 175 |
+
out_item["model_max_new_tokens"] = max_new_tokens
|
| 176 |
+
out_item["model_temperature"] = temperature
|
| 177 |
+
out_item["model_top_p"] = top_p
|
| 178 |
+
results.append(out_item)
|
| 179 |
+
|
| 180 |
+
if (idx + 1) % 10 == 0:
|
| 181 |
+
print(f"Processed {idx + 1} / {len(test_items)} samples")
|
| 182 |
+
|
| 183 |
+
return results
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def main():
|
| 187 |
+
args = parse_args()
|
| 188 |
+
|
| 189 |
+
os.makedirs(RESULTS_DIR, exist_ok=True)
|
| 190 |
+
|
| 191 |
+
print("Loading prompts from", PROMPT_DIR)
|
| 192 |
+
prompts = load_prompts()
|
| 193 |
+
|
| 194 |
+
print("Loading test data from", TEST_JSON)
|
| 195 |
+
with open(TEST_JSON, "r", encoding="utf-8") as f:
|
| 196 |
+
test_items = json.load(f)
|
| 197 |
+
|
| 198 |
+
print(f"Test samples: {len(test_items)}")
|
| 199 |
+
|
| 200 |
+
model, tokenizer = load_model_and_tokenizer(MODEL_SAVE_DIR)
|
| 201 |
+
|
| 202 |
+
print(
|
| 203 |
+
f"Running inference with max_new_tokens={args.max_new_tokens}, "
|
| 204 |
+
f"temperature={args.temperature}, top_p={args.top_p}"
|
| 205 |
+
)
|
| 206 |
+
results = run_inference(
|
| 207 |
+
model=model,
|
| 208 |
+
tokenizer=tokenizer,
|
| 209 |
+
test_items=test_items,
|
| 210 |
+
prompts=prompts,
|
| 211 |
+
max_new_tokens=args.max_new_tokens,
|
| 212 |
+
temperature=args.temperature,
|
| 213 |
+
top_p=args.top_p,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 217 |
+
output_filename = args.output_json
|
| 218 |
+
if not output_filename.endswith(".json"):
|
| 219 |
+
output_filename += ".json"
|
| 220 |
+
|
| 221 |
+
output_path = os.path.join(RESULTS_DIR, output_filename)
|
| 222 |
+
|
| 223 |
+
# If the filename already exists, append a timestamp to avoid silent overwrite.
|
| 224 |
+
if os.path.exists(output_path):
|
| 225 |
+
name, ext = os.path.splitext(output_filename)
|
| 226 |
+
output_filename = f"{name}_{timestamp}{ext}"
|
| 227 |
+
output_path = os.path.join(RESULTS_DIR, output_filename)
|
| 228 |
+
|
| 229 |
+
print("Saving results to", output_path)
|
| 230 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
| 231 |
+
json.dump(results, f, ensure_ascii=False, indent=2)
|
| 232 |
+
|
| 233 |
+
print("Done.")
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
if __name__ == "__main__":
|
| 237 |
+
main()
|
| 238 |
+
|
code/fine_tune_sft_dpo/qwen3_infer_bn.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Run inference with the finetuned Bangla Qwen3 model on test_bn.json
|
| 3 |
+
and save the generation results under results/bn.
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 7 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
import re
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
from typing import Any, Dict, List
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def strip_think_blocks(text: str) -> str:
|
| 20 |
+
"""Remove <think>...</think> reasoning blocks from model output."""
|
| 21 |
+
cleaned = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
|
| 22 |
+
return cleaned if cleaned else text
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# Paths (keep in sync with qwen3-finetune_bn.py)
|
| 26 |
+
BASE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo"
|
| 27 |
+
MODEL_SAVE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo/model/bn"
|
| 28 |
+
PROMPT_DIR = os.path.join(BASE_DIR, "prompt_bn")
|
| 29 |
+
TEST_JSON = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo/dataset/bn/test_bn.json"
|
| 30 |
+
RESULTS_DIR = os.path.join(BASE_DIR, "results", "bn")
|
| 31 |
+
SOURCE_LANG = "Bangla"
|
| 32 |
+
|
| 33 |
+
LABEL_TO_PROMPT_FILE = {
|
| 34 |
+
"low_health_literacy": "prompt_low",
|
| 35 |
+
"intermediate_health_literacy": "prompt_intermediate",
|
| 36 |
+
"proficient_health_literacy": "prompt_proficient",
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def load_prompts() -> Dict[str, str]:
|
| 41 |
+
"""Load prompt templates from prompt_bn directory."""
|
| 42 |
+
prompts: Dict[str, str] = {}
|
| 43 |
+
for label, filename in LABEL_TO_PROMPT_FILE.items():
|
| 44 |
+
path = os.path.join(PROMPT_DIR, filename)
|
| 45 |
+
if os.path.isfile(path):
|
| 46 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 47 |
+
prompts[label] = f.read()
|
| 48 |
+
else:
|
| 49 |
+
raise FileNotFoundError(f"Prompt file not found: {path}")
|
| 50 |
+
return prompts
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def build_user_message(
|
| 54 |
+
prompt_template: str,
|
| 55 |
+
full_text: str,
|
| 56 |
+
gold_summary: str,
|
| 57 |
+
source_lang: str = SOURCE_LANG,
|
| 58 |
+
) -> str:
|
| 59 |
+
"""Fill prompt template with full_text, gold_summary, source_lang."""
|
| 60 |
+
return (
|
| 61 |
+
prompt_template.replace("{full_text}", full_text)
|
| 62 |
+
.replace("{gold_summary}", gold_summary)
|
| 63 |
+
.replace("{source_lang}", source_lang)
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def parse_args() -> argparse.Namespace:
|
| 68 |
+
p = argparse.ArgumentParser(
|
| 69 |
+
description="Run inference with finetuned Qwen3-4B Bangla model on test_bn.json."
|
| 70 |
+
)
|
| 71 |
+
p.add_argument(
|
| 72 |
+
"--max-new-tokens",
|
| 73 |
+
type=int,
|
| 74 |
+
default=512,
|
| 75 |
+
help="Maximum number of new tokens to generate per sample.",
|
| 76 |
+
)
|
| 77 |
+
p.add_argument(
|
| 78 |
+
"--temperature",
|
| 79 |
+
type=float,
|
| 80 |
+
default=0.7,
|
| 81 |
+
help="Sampling temperature.",
|
| 82 |
+
)
|
| 83 |
+
p.add_argument(
|
| 84 |
+
"--top-p",
|
| 85 |
+
type=float,
|
| 86 |
+
default=0.9,
|
| 87 |
+
help="Top-p (nucleus) sampling value.",
|
| 88 |
+
)
|
| 89 |
+
p.add_argument(
|
| 90 |
+
"--output-json",
|
| 91 |
+
type=str,
|
| 92 |
+
default="test_bn_qwen3-4B_sft_inference.json",
|
| 93 |
+
help=(
|
| 94 |
+
"Output JSON filename (saved under results/bn). "
|
| 95 |
+
"If it already exists, it will be overwritten."
|
| 96 |
+
),
|
| 97 |
+
)
|
| 98 |
+
return p.parse_args()
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def load_model_and_tokenizer(model_dir: str):
|
| 102 |
+
"""Load the merged finetuned model and tokenizer for inference."""
|
| 103 |
+
if not os.path.isdir(model_dir):
|
| 104 |
+
raise FileNotFoundError(
|
| 105 |
+
f"Finetuned model directory not found: {model_dir}. "
|
| 106 |
+
"Make sure qwen3-finetune_bn.py was run with model saving enabled."
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
print(f"Loading tokenizer from {model_dir}")
|
| 110 |
+
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
| 111 |
+
if tokenizer.pad_token is None:
|
| 112 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 113 |
+
|
| 114 |
+
print(f"Loading model from {model_dir}")
|
| 115 |
+
if torch.cuda.is_available():
|
| 116 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 117 |
+
model_dir,
|
| 118 |
+
torch_dtype=torch.bfloat16,
|
| 119 |
+
device_map="auto",
|
| 120 |
+
)
|
| 121 |
+
else:
|
| 122 |
+
model = AutoModelForCausalLM.from_pretrained(model_dir)
|
| 123 |
+
|
| 124 |
+
model.eval()
|
| 125 |
+
return model, tokenizer
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def run_inference(
|
| 129 |
+
model,
|
| 130 |
+
tokenizer,
|
| 131 |
+
test_items: List[Dict[str, Any]],
|
| 132 |
+
prompts: Dict[str, str],
|
| 133 |
+
max_new_tokens: int,
|
| 134 |
+
temperature: float,
|
| 135 |
+
top_p: float,
|
| 136 |
+
) -> List[Dict[str, Any]]:
|
| 137 |
+
"""Generate adapted texts for each test item."""
|
| 138 |
+
results: List[Dict[str, Any]] = []
|
| 139 |
+
|
| 140 |
+
device = next(model.parameters()).device
|
| 141 |
+
|
| 142 |
+
for idx, item in enumerate(test_items):
|
| 143 |
+
label = item.get("label")
|
| 144 |
+
fulltext = item.get("fulltext", "")
|
| 145 |
+
summary = item.get("summary", "")
|
| 146 |
+
|
| 147 |
+
if not fulltext or label not in prompts:
|
| 148 |
+
# Keep the original item, but note that generation was skipped.
|
| 149 |
+
out_item = dict(item)
|
| 150 |
+
out_item["model_gen_text"] = ""
|
| 151 |
+
out_item["model_gen_skipped"] = True
|
| 152 |
+
results.append(out_item)
|
| 153 |
+
continue
|
| 154 |
+
|
| 155 |
+
user_msg = build_user_message(prompts[label], fulltext, summary)
|
| 156 |
+
messages = [{"role": "user", "content": user_msg}]
|
| 157 |
+
|
| 158 |
+
text = tokenizer.apply_chat_template(
|
| 159 |
+
messages,
|
| 160 |
+
add_generation_prompt=True,
|
| 161 |
+
tokenize=False,
|
| 162 |
+
enable_thinking=False,
|
| 163 |
+
)
|
| 164 |
+
inputs = tokenizer(text, return_tensors="pt").to(device)
|
| 165 |
+
input_len = inputs["input_ids"].shape[-1]
|
| 166 |
+
|
| 167 |
+
with torch.no_grad():
|
| 168 |
+
generated_ids = model.generate(
|
| 169 |
+
**inputs,
|
| 170 |
+
max_new_tokens=max_new_tokens,
|
| 171 |
+
do_sample=True,
|
| 172 |
+
temperature=temperature,
|
| 173 |
+
top_p=top_p,
|
| 174 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
gen_ids = generated_ids[0, input_len:]
|
| 178 |
+
gen_text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
|
| 179 |
+
gen_text = strip_think_blocks(gen_text)
|
| 180 |
+
|
| 181 |
+
out_item = dict(item)
|
| 182 |
+
out_item["model_gen_text"] = gen_text
|
| 183 |
+
out_item["model_name"] = "qwen3-4B_sft_bn"
|
| 184 |
+
out_item["model_max_new_tokens"] = max_new_tokens
|
| 185 |
+
out_item["model_temperature"] = temperature
|
| 186 |
+
out_item["model_top_p"] = top_p
|
| 187 |
+
results.append(out_item)
|
| 188 |
+
|
| 189 |
+
if (idx + 1) % 10 == 0:
|
| 190 |
+
print(f"Processed {idx + 1} / {len(test_items)} samples")
|
| 191 |
+
|
| 192 |
+
return results
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def main():
|
| 196 |
+
args = parse_args()
|
| 197 |
+
|
| 198 |
+
os.makedirs(RESULTS_DIR, exist_ok=True)
|
| 199 |
+
|
| 200 |
+
print("Loading prompts from", PROMPT_DIR)
|
| 201 |
+
prompts = load_prompts()
|
| 202 |
+
|
| 203 |
+
print("Loading test data from", TEST_JSON)
|
| 204 |
+
with open(TEST_JSON, "r", encoding="utf-8") as f:
|
| 205 |
+
test_items = json.load(f)
|
| 206 |
+
|
| 207 |
+
print(f"Test samples: {len(test_items)}")
|
| 208 |
+
|
| 209 |
+
model, tokenizer = load_model_and_tokenizer(MODEL_SAVE_DIR)
|
| 210 |
+
|
| 211 |
+
print(
|
| 212 |
+
f"Running inference with max_new_tokens={args.max_new_tokens}, "
|
| 213 |
+
f"temperature={args.temperature}, top_p={args.top_p}"
|
| 214 |
+
)
|
| 215 |
+
results = run_inference(
|
| 216 |
+
model=model,
|
| 217 |
+
tokenizer=tokenizer,
|
| 218 |
+
test_items=test_items,
|
| 219 |
+
prompts=prompts,
|
| 220 |
+
max_new_tokens=args.max_new_tokens,
|
| 221 |
+
temperature=args.temperature,
|
| 222 |
+
top_p=args.top_p,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 226 |
+
output_filename = args.output_json
|
| 227 |
+
if not output_filename.endswith(".json"):
|
| 228 |
+
output_filename += ".json"
|
| 229 |
+
|
| 230 |
+
output_path = os.path.join(RESULTS_DIR, output_filename)
|
| 231 |
+
|
| 232 |
+
# If the filename already exists, append a timestamp to avoid silent overwrite.
|
| 233 |
+
if os.path.exists(output_path):
|
| 234 |
+
name, ext = os.path.splitext(output_filename)
|
| 235 |
+
output_filename = f"{name}_{timestamp}{ext}"
|
| 236 |
+
output_path = os.path.join(RESULTS_DIR, output_filename)
|
| 237 |
+
|
| 238 |
+
print("Saving results to", output_path)
|
| 239 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
| 240 |
+
json.dump(results, f, ensure_ascii=False, indent=2)
|
| 241 |
+
|
| 242 |
+
print("Done.")
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
if __name__ == "__main__":
|
| 246 |
+
main()
|
| 247 |
+
|
code/fine_tune_sft_dpo/results/bn/gpt5_inference_all_20260314_110445.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:72cc6e0f72af25f7b56ac4c623d23be4d57d35e09473e4a012eff27680942e85
|
| 3 |
+
size 16659336
|
code/fine_tune_sft_dpo/results/bn/gpt5_inference_all_wo_gs_20260314_173736.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:195bb9580d98d1a43f642f06f2c982397336c42cb0cf68bc641138997b699f9e
|
| 3 |
+
size 17385349
|
code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-mini_20260314_110445.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:20bf9c12039c4bc33421452806e48ddc1ef7dca13403de06b95642cdf27e4334
|
| 3 |
+
size 5984175
|
code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-mini_20260314_173736.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:49eef179000c958ac8986ee7a3fe0a79d44714b35e9a2a09ca68402eda9cd37f
|
| 3 |
+
size 6234645
|
code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-nano_20260314_110445.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5a2a2acef3f83a916faee240a310e9595f6b6cb043c7e1081f014fce80c5d836
|
| 3 |
+
size 5237182
|
code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-nano_20260314_173736.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b5439fb64537f7f159f61b97f46f4f400238fee1c65fd08f7b514ffc46f9160a
|
| 3 |
+
size 5355465
|
code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5_20260314_110445.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fbd39a27ea21e26c4d1ba7d4ef85c41964713b655ff03574c5d260ec7b94a5e4
|
| 3 |
+
size 5437979
|
code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5_20260314_173736.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3f16ff37f4c633d22a245b087dfb14583f14c3e27def8d896811e31a61af45e5
|
| 3 |
+
size 5795239
|
code/fine_tune_sft_dpo/results/bn/gpt5_inference_summary_20260314_110445.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3a0b61606edceb252580fdb18a15ac237d5d69bd896bea1d430cad5a6113a152
|
| 3 |
+
size 1684
|
code/fine_tune_sft_dpo/results/bn/gpt5_inference_summary_wo_gs_20260314_173736.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eb177630eb8e44f478897e4e96af2ead55760791e59605ce091ff49c8bb7f403
|
| 3 |
+
size 1690
|
code/fine_tune_sft_dpo/results/bn/inference_summary_vllm_20260314_101627.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b5d5cf9a06af93f982c38c5901fbd8478e4bce6ed6bdb3b5480ad1ff586187c2
|
| 3 |
+
size 471
|
code/fine_tune_sft_dpo/results/bn/{inference_summary_vllm_20260311_044629.json → misc/inference_summary_vllm_20260311_044629.json}
RENAMED
|
File without changes
|