shahidul034 commited on
Commit
93694bb
·
1 Parent(s): 030876e

"Update readCtrl repo"

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. code/RL_model/verl/verl_train/reward_func/reward_func/evaluate_scores.py +267 -0
  2. code/RL_model/verl/verl_train/reward_func/reward_func/reward_new_v6_bn_v4.py +4 -4
  3. code/RL_model/verl/verl_train/reward_func/reward_func/reward_new_v6_bn_v4_rmv_src_cov.py +835 -0
  4. code/RL_model/verl/verl_train/script/run_qwen3-8b_v3(2G)_v2.sh +1 -1
  5. code/RL_model/verl/verl_train/script/run_qwen3-8b_v3(2G)_v3.sh +56 -0
  6. code/RL_model/verl/verl_train/script/run_qwen3-8b_v3(4G)_v2.sh +1 -1
  7. code/fine_tune_sft_dpo/best_of_n_qwen3_vllm_bn.py +518 -0
  8. code/fine_tune_sft_dpo/dataset/bn/old/test_bn_subclaims.json +3 -0
  9. code/fine_tune_sft_dpo/dataset/bn/test_bn_subclaims.json +3 -0
  10. code/fine_tune_sft_dpo/eval.sh +0 -2
  11. code/fine_tune_sft_dpo/evaluate_scores_bn.py +554 -0
  12. code/fine_tune_sft_dpo/evaluate_scores_bn_vllm.py +611 -0
  13. code/fine_tune_sft_dpo/evaluate_scores_bn_vllm_rl.py +567 -0
  14. code/fine_tune_sft_dpo/evaluation/bn/bn_200_eval_results.json +3 -0
  15. code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_mini_eval_results_20260316_071029.json +3 -0
  16. code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_mini_prepared_20260316_071029.json +3 -0
  17. code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_nano_eval_results_20260316_071029.json +3 -0
  18. code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_nano_prepared_20260316_071029.json +3 -0
  19. code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_prepared_20260316_071029.json +3 -0
  20. code/fine_tune_sft_dpo/evaluation/bn/test_best_of_n_qwen3-4B_base_eval_results.json +3 -0
  21. code/fine_tune_sft_dpo/evaluation/bn/test_best_of_n_qwen3-4B_sft_eval_results.json +3 -0
  22. code/fine_tune_sft_dpo/evaluation/bn/test_inference_vllm_qwen3-4B_base_eval_results.json +3 -0
  23. code/fine_tune_sft_dpo/evaluation/bn/test_inference_vllm_qwen3-4B_sft_eval_results.json +3 -0
  24. code/fine_tune_sft_dpo/evaluation/bn/test_self_refine_vllm_qwen3_4B_base_eval_results.json +3 -0
  25. code/fine_tune_sft_dpo/evaluation/bn/test_self_refine_vllm_qwen3_4B_sft_eval_results.json +3 -0
  26. code/fine_tune_sft_dpo/evaluation/bn/v1/test_best_of_n_qwen3-4B_base_eval_results.json +3 -0
  27. code/fine_tune_sft_dpo/evaluation/bn/v1/test_best_of_n_qwen3-4B_sft_eval_results.json +3 -0
  28. code/fine_tune_sft_dpo/evaluation/bn/v1/test_inference_vllm_qwen3-4B_base_eval_results.json +3 -0
  29. code/fine_tune_sft_dpo/evaluation/bn/v1/test_inference_vllm_qwen3-4B_sft_eval_results.json +3 -0
  30. code/fine_tune_sft_dpo/evaluation/bn/v1/test_self_refine_vllm_qwen3_4B_base_eval_results.json +3 -0
  31. code/fine_tune_sft_dpo/evaluation/bn/v1/test_self_refine_vllm_qwen3_4B_sft_eval_results.json +3 -0
  32. code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt +58 -0
  33. code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_intermediate +31 -0
  34. code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_low +31 -0
  35. code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_proficient +31 -0
  36. code/fine_tune_sft_dpo/qwen3-inference-vllm_bn.py +14 -4
  37. code/fine_tune_sft_dpo/qwen3_best_of_n.py +238 -0
  38. code/fine_tune_sft_dpo/qwen3_infer_bn.py +247 -0
  39. code/fine_tune_sft_dpo/results/bn/gpt5_inference_all_20260314_110445.jsonl +3 -0
  40. code/fine_tune_sft_dpo/results/bn/gpt5_inference_all_wo_gs_20260314_173736.jsonl +3 -0
  41. code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-mini_20260314_110445.jsonl +3 -0
  42. code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-mini_20260314_173736.jsonl +3 -0
  43. code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-nano_20260314_110445.jsonl +3 -0
  44. code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-nano_20260314_173736.jsonl +3 -0
  45. code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5_20260314_110445.jsonl +3 -0
  46. code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5_20260314_173736.jsonl +3 -0
  47. code/fine_tune_sft_dpo/results/bn/gpt5_inference_summary_20260314_110445.json +3 -0
  48. code/fine_tune_sft_dpo/results/bn/gpt5_inference_summary_wo_gs_20260314_173736.json +3 -0
  49. code/fine_tune_sft_dpo/results/bn/inference_summary_vllm_20260314_101627.json +3 -0
  50. code/fine_tune_sft_dpo/results/bn/{inference_summary_vllm_20260311_044629.json → misc/inference_summary_vllm_20260311_044629.json} +0 -0
code/RL_model/verl/verl_train/reward_func/reward_func/evaluate_scores.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Standalone evaluation script for computing factuality, hallucination, and
4
+ classifier scores on a JSON file.
5
+
6
+ Expected input JSON: a list of objects, each with:
7
+ - fulltext : str (source document)
8
+ - fulltext_subclaims : list[str] (subclaims from fulltext, optional)
9
+ - summary_text : str (summary / reference text)
10
+ - summary_subclaims : list[str] (subclaims from summary)
11
+ - generated_text : str (model-generated text to evaluate)
12
+ - label : str (target literacy level, e.g. "low_health_literacy")
13
+
14
+ The script reuses the reward functions from reward_new_v6_bn_v4_rmv_src_cov.py:
15
+ - factuality_score : fraction of summary subclaims supported by generated_text
16
+ - hallucination_score: fraction of gen subclaims NOT supported by fulltext
17
+ - classifier_score : whether generated_text matches the target literacy level
18
+
19
+ Requires the same vLLM endpoints as the reward file:
20
+ - Support checker : VLLM_SUPPORT_CHECK_BN_API_BASE (default http://localhost:8090/v1)
21
+ - Classifier : VLLM_CLASSIFIER_BN_API_BASE (default http://localhost:8040/v1)
22
+ - Subclaim extractor: VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE (default http://localhost:8050/v1)
23
+
24
+ Usage:
25
+ python evaluate_scores.py --input data.json [--output results.json] [--target-level low_health_literacy]
26
+ """
27
+
28
+ import argparse
29
+ import json
30
+ import os
31
+ import sys
32
+ import time
33
+ from typing import Any, Dict, List, Optional
34
+
35
+ from tqdm import tqdm
36
+
37
+ # Import scoring utilities from the reward module (same directory).
38
+ from reward_new_v6_bn_v4_rmv_src_cov import (
39
+ _call_support_api,
40
+ _compute_classifier_reward,
41
+ _extract_subclaims_from_text,
42
+ _is_bangla_text,
43
+ _nonlinear_grounding,
44
+ compute_rewards,
45
+ )
46
+
47
+
48
+ def evaluate_single(
49
+ item: Dict[str, Any],
50
+ target_level_override: Optional[str] = None,
51
+ ) -> Dict[str, Any]:
52
+ """
53
+ Evaluate a single item and return detailed scores.
54
+ """
55
+ fulltext = item.get("fulltext", "")
56
+ summary_text = item.get("summary_text") or item.get("summary", "")
57
+ summary_subclaims = item.get("summary_subclaims", [])
58
+ generated_text = item.get("generated_text") or item.get("predicted_gen_text", "")
59
+ target_level = target_level_override or item.get("label", "")
60
+
61
+ result: Dict[str, Any] = {
62
+ "doc_id": item.get("doc_id", ""),
63
+ "target_level": target_level,
64
+ "generated_text_len": len(generated_text.strip()) if generated_text else 0,
65
+ "factuality_score": None,
66
+ "hallucination_score": None,
67
+ "classifier_score": None,
68
+ "grounding_score": None,
69
+ "factuality_supported": 0,
70
+ "total_summary_subclaims": len(summary_subclaims) if summary_subclaims else 0,
71
+ "hallucination_supported": 0,
72
+ "total_gen_segments": 0,
73
+ "skipped": False,
74
+ "skip_reason": "",
75
+ }
76
+
77
+ if not generated_text or len(generated_text.strip()) < 10:
78
+ result["skipped"] = True
79
+ result["skip_reason"] = "generated_text missing or too short (<10 chars)"
80
+ return result
81
+
82
+ # -- Factuality & Hallucination via compute_rewards --
83
+ rewards = compute_rewards(
84
+ fulltext=fulltext,
85
+ generated_text=generated_text,
86
+ target_level=target_level,
87
+ summary_subclaims=summary_subclaims,
88
+ summary_text=summary_text,
89
+ )
90
+
91
+ factuality_score = rewards["factuality_score"]
92
+ h_score = rewards["hallucination_score"]
93
+
94
+ if factuality_score is None:
95
+ factuality_score = 0.5
96
+ if h_score is None:
97
+ h_score = 0.5
98
+
99
+ grounding_score = _nonlinear_grounding(h_score)
100
+
101
+ # -- Classifier --
102
+ input_text = fulltext or ""
103
+ class_score = _compute_classifier_reward(target_level, generated_text, input_text)
104
+
105
+ result.update({
106
+ "factuality_score": round(factuality_score, 4),
107
+ "hallucination_score": round(h_score, 4),
108
+ "grounding_score": round(grounding_score, 4),
109
+ "classifier_score": round(class_score, 4),
110
+ "factuality_supported": rewards.get("factuality_supported", 0),
111
+ "total_summary_subclaims": rewards.get("total_summary_subclaims", 0),
112
+ "hallucination_supported": rewards.get("hallucination_supported", 0),
113
+ "total_gen_segments": rewards.get("total_gen_segments", 0),
114
+ })
115
+
116
+ return result
117
+
118
+
119
+ def compute_aggregate(results: List[Dict[str, Any]]) -> Dict[str, Any]:
120
+ """Compute aggregate statistics over all evaluated items."""
121
+ scored = [r for r in results if not r.get("skipped", False)]
122
+ n = len(scored)
123
+ total = len(results)
124
+ skipped = total - n
125
+
126
+ if n == 0:
127
+ return {
128
+ "total_items": total,
129
+ "scored_items": 0,
130
+ "skipped_items": skipped,
131
+ "avg_factuality_score": None,
132
+ "avg_hallucination_score": None,
133
+ "avg_grounding_score": None,
134
+ "avg_classifier_score": None,
135
+ }
136
+
137
+ def safe_avg(key):
138
+ vals = [r[key] for r in scored if r[key] is not None]
139
+ return round(sum(vals) / len(vals), 4) if vals else None
140
+
141
+ return {
142
+ "total_items": total,
143
+ "scored_items": n,
144
+ "skipped_items": skipped,
145
+ "avg_factuality_score": safe_avg("factuality_score"),
146
+ "avg_hallucination_score": safe_avg("hallucination_score"),
147
+ "avg_grounding_score": safe_avg("grounding_score"),
148
+ "avg_classifier_score": safe_avg("classifier_score"),
149
+ }
150
+
151
+
152
+ def main():
153
+ parser = argparse.ArgumentParser(
154
+ description="Evaluate factuality, hallucination, and classifier scores on a JSON file."
155
+ )
156
+ parser.add_argument(
157
+ "--input", "-i", required=True,
158
+ help="Path to input JSON file (list of objects).",
159
+ )
160
+ parser.add_argument(
161
+ "--output", "-o", default=None,
162
+ help="Path to output JSON file with per-item scores. "
163
+ "Defaults to <input_stem>_eval_results.json.",
164
+ )
165
+ parser.add_argument(
166
+ "--target-level", "-t", default=None,
167
+ help="Override target literacy level for all items "
168
+ "(e.g. low_health_literacy). If not set, uses each item's 'label' field.",
169
+ )
170
+ parser.add_argument(
171
+ "--support-check-url", default=None,
172
+ help="Override VLLM_SUPPORT_CHECK_BN_API_BASE.",
173
+ )
174
+ parser.add_argument(
175
+ "--classifier-url", default=None,
176
+ help="Override VLLM_CLASSIFIER_BN_API_BASE.",
177
+ )
178
+ parser.add_argument(
179
+ "--subclaim-extractor-url", default=None,
180
+ help="Override VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE.",
181
+ )
182
+ args = parser.parse_args()
183
+
184
+ if args.support_check_url:
185
+ os.environ["VLLM_SUPPORT_CHECK_BN_API_BASE"] = args.support_check_url
186
+ if args.classifier_url:
187
+ os.environ["VLLM_CLASSIFIER_BN_API_BASE"] = args.classifier_url
188
+ if args.subclaim_extractor_url:
189
+ os.environ["VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE"] = args.subclaim_extractor_url
190
+
191
+ # Load input
192
+ with open(args.input, "r", encoding="utf-8") as f:
193
+ data = json.load(f)
194
+
195
+ if not isinstance(data, list):
196
+ print(f"Error: Expected a JSON list, got {type(data).__name__}.", file=sys.stderr)
197
+ sys.exit(1)
198
+
199
+ print(f"Loaded {len(data)} items from {args.input}")
200
+ print(f" Support check API : {os.getenv('VLLM_SUPPORT_CHECK_BN_API_BASE', 'http://localhost:8090/v1')}")
201
+ print(f" Classifier API : {os.getenv('VLLM_CLASSIFIER_BN_API_BASE', 'http://localhost:8040/v1')}")
202
+ print(f" Subclaim extractor: {os.getenv('VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE', 'http://localhost:8050/v1')}")
203
+ if args.target_level:
204
+ print(f" Target level override: {args.target_level}")
205
+ print("-" * 60)
206
+
207
+ # Evaluate each item
208
+ results = []
209
+ start_time = time.time()
210
+ for idx, item in enumerate(tqdm(data, desc="Evaluating")):
211
+ r = evaluate_single(item, target_level_override=args.target_level)
212
+ r["index"] = idx
213
+ results.append(r)
214
+
215
+ if (idx + 1) % 10 == 0 or idx == 0:
216
+ partial_agg = compute_aggregate(results)
217
+ tqdm.write(
218
+ f" [{idx+1}/{len(data)}] "
219
+ f"fact={partial_agg['avg_factuality_score']} "
220
+ f"hallu={partial_agg['avg_hallucination_score']} "
221
+ f"cls={partial_agg['avg_classifier_score']}"
222
+ )
223
+
224
+ elapsed = time.time() - start_time
225
+
226
+ # Aggregate
227
+ agg = compute_aggregate(results)
228
+
229
+ # Output path
230
+ if args.output:
231
+ out_path = args.output
232
+ else:
233
+ stem = os.path.splitext(os.path.basename(args.input))[0]
234
+ out_dir = os.path.dirname(args.input) or "."
235
+ out_path = os.path.join(out_dir, f"{stem}_eval_results.json")
236
+
237
+ output = {
238
+ "input_file": os.path.abspath(args.input),
239
+ "target_level_override": args.target_level,
240
+ "elapsed_seconds": round(elapsed, 2),
241
+ "aggregate": agg,
242
+ "per_item": results,
243
+ }
244
+
245
+ with open(out_path, "w", encoding="utf-8") as f:
246
+ json.dump(output, f, indent=2, ensure_ascii=False)
247
+
248
+ # Print summary
249
+ print("\n" + "=" * 60)
250
+ print("EVALUATION SUMMARY")
251
+ print("=" * 60)
252
+ print(f" Total items : {agg['total_items']}")
253
+ print(f" Scored items : {agg['scored_items']}")
254
+ print(f" Skipped items : {agg['skipped_items']}")
255
+ print(f" Elapsed time : {round(elapsed, 1)}s")
256
+ print("-" * 60)
257
+ print(f" Avg Factuality Score : {agg['avg_factuality_score']}")
258
+ print(f" Avg Hallucination Score: {agg['avg_hallucination_score']}")
259
+ print(f" Avg Grounding Score : {agg['avg_grounding_score']}")
260
+ print(f" Avg Classifier Score : {agg['avg_classifier_score']}")
261
+ print("-" * 60)
262
+ print(f" Results saved to: {out_path}")
263
+ print("=" * 60)
264
+
265
+
266
+ if __name__ == "__main__":
267
+ main()
code/RL_model/verl/verl_train/reward_func/reward_func/reward_new_v6_bn_v4.py CHANGED
@@ -844,10 +844,10 @@ def compute_score(data_source, solution_str, ground_truth, extra_info=None):
844
  the generated text; must stay within a level-specific
845
  [min, max] range — too little OR too much is penalised.
846
  """
847
- W_FACTUALITY = 0.20
848
- W_HALLU = 0.20
849
- W_SRC_COV = 0.20
850
- W_CLASSIFIER = 0.25
851
  W_LENGTH = 0.15
852
 
853
  FAIL = {
 
844
  the generated text; must stay within a level-specific
845
  [min, max] range — too little OR too much is penalised.
846
  """
847
+ W_FACTUALITY = 0.25
848
+ W_HALLU = 0.15
849
+ W_SRC_COV = 0.25
850
+ W_CLASSIFIER = 0.20
851
  W_LENGTH = 0.15
852
 
853
  FAIL = {
code/RL_model/verl/verl_train/reward_func/reward_func/reward_new_v6_bn_v4_rmv_src_cov.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import os
3
+ import re
4
+ import json
5
+ import argparse
6
+ from typing import Any, List, Dict, Optional
7
+ import warnings
8
+ import requests
9
+ warnings.filterwarnings("ignore")
10
+
11
+ # vLLM endpoints (see script/vllm.sh: support-check port 8090, classifier port 8040)
12
+ # Use localhost when reward runs on same machine; set env vars for remote (e.g. http://HOST:8090/v1).
13
+ VLLM_SUPPORT_CHECK_BN_API_BASE = os.getenv(
14
+ "VLLM_SUPPORT_CHECK_BN_API_BASE",
15
+ "http://localhost:8090/v1",
16
+ )
17
+ # Both support-check and classifier use Bangla prompts.
18
+ SUPPORT_CHECK_PROMPT_LANGUAGE = "bn"
19
+
20
+ VLLM_CLASSIFIER_BN_API_BASE = os.getenv(
21
+ "VLLM_CLASSIFIER_BN_API_BASE",
22
+ "http://localhost:8040/v1",
23
+ )
24
+
25
+ # Subclaim-extractor vLLM endpoint (Bangla medical text → subclaim list)
26
+ VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE = os.getenv(
27
+ "VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE",
28
+ "http://localhost:8050/v1",
29
+ )
30
+
31
+
32
+ # ---------------------------------------------------------------------------
33
+ # Support check via vLLM (Gemma-3 fine-tuned; same prompt as support_check gemma3-finetune)
34
+ # ---------------------------------------------------------------------------
35
+
36
+ def _build_support_list_user_prompt(context: str, subclaims: List[str]) -> str:
37
+ """Build list-of-subclaims support prompt in Bangla (matches support_check model_finetune/gemma3-finetune.py)."""
38
+ numbered = "\n".join(f"{idx + 1}. {sc}" for idx, sc in enumerate(subclaims))
39
+ return (
40
+ "আপনাকে একটি মেডিকেল কেস বর্ণনা এবং একাধিক সাবক্লেইমের তালিকা দেওয়া হবে। "
41
+ "প্রতিটি সাবক্লেইম টেক্সট দ্বারা সমর্থিত কি না তা নির্ধারণ করুন।\n\n"
42
+ f"টেক্সট:\n{context}\n\n"
43
+ f"সাবক্লেইমগুলোর তালিকা:\n{numbered}\n\n"
44
+ "প্রতিটি সাবক্লেইমের জন্য ক্রমানুসারে লেবেল দিন। "
45
+ "নির্দিষ্টভাবে একটি JSON array আকারে উত্তর দিন, যেমন:\n"
46
+ '["supported", "not_supported", ...]\n'
47
+ "অন্য কিছু লিখবেন না।"
48
+ )
49
+
50
+
51
+ def _parse_support_label_array(raw_text: str) -> List[str]:
52
+ """Parse model output to list of 'supported' | 'not_supported' (matches support_check gemma3-finetune)."""
53
+ text = (raw_text or "").strip()
54
+ if not text:
55
+ return []
56
+ if "```" in text:
57
+ text = text.replace("```json", "").replace("```", "").strip()
58
+ start = text.find("[")
59
+ end = text.rfind("]")
60
+ if start != -1 and end != -1 and end > start:
61
+ text = text[start : end + 1]
62
+ parsed = None
63
+ for parser in (json.loads, ast.literal_eval):
64
+ try:
65
+ parsed = parser(text)
66
+ break
67
+ except Exception:
68
+ continue
69
+ if not isinstance(parsed, list):
70
+ return []
71
+ # import ipdb; ipdb.set_trace()
72
+ normalized = []
73
+ for item in parsed:
74
+ if not isinstance(item, str):
75
+ normalized.append("invalid")
76
+ continue
77
+
78
+ label = item.strip().lower().replace("-", "_").replace(" ", "_")
79
+ # Strict keyword check: if not one of these two, it is invalid
80
+ if label in {"supported", "not_supported"}:
81
+ normalized.append(label)
82
+ else:
83
+ normalized.append("invalid")
84
+ return normalized
85
+
86
+
87
+ def _format_gemma3_for_support(user_message: str) -> str:
88
+ """Format user message for Gemma-3 chat (vLLM)."""
89
+ return (
90
+ f"<start_of_turn>user\n{user_message}<end_of_turn>\n"
91
+ "<start_of_turn>model\n"
92
+ )
93
+
94
+
95
+ def _call_vllm_support_check(prompt: str, max_tokens: int = 512, timeout: float = 120.0) -> Optional[str]:
96
+ """Call vLLM completions API for support-check model. Returns generated text or None."""
97
+ base = VLLM_SUPPORT_CHECK_BN_API_BASE.rstrip("/")
98
+ url = f"{base}/completions"
99
+ payload = {
100
+ "prompt": prompt,
101
+ "max_tokens": max_tokens,
102
+ "temperature": 0.0,
103
+ "stop": ["<end_of_turn>", "<start_of_turn>", "\n\n"],
104
+ }
105
+ try:
106
+ resp = requests.post(url, json=payload, timeout=timeout)
107
+ resp.raise_for_status()
108
+ data = resp.json()
109
+ choices = data.get("choices")
110
+ # import ipdb; ipdb.set_trace()
111
+ if choices and len(choices) > 0:
112
+ text = choices[0].get("text", "")
113
+ return (text or "").strip()
114
+ return None
115
+ except requests.exceptions.RequestException:
116
+ return None
117
+
118
+
119
+ def _call_support_api(
120
+ context: str,
121
+ subclaims: List[str],
122
+ threshold: float = 0.5,
123
+ batch_size: int = 128,
124
+ ) -> Optional[List[str]]:
125
+ """
126
+ Call the BN support-check model via vLLM (Gemma-3 fine-tuned, subclaim_list mode).
127
+ Same prompt as support_check/model_finetune/gemma3-finetune.py.
128
+
129
+ Returns
130
+ -------
131
+ List[str] : one label per subclaim — "supported" | "not_supported" | "invalid".
132
+ None : on total failure (network or empty/unparseable response).
133
+ """
134
+ if not context or not subclaims:
135
+ return ["invalid"] * len(subclaims)
136
+
137
+ user_prompt = _build_support_list_user_prompt(context, subclaims)
138
+ prompt = _format_gemma3_for_support(user_prompt)
139
+ raw = _call_vllm_support_check(prompt)
140
+ # import ipdb; ipdb.set_trace()
141
+ if raw is None:
142
+ print("Warning: Support check vLLM call failed (returning None).")
143
+ return None
144
+
145
+ labels = _parse_support_label_array(raw)
146
+ n = len(subclaims)
147
+ if len(labels) < n:
148
+ labels = labels + ["invalid"] * (n - len(labels))
149
+ elif len(labels) > n:
150
+ labels = labels[:n]
151
+ # import ipdb; ipdb.set_trace()
152
+ return labels
153
+
154
+
155
+ # ---------------------------------------------------------------------------
156
+ # Subclaim extractor (Bangla, vLLM) + sentence splitter fallback
157
+ # ---------------------------------------------------------------------------
158
+
159
+ # Minimum character length for a sentence to be considered a real unit.
160
+ # Used only as a fallback when subclaim extraction is unavailable.
161
+ MIN_SENTENCE_CHARS = 15
162
+
163
+
164
+ def _is_bangla_text(text: str, min_bangla_ratio: float = 0.6) -> bool:
165
+ """
166
+ Heuristic check: returns True if the majority of alphabetic characters
167
+ in `text` are Bangla (Unicode block \u0980–\u09FF).
168
+ """
169
+ if not text:
170
+ return False
171
+ bangla_chars = 0
172
+ alpha_chars = 0
173
+ for ch in text:
174
+ if ch.isalpha():
175
+ alpha_chars += 1
176
+ if "\u0980" <= ch <= "\u09FF":
177
+ bangla_chars += 1
178
+ if alpha_chars == 0:
179
+ return False
180
+ return (bangla_chars / alpha_chars) >= min_bangla_ratio
181
+
182
+
183
+ def _split_into_sentences(text: str, min_chars: int = MIN_SENTENCE_CHARS) -> List[str]:
184
+ """
185
+ Split text into sentences at [.!?] boundaries.
186
+ Segments shorter than `min_chars` characters are dropped to
187
+ prevent micro-fragment padding from gaming ratio-based scores.
188
+ """
189
+ if not text or not text.strip():
190
+ return []
191
+ parts = re.split(r"(?<=[.!?\u0964])\s+", text.strip())
192
+ return [s.strip() for s in parts if len(s.strip()) >= min_chars]
193
+
194
+
195
+ def _build_subclaim_extraction_prompt(medical_text: str) -> str:
196
+ """
197
+ Bangla subclaim-extraction prompt (same wording as `extract_bn_subclaims_vllm.py`,
198
+ generalized to "medical text" so it works for any generated explanation).
199
+ """
200
+ return f"""
201
+ You are an expert medical annotator. The following text is in Bangla (Bengali).
202
+
203
+ Your task is to extract granular, factual subclaims from the provided medical text.
204
+ A subclaim is the smallest standalone factual unit that can be independently verified.
205
+
206
+ Instructions:
207
+ 1. Read the Bangla medical text carefully.
208
+ 2. Extract factual statements explicitly stated in the text.
209
+ 3. Each subclaim must:
210
+ - Be in Bangla (same language as the input)
211
+ - Contain exactly ONE factual assertion
212
+ - Come directly from the text (no inference or interpretation)
213
+ - Preserve original wording as much as possible
214
+ - Include any negation, uncertainty, or qualifier
215
+ 4. Do NOT:
216
+ - Combine multiple facts into one subclaim
217
+ - Add new information
218
+ - Translate to another language
219
+ 5. Return ONLY a valid JSON array of strings.
220
+ 6. Use double quotes and valid JSON formatting only (no markdown, no commentary).
221
+
222
+ Medical Text (Bangla):
223
+ {medical_text}
224
+
225
+ Return format:
226
+ [
227
+ "subclaim 1",
228
+ "subclaim 2"
229
+ ]
230
+ """.strip()
231
+
232
+
233
+ def _strip_markdown_json_block(text: str) -> str:
234
+ """Strip optional markdown code fence (e.g. ```json\\n[...]\\n```), if present."""
235
+ text = (text or "").strip()
236
+ if not text:
237
+ return ""
238
+ if text.startswith("```json"):
239
+ text = text[7:].lstrip("\n")
240
+ elif text.startswith("```"):
241
+ text = text[3:].lstrip("\n")
242
+ if text.endswith("```"):
243
+ text = text[:-3].rstrip("\n")
244
+ return text.strip()
245
+
246
+
247
+ def _parse_subclaim_list_output(output_text: str) -> List[str]:
248
+ """Parse subclaim-extractor model output into a list of Bangla subclaims."""
249
+ output_text = (output_text or "").strip()
250
+ if not output_text:
251
+ return []
252
+
253
+ if "</think>" in output_text:
254
+ output_text = output_text.split("</think>")[-1].strip()
255
+
256
+ output_text = _strip_markdown_json_block(output_text)
257
+
258
+ start_idx = output_text.find("[")
259
+ end_idx = output_text.rfind("]") + 1
260
+ if start_idx != -1 and end_idx > start_idx:
261
+ content = output_text[start_idx:end_idx]
262
+ parsed = json.loads(content)
263
+ if isinstance(parsed, list):
264
+ return [str(s).strip() for s in parsed if str(s).strip()]
265
+
266
+ raise ValueError("Incomplete or invalid JSON list")
267
+
268
+
269
+ def _call_vllm_subclaim_extractor(
270
+ text: str,
271
+ max_tokens: int = 2048,
272
+ temperature: float = 0.2,
273
+ timeout: float = 120.0,
274
+ ) -> Optional[List[str]]:
275
+ """
276
+ Call Bangla subclaim-extractor model via vLLM (OpenAI /chat/completions).
277
+
278
+ Returns a list of subclaims on success, or None on total failure.
279
+ """
280
+ if not text or not text.strip():
281
+ return []
282
+
283
+ base = VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE.rstrip("/")
284
+ url = f"{base}/chat/completions"
285
+
286
+ prompt = _build_subclaim_extraction_prompt(text)
287
+ payload = {
288
+ "model": os.getenv("VLLM_SUBCLAIM_EXTRACTOR_MODEL_NAME", "subclaim-extractor"),
289
+ "messages": [{"role": "user", "content": prompt}],
290
+ "temperature": temperature,
291
+ "max_tokens": max_tokens,
292
+ }
293
+
294
+ try:
295
+ resp = requests.post(url, json=payload, timeout=timeout)
296
+ resp.raise_for_status()
297
+ data = resp.json()
298
+ choices = data.get("choices") or []
299
+ if not choices:
300
+ return None
301
+ content = (choices[0].get("message", {}) or {}).get("content", "") or ""
302
+ # import ipdb; ipdb.set_trace()
303
+ return _parse_subclaim_list_output(content)
304
+ except Exception:
305
+ return None
306
+
307
+
308
+ def _extract_subclaims_from_text(text: str) -> List[str]:
309
+ """
310
+ Extract Bangla subclaims from generated text using the vLLM subclaim-extractor.
311
+
312
+ On failure (e.g., server down or parse error), falls back to sentence splitting
313
+ so the rest of the reward logic can still operate.
314
+ """
315
+ subclaims = _call_vllm_subclaim_extractor(text)
316
+ if subclaims is None:
317
+ # Fallback: keep system running even if extractor is unavailable.
318
+ return _split_into_sentences(text)
319
+ return subclaims
320
+
321
+
322
+ # ---------------------------------------------------------------------------
323
+ # Two reward signals:
324
+ # 1. Factuality — summary subclaims vs gen_text (how much summary info is in gen_text)
325
+ # 2. Hallucination — gen_segments vs fulltext (how much gen info is NOT in fulltext)
326
+ # ---------------------------------------------------------------------------
327
+
328
+ def compute_rewards(
329
+ fulltext: str,
330
+ generated_text: str,
331
+ target_level: str,
332
+ summary_subclaims: Optional[List[str]] = None,
333
+ summary_text: Optional[str] = None,
334
+ threshold: float = 0.5,
335
+ batch_size: int = 128,
336
+ ) -> Dict[str, Optional[float]]:
337
+ """
338
+ Compute two independent reward signals.
339
+
340
+ 1. **Factuality** (summary_subclaims → gen_text):
341
+ Use pre-extracted *summary_subclaims*, check how many are supported
342
+ by the generated text. Measures "how much of the summary's information
343
+ made it into the output".
344
+
345
+ 2. **Hallucination** (gen_segments → fulltext):
346
+ Extract subclaims from the *generated text* (gen_segments), then check
347
+ how many are supported by the source fulltext. The *unsupported*
348
+ fraction is the hallucination score (lower is better).
349
+
350
+ Returns dict with:
351
+ factuality_score : [0,1] fraction of summary subclaims supported by gen_text
352
+ factuality_supported : int count
353
+ total_summary_subclaims : int
354
+ hallucination_score : [0,1] fraction of gen_segments NOT supported by fulltext
355
+ hallucination_supported : int count of gen_segments supported by fulltext
356
+ total_gen_segments : int
357
+ """
358
+ result: Dict[str, Any] = {
359
+ "factuality_score": None,
360
+ "factuality_supported": 0,
361
+ "total_summary_subclaims": 0,
362
+ "hallucination_score": None,
363
+ "hallucination_supported": 0,
364
+ "total_gen_segments": 0,
365
+ }
366
+
367
+ gen_segments = _extract_subclaims_from_text(generated_text)
368
+
369
+ if not gen_segments:
370
+ result.update({
371
+ "hallucination_score": 0.0,
372
+ "factuality_score": 0.0,
373
+ })
374
+ return result
375
+
376
+ total_gen = len(gen_segments)
377
+ result["total_gen_segments"] = total_gen
378
+
379
+ # =====================================================================
380
+ # 1. FACTUALITY — summary subclaims checked against gen_text
381
+ # "How much information from the summary exists in the generated text?"
382
+ # =====================================================================
383
+ factuality_score = None
384
+ if summary_subclaims and len(summary_subclaims) > 0:
385
+ result["total_summary_subclaims"] = len(summary_subclaims)
386
+
387
+ labels_summary_vs_gen = _call_support_api(
388
+ context=generated_text,
389
+ subclaims=summary_subclaims,
390
+ threshold=threshold,
391
+ batch_size=batch_size,
392
+ )
393
+ if labels_summary_vs_gen is not None:
394
+ valid = [l for l in labels_summary_vs_gen if str(l).strip().lower() != "invalid"]
395
+ if valid:
396
+ sup = sum(1 for l in valid if str(l).strip().lower() == "supported")
397
+ factuality_score = sup / len(summary_subclaims)
398
+ result["factuality_supported"] = sup
399
+ else:
400
+ factuality_score = 0.0
401
+
402
+ result["factuality_score"] = factuality_score
403
+
404
+ # =====================================================================
405
+ # 2. HALLUCINATION — gen_segments checked against fulltext
406
+ # "How much info in gen_segments is NOT supported by the fulltext?"
407
+ # =====================================================================
408
+ hallucination_score = None
409
+ if fulltext and fulltext.strip():
410
+ labels_gen_vs_full = _call_support_api(
411
+ context=fulltext,
412
+ subclaims=gen_segments,
413
+ threshold=threshold,
414
+ batch_size=batch_size,
415
+ )
416
+ if labels_gen_vs_full is not None and len(labels_gen_vs_full) > 0:
417
+ sup_full = sum(
418
+ 1 for l in labels_gen_vs_full
419
+ if str(l).strip().lower() == "supported"
420
+ )
421
+
422
+ unsupported_indices = [
423
+ i for i, l in enumerate(labels_gen_vs_full)
424
+ if str(l).strip().lower() != "supported"
425
+ ]
426
+
427
+ if unsupported_indices and summary_text and summary_text.strip():
428
+ unsup_segments = [gen_segments[i] for i in unsupported_indices]
429
+ rescue_labels = _call_support_api(
430
+ context=summary_text,
431
+ subclaims=unsup_segments,
432
+ threshold=threshold,
433
+ batch_size=batch_size,
434
+ )
435
+ if rescue_labels:
436
+ rescued = sum(
437
+ 1 for l in rescue_labels
438
+ if str(l).strip().lower() == "supported"
439
+ )
440
+ sup_full += rescued
441
+
442
+ hallucination_score = max(0.0, (total_gen - sup_full) / total_gen)
443
+ result["hallucination_supported"] = sup_full
444
+ else:
445
+ hallucination_score = 0.0
446
+
447
+ result["hallucination_score"] = hallucination_score
448
+
449
+ return result
450
+
451
+
452
+ # ---------------------------------------------------------------------------
453
+ # BN health-literacy classifier via vLLM (Gemma-3 fine-tuned model)
454
+ # Uses Bangla prompt; model is assumed running in vLLM.
455
+ # ---------------------------------------------------------------------------
456
+
457
+
458
+ def build_classification_user_prompt(fulltext: str, gen_text: str) -> str:
459
+ """Build the classification user prompt in English (matches gemma3-finetune.py). Full text is reference; generated text is what to classify."""
460
+ return (
461
+ "You will be given a medical case description as reference (full text) and a generated text to classify. "
462
+ "Determine the patient's health literacy level based only on the generated text.\n\n"
463
+ f"Reference (full text):\n{fulltext}\n\n"
464
+ f"Generated text (to classify):\n{gen_text}\n\n"
465
+ "Reply with exactly one label from this set:\n"
466
+ "low_health_literacy, intermediate_health_literacy, proficient_health_literacy"
467
+ )
468
+
469
+
470
+ def format_gemma3_prompt(user_message: str) -> str:
471
+ """Format user message for Gemma-3 chat (vLLM expects this)."""
472
+ return (
473
+ f"<start_of_turn>user\n{user_message}<end_of_turn>\n"
474
+ "<start_of_turn>model\n"
475
+ )
476
+
477
+
478
+ def _call_vllm_classifier(prompt: str, max_tokens: int = 64, timeout: float = 60.0) -> Optional[str]:
479
+ """
480
+ Call vLLM completions API. Returns generated text or None on failure.
481
+ """
482
+ url = f"{VLLM_CLASSIFIER_BN_API_BASE.rstrip('/')}/completions"
483
+ payload = {
484
+ "prompt": prompt,
485
+ "max_tokens": max_tokens,
486
+ "temperature": 0.0,
487
+ "stop": ["<end_of_turn>", "<start_of_turn>", "\n\n","<eos>"],
488
+ }
489
+ try:
490
+ resp = requests.post(url, json=payload, timeout=timeout)
491
+ resp.raise_for_status()
492
+ data = resp.json()
493
+ choices = data.get("choices")
494
+ # import ipdb; ipdb.set_trace()
495
+ if choices and len(choices) > 0:
496
+ text = choices[0].get("text", "")
497
+ return (text or "").strip()
498
+ return None
499
+ except requests.exceptions.RequestException as exc:
500
+ return None
501
+
502
+
503
+ def _parse_classifier_output(raw: str) -> str:
504
+ """
505
+ Extract health literacy label from model output. Normalize to
506
+ low_health_literacy | intermediate_health_literacy | proficient_health_literacy.
507
+ Returns empty string if no valid label found.
508
+ """
509
+ if not raw:
510
+ return ""
511
+ raw = raw.strip().lower()
512
+ # Take first line and clean
513
+ first_line = raw.split("\n")[0].strip()
514
+ for label in ["low_health_literacy", "intermediate_health_literacy", "proficient_health_literacy"]:
515
+ if label in first_line or label in raw:
516
+ # import ipdb; ipdb.set_trace()
517
+ return label
518
+ return ""
519
+
520
+
521
+ _CLASSIFIER_ERROR_LOGGED = False
522
+
523
+
524
+ def _predict_label(input_text: str, generated_text: str) -> str:
525
+ """
526
+ Run BN health-literacy classifier via vLLM (Gemma-3 fine-tuned).
527
+ Uses fulltext=input_text and gen_text=generated_text; returns normalized label or "".
528
+ """
529
+ global _CLASSIFIER_ERROR_LOGGED
530
+ try:
531
+ user_prompt = build_classification_user_prompt(input_text or "", generated_text or "")
532
+ prompt = format_gemma3_prompt(user_prompt)
533
+ raw = _call_vllm_classifier(prompt)
534
+ # import ipdb; ipdb.set_trace()
535
+ if raw is None:
536
+ if not _CLASSIFIER_ERROR_LOGGED:
537
+ print("Warning: BN classifier vLLM call failed, continuing without it.")
538
+ _CLASSIFIER_ERROR_LOGGED = True
539
+ return ""
540
+ return _parse_classifier_output(raw)
541
+ except Exception as exc:
542
+ if not _CLASSIFIER_ERROR_LOGGED:
543
+ print(f"Warning: BN literacy classifier unavailable, continuing without it: {exc}")
544
+ _CLASSIFIER_ERROR_LOGGED = True
545
+ return ""
546
+
547
+
548
+ def _parse_solution_json(solution_str):
549
+ if isinstance(solution_str, (dict, list)):
550
+ return solution_str
551
+ try:
552
+ cleaned_str = str(solution_str).strip()
553
+ if "```json" in cleaned_str:
554
+ cleaned_str = cleaned_str.split("```json")[1].split("```")[0].strip()
555
+ elif "```" in cleaned_str:
556
+ cleaned_str = cleaned_str.split("```")[1].split("```")[0].strip()
557
+ return json.loads(cleaned_str)
558
+ except Exception:
559
+ return None
560
+
561
+
562
+ def _compute_classifier_reward(target_level: str, gen_text: str, input_text: str) -> float:
563
+ """
564
+ Soft classifier score in [0, 1] (NOT binary +1/-1).
565
+
566
+ 1.0 — predicted label matches target level (correct style)
567
+ 0.0 — predicted label does not match (wrong style)
568
+ 0.5 — classifier unavailable; neutral / no signal
569
+
570
+ Uses BN classifier via vLLM (Gemma-3); needs input_text (fulltext) and gen_text.
571
+ """
572
+ result = _predict_label(input_text, gen_text)
573
+ if result == "": # unavailable → neutral, no penalty
574
+ return 0.5
575
+ if result.strip().lower() == target_level.strip().lower():
576
+ # import ipdb; ipdb.set_trace()
577
+ return 1.0 # correct literacy style
578
+ return 0.0 # wrong literacy style (penalty-free cliff avoided)
579
+
580
+
581
+ # ---------------------------------------------------------------------------
582
+ # Copy-paste penalty (prevent trivial copy of input_text)
583
+ # ---------------------------------------------------------------------------
584
+
585
+ def _approx_copy_ratio(input_text: str, gen_text: str) -> float:
586
+ """
587
+ Rough similarity estimate between input and generated text.
588
+
589
+ - Detects near-verbatim copy via substring + length ratio.
590
+ - Otherwise uses token overlap (gen tokens that also appear in input).
591
+ Returns value in [0, 1], where 1 ≈ almost exact copy.
592
+ """
593
+ a = (input_text or "").strip()
594
+ b = (gen_text or "").strip()
595
+ if not a or not b:
596
+ return 0.0
597
+
598
+ len_a, len_b = len(a), len(b)
599
+ shorter, longer = (a, b) if len_a <= len_b else (b, a)
600
+
601
+ # Near-verbatim copy: one string almost fully contained in the other.
602
+ if shorter and shorter in longer:
603
+ ratio = len(shorter) / max(1, len(longer))
604
+ if ratio >= 0.9:
605
+ return 1.0
606
+
607
+ # Fallback: 3-gram (trigram) token overlap to reduce false positives
608
+ # from shared medical vocabulary (drug names, symptoms, etc.).
609
+ def _tokens(t: str):
610
+ return [tok for tok in re.split(r"\s+", t) if tok]
611
+
612
+ def _shingles(tokens, n=3):
613
+ if len(tokens) < n:
614
+ return set()
615
+ return {" ".join(tokens[i : i + n]) for i in range(len(tokens) - n + 1)}
616
+
617
+ toks_a = _tokens(a)
618
+ toks_b = _tokens(b)
619
+ if not toks_a or not toks_b:
620
+ return 0.0
621
+
622
+ sh_a = _shingles(toks_a, n=3)
623
+ sh_b = _shingles(toks_b, n=3)
624
+ if not sh_a or not sh_b:
625
+ return 0.0
626
+
627
+ overlap = len(sh_a & sh_b) / max(1, len(sh_b))
628
+ return max(0.0, min(1.0, overlap))
629
+
630
+
631
+ def _compute_copy_penalty(input_text: str, gen_text: str) -> float:
632
+ """
633
+ Map copy ratio → penalty in [0, 1].
634
+
635
+ - ≤ 0.7 similarity → no penalty
636
+ - 0.7–1.0 → linearly ramp penalty up to 1.0
637
+ """
638
+ ratio = _approx_copy_ratio(input_text, gen_text)
639
+ if ratio <= 0.7:
640
+ return 0.0
641
+ # Scale [0.7, 1.0] → [0, 1]
642
+ return max(0.0, min(1.0, (ratio - 0.7) / 0.3))
643
+
644
+
645
+ # ---------------------------------------------------------------------------
646
+ # Main scoring function
647
+ # ---------------------------------------------------------------------------
648
+ def _nonlinear_grounding(h_score: float) -> float:
649
+ """
650
+ Sharper penalty for hallucination.
651
+
652
+ h_score=0.00 → 1.00 (perfect)
653
+ h_score=0.05 → 0.95 (mild)
654
+ h_score=0.10 → 0.82 (noticeable)
655
+ h_score=0.17 → 0.65 (significant — was 0.83 before!)
656
+ h_score=0.30 → 0.36 (harsh)
657
+ h_score=0.50 → 0.13 (near zero)
658
+ """
659
+ return max(0.0, (1.0 - h_score) ** 2.5)
660
+ def compute_score(data_source, solution_str, ground_truth, extra_info=None):
661
+ """
662
+ Reward = weighted sum of three components (all in [0, 1]):
663
+
664
+ W_FACTUALITY × factuality_score (summary info present in gen_text)
665
+ W_HALLU × (1 - hallucination_score) (gen_segments grounded in fulltext)
666
+ W_CLASSIFIER × classifier_score (style match)
667
+
668
+ 1. Factuality : extract subclaims from *summary*, check how many are
669
+ supported by the generated text.
670
+ 2. Hallucination: extract subclaims from *generated text*, check how many
671
+ are NOT supported by the fulltext.
672
+ """
673
+ W_FACTUALITY = 0.40
674
+ W_HALLU = 0.25
675
+ W_CLASSIFIER = 0.35
676
+
677
+ FAIL = {
678
+ "score": -1.0,
679
+ "factuality_score": 0.0,
680
+ "hallucination_score": 0.0,
681
+ "classifier_score": 0.0,
682
+ "factuality_supported": 0,
683
+ "hallucination_supported": 0,
684
+ "total_gen_segments": 0,
685
+ }
686
+
687
+ # 1. Parse & validate
688
+ data = _parse_solution_json(solution_str)
689
+ if not data:
690
+ return FAIL
691
+
692
+ target_level = extra_info.get("target_level") if extra_info else None
693
+ gen_text = data.get(target_level, "") if target_level else ""
694
+
695
+ if not gen_text or len(gen_text.strip()) < 10:
696
+ return FAIL
697
+
698
+ if not _is_bangla_text(gen_text):
699
+ return FAIL
700
+
701
+ fulltext = ground_truth.get("fulltext") or ground_truth.get("input_text", "")
702
+ input_text = ground_truth.get("input_text", "")
703
+ summary_subclaims = ground_truth.get("summary_subclaims")
704
+ summary_text = ground_truth.get("summary_text", "")
705
+
706
+ # 2. Compute the two core rewards
707
+ rewards = compute_rewards(
708
+ fulltext=fulltext,
709
+ generated_text=gen_text,
710
+ target_level=target_level,
711
+ summary_subclaims=summary_subclaims,
712
+ summary_text=summary_text,
713
+ )
714
+
715
+ factuality_score = rewards["factuality_score"]
716
+ h_score = rewards["hallucination_score"]
717
+ total_gen_units = rewards.get("total_gen_segments", 0)
718
+
719
+ if factuality_score is None:
720
+ factuality_score = 0.5
721
+ if h_score is None:
722
+ h_score = 0.5
723
+
724
+ grounding_score = _nonlinear_grounding(h_score)
725
+
726
+ # 3. Classifier (style match)
727
+ class_score = _compute_classifier_reward(target_level, gen_text, input_text)
728
+
729
+ # 4. Final weighted sum
730
+ final_reward = (
731
+ W_FACTUALITY * factuality_score +
732
+ W_HALLU * grounding_score +
733
+ W_CLASSIFIER * class_score
734
+ )
735
+
736
+ # 5. Copy-paste penalty
737
+ copy_penalty = _compute_copy_penalty(input_text, gen_text)
738
+ if copy_penalty > 0.0:
739
+ final_reward = max(0.0, final_reward * (1.0 - copy_penalty))
740
+
741
+ return {
742
+ "score": float(final_reward),
743
+ "factuality_score": float(factuality_score),
744
+ "hallucination_score": float(h_score),
745
+ "classifier_score": float(class_score),
746
+ "factuality_supported": int(rewards.get("factuality_supported", 0)),
747
+ "hallucination_supported": int(rewards.get("hallucination_supported", 0)),
748
+ "total_gen_segments": int(total_gen_units),
749
+ }
750
+
751
+
752
+ # ---------------------------------------------------------------------------
753
+ # Test mode
754
+ # ---------------------------------------------------------------------------
755
+
756
+ test_mode = True
757
+ if test_mode:
758
+ import time
759
+
760
+ def run_actual_api_test():
761
+ # Bangla medical example (support-check and classifier use Bangla prompts)
762
+ ground_truth = {
763
+ "summary_subclaims": [
764
+ "লিসিনোপ্রিল উচ্চ রক্তচাপের চিকিৎসায় ব্যবহৃত হয়।",
765
+ "এটি ACE ইনহিবিটর যা আপনার হৃদযন্ত্রকে ভালো কাজ করতে সাহায্য করে।",
766
+ "সাধারণ পার্শ্বপ্রতিক্রিয়ার মধ্যে শুষ্ক কাশি রয়েছে।",
767
+ "গর্ভবতী হলে ব্যবহার করবেন না।",
768
+ ],
769
+ "input_text": (
770
+ "লিসিনোপ্রিল উচ্চ রক্তচাপের চিকিৎসায় ব্যবহৃত হয়। "
771
+ "এটি ACE ইনহিবিটর নামক ওষুধ। "
772
+ "এটি আপনার হৃদযন্ত্রকে ভালো কাজ করতে সাহায্য করে।"
773
+ ),
774
+ "summary_text": (
775
+ "লিসিনোপ্রিল উচ্চ রক্তচাপের চিকিৎসায় ব্যবহৃত হয়। "
776
+ "এটি ACE ইনহিবিটর যা আপনার হৃদযন্ত্রকে ভালো কাজ করতে সাহায্য করে। "
777
+ "সাধারণ পার্শ্বপ্রতিক্রিয়ার মধ্যে শুষ্ক কাশি রয়েছে। "
778
+ "গর্ভবতী হলে ব্যবহার করবেন না।"
779
+ ),
780
+ }
781
+
782
+ # LLM output: low_health_literacy style, grounded in summary
783
+ generated_response = {
784
+ "low_health_literacy": (
785
+ "এই ওষুধ আপনার উচ্চ রক্তচাপের জন্য। "
786
+ "এটি ACE ইনহিবিটর ধরনের ওষুধ। "
787
+ "এটি আপনার ��ৃদযন্ত্রকে ভালো কাজ করতে সাহায্য করে। "
788
+ "গর্ভবতী হলে এই ওষুধ খাবেন না।"
789
+ )
790
+ }
791
+
792
+ solution_str = f"```json\n{json.dumps(generated_response)}\n```"
793
+ extra_info = {"target_level": "low_health_literacy"}
794
+
795
+ print("📡 Running BN reward test (Bangla example)...")
796
+ start_time = time.time()
797
+
798
+ try:
799
+ score = compute_score(
800
+ data_source="real_api_test",
801
+ solution_str=solution_str,
802
+ ground_truth=ground_truth,
803
+ extra_info=extra_info,
804
+ )
805
+
806
+ final_score = score["score"] if isinstance(score, dict) else score
807
+
808
+ duration = time.time() - start_time
809
+ print(f"\nAPI Call Successful ({round(duration, 2)}s)")
810
+ print("-" * 50)
811
+ print(f"Target Level : {extra_info['target_level']}")
812
+ print(f"Final Reward : {round(final_score, 4)}")
813
+ print(f"factuality_score : {round(score.get('factuality_score', 0), 4)} (summary subclaims in gen_text)")
814
+ print(f"hallucination_score : {round(score.get('hallucination_score', 0), 4)} (gen_segments NOT in fulltext)")
815
+ print(f"classifier_score : {round(score.get('classifier_score', 0), 4)}")
816
+ print(f"factuality_supported : {score.get('factuality_supported', 0)}")
817
+ print(f"hallucination_supported: {score.get('hallucination_supported', 0)}")
818
+ print(f"total_gen_segments : {score.get('total_gen_segments', 0)}")
819
+ print("-" * 50)
820
+ print("\nReward definitions:")
821
+ print("- factuality_score : fraction of *summary* subclaims supported by gen_text [0,1]")
822
+ print("- hallucination_score : fraction of *gen_segments* NOT supported by fulltext [0,1] (lower=better)")
823
+ print("- classifier_score : 1.0 match, 0.0 mismatch, 0.5 unavailable")
824
+ print("- Weights: factuality=0.35, grounding=0.30, classifier=0.35")
825
+
826
+ except Exception as e:
827
+ print(f"\n❌ API Call Failed!")
828
+ print(f"Error Type: {type(e).__name__}")
829
+ print(f"Details: {str(e)}")
830
+ print("\nPossible fixes:")
831
+ print("1. Check if the support-check vLLM server is running (VLLM_SUPPORT_CHECK_BN_API_BASE).")
832
+ print("2. Check if the classifier vLLM server is running (VLLM_CLASSIFIER_BN_API_BASE).")
833
+
834
+ if __name__ == "__main__":
835
+ run_actual_api_test()
code/RL_model/verl/verl_train/script/run_qwen3-8b_v3(2G)_v2.sh CHANGED
@@ -10,7 +10,7 @@ PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
10
  algorithm.adv_estimator=grpo \
11
  data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/bn_dataset/train.parquet \
12
  data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/bn_dataset/test.parquet \
13
- custom_reward_function.path=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/reward_func/reward_func/reward_new_v6_bn_v4.py \
14
  data.train_batch_size=256 \
15
  data.max_prompt_length=6144 \
16
  data.max_response_length=2048 \
 
10
  algorithm.adv_estimator=grpo \
11
  data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/bn_dataset/train.parquet \
12
  data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/bn_dataset/test.parquet \
13
+ custom_reward_function.path="/home/mshahidul/readctrl/code/RL_model/verl/verl_train/reward_func/reward_func/reward_new_v6_bn_v4_rmv_src_cov.py" \
14
  data.train_batch_size=256 \
15
  data.max_prompt_length=6144 \
16
  data.max_response_length=2048 \
code/RL_model/verl/verl_train/script/run_qwen3-8b_v3(2G)_v3.sh ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -x
2
+
3
+ unset PYTORCH_CUDA_ALLOC_CONF
4
+ export EXPERIMENT_NAME=qwen3-4b-instruct-bn
5
+ export WAND_PROJECT='readctrl-verl'
6
+ export CUDA_DEVICE_ORDER="PCI_BUS_ID"
7
+ export CUDA_VISIBLE_DEVICES=1,2
8
+
9
+ PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
10
+ algorithm.adv_estimator=grpo \
11
+ data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/bn_dataset/train.parquet \
12
+ data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/bn_dataset/test.parquet \
13
+ custom_reward_function.path="/home/mshahidul/readctrl/code/RL_model/verl/verl_train/reward_func/reward_func/reward_new_v6_bn_v4.py" \
14
+ data.train_batch_size=256 \
15
+ data.max_prompt_length=6144 \
16
+ data.max_response_length=2048 \
17
+ data.filter_overlong_prompts=True \
18
+ data.truncation='error' \
19
+ actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507 \
20
+ actor_rollout_ref.actor.optim.lr=1e-6 \
21
+ actor_rollout_ref.model.use_remove_padding=True \
22
+ actor_rollout_ref.actor.ppo_mini_batch_size=256 \
23
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
24
+ actor_rollout_ref.actor.use_kl_loss=True \
25
+ actor_rollout_ref.actor.kl_loss_coef=0.001 \
26
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
27
+ actor_rollout_ref.actor.entropy_coeff=0 \
28
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
29
+ actor_rollout_ref.actor.fsdp_config.param_offload=False \
30
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
31
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
32
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
33
+ actor_rollout_ref.rollout.name=vllm \
34
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.35 \
35
+ actor_rollout_ref.rollout.enforce_eager=True \
36
+ actor_rollout_ref.rollout.max_model_len=8192 \
37
+ actor_rollout_ref.rollout.n=5 \
38
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \
39
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
40
+ algorithm.use_kl_in_reward=False \
41
+ trainer.critic_warmup=0 \
42
+ trainer.logger='["console","wandb"]' \
43
+ trainer.project_name=$WAND_PROJECT \
44
+ trainer.experiment_name=$EXPERIMENT_NAME \
45
+ trainer.n_gpus_per_node=2 \
46
+ trainer.nnodes=1 \
47
+ trainer.save_freq=20 \
48
+ trainer.test_freq=10 \
49
+ trainer.log_val_generations=1 \
50
+ +trainer.remove_previous_ckpt_in_save=true \
51
+ trainer.max_actor_ckpt_to_keep=1 \
52
+ trainer.max_critic_ckpt_to_keep=1 \
53
+ trainer.resume_mode=auto \
54
+ trainer.default_local_dir=/home/mshahidul/readctrl/code/RL_model/models/reward_new_v6_bn_v4_test2 \
55
+ trainer.total_epochs=45 $@ \
56
+ 2>&1 | tee $EXPERIMENT_NAME.log
code/RL_model/verl/verl_train/script/run_qwen3-8b_v3(4G)_v2.sh CHANGED
@@ -54,6 +54,6 @@ PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
54
  trainer.max_actor_ckpt_to_keep=1 \
55
  trainer.max_critic_ckpt_to_keep=1 \
56
  trainer.resume_mode=auto \
57
- trainer.default_local_dir=/home/mshahidul/readctrl/code/RL_model/models/bn_wo_summary \
58
  trainer.total_epochs=45 $@ \
59
  2>&1 | tee $EXPERIMENT_NAME.log
 
54
  trainer.max_actor_ckpt_to_keep=1 \
55
  trainer.max_critic_ckpt_to_keep=1 \
56
  trainer.resume_mode=auto \
57
+ trainer.default_local_dir=/home/mshahidul/readctrl/code/RL_model/models/readCtrl_RL_bn_srcCov_v1 \
58
  trainer.total_epochs=45 $@ \
59
  2>&1 | tee $EXPERIMENT_NAME.log
code/fine_tune_sft_dpo/best_of_n_qwen3_vllm_bn.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
3
+ os.environ["CUDA_VISIBLE_DEVICES"] = "6"
4
+
5
+ import argparse
6
+ import json
7
+ import re
8
+ from datetime import datetime
9
+ from typing import Any, Dict, List, Tuple
10
+
11
+ from vllm import LLM, SamplingParams
12
+ from transformers import AutoTokenizer
13
+
14
+
15
+ def strip_think_blocks(text: str) -> str:
16
+ """Remove <think>...</think> reasoning blocks from model output."""
17
+ cleaned = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
18
+ return cleaned if cleaned else text
19
+
20
+
21
+ BASE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo"
22
+ FINETUNED_MODEL_DIR = os.path.join(BASE_DIR, "model", "bn")
23
+ PROMPT_DIR = os.path.join(BASE_DIR, "prompt_bn")
24
+ TEST_JSON = os.path.join(BASE_DIR, "dataset", "bn", "test_bn.json")
25
+ RESULTS_DIR = os.path.join(BASE_DIR, "results", "bn")
26
+
27
+ SOURCE_LANG = "Bengali"
28
+
29
+ LABEL_TO_PROMPT_FILE = {
30
+ "low_health_literacy": "prompt_low",
31
+ "intermediate_health_literacy": "prompt_intermediate",
32
+ "proficient_health_literacy": "prompt_proficient",
33
+ }
34
+
35
+ LABEL_TO_READABILITY = {
36
+ "low_health_literacy": (
37
+ "Low Health Literacy (High Readability): individuals needing the simplest "
38
+ "terms for immediate action, using 'living room' language, one idea per "
39
+ "sentence, and focusing only on need-to-know information from the Gold Summary."
40
+ ),
41
+ "intermediate_health_literacy": (
42
+ "Intermediate Health Literacy (Medium Readability): the general public at a "
43
+ "news-reading level, with standard vocabulary and some common medical terms, "
44
+ "and a balanced level of detail led by the Gold Summary."
45
+ ),
46
+ "proficient_health_literacy": (
47
+ "Proficient Health Literacy (Low Readability): researchers, clinicians, or "
48
+ "highly informed patients, using technical and academic language, high "
49
+ "information density, and full clinical nuance and terminology from the "
50
+ "Source Text."
51
+ ),
52
+ }
53
+
54
+
55
+ def load_prompts(prompt_dir: str) -> Dict[str, str]:
56
+ prompts: Dict[str, str] = {}
57
+ for label, filename in LABEL_TO_PROMPT_FILE.items():
58
+ path = os.path.join(prompt_dir, filename)
59
+ if os.path.isfile(path):
60
+ with open(path, "r", encoding="utf-8") as f:
61
+ prompts[label] = f.read()
62
+ else:
63
+ raise FileNotFoundError(f"Prompt file not found: {path}")
64
+ return prompts
65
+
66
+
67
+ def build_generation_user_message(
68
+ prompt_template: str,
69
+ full_text: str,
70
+ gold_summary: str,
71
+ source_lang: str = SOURCE_LANG,
72
+ ) -> str:
73
+ return (
74
+ prompt_template.replace("{full_text}", full_text)
75
+ .replace("{gold_summary}", gold_summary)
76
+ .replace("{source_lang}", source_lang)
77
+ )
78
+
79
+
80
+ def build_selection_user_message(
81
+ full_text: str,
82
+ label: str,
83
+ candidates: List[str],
84
+ source_lang: str = SOURCE_LANG,
85
+ ) -> str:
86
+ readability = LABEL_TO_READABILITY.get(label, label)
87
+ numbered = []
88
+ for i, cand in enumerate(candidates, start=1):
89
+ numbered.append(f"[{i}]\n{cand.strip()}")
90
+ candidates_block = "\n\n".join(numbered)
91
+
92
+ return (
93
+ "You are selecting the best patient-friendly summary of a medical case.\n\n"
94
+ f"Original text ({source_lang}):\n{full_text}\n\n"
95
+ f"Readability requirement: {readability}.\n\n"
96
+ f"Here are {len(candidates)} candidate summaries:\n\n"
97
+ f"{candidates_block}\n\n"
98
+ "Choose the single candidate that best matches the readability "
99
+ "requirement and accurately reflects the key clinical information.\n"
100
+ "Answer with exactly one line in the form:\n"
101
+ '"BEST_INDEX: k"\n'
102
+ f"where k is an integer from 1 to {len(candidates)}."
103
+ )
104
+
105
+
106
+ def parse_best_index(text: str, num_candidates: int) -> int:
107
+ # Look for an integer in the model output; default to 1 if parsing fails.
108
+ match = re.search(r"(\d+)", text)
109
+ if not match:
110
+ return 1
111
+ idx = int(match.group(1))
112
+ if idx < 1 or idx > num_candidates:
113
+ return 1
114
+ return idx
115
+
116
+
117
+ def build_generation_prompts_for_model(
118
+ tokenizer,
119
+ test_list: List[Dict[str, Any]],
120
+ prompts: Dict[str, str],
121
+ source_lang: str = SOURCE_LANG,
122
+ ) -> Tuple[List[str], List[Dict[str, Any]]]:
123
+ batched_prompts: List[str] = []
124
+ meta: List[Dict[str, Any]] = []
125
+
126
+ for idx, item in enumerate(test_list):
127
+ label = item.get("label")
128
+ doc_id = item.get("doc_id", idx)
129
+ fulltext = item.get("fulltext", "")
130
+ summary = item.get("summary", "")
131
+ gold_gen_text = item.get("gen_text", "")
132
+
133
+ if label not in prompts:
134
+ meta.append(
135
+ {
136
+ "doc_id": doc_id,
137
+ "label": label,
138
+ "gold_gen_text": gold_gen_text,
139
+ "fulltext": fulltext,
140
+ "summary": summary,
141
+ "error": f"Unknown label: {label}",
142
+ }
143
+ )
144
+ batched_prompts.append(None) # type: ignore[arg-type]
145
+ continue
146
+
147
+ user_prompt = build_generation_user_message(
148
+ prompts[label],
149
+ fulltext,
150
+ summary,
151
+ source_lang=source_lang,
152
+ )
153
+ chat = [{"role": "user", "content": user_prompt}]
154
+ formatted = tokenizer.apply_chat_template(
155
+ chat,
156
+ tokenize=False,
157
+ add_generation_prompt=True,
158
+ enable_thinking=False,
159
+ )
160
+
161
+ batched_prompts.append(formatted)
162
+ meta.append(
163
+ {
164
+ "doc_id": doc_id,
165
+ "label": label,
166
+ "gold_gen_text": gold_gen_text,
167
+ "fulltext": fulltext,
168
+ "summary": summary,
169
+ "error": None,
170
+ }
171
+ )
172
+
173
+ return batched_prompts, meta
174
+
175
+
176
+ def run_best_of_n_for_model(
177
+ model_id: str,
178
+ model_key: str,
179
+ test_list: List[Dict[str, Any]],
180
+ prompts: Dict[str, str],
181
+ max_new_tokens: int,
182
+ temperature: float,
183
+ num_candidates: int,
184
+ batch_size: int,
185
+ source_lang: str = SOURCE_LANG,
186
+ ) -> Dict[int, Dict[str, Any]]:
187
+ print(f"\n=== Running model {model_key}: {model_id} ===")
188
+
189
+ print("Loading tokenizer...")
190
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
191
+
192
+ print("Preparing prompts...")
193
+ batched_prompts, meta = build_generation_prompts_for_model(
194
+ tokenizer, test_list, prompts, source_lang=source_lang
195
+ )
196
+
197
+ print("Loading vLLM model...")
198
+ llm = LLM(
199
+ model=model_id,
200
+ trust_remote_code=True,
201
+ )
202
+
203
+ gen_sampling_params = SamplingParams(
204
+ temperature=temperature,
205
+ max_tokens=max_new_tokens,
206
+ n=num_candidates,
207
+ )
208
+
209
+ # Filter out None prompts (unknown labels) for generation
210
+ valid_indices = [i for i, p in enumerate(batched_prompts) if p is not None]
211
+ valid_prompts = [batched_prompts[i] for i in valid_indices]
212
+
213
+ total_valid = len(valid_prompts)
214
+ batch_size = max(1, batch_size)
215
+ print(
216
+ f"Running vLLM generation on {total_valid} samples "
217
+ f"in batches of {batch_size} with Best-of-{num_candidates}..."
218
+ )
219
+
220
+ candidates_per_idx: Dict[int, List[str]] = {}
221
+
222
+ num_batches = (total_valid + batch_size - 1) // batch_size
223
+ for batch_idx in range(num_batches):
224
+ start = batch_idx * batch_size
225
+ end = min(start + batch_size, total_valid)
226
+ batch_prompts = valid_prompts[start:end]
227
+ batch_indices = valid_indices[start:end]
228
+
229
+ print(
230
+ f"Generating batch {batch_idx + 1}/{num_batches} "
231
+ f"with {len(batch_prompts)} samples..."
232
+ )
233
+ outputs = llm.generate(batch_prompts, sampling_params=gen_sampling_params)
234
+
235
+ for idx_in_batch, output in enumerate(outputs):
236
+ original_idx = batch_indices[idx_in_batch]
237
+ cand_texts = [strip_think_blocks(o.text) for o in output.outputs]
238
+ candidates_per_idx[original_idx] = cand_texts
239
+
240
+ # Now build selection prompts to choose the best candidate for each valid sample.
241
+ print("Building selection prompts for Best-of-N choice...")
242
+ selection_prompts: List[str] = []
243
+ selection_indices: List[int] = []
244
+ reverse_map: Dict[int, int] = {}
245
+
246
+ for original_idx in valid_indices:
247
+ info = meta[original_idx]
248
+ if info["error"] is not None:
249
+ continue
250
+ cands = candidates_per_idx.get(original_idx, [])
251
+ if not cands:
252
+ continue
253
+ sel_user = build_selection_user_message(
254
+ info["fulltext"],
255
+ info["label"],
256
+ cands,
257
+ source_lang=source_lang,
258
+ )
259
+ chat = [{"role": "user", "content": sel_user}]
260
+ formatted = tokenizer.apply_chat_template(
261
+ chat,
262
+ tokenize=False,
263
+ add_generation_prompt=True,
264
+ enable_thinking=False,
265
+ )
266
+ reverse_map[len(selection_prompts)] = original_idx
267
+ selection_prompts.append(formatted)
268
+
269
+ select_sampling_params = SamplingParams(
270
+ temperature=0.0,
271
+ max_tokens=32,
272
+ n=1,
273
+ )
274
+
275
+ best_index_per_idx: Dict[int, int] = {}
276
+
277
+ total_select = len(selection_prompts)
278
+ if total_select > 0:
279
+ print(
280
+ f"Running selection passes on {total_select} samples "
281
+ f"in batches of {batch_size}..."
282
+ )
283
+ num_sel_batches = (total_select + batch_size - 1) // batch_size
284
+ for batch_idx in range(num_sel_batches):
285
+ start = batch_idx * batch_size
286
+ end = min(start + batch_size, total_select)
287
+ batch_prompts = selection_prompts[start:end]
288
+
289
+ print(
290
+ f"Selecting batch {batch_idx + 1}/{num_sel_batches} "
291
+ f"with {len(batch_prompts)} samples..."
292
+ )
293
+ outputs = llm.generate(
294
+ batch_prompts, sampling_params=select_sampling_params
295
+ )
296
+
297
+ for idx_in_batch, output in enumerate(outputs):
298
+ global_sel_idx = start + idx_in_batch
299
+ original_idx = reverse_map[global_sel_idx]
300
+ raw_text = strip_think_blocks(output.outputs[0].text)
301
+ best_idx = parse_best_index(raw_text, num_candidates)
302
+ best_index_per_idx[original_idx] = best_idx
303
+
304
+ # Build structured results per original index.
305
+ model_results: Dict[int, Dict[str, Any]] = {}
306
+ for idx, info in enumerate(meta):
307
+ if info["error"] is not None:
308
+ model_results[idx] = {
309
+ "error": info["error"],
310
+ }
311
+ continue
312
+
313
+ cands = candidates_per_idx.get(idx, [])
314
+ best_idx = best_index_per_idx.get(idx, 1 if cands else None)
315
+ best_summary = (
316
+ cands[best_idx - 1] if cands and best_idx is not None and 1 <= best_idx <= len(cands) else ""
317
+ )
318
+
319
+ model_results[idx] = {
320
+ "candidates": cands,
321
+ "best_index": best_idx,
322
+ "best_summary": best_summary,
323
+ }
324
+
325
+ return model_results
326
+
327
+
328
+ def parse_args():
329
+ p = argparse.ArgumentParser(
330
+ description=(
331
+ "Run vLLM inference with Best-of-N for both the finetuned "
332
+ "Qwen3 model and the base Qwen/Qwen3-4B-Instruct-2507 model "
333
+ "on test_bn.json (Bengali)."
334
+ )
335
+ )
336
+ p.add_argument(
337
+ "--prompt-dir",
338
+ type=str,
339
+ default=PROMPT_DIR,
340
+ help="Directory containing prompt files (prompt_low, prompt_intermediate, prompt_proficient).",
341
+ )
342
+ p.add_argument(
343
+ "--finetuned-model-dir",
344
+ type=str,
345
+ default=FINETUNED_MODEL_DIR,
346
+ help="Path to the merged finetuned model directory.",
347
+ )
348
+ p.add_argument(
349
+ "--test-data",
350
+ type=str,
351
+ default=TEST_JSON,
352
+ help="Path to the test data JSON file.",
353
+ )
354
+ p.add_argument(
355
+ "--src-lang",
356
+ type=str,
357
+ default=SOURCE_LANG,
358
+ help="Source language of the text (e.g. Bengali, English).",
359
+ )
360
+ p.add_argument(
361
+ "--base-model-id",
362
+ type=str,
363
+ default="Qwen/Qwen3-4B-Instruct-2507",
364
+ help="Hugging Face model id for the base Qwen3 instruct model.",
365
+ )
366
+ p.add_argument(
367
+ "--max-new-tokens",
368
+ type=int,
369
+ default=512,
370
+ help="Maximum number of new tokens to generate per candidate.",
371
+ )
372
+ p.add_argument(
373
+ "--temperature",
374
+ type=float,
375
+ default=0.7,
376
+ help="Sampling temperature for candidate generation.",
377
+ )
378
+ p.add_argument(
379
+ "--num-candidates",
380
+ type=int,
381
+ default=5,
382
+ help="Number of candidate summaries to generate per example (N in Best-of-N).",
383
+ )
384
+ p.add_argument(
385
+ "--batch-size",
386
+ type=int,
387
+ default=16,
388
+ help="Batch size for vLLM generation.",
389
+ )
390
+ p.add_argument(
391
+ "--output-file",
392
+ type=str,
393
+ default=None,
394
+ help=(
395
+ "Optional path for the main results JSON file. "
396
+ "If not set, a timestamped name in the results directory is used."
397
+ ),
398
+ )
399
+ p.add_argument(
400
+ "--model",
401
+ type=str,
402
+ choices=["base", "finetuned", "both"],
403
+ default="both",
404
+ help=(
405
+ "Which model(s) to run: 'base' (Qwen3-4B-Instruct), "
406
+ "'finetuned' (local SFT model), or 'both' (default)."
407
+ ),
408
+ )
409
+ return p.parse_args()
410
+
411
+
412
+ def main():
413
+ args = parse_args()
414
+
415
+ os.makedirs(RESULTS_DIR, exist_ok=True)
416
+
417
+ print("Loading prompts from", args.prompt_dir)
418
+ prompts = load_prompts(args.prompt_dir)
419
+
420
+ print("Loading test data from", args.test_data)
421
+ with open(args.test_data, "r", encoding="utf-8") as f:
422
+ test_list = json.load(f)
423
+
424
+ # Run Best-of-N for the selected model(s), one at a time to save GPU memory.
425
+ finetuned_results: Dict[int, Dict[str, Any]] = {}
426
+ base_results: Dict[int, Dict[str, Any]] = {}
427
+
428
+ if args.model in ("finetuned", "both"):
429
+ finetuned_results = run_best_of_n_for_model(
430
+ model_id=args.finetuned_model_dir,
431
+ model_key="qwen3_finetuned",
432
+ test_list=test_list,
433
+ prompts=prompts,
434
+ max_new_tokens=args.max_new_tokens,
435
+ temperature=args.temperature,
436
+ num_candidates=args.num_candidates,
437
+ batch_size=args.batch_size,
438
+ source_lang=args.src_lang,
439
+ )
440
+
441
+ if args.model in ("base", "both"):
442
+ base_results = run_best_of_n_for_model(
443
+ model_id=args.base_model_id,
444
+ model_key="qwen3_base",
445
+ test_list=test_list,
446
+ prompts=prompts,
447
+ max_new_tokens=args.max_new_tokens,
448
+ temperature=args.temperature,
449
+ num_candidates=args.num_candidates,
450
+ batch_size=args.batch_size,
451
+ source_lang=args.src_lang,
452
+ )
453
+
454
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
455
+ if args.output_file:
456
+ out_path = args.output_file
457
+ base, ext = os.path.splitext(out_path)
458
+ if not ext:
459
+ out_path = base + ".json"
460
+ base = out_path.rsplit(".", 1)[0]
461
+ summary_path = base + "_summary.json"
462
+ else:
463
+ out_path = os.path.join(RESULTS_DIR, f"test_best_of_n_vllm_{timestamp}.json")
464
+ summary_path = os.path.join(
465
+ RESULTS_DIR, f"inference_best_of_n_vllm_{timestamp}.json"
466
+ )
467
+
468
+ combined_results = []
469
+ for idx, item in enumerate(test_list):
470
+ label = item.get("label")
471
+ doc_id = item.get("doc_id", idx)
472
+ gold_gen_text = item.get("gen_text", "")
473
+
474
+ entry: Dict[str, Any] = {
475
+ "doc_id": doc_id,
476
+ "label": label,
477
+ "gold_gen_text": gold_gen_text,
478
+ "predicted_label": item.get("predicted_label", ""),
479
+ "prediction_correct": item.get("prediction_correct", None),
480
+ }
481
+
482
+ if args.model in ("finetuned", "both"):
483
+ entry["qwen3_finetuned"] = finetuned_results.get(idx, {})
484
+ if args.model in ("base", "both"):
485
+ entry["qwen3_base"] = base_results.get(idx, {})
486
+
487
+ combined_results.append(entry)
488
+
489
+ with open(out_path, "w", encoding="utf-8") as f:
490
+ json.dump(combined_results, f, ensure_ascii=False, indent=2)
491
+
492
+ summary_data: Dict[str, Any] = {
493
+ "model_run": args.model,
494
+ "test_json": args.test_data,
495
+ "prompt_dir": args.prompt_dir,
496
+ "src_lang": args.src_lang,
497
+ "num_test_samples": len(test_list),
498
+ "results_file": out_path,
499
+ "timestamp": timestamp,
500
+ "max_new_tokens": args.max_new_tokens,
501
+ "temperature": args.temperature,
502
+ "num_candidates": args.num_candidates,
503
+ }
504
+ if args.model in ("finetuned", "both"):
505
+ summary_data["finetuned_model_dir"] = args.finetuned_model_dir
506
+ if args.model in ("base", "both"):
507
+ summary_data["base_model_id"] = args.base_model_id
508
+
509
+ with open(summary_path, "w", encoding="utf-8") as f:
510
+ json.dump(summary_data, f, ensure_ascii=False, indent=2)
511
+
512
+ print(f"\nResults saved to {out_path}")
513
+ print(f"Summary saved to {summary_path}")
514
+
515
+
516
+ if __name__ == "__main__":
517
+ main()
518
+
code/fine_tune_sft_dpo/dataset/bn/old/test_bn_subclaims.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f99326a350c42c5012c10e6898e892ac2962230f627f91a1da814fe8a8f79bba
3
+ size 2249546
code/fine_tune_sft_dpo/dataset/bn/test_bn_subclaims.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f58aca4d71f46ded80cabe51e3f6b96c0774eae5d4680e46eec0cadefe121e9
3
+ size 5741053
code/fine_tune_sft_dpo/eval.sh CHANGED
@@ -1,2 +0,0 @@
1
- python /home/mshahidul/readctrl/code/fine_tune_sft_dpo/test_classifier_with_subclaim_thresholds.py \
2
- --input-file /home/mshahidul/readctrl/code/fine_tune_sft_dpo/results/en/test_best_of_n_qwen3-4B_sft.json
 
 
 
code/fine_tune_sft_dpo/evaluate_scores_bn.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Standalone evaluation script for computing factuality, hallucination, and
4
+ classifier scores on a JSON file.
5
+
6
+ Supports two input formats:
7
+
8
+ 1. **Standard format** — a list of objects, each with:
9
+ - fulltext, summary_text, summary_subclaims, generated_text, label
10
+
11
+ 2. **Best-of-N (BON) format** — a list of objects, each with:
12
+ - doc_id, label, qwen3_base.best_summary (JSON-wrapped generated text)
13
+ Requires a separate --subclaims file to supply fulltext, summary,
14
+ summary_subclaims, and fulltext_subclaims (keyed by doc_id).
15
+
16
+ 3. **Inference format** — a list of objects, each with:
17
+ - doc_id, label, predicted_gen_text (JSON-wrapped evaluated summary),
18
+ optionally gold_gen_text
19
+ predicted_gen_text is the summary to evaluate (same JSON key-by-label
20
+ format as best_summary). Requires --subclaims for fulltext and subclaims.
21
+
22
+ 4. **Self-refine format** — a list of objects, each with:
23
+ - doc_id, label, final_summary (the generated text to evaluate),
24
+ optionally gold_gen_text, gold_summary
25
+ final_summary is the summary to evaluate (plain text or JSON-wrapped by
26
+ label). Requires --subclaims for fulltext and subclaims.
27
+
28
+ The script reuses the reward functions from reward_new_v6_bn_v4_rmv_src_cov.py:
29
+ - factuality_score : fraction of summary subclaims supported by generated_text
30
+ - hallucination_score: fraction of gen subclaims NOT supported by fulltext
31
+ - classifier_score : whether generated_text matches the target literacy level
32
+
33
+ Requires the same vLLM endpoints as the reward file:
34
+ - Support checker : VLLM_SUPPORT_CHECK_BN_API_BASE (default http://localhost:8090/v1)
35
+ - Classifier : VLLM_CLASSIFIER_BN_API_BASE (default http://localhost:8040/v1)
36
+ - Subclaim extractor: VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE (default http://localhost:8050/v1)
37
+
38
+ Usage:
39
+ # Standard format
40
+ python evaluate_scores.py --input data.json [--output results.json]
41
+
42
+ # BON format with subclaims file
43
+ python evaluate_scores.py --input bon_results.json --subclaims subclaims.json --output-dir evaluation/bn/
44
+
45
+ # Inference format (predicted_gen_text as evaluated summary)
46
+ python evaluate_scores.py --input test_inference_vllm_qwen3-4B_base.json --subclaims subclaims.json --output results.json
47
+
48
+ # Self-refine format (final_summary as evaluated summary)
49
+ python evaluate_scores.py --input test_self_refine_vllm_qwen3_4B_base.json --subclaims subclaims.json --output-dir evaluation/bn/
50
+ """
51
+
52
+ import argparse
53
+ import json
54
+ import os
55
+ import re
56
+ import sys
57
+ import time
58
+ from typing import Any, Dict, List, Optional
59
+
60
+ from tqdm import tqdm
61
+
62
+ # Import scoring utilities from the reward module (same directory).
63
+ from reward_new_v6_bn_v4_rmv_src_cov import (
64
+ _call_support_api,
65
+ _compute_classifier_reward,
66
+ _extract_subclaims_from_text,
67
+ _is_bangla_text,
68
+ _nonlinear_grounding,
69
+ compute_rewards,
70
+ )
71
+
72
+
73
+ def extract_text_from_best_summary(best_summary: str, label: str) -> str:
74
+ """Extract the raw generated text from a BON best_summary string.
75
+
76
+ The best_summary is a (possibly truncated) JSON string like:
77
+ '{"proficient_health_literacy": "...actual text..."}'
78
+ We locate the value after the label key and strip JSON wrapping.
79
+ """
80
+ key_pattern = re.compile(re.escape(f'"{label}"') + r'\s*:\s*"')
81
+ m = key_pattern.search(best_summary)
82
+ if not m:
83
+ return best_summary.strip()
84
+ text = best_summary[m.end():]
85
+ if text.endswith('"\n}'):
86
+ text = text[:-3]
87
+ elif text.endswith('"}\n'):
88
+ text = text[:-3]
89
+ elif text.endswith('"}'):
90
+ text = text[:-2]
91
+ elif text.endswith('"'):
92
+ text = text[:-1]
93
+ text = text.replace("\\n", "\n").replace('\\"', '"')
94
+ return text.strip()
95
+
96
+
97
+ def prepare_bon_items(
98
+ bon_data: List[Dict[str, Any]],
99
+ subclaims_data: List[Dict[str, Any]],
100
+ model_key: str = "qwen3_base",
101
+ ) -> List[Dict[str, Any]]:
102
+ """Merge BON results with subclaims data into the standard evaluation format."""
103
+ sc_by_docid = {}
104
+ for item in subclaims_data:
105
+ sc_by_docid[item["doc_id"]] = item
106
+
107
+ prepared = []
108
+ for item in bon_data:
109
+ doc_id = item["doc_id"]
110
+ label = item["label"]
111
+ sc = sc_by_docid.get(doc_id, {})
112
+
113
+ model_data = item.get(model_key, {})
114
+ best_summary = model_data.get("best_summary", "") or model_data.get("predicted_gen_text", "")
115
+ generated_text = extract_text_from_best_summary(best_summary, label)
116
+
117
+ prepared.append({
118
+ "doc_id": doc_id,
119
+ "label": label,
120
+ "fulltext": sc.get("fulltext", ""),
121
+ "summary_text": sc.get("summary", ""),
122
+ "summary_subclaims": sc.get("summary_subclaims", []),
123
+ "fulltext_subclaims": sc.get("fulltext_subclaims", []),
124
+ "generated_text": generated_text,
125
+ "gold_gen_text": item.get("gold_gen_text", ""),
126
+ "predicted_label": item.get("predicted_label", ""),
127
+ "prediction_correct": item.get("prediction_correct", False),
128
+ })
129
+ return prepared
130
+
131
+
132
+ def prepare_inference_items(
133
+ inference_data: List[Dict[str, Any]],
134
+ subclaims_data: List[Dict[str, Any]],
135
+ ) -> List[Dict[str, Any]]:
136
+ """Merge inference-format results (doc_id, label, predicted_gen_text) with
137
+ subclaims data into the standard evaluation format. predicted_gen_text is
138
+ the JSON-wrapped evaluated summary; the raw text is extracted using the
139
+ item's label.
140
+ """
141
+ sc_by_docid = {}
142
+ for item in subclaims_data:
143
+ sc_by_docid[item["doc_id"]] = item
144
+
145
+ prepared = []
146
+ for item in inference_data:
147
+ doc_id = item["doc_id"]
148
+ label = item["label"]
149
+ sc = sc_by_docid.get(doc_id, {})
150
+
151
+ raw_pred = item.get("predicted_gen_text", "") or ""
152
+ generated_text = extract_text_from_best_summary(raw_pred, label)
153
+
154
+ prepared.append({
155
+ "doc_id": doc_id,
156
+ "label": label,
157
+ "fulltext": sc.get("fulltext", ""),
158
+ "summary_text": sc.get("summary", ""),
159
+ "summary_subclaims": sc.get("summary_subclaims", []),
160
+ "fulltext_subclaims": sc.get("fulltext_subclaims", []),
161
+ "generated_text": generated_text,
162
+ "gold_gen_text": item.get("gold_gen_text", ""),
163
+ })
164
+ return prepared
165
+
166
+
167
+ def prepare_self_refine_items(
168
+ self_refine_data: List[Dict[str, Any]],
169
+ subclaims_data: List[Dict[str, Any]],
170
+ ) -> List[Dict[str, Any]]:
171
+ """Merge self-refine format (doc_id, label, final_summary) with subclaims
172
+ data. final_summary is the generated text to evaluate (plain text or
173
+ JSON-wrapped by label); it is extracted and used as generated_text.
174
+ """
175
+ sc_by_docid = {}
176
+ for item in subclaims_data:
177
+ sc_by_docid[item["doc_id"]] = item
178
+
179
+ prepared = []
180
+ for item in self_refine_data:
181
+ doc_id = item["doc_id"]
182
+ label = item["label"]
183
+ sc = sc_by_docid.get(doc_id, {})
184
+
185
+ raw_final = item.get("final_summary", "") or ""
186
+ generated_text = extract_text_from_best_summary(raw_final, label)
187
+
188
+ prepared.append({
189
+ "doc_id": doc_id,
190
+ "label": label,
191
+ "fulltext": sc.get("fulltext", ""),
192
+ "summary_text": sc.get("summary", ""),
193
+ "summary_subclaims": sc.get("summary_subclaims", []),
194
+ "fulltext_subclaims": sc.get("fulltext_subclaims", []),
195
+ "generated_text": generated_text,
196
+ "gold_gen_text": item.get("gold_gen_text", ""),
197
+ })
198
+ return prepared
199
+
200
+
201
+ def evaluate_single(
202
+ item: Dict[str, Any],
203
+ target_level_override: Optional[str] = None,
204
+ ) -> Dict[str, Any]:
205
+ """
206
+ Evaluate a single item and return detailed scores.
207
+ """
208
+ fulltext = item.get("fulltext", "")
209
+ summary_text = item.get("summary_text") or item.get("summary", "")
210
+ summary_subclaims = item.get("summary_subclaims", [])
211
+ generated_text = item.get("generated_text") or item.get("predicted_gen_text", "")
212
+ target_level = target_level_override or item.get("label", "")
213
+
214
+ result: Dict[str, Any] = {
215
+ "doc_id": item.get("doc_id", ""),
216
+ "target_level": target_level,
217
+ "generated_text_len": len(generated_text.strip()) if generated_text else 0,
218
+ "factuality_score": None,
219
+ "hallucination_score": None,
220
+ "classifier_score": None,
221
+ "grounding_score": None,
222
+ "factuality_supported": 0,
223
+ "total_summary_subclaims": len(summary_subclaims) if summary_subclaims else 0,
224
+ "hallucination_supported": 0,
225
+ "total_gen_segments": 0,
226
+ "skipped": False,
227
+ "skip_reason": "",
228
+ }
229
+
230
+ if not generated_text or len(generated_text.strip()) < 10:
231
+ result["skipped"] = True
232
+ result["skip_reason"] = "generated_text missing or too short (<10 chars)"
233
+ return result
234
+
235
+ # -- Factuality & Hallucination via compute_rewards --
236
+ rewards = compute_rewards(
237
+ fulltext=fulltext,
238
+ generated_text=generated_text,
239
+ target_level=target_level,
240
+ summary_subclaims=summary_subclaims,
241
+ summary_text=summary_text,
242
+ )
243
+
244
+ factuality_score = rewards["factuality_score"]
245
+ h_score = rewards["hallucination_score"]
246
+
247
+ if factuality_score is None:
248
+ factuality_score = 0.5
249
+ if h_score is None:
250
+ h_score = 0.5
251
+
252
+ grounding_score = _nonlinear_grounding(h_score)
253
+
254
+ # -- Classifier --
255
+ input_text = fulltext or ""
256
+ class_score = _compute_classifier_reward(target_level, generated_text, input_text)
257
+
258
+ result.update({
259
+ "factuality_score": round(factuality_score, 4),
260
+ "hallucination_score": round(h_score, 4),
261
+ "grounding_score": round(grounding_score, 4),
262
+ "classifier_score": round(class_score, 4),
263
+ "factuality_supported": rewards.get("factuality_supported", 0),
264
+ "total_summary_subclaims": rewards.get("total_summary_subclaims", 0),
265
+ "hallucination_supported": rewards.get("hallucination_supported", 0),
266
+ "total_gen_segments": rewards.get("total_gen_segments", 0),
267
+ })
268
+
269
+ return result
270
+
271
+
272
+ def compute_aggregate(results: List[Dict[str, Any]]) -> Dict[str, Any]:
273
+ """Compute aggregate statistics over all evaluated items."""
274
+ scored = [r for r in results if not r.get("skipped", False)]
275
+ n = len(scored)
276
+ total = len(results)
277
+ skipped = total - n
278
+
279
+ if n == 0:
280
+ return {
281
+ "total_items": total,
282
+ "scored_items": 0,
283
+ "skipped_items": skipped,
284
+ "avg_factuality_score": None,
285
+ "avg_hallucination_score": None,
286
+ "avg_grounding_score": None,
287
+ "avg_classifier_score": None,
288
+ }
289
+
290
+ def safe_avg(key):
291
+ vals = [r[key] for r in scored if r[key] is not None]
292
+ return round(sum(vals) / len(vals), 4) if vals else None
293
+
294
+ return {
295
+ "total_items": total,
296
+ "scored_items": n,
297
+ "skipped_items": skipped,
298
+ "avg_factuality_score": safe_avg("factuality_score"),
299
+ "avg_hallucination_score": safe_avg("hallucination_score"),
300
+ "avg_grounding_score": safe_avg("grounding_score"),
301
+ "avg_classifier_score": safe_avg("classifier_score"),
302
+ }
303
+
304
+
305
+ def main():
306
+ parser = argparse.ArgumentParser(
307
+ description="Evaluate factuality, hallucination, and classifier scores on a JSON file."
308
+ )
309
+ parser.add_argument(
310
+ "--input", "-i", required=True,
311
+ help="Path to input JSON file (list of objects).",
312
+ )
313
+ parser.add_argument(
314
+ "--output", "-o", default=None,
315
+ help="Path to output JSON file with per-item scores. "
316
+ "Defaults to <input_stem>_eval_results.json.",
317
+ )
318
+ parser.add_argument(
319
+ "--output-dir", default=None,
320
+ help="Directory to save output files. If set, output filename is derived "
321
+ "from input filename and placed in this directory.",
322
+ )
323
+ parser.add_argument(
324
+ "--subclaims", "-s", default=None,
325
+ help="Path to subclaims JSON file (for BON format). Provides fulltext, "
326
+ "summary, summary_subclaims, and fulltext_subclaims keyed by doc_id.",
327
+ )
328
+ parser.add_argument(
329
+ "--model-key", default="qwen3_base",
330
+ help="Key in the BON data containing candidates/best_summary (default: qwen3_base).",
331
+ )
332
+ parser.add_argument(
333
+ "--target-level", "-t", default=None,
334
+ help="Override target literacy level for all items "
335
+ "(e.g. low_health_literacy). If not set, uses each item's 'label' field.",
336
+ )
337
+ parser.add_argument(
338
+ "--support-check-url", default=None,
339
+ help="Override VLLM_SUPPORT_CHECK_BN_API_BASE.",
340
+ )
341
+ parser.add_argument(
342
+ "--classifier-url", default=None,
343
+ help="Override VLLM_CLASSIFIER_BN_API_BASE.",
344
+ )
345
+ parser.add_argument(
346
+ "--subclaim-extractor-url", default=None,
347
+ help="Override VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE.",
348
+ )
349
+ args = parser.parse_args()
350
+
351
+ if args.support_check_url:
352
+ os.environ["VLLM_SUPPORT_CHECK_BN_API_BASE"] = args.support_check_url
353
+ if args.classifier_url:
354
+ os.environ["VLLM_CLASSIFIER_BN_API_BASE"] = args.classifier_url
355
+ if args.subclaim_extractor_url:
356
+ os.environ["VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE"] = args.subclaim_extractor_url
357
+
358
+ # Load input
359
+ with open(args.input, "r", encoding="utf-8") as f:
360
+ raw_data = json.load(f)
361
+
362
+ if not isinstance(raw_data, list):
363
+ print(f"Error: Expected a JSON list, got {type(raw_data).__name__}.", file=sys.stderr)
364
+ sys.exit(1)
365
+
366
+ # Detect BON format: items have a model key (e.g. qwen3_base) with best_summary
367
+ is_bon = (
368
+ len(raw_data) > 0
369
+ and args.model_key in raw_data[0]
370
+ and "best_summary" in raw_data[0].get(args.model_key, {})
371
+ )
372
+
373
+ # Detect inference format: top-level doc_id, label, predicted_gen_text; no fulltext/summary_subclaims
374
+ is_inference = (
375
+ len(raw_data) > 0
376
+ and "doc_id" in raw_data[0]
377
+ and "label" in raw_data[0]
378
+ and "predicted_gen_text" in raw_data[0]
379
+ and raw_data[0].get("fulltext") is None
380
+ and raw_data[0].get("summary_subclaims") is None
381
+ )
382
+
383
+ # Detect self-refine format: doc_id, label, final_summary as gen text; no fulltext/summary_subclaims
384
+ is_self_refine = (
385
+ len(raw_data) > 0
386
+ and "doc_id" in raw_data[0]
387
+ and "label" in raw_data[0]
388
+ and "final_summary" in raw_data[0]
389
+ and raw_data[0].get("fulltext") is None
390
+ and raw_data[0].get("summary_subclaims") is None
391
+ )
392
+
393
+ if is_bon:
394
+ if not args.subclaims:
395
+ print("Error: BON format detected but --subclaims file not provided.", file=sys.stderr)
396
+ sys.exit(1)
397
+ with open(args.subclaims, "r", encoding="utf-8") as f:
398
+ subclaims_data = json.load(f)
399
+ print(f"BON format detected (model_key={args.model_key})")
400
+ print(f"Loaded {len(raw_data)} BON items from {args.input}")
401
+ print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}")
402
+ data = prepare_bon_items(raw_data, subclaims_data, model_key=args.model_key)
403
+ print(f"Prepared {len(data)} items for evaluation")
404
+ elif is_inference:
405
+ if not args.subclaims:
406
+ print("Error: Inference format detected (predicted_gen_text) but --subclaims file not provided.", file=sys.stderr)
407
+ sys.exit(1)
408
+ with open(args.subclaims, "r", encoding="utf-8") as f:
409
+ subclaims_data = json.load(f)
410
+ print("Inference format detected (predicted_gen_text as evaluated summary)")
411
+ print(f"Loaded {len(raw_data)} inference items from {args.input}")
412
+ print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}")
413
+ data = prepare_inference_items(raw_data, subclaims_data)
414
+ print(f"Prepared {len(data)} items for evaluation")
415
+ elif is_self_refine:
416
+ if not args.subclaims:
417
+ print("Error: Self-refine format detected (final_summary) but --subclaims file not provided.", file=sys.stderr)
418
+ sys.exit(1)
419
+ with open(args.subclaims, "r", encoding="utf-8") as f:
420
+ subclaims_data = json.load(f)
421
+ print("Self-refine format detected (final_summary as evaluated summary)")
422
+ print(f"Loaded {len(raw_data)} self-refine items from {args.input}")
423
+ print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}")
424
+ data = prepare_self_refine_items(raw_data, subclaims_data)
425
+ print(f"Prepared {len(data)} items for evaluation")
426
+ else:
427
+ data = raw_data
428
+ print(f"Loaded {len(data)} items from {args.input}")
429
+
430
+ print(f" Support check API : {os.getenv('VLLM_SUPPORT_CHECK_BN_API_BASE', 'http://localhost:8090/v1')}")
431
+ print(f" Classifier API : {os.getenv('VLLM_CLASSIFIER_BN_API_BASE', 'http://localhost:8040/v1')}")
432
+ print(f" Subclaim extractor: {os.getenv('VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE', 'http://localhost:8050/v1')}")
433
+ if args.target_level:
434
+ print(f" Target level override: {args.target_level}")
435
+ print("-" * 60)
436
+
437
+ # Evaluate each item
438
+ results = []
439
+ start_time = time.time()
440
+ for idx, item in enumerate(tqdm(data, desc="Evaluating")):
441
+ r = evaluate_single(item, target_level_override=args.target_level)
442
+ r["index"] = idx
443
+ r["doc_id"] = item.get("doc_id", "")
444
+ results.append(r)
445
+
446
+ if (idx + 1) % 10 == 0 or idx == 0:
447
+ partial_agg = compute_aggregate(results)
448
+ tqdm.write(
449
+ f" [{idx+1}/{len(data)}] "
450
+ f"fact={partial_agg['avg_factuality_score']} "
451
+ f"hallu={partial_agg['avg_hallucination_score']} "
452
+ f"cls={partial_agg['avg_classifier_score']}"
453
+ )
454
+
455
+ elapsed = time.time() - start_time
456
+
457
+ # --- Validation: all items must be evaluated with non-null scores ---
458
+ expected_count = len(data)
459
+ skipped_items = [r for r in results if r.get("skipped", False)]
460
+ null_score_items = []
461
+ for r in results:
462
+ if r.get("skipped", False):
463
+ continue
464
+ for key in ("factuality_score", "hallucination_score", "classifier_score", "grounding_score"):
465
+ if r.get(key) is None:
466
+ null_score_items.append((r.get("index"), r.get("doc_id"), key))
467
+
468
+ has_errors = False
469
+ if skipped_items:
470
+ has_errors = True
471
+ print(f"\nERROR: {len(skipped_items)} out of {expected_count} items were skipped:", file=sys.stderr)
472
+ for r in skipped_items:
473
+ print(f" index={r.get('index')} doc_id={r.get('doc_id')} reason={r.get('skip_reason')}", file=sys.stderr)
474
+
475
+ if null_score_items:
476
+ has_errors = True
477
+ print(f"\nERROR: {len(null_score_items)} null score(s) found:", file=sys.stderr)
478
+ for idx, doc_id, key in null_score_items:
479
+ print(f" index={idx} doc_id={doc_id} null_field={key}", file=sys.stderr)
480
+
481
+ if len(results) != expected_count:
482
+ has_errors = True
483
+ print(f"\nERROR: Expected {expected_count} results but got {len(results)}.", file=sys.stderr)
484
+
485
+ if has_errors:
486
+ print(f"\nAborting: will NOT save results. All {expected_count} items must be fully evaluated with non-null scores.", file=sys.stderr)
487
+ sys.exit(1)
488
+
489
+ # Aggregate
490
+ agg = compute_aggregate(results)
491
+
492
+ # Per-label aggregates
493
+ label_groups: Dict[str, List[Dict[str, Any]]] = {}
494
+ for r in results:
495
+ lbl = r.get("target_level", "unknown")
496
+ label_groups.setdefault(lbl, []).append(r)
497
+ per_label_agg = {lbl: compute_aggregate(items) for lbl, items in sorted(label_groups.items())}
498
+
499
+ # Output path
500
+ if args.output:
501
+ out_path = args.output
502
+ elif args.output_dir:
503
+ os.makedirs(args.output_dir, exist_ok=True)
504
+ stem = os.path.splitext(os.path.basename(args.input))[0]
505
+ out_path = os.path.join(args.output_dir, f"{stem}_eval_results.json")
506
+ else:
507
+ stem = os.path.splitext(os.path.basename(args.input))[0]
508
+ out_dir = os.path.dirname(args.input) or "."
509
+ out_path = os.path.join(out_dir, f"{stem}_eval_results.json")
510
+
511
+ os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
512
+
513
+ output = {
514
+ "input_file": os.path.abspath(args.input),
515
+ "subclaims_file": os.path.abspath(args.subclaims) if args.subclaims else None,
516
+ "model_key": args.model_key if is_bon else None,
517
+ "inference_format": is_inference if not is_bon else False,
518
+ "self_refine_format": is_self_refine if not is_bon and not is_inference else False,
519
+ "target_level_override": args.target_level,
520
+ "elapsed_seconds": round(elapsed, 2),
521
+ "aggregate": agg,
522
+ "per_label_aggregate": per_label_agg,
523
+ "per_item": results,
524
+ }
525
+
526
+ with open(out_path, "w", encoding="utf-8") as f:
527
+ json.dump(output, f, indent=2, ensure_ascii=False)
528
+
529
+ # Print summary
530
+ print("\n" + "=" * 60)
531
+ print("EVALUATION SUMMARY")
532
+ print("=" * 60)
533
+ print(f" Total items : {agg['total_items']}")
534
+ print(f" Scored items : {agg['scored_items']}")
535
+ print(f" Skipped items : {agg['skipped_items']}")
536
+ print(f" Elapsed time : {round(elapsed, 1)}s")
537
+ print("-" * 60)
538
+ print(f" Avg Factuality Score : {agg['avg_factuality_score']}")
539
+ print(f" Avg Hallucination Score: {agg['avg_hallucination_score']}")
540
+ print(f" Avg Grounding Score : {agg['avg_grounding_score']}")
541
+ print(f" Avg Classifier Score : {agg['avg_classifier_score']}")
542
+ print("-" * 60)
543
+ for lbl, la in per_label_agg.items():
544
+ print(f" [{lbl}] items={la['scored_items']}"
545
+ f" fact={la['avg_factuality_score']}"
546
+ f" hallu={la['avg_hallucination_score']}"
547
+ f" cls={la['avg_classifier_score']}")
548
+ print("-" * 60)
549
+ print(f" Results saved to: {out_path}")
550
+ print("=" * 60)
551
+
552
+
553
+ if __name__ == "__main__":
554
+ main()
code/fine_tune_sft_dpo/evaluate_scores_bn_vllm.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Standalone evaluation script for computing factuality, hallucination, and
4
+ classifier scores on a JSON/JSONL file.
5
+
6
+ Supports these input formats:
7
+
8
+ 1. **Standard format** — a list of objects, each with:
9
+ - fulltext, summary_text, summary_subclaims, generated_text, label
10
+
11
+ 2. **Best-of-N (BON) format** — a list of objects, each with:
12
+ - doc_id, label, qwen3_base.best_summary (JSON-wrapped generated text)
13
+ Requires a separate --subclaims file to supply fulltext, summary,
14
+ summary_subclaims, and fulltext_subclaims (keyed by doc_id).
15
+
16
+ 3. **Inference format** — a list of objects, each with:
17
+ - doc_id, label, predicted_gen_text (JSON-wrapped evaluated summary),
18
+ optionally gold_gen_text
19
+ predicted_gen_text is the summary to evaluate (same JSON key-by-label
20
+ format as best_summary). Requires --subclaims for fulltext and subclaims.
21
+
22
+ 4. **Self-refine format** — a list of objects, each with:
23
+ - doc_id, label, final_summary (the generated text to evaluate),
24
+ optionally gold_gen_text, gold_summary
25
+ final_summary is the summary to evaluate (plain text or JSON-wrapped by
26
+ label). Requires --subclaims for fulltext and subclaims.
27
+
28
+ 5. **RL inference format** (JSONL) — one JSON object per line, each with:
29
+ - doc_id, gold_label, input_text, summary_text, subclaims, generated_text
30
+ Self-contained: no --subclaims file needed. Field mapping:
31
+ gold_label -> label, input_text -> fulltext, subclaims -> summary_subclaims
32
+
33
+ The script reuses the reward functions from reward_new_v6_bn_v4_rmv_src_cov.py:
34
+ - factuality_score : fraction of summary subclaims supported by generated_text
35
+ - hallucination_score: fraction of gen subclaims NOT supported by fulltext
36
+ - classifier_score : whether generated_text matches the target literacy level
37
+
38
+ Requires the same vLLM endpoints as the reward file:
39
+ - Support checker : VLLM_SUPPORT_CHECK_BN_API_BASE (default http://localhost:8090/v1)
40
+ - Classifier : VLLM_CLASSIFIER_BN_API_BASE (default http://localhost:8040/v1)
41
+ - Subclaim extractor: VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE (default http://localhost:8050/v1)
42
+
43
+ Usage:
44
+ # Standard format
45
+ python evaluate_scores.py --input data.json [--output results.json]
46
+
47
+ # BON format with subclaims file
48
+ python evaluate_scores.py --input bon_results.json --subclaims subclaims.json --output-dir evaluation/bn/
49
+
50
+ # Inference format (predicted_gen_text as evaluated summary)
51
+ python evaluate_scores.py --input test_inference_vllm_qwen3-4B_base.json --subclaims subclaims.json --output results.json
52
+
53
+ # Self-refine format (final_summary as evaluated summary)
54
+ python evaluate_scores.py --input test_self_refine_vllm_qwen3_4B_base.json --subclaims subclaims.json --output-dir evaluation/bn/
55
+
56
+ # RL inference format (JSONL, self-contained)
57
+ python evaluate_scores.py --input bn_200.jsonl --output-dir evaluation/bn/
58
+ """
59
+
60
+ import argparse
61
+ import json
62
+ import os
63
+ import re
64
+ import sys
65
+ import time
66
+ from typing import Any, Dict, List, Optional
67
+
68
+ from tqdm import tqdm
69
+
70
+ # Import scoring utilities from the reward module (same directory).
71
+ from reward_new_v6_bn_v4_rmv_src_cov import (
72
+ _call_support_api,
73
+ _compute_classifier_reward,
74
+ _extract_subclaims_from_text,
75
+ _is_bangla_text,
76
+ _nonlinear_grounding,
77
+ compute_rewards,
78
+ )
79
+
80
+
81
+ def extract_text_from_best_summary(best_summary: str, label: str) -> str:
82
+ """Extract the raw generated text from a BON best_summary string.
83
+
84
+ The best_summary is a (possibly truncated) JSON string like:
85
+ '{"proficient_health_literacy": "...actual text..."}'
86
+ We locate the value after the label key and strip JSON wrapping.
87
+ """
88
+ key_pattern = re.compile(re.escape(f'"{label}"') + r'\s*:\s*"')
89
+ m = key_pattern.search(best_summary)
90
+ if not m:
91
+ return best_summary.strip()
92
+ text = best_summary[m.end():]
93
+ if text.endswith('"\n}'):
94
+ text = text[:-3]
95
+ elif text.endswith('"}\n'):
96
+ text = text[:-3]
97
+ elif text.endswith('"}'):
98
+ text = text[:-2]
99
+ elif text.endswith('"'):
100
+ text = text[:-1]
101
+ text = text.replace("\\n", "\n").replace('\\"', '"')
102
+ return text.strip()
103
+
104
+
105
+ def prepare_bon_items(
106
+ bon_data: List[Dict[str, Any]],
107
+ subclaims_data: List[Dict[str, Any]],
108
+ model_key: str = "qwen3_base",
109
+ ) -> List[Dict[str, Any]]:
110
+ """Merge BON results with subclaims data into the standard evaluation format."""
111
+ sc_by_docid = {}
112
+ for item in subclaims_data:
113
+ sc_by_docid[item["doc_id"]] = item
114
+
115
+ prepared = []
116
+ for item in bon_data:
117
+ doc_id = item["doc_id"]
118
+ label = item["label"]
119
+ sc = sc_by_docid.get(doc_id, {})
120
+
121
+ model_data = item.get(model_key, {})
122
+ best_summary = model_data.get("best_summary", "") or model_data.get("predicted_gen_text", "")
123
+ generated_text = extract_text_from_best_summary(best_summary, label)
124
+
125
+ prepared.append({
126
+ "doc_id": doc_id,
127
+ "label": label,
128
+ "fulltext": sc.get("fulltext", ""),
129
+ "summary_text": sc.get("summary", ""),
130
+ "summary_subclaims": sc.get("summary_subclaims", []),
131
+ "fulltext_subclaims": sc.get("fulltext_subclaims", []),
132
+ "generated_text": generated_text,
133
+ "gold_gen_text": item.get("gold_gen_text", ""),
134
+ "predicted_label": item.get("predicted_label", ""),
135
+ "prediction_correct": item.get("prediction_correct", False),
136
+ })
137
+ return prepared
138
+
139
+
140
+ def prepare_inference_items(
141
+ inference_data: List[Dict[str, Any]],
142
+ subclaims_data: List[Dict[str, Any]],
143
+ ) -> List[Dict[str, Any]]:
144
+ """Merge inference-format results (doc_id, label, predicted_gen_text) with
145
+ subclaims data into the standard evaluation format. predicted_gen_text is
146
+ the JSON-wrapped evaluated summary; the raw text is extracted using the
147
+ item's label.
148
+ """
149
+ sc_by_docid = {}
150
+ for item in subclaims_data:
151
+ sc_by_docid[item["doc_id"]] = item
152
+
153
+ prepared = []
154
+ for item in inference_data:
155
+ doc_id = item["doc_id"]
156
+ label = item["label"]
157
+ sc = sc_by_docid.get(doc_id, {})
158
+
159
+ raw_pred = item.get("predicted_gen_text", "") or ""
160
+ generated_text = extract_text_from_best_summary(raw_pred, label)
161
+
162
+ prepared.append({
163
+ "doc_id": doc_id,
164
+ "label": label,
165
+ "fulltext": sc.get("fulltext", ""),
166
+ "summary_text": sc.get("summary", ""),
167
+ "summary_subclaims": sc.get("summary_subclaims", []),
168
+ "fulltext_subclaims": sc.get("fulltext_subclaims", []),
169
+ "generated_text": generated_text,
170
+ "gold_gen_text": item.get("gold_gen_text", ""),
171
+ })
172
+ return prepared
173
+
174
+
175
+ def prepare_self_refine_items(
176
+ self_refine_data: List[Dict[str, Any]],
177
+ subclaims_data: List[Dict[str, Any]],
178
+ ) -> List[Dict[str, Any]]:
179
+ """Merge self-refine format (doc_id, label, final_summary) with subclaims
180
+ data. final_summary is the generated text to evaluate (plain text or
181
+ JSON-wrapped by label); it is extracted and used as generated_text.
182
+ """
183
+ sc_by_docid = {}
184
+ for item in subclaims_data:
185
+ sc_by_docid[item["doc_id"]] = item
186
+
187
+ prepared = []
188
+ for item in self_refine_data:
189
+ doc_id = item["doc_id"]
190
+ label = item["label"]
191
+ sc = sc_by_docid.get(doc_id, {})
192
+
193
+ raw_final = item.get("final_summary", "") or ""
194
+ generated_text = extract_text_from_best_summary(raw_final, label)
195
+
196
+ prepared.append({
197
+ "doc_id": doc_id,
198
+ "label": label,
199
+ "fulltext": sc.get("fulltext", ""),
200
+ "summary_text": sc.get("summary", ""),
201
+ "summary_subclaims": sc.get("summary_subclaims", []),
202
+ "fulltext_subclaims": sc.get("fulltext_subclaims", []),
203
+ "generated_text": generated_text,
204
+ "gold_gen_text": item.get("gold_gen_text", ""),
205
+ })
206
+ return prepared
207
+
208
+
209
+ def prepare_rl_inference_items(
210
+ rl_data: List[Dict[str, Any]],
211
+ ) -> List[Dict[str, Any]]:
212
+ """Convert RL inference JSONL items into the standard evaluation format.
213
+
214
+ Field mapping:
215
+ gold_label -> label
216
+ input_text -> fulltext
217
+ subclaims -> summary_subclaims
218
+ summary_text -> summary_text
219
+ generated_text -> generated_text (plain text, used as-is)
220
+ """
221
+ prepared = []
222
+ for item in rl_data:
223
+ prepared.append({
224
+ "doc_id": item.get("doc_id", ""),
225
+ "label": item.get("gold_label", ""),
226
+ "fulltext": item.get("input_text", ""),
227
+ "summary_text": item.get("summary_text", ""),
228
+ "summary_subclaims": item.get("subclaims", []),
229
+ "generated_text": item.get("generated_text", ""),
230
+ })
231
+ return prepared
232
+
233
+
234
+ def evaluate_single(
235
+ item: Dict[str, Any],
236
+ target_level_override: Optional[str] = None,
237
+ ) -> Dict[str, Any]:
238
+ """
239
+ Evaluate a single item and return detailed scores.
240
+ """
241
+ fulltext = item.get("fulltext", "")
242
+ summary_text = item.get("summary_text") or item.get("summary", "")
243
+ summary_subclaims = item.get("summary_subclaims", [])
244
+ generated_text = item.get("generated_text") or item.get("predicted_gen_text", "")
245
+ target_level = target_level_override or item.get("label", "")
246
+
247
+ result: Dict[str, Any] = {
248
+ "doc_id": item.get("doc_id", ""),
249
+ "target_level": target_level,
250
+ "generated_text_len": len(generated_text.strip()) if generated_text else 0,
251
+ "factuality_score": None,
252
+ "hallucination_score": None,
253
+ "classifier_score": None,
254
+ "grounding_score": None,
255
+ "factuality_supported": 0,
256
+ "total_summary_subclaims": len(summary_subclaims) if summary_subclaims else 0,
257
+ "hallucination_supported": 0,
258
+ "total_gen_segments": 0,
259
+ "skipped": False,
260
+ "skip_reason": "",
261
+ }
262
+
263
+ if not generated_text or len(generated_text.strip()) < 10:
264
+ result["skipped"] = True
265
+ result["skip_reason"] = "generated_text missing or too short (<10 chars)"
266
+ return result
267
+
268
+ # -- Factuality & Hallucination via compute_rewards --
269
+ rewards = compute_rewards(
270
+ fulltext=fulltext,
271
+ generated_text=generated_text,
272
+ target_level=target_level,
273
+ summary_subclaims=summary_subclaims,
274
+ summary_text=summary_text,
275
+ )
276
+
277
+ factuality_score = rewards["factuality_score"]
278
+ h_score = rewards["hallucination_score"]
279
+
280
+ if factuality_score is None:
281
+ factuality_score = 0.5
282
+ if h_score is None:
283
+ h_score = 0.5
284
+
285
+ grounding_score = _nonlinear_grounding(h_score)
286
+
287
+ # -- Classifier --
288
+ input_text = fulltext or ""
289
+ class_score = _compute_classifier_reward(target_level, generated_text, input_text)
290
+
291
+ result.update({
292
+ "factuality_score": round(factuality_score, 4),
293
+ "hallucination_score": round(h_score, 4),
294
+ "grounding_score": round(grounding_score, 4),
295
+ "classifier_score": round(class_score, 4),
296
+ "factuality_supported": rewards.get("factuality_supported", 0),
297
+ "total_summary_subclaims": rewards.get("total_summary_subclaims", 0),
298
+ "hallucination_supported": rewards.get("hallucination_supported", 0),
299
+ "total_gen_segments": rewards.get("total_gen_segments", 0),
300
+ })
301
+
302
+ return result
303
+
304
+
305
+ def compute_aggregate(results: List[Dict[str, Any]]) -> Dict[str, Any]:
306
+ """Compute aggregate statistics over all evaluated items."""
307
+ scored = [r for r in results if not r.get("skipped", False)]
308
+ n = len(scored)
309
+ total = len(results)
310
+ skipped = total - n
311
+
312
+ if n == 0:
313
+ return {
314
+ "total_items": total,
315
+ "scored_items": 0,
316
+ "skipped_items": skipped,
317
+ "avg_factuality_score": None,
318
+ "avg_hallucination_score": None,
319
+ "avg_grounding_score": None,
320
+ "avg_classifier_score": None,
321
+ }
322
+
323
+ def safe_avg(key):
324
+ vals = [r[key] for r in scored if r[key] is not None]
325
+ return round(sum(vals) / len(vals), 4) if vals else None
326
+
327
+ return {
328
+ "total_items": total,
329
+ "scored_items": n,
330
+ "skipped_items": skipped,
331
+ "avg_factuality_score": safe_avg("factuality_score"),
332
+ "avg_hallucination_score": safe_avg("hallucination_score"),
333
+ "avg_grounding_score": safe_avg("grounding_score"),
334
+ "avg_classifier_score": safe_avg("classifier_score"),
335
+ }
336
+
337
+
338
+ def main():
339
+ parser = argparse.ArgumentParser(
340
+ description="Evaluate factuality, hallucination, and classifier scores on a JSON file."
341
+ )
342
+ parser.add_argument(
343
+ "--input", "-i", required=True,
344
+ help="Path to input JSON file (list of objects).",
345
+ )
346
+ parser.add_argument(
347
+ "--output", "-o", default=None,
348
+ help="Path to output JSON file with per-item scores. "
349
+ "Defaults to <input_stem>_eval_results.json.",
350
+ )
351
+ parser.add_argument(
352
+ "--output-dir", default=None,
353
+ help="Directory to save output files. If set, output filename is derived "
354
+ "from input filename and placed in this directory.",
355
+ )
356
+ parser.add_argument(
357
+ "--subclaims", "-s", default=None,
358
+ help="Path to subclaims JSON file (for BON format). Provides fulltext, "
359
+ "summary, summary_subclaims, and fulltext_subclaims keyed by doc_id.",
360
+ )
361
+ parser.add_argument(
362
+ "--model-key", default="qwen3_base",
363
+ help="Key in the BON data containing candidates/best_summary (default: qwen3_base).",
364
+ )
365
+ parser.add_argument(
366
+ "--target-level", "-t", default=None,
367
+ help="Override target literacy level for all items "
368
+ "(e.g. low_health_literacy). If not set, uses each item's 'label' field.",
369
+ )
370
+ parser.add_argument(
371
+ "--support-check-url", default=None,
372
+ help="Override VLLM_SUPPORT_CHECK_BN_API_BASE.",
373
+ )
374
+ parser.add_argument(
375
+ "--classifier-url", default=None,
376
+ help="Override VLLM_CLASSIFIER_BN_API_BASE.",
377
+ )
378
+ parser.add_argument(
379
+ "--subclaim-extractor-url", default=None,
380
+ help="Override VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE.",
381
+ )
382
+ args = parser.parse_args()
383
+
384
+ if args.support_check_url:
385
+ os.environ["VLLM_SUPPORT_CHECK_BN_API_BASE"] = args.support_check_url
386
+ if args.classifier_url:
387
+ os.environ["VLLM_CLASSIFIER_BN_API_BASE"] = args.classifier_url
388
+ if args.subclaim_extractor_url:
389
+ os.environ["VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE"] = args.subclaim_extractor_url
390
+
391
+ # Load input (JSON list or JSONL)
392
+ if args.input.endswith(".jsonl"):
393
+ raw_data = []
394
+ with open(args.input, "r", encoding="utf-8") as f:
395
+ for line in f:
396
+ line = line.strip()
397
+ if line:
398
+ raw_data.append(json.loads(line))
399
+ else:
400
+ with open(args.input, "r", encoding="utf-8") as f:
401
+ raw_data = json.load(f)
402
+
403
+ if not isinstance(raw_data, list):
404
+ print(f"Error: Expected a JSON list, got {type(raw_data).__name__}.", file=sys.stderr)
405
+ sys.exit(1)
406
+
407
+ # Detect BON format: items have a model key (e.g. qwen3_base) with best_summary
408
+ is_bon = (
409
+ len(raw_data) > 0
410
+ and args.model_key in raw_data[0]
411
+ and "best_summary" in raw_data[0].get(args.model_key, {})
412
+ )
413
+
414
+ # Detect inference format: top-level doc_id, label, predicted_gen_text; no fulltext/summary_subclaims
415
+ is_inference = (
416
+ len(raw_data) > 0
417
+ and "doc_id" in raw_data[0]
418
+ and "label" in raw_data[0]
419
+ and "predicted_gen_text" in raw_data[0]
420
+ and raw_data[0].get("fulltext") is None
421
+ and raw_data[0].get("summary_subclaims") is None
422
+ )
423
+
424
+ # Detect self-refine format: doc_id, label, final_summary as gen text; no fulltext/summary_subclaims
425
+ is_self_refine = (
426
+ len(raw_data) > 0
427
+ and "doc_id" in raw_data[0]
428
+ and "label" in raw_data[0]
429
+ and "final_summary" in raw_data[0]
430
+ and raw_data[0].get("fulltext") is None
431
+ and raw_data[0].get("summary_subclaims") is None
432
+ )
433
+
434
+ # Detect RL inference format: gold_label, input_text, subclaims, generated_text
435
+ is_rl_inference = (
436
+ len(raw_data) > 0
437
+ and "doc_id" in raw_data[0]
438
+ and "gold_label" in raw_data[0]
439
+ and "input_text" in raw_data[0]
440
+ and "generated_text" in raw_data[0]
441
+ and "subclaims" in raw_data[0]
442
+ )
443
+
444
+ if is_rl_inference:
445
+ print("RL inference format detected (gold_label, input_text, subclaims, generated_text)")
446
+ print(f"Loaded {len(raw_data)} RL inference items from {args.input}")
447
+ data = prepare_rl_inference_items(raw_data)
448
+ print(f"Prepared {len(data)} items for evaluation")
449
+ elif is_bon:
450
+ if not args.subclaims:
451
+ print("Error: BON format detected but --subclaims file not provided.", file=sys.stderr)
452
+ sys.exit(1)
453
+ with open(args.subclaims, "r", encoding="utf-8") as f:
454
+ subclaims_data = json.load(f)
455
+ print(f"BON format detected (model_key={args.model_key})")
456
+ print(f"Loaded {len(raw_data)} BON items from {args.input}")
457
+ print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}")
458
+ data = prepare_bon_items(raw_data, subclaims_data, model_key=args.model_key)
459
+ print(f"Prepared {len(data)} items for evaluation")
460
+ elif is_inference:
461
+ if not args.subclaims:
462
+ print("Error: Inference format detected (predicted_gen_text) but --subclaims file not provided.", file=sys.stderr)
463
+ sys.exit(1)
464
+ with open(args.subclaims, "r", encoding="utf-8") as f:
465
+ subclaims_data = json.load(f)
466
+ print("Inference format detected (predicted_gen_text as evaluated summary)")
467
+ print(f"Loaded {len(raw_data)} inference items from {args.input}")
468
+ print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}")
469
+ data = prepare_inference_items(raw_data, subclaims_data)
470
+ print(f"Prepared {len(data)} items for evaluation")
471
+ elif is_self_refine:
472
+ if not args.subclaims:
473
+ print("Error: Self-refine format detected (final_summary) but --subclaims file not provided.", file=sys.stderr)
474
+ sys.exit(1)
475
+ with open(args.subclaims, "r", encoding="utf-8") as f:
476
+ subclaims_data = json.load(f)
477
+ print("Self-refine format detected (final_summary as evaluated summary)")
478
+ print(f"Loaded {len(raw_data)} self-refine items from {args.input}")
479
+ print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}")
480
+ data = prepare_self_refine_items(raw_data, subclaims_data)
481
+ print(f"Prepared {len(data)} items for evaluation")
482
+ else:
483
+ data = raw_data
484
+ print(f"Loaded {len(data)} items from {args.input}")
485
+
486
+ print(f" Support check API : {os.getenv('VLLM_SUPPORT_CHECK_BN_API_BASE', 'http://localhost:8090/v1')}")
487
+ print(f" Classifier API : {os.getenv('VLLM_CLASSIFIER_BN_API_BASE', 'http://localhost:8040/v1')}")
488
+ print(f" Subclaim extractor: {os.getenv('VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE', 'http://localhost:8050/v1')}")
489
+ if args.target_level:
490
+ print(f" Target level override: {args.target_level}")
491
+ print("-" * 60)
492
+
493
+ # Evaluate each item
494
+ results = []
495
+ start_time = time.time()
496
+ for idx, item in enumerate(tqdm(data, desc="Evaluating")):
497
+ r = evaluate_single(item, target_level_override=args.target_level)
498
+ r["index"] = idx
499
+ r["doc_id"] = item.get("doc_id", "")
500
+ results.append(r)
501
+
502
+ if (idx + 1) % 10 == 0 or idx == 0:
503
+ partial_agg = compute_aggregate(results)
504
+ tqdm.write(
505
+ f" [{idx+1}/{len(data)}] "
506
+ f"fact={partial_agg['avg_factuality_score']} "
507
+ f"hallu={partial_agg['avg_hallucination_score']} "
508
+ f"cls={partial_agg['avg_classifier_score']}"
509
+ )
510
+
511
+ elapsed = time.time() - start_time
512
+
513
+ # --- Validation: all items must be evaluated with non-null scores ---
514
+ expected_count = len(data)
515
+ skipped_items = [r for r in results if r.get("skipped", False)]
516
+ null_score_items = []
517
+ for r in results:
518
+ if r.get("skipped", False):
519
+ continue
520
+ for key in ("factuality_score", "hallucination_score", "classifier_score", "grounding_score"):
521
+ if r.get(key) is None:
522
+ null_score_items.append((r.get("index"), r.get("doc_id"), key))
523
+
524
+ has_errors = False
525
+ if skipped_items:
526
+ has_errors = True
527
+ print(f"\nERROR: {len(skipped_items)} out of {expected_count} items were skipped:", file=sys.stderr)
528
+ for r in skipped_items:
529
+ print(f" index={r.get('index')} doc_id={r.get('doc_id')} reason={r.get('skip_reason')}", file=sys.stderr)
530
+
531
+ if null_score_items:
532
+ has_errors = True
533
+ print(f"\nERROR: {len(null_score_items)} null score(s) found:", file=sys.stderr)
534
+ for idx, doc_id, key in null_score_items:
535
+ print(f" index={idx} doc_id={doc_id} null_field={key}", file=sys.stderr)
536
+
537
+ if len(results) != expected_count:
538
+ has_errors = True
539
+ print(f"\nERROR: Expected {expected_count} results but got {len(results)}.", file=sys.stderr)
540
+
541
+ if has_errors:
542
+ print(f"\nAborting: will NOT save results. All {expected_count} items must be fully evaluated with non-null scores.", file=sys.stderr)
543
+ sys.exit(1)
544
+
545
+ # Aggregate
546
+ agg = compute_aggregate(results)
547
+
548
+ # Per-label aggregates
549
+ label_groups: Dict[str, List[Dict[str, Any]]] = {}
550
+ for r in results:
551
+ lbl = r.get("target_level", "unknown")
552
+ label_groups.setdefault(lbl, []).append(r)
553
+ per_label_agg = {lbl: compute_aggregate(items) for lbl, items in sorted(label_groups.items())}
554
+
555
+ # Output path
556
+ if args.output:
557
+ out_path = args.output
558
+ elif args.output_dir:
559
+ os.makedirs(args.output_dir, exist_ok=True)
560
+ stem = os.path.splitext(os.path.basename(args.input))[0]
561
+ out_path = os.path.join(args.output_dir, f"{stem}_eval_results.json")
562
+ else:
563
+ stem = os.path.splitext(os.path.basename(args.input))[0]
564
+ out_dir = os.path.dirname(args.input) or "."
565
+ out_path = os.path.join(out_dir, f"{stem}_eval_results.json")
566
+
567
+ os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
568
+
569
+ output = {
570
+ "input_file": os.path.abspath(args.input),
571
+ "subclaims_file": os.path.abspath(args.subclaims) if args.subclaims else None,
572
+ "model_key": args.model_key if is_bon else None,
573
+ "inference_format": is_inference if not is_bon else False,
574
+ "self_refine_format": is_self_refine if not is_bon and not is_inference else False,
575
+ "rl_inference_format": is_rl_inference,
576
+ "target_level_override": args.target_level,
577
+ "elapsed_seconds": round(elapsed, 2),
578
+ "aggregate": agg,
579
+ "per_label_aggregate": per_label_agg,
580
+ "per_item": results,
581
+ }
582
+
583
+ with open(out_path, "w", encoding="utf-8") as f:
584
+ json.dump(output, f, indent=2, ensure_ascii=False)
585
+
586
+ # Print summary
587
+ print("\n" + "=" * 60)
588
+ print("EVALUATION SUMMARY")
589
+ print("=" * 60)
590
+ print(f" Total items : {agg['total_items']}")
591
+ print(f" Scored items : {agg['scored_items']}")
592
+ print(f" Skipped items : {agg['skipped_items']}")
593
+ print(f" Elapsed time : {round(elapsed, 1)}s")
594
+ print("-" * 60)
595
+ print(f" Avg Factuality Score : {agg['avg_factuality_score']}")
596
+ print(f" Avg Hallucination Score: {agg['avg_hallucination_score']}")
597
+ print(f" Avg Grounding Score : {agg['avg_grounding_score']}")
598
+ print(f" Avg Classifier Score : {agg['avg_classifier_score']}")
599
+ print("-" * 60)
600
+ for lbl, la in per_label_agg.items():
601
+ print(f" [{lbl}] items={la['scored_items']}"
602
+ f" fact={la['avg_factuality_score']}"
603
+ f" hallu={la['avg_hallucination_score']}"
604
+ f" cls={la['avg_classifier_score']}")
605
+ print("-" * 60)
606
+ print(f" Results saved to: {out_path}")
607
+ print("=" * 60)
608
+
609
+
610
+ if __name__ == "__main__":
611
+ main()
code/fine_tune_sft_dpo/evaluate_scores_bn_vllm_rl.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Standalone evaluation script for computing factuality, hallucination, and
4
+ classifier scores on a JSON file.
5
+
6
+ Supports two input formats:
7
+
8
+ 1. **Standard format** — a list of objects, each with:
9
+ - fulltext, summary_text, summary_subclaims, generated_text, label
10
+
11
+ 2. **Best-of-N (BON) format** — a list of objects, each with:
12
+ - doc_id, label, qwen3_base.best_summary (JSON-wrapped generated text)
13
+ Requires a separate --subclaims file to supply fulltext, summary,
14
+ summary_subclaims, and fulltext_subclaims (keyed by doc_id).
15
+
16
+ 3. **Inference format** — a list of objects, each with:
17
+ - doc_id, label, predicted_gen_text (JSON-wrapped evaluated summary),
18
+ optionally gold_gen_text
19
+ predicted_gen_text is the summary to evaluate (same JSON key-by-label
20
+ format as best_summary). Requires --subclaims for fulltext and subclaims.
21
+
22
+ 4. **Self-refine format** — a list of objects, each with:
23
+ - doc_id, label, final_summary (the generated text to evaluate),
24
+ optionally gold_gen_text, gold_summary
25
+ final_summary is the summary to evaluate (plain text or JSON-wrapped by
26
+ label). Requires --subclaims for fulltext and subclaims.
27
+
28
+ The script reuses the reward functions from reward_new_v6_bn_v4_rmv_src_cov.py:
29
+ - factuality_score : fraction of summary subclaims supported by generated_text
30
+ - hallucination_score: fraction of gen subclaims NOT supported by fulltext
31
+ - classifier_score : whether generated_text matches the target literacy level
32
+
33
+ Requires the same vLLM endpoints as the reward file:
34
+ - Support checker : VLLM_SUPPORT_CHECK_BN_API_BASE (default http://localhost:8090/v1)
35
+ - Classifier : VLLM_CLASSIFIER_BN_API_BASE (default http://localhost:8040/v1)
36
+ - Subclaim extractor: VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE (default http://localhost:8050/v1)
37
+
38
+ Usage:
39
+ # Standard format
40
+ python evaluate_scores.py --input data.json [--output results.json]
41
+
42
+ # BON format with subclaims file
43
+ python evaluate_scores.py --input bon_results.json --subclaims subclaims.json --output-dir evaluation/bn/
44
+
45
+ # Inference format (predicted_gen_text as evaluated summary)
46
+ python evaluate_scores.py --input test_inference_vllm_qwen3-4B_base.json --subclaims subclaims.json --output results.json
47
+
48
+ # Self-refine format (final_summary as evaluated summary)
49
+ python evaluate_scores.py --input test_self_refine_vllm_qwen3_4B_base.json --subclaims subclaims.json --output-dir evaluation/bn/
50
+ """
51
+
52
+ import argparse
53
+ import json
54
+ import os
55
+ import re
56
+ import sys
57
+ import time
58
+ from typing import Any, Dict, List, Optional
59
+
60
+ from tqdm import tqdm
61
+
62
+ # Import scoring utilities from the reward module (same directory).
63
+ from reward_new_v6_bn_v4_rmv_src_cov import (
64
+ _call_support_api,
65
+ _compute_classifier_reward,
66
+ _extract_subclaims_from_text,
67
+ _is_bangla_text,
68
+ _nonlinear_grounding,
69
+ compute_rewards,
70
+ )
71
+
72
+
73
+ def extract_text_from_best_summary(best_summary: str, label: str) -> str:
74
+ """Extract the raw generated text from a BON best_summary string.
75
+
76
+ The best_summary is a (possibly truncated) JSON string like:
77
+ '{"proficient_health_literacy": "...actual text..."}'
78
+ We locate the value after the label key and strip JSON wrapping.
79
+ """
80
+ key_pattern = re.compile(re.escape(f'"{label}"') + r'\s*:\s*"')
81
+ m = key_pattern.search(best_summary)
82
+ if not m:
83
+ return best_summary.strip()
84
+ text = best_summary[m.end():]
85
+ if text.endswith('"\n}'):
86
+ text = text[:-3]
87
+ elif text.endswith('"}\n'):
88
+ text = text[:-3]
89
+ elif text.endswith('"}'):
90
+ text = text[:-2]
91
+ elif text.endswith('"'):
92
+ text = text[:-1]
93
+ text = text.replace("\\n", "\n").replace('\\"', '"')
94
+ return text.strip()
95
+
96
+
97
+ def prepare_bon_items(
98
+ bon_data: List[Dict[str, Any]],
99
+ subclaims_data: List[Dict[str, Any]],
100
+ model_key: str = "qwen3_base",
101
+ ) -> List[Dict[str, Any]]:
102
+ """Merge BON results with subclaims data into the standard evaluation format."""
103
+ sc_by_docid = {}
104
+ for item in subclaims_data:
105
+ sc_by_docid[item["doc_id"]] = item
106
+
107
+ prepared = []
108
+ for item in bon_data:
109
+ doc_id = item["doc_id"]
110
+ label = item["label"]
111
+ sc = sc_by_docid.get(doc_id, {})
112
+
113
+ model_data = item.get(model_key, {})
114
+ best_summary = model_data.get("best_summary", "") or model_data.get("predicted_gen_text", "")
115
+ generated_text = extract_text_from_best_summary(best_summary, label)
116
+
117
+ prepared.append({
118
+ "doc_id": doc_id,
119
+ "label": label,
120
+ "fulltext": sc.get("fulltext", ""),
121
+ "summary_text": sc.get("summary", ""),
122
+ "summary_subclaims": sc.get("summary_subclaims", []),
123
+ "fulltext_subclaims": sc.get("fulltext_subclaims", []),
124
+ "generated_text": generated_text,
125
+ "gold_gen_text": item.get("gold_gen_text", ""),
126
+ "predicted_label": item.get("predicted_label", ""),
127
+ "prediction_correct": item.get("prediction_correct", False),
128
+ })
129
+ return prepared
130
+
131
+
132
+ def prepare_inference_items(
133
+ inference_data: List[Dict[str, Any]],
134
+ subclaims_data: List[Dict[str, Any]],
135
+ ) -> List[Dict[str, Any]]:
136
+ """Merge inference-format results (doc_id, label, predicted_gen_text) with
137
+ subclaims data into the standard evaluation format. predicted_gen_text is
138
+ the JSON-wrapped evaluated summary; the raw text is extracted using the
139
+ item's label.
140
+ """
141
+ sc_by_docid = {}
142
+ for item in subclaims_data:
143
+ sc_by_docid[item["doc_id"]] = item
144
+
145
+ prepared = []
146
+ for item in inference_data:
147
+ doc_id = item["doc_id"]
148
+ label = item["label"]
149
+ sc = sc_by_docid.get(doc_id, {})
150
+
151
+ raw_pred = item.get("predicted_gen_text", "") or "" or item.get("generated_text", "")
152
+ generated_text = extract_text_from_best_summary(raw_pred, label)
153
+
154
+ prepared.append({
155
+ "doc_id": doc_id,
156
+ "label": label,
157
+ "fulltext": sc.get("fulltext", ""),
158
+ "summary_text": sc.get("summary", ""),
159
+ "summary_subclaims": sc.get("summary_subclaims", []),
160
+ "fulltext_subclaims": sc.get("fulltext_subclaims", []),
161
+ "generated_text": generated_text,
162
+ "gold_gen_text": item.get("gold_gen_text", ""),
163
+ })
164
+ return prepared
165
+
166
+
167
+ def prepare_self_refine_items(
168
+ self_refine_data: List[Dict[str, Any]],
169
+ subclaims_data: List[Dict[str, Any]],
170
+ ) -> List[Dict[str, Any]]:
171
+ """Merge self-refine format (doc_id, label, final_summary) with subclaims
172
+ data. final_summary is the generated text to evaluate (plain text or
173
+ JSON-wrapped by label); it is extracted and used as generated_text.
174
+ """
175
+ sc_by_docid = {}
176
+ for item in subclaims_data:
177
+ sc_by_docid[item["doc_id"]] = item
178
+
179
+ prepared = []
180
+ for item in self_refine_data:
181
+ doc_id = item["doc_id"]
182
+ label = item["label"]
183
+ sc = sc_by_docid.get(doc_id, {})
184
+
185
+ raw_final = item.get("final_summary", "") or ""
186
+ generated_text = extract_text_from_best_summary(raw_final, label)
187
+
188
+ prepared.append({
189
+ "doc_id": doc_id,
190
+ "label": label,
191
+ "fulltext": sc.get("fulltext", ""),
192
+ "summary_text": sc.get("summary", ""),
193
+ "summary_subclaims": sc.get("summary_subclaims", []),
194
+ "fulltext_subclaims": sc.get("fulltext_subclaims", []),
195
+ "generated_text": generated_text,
196
+ "gold_gen_text": item.get("gold_gen_text", ""),
197
+ })
198
+ return prepared
199
+
200
+
201
+ def evaluate_single(
202
+ item: Dict[str, Any],
203
+ target_level_override: Optional[str] = None,
204
+ ) -> Dict[str, Any]:
205
+ """
206
+ Evaluate a single item and return detailed scores.
207
+ """
208
+ fulltext = item.get("fulltext", "")
209
+ summary_text = item.get("summary_text") or item.get("summary", "")
210
+ summary_subclaims = item.get("summary_subclaims", [])
211
+ generated_text = item.get("generated_text") or item.get("predicted_gen_text", "")
212
+ target_level = target_level_override or item.get("label", "")
213
+
214
+ result: Dict[str, Any] = {
215
+ "doc_id": item.get("doc_id", ""),
216
+ "target_level": target_level,
217
+ "generated_text_len": len(generated_text.strip()) if generated_text else 0,
218
+ "factuality_score": None,
219
+ "hallucination_score": None,
220
+ "classifier_score": None,
221
+ "grounding_score": None,
222
+ "factuality_supported": 0,
223
+ "total_summary_subclaims": len(summary_subclaims) if summary_subclaims else 0,
224
+ "hallucination_supported": 0,
225
+ "total_gen_segments": 0,
226
+ "skipped": False,
227
+ "skip_reason": "",
228
+ }
229
+
230
+ if not generated_text or len(generated_text.strip()) < 10:
231
+ result["skipped"] = True
232
+ result["skip_reason"] = "generated_text missing or too short (<10 chars)"
233
+ return result
234
+
235
+ # -- Factuality & Hallucination via compute_rewards --
236
+ rewards = compute_rewards(
237
+ fulltext=fulltext,
238
+ generated_text=generated_text,
239
+ target_level=target_level,
240
+ summary_subclaims=summary_subclaims,
241
+ summary_text=summary_text,
242
+ )
243
+
244
+ factuality_score = rewards["factuality_score"]
245
+ h_score = rewards["hallucination_score"]
246
+
247
+ if factuality_score is None:
248
+ factuality_score = 0.5
249
+ if h_score is None:
250
+ h_score = 0.5
251
+
252
+ grounding_score = _nonlinear_grounding(h_score)
253
+
254
+ # -- Classifier --
255
+ input_text = fulltext or ""
256
+ class_score = _compute_classifier_reward(target_level, generated_text, input_text)
257
+
258
+ result.update({
259
+ "factuality_score": round(factuality_score, 4),
260
+ "hallucination_score": round(h_score, 4),
261
+ "grounding_score": round(grounding_score, 4),
262
+ "classifier_score": round(class_score, 4),
263
+ "factuality_supported": rewards.get("factuality_supported", 0),
264
+ "total_summary_subclaims": rewards.get("total_summary_subclaims", 0),
265
+ "hallucination_supported": rewards.get("hallucination_supported", 0),
266
+ "total_gen_segments": rewards.get("total_gen_segments", 0),
267
+ })
268
+
269
+ return result
270
+
271
+
272
+ def compute_aggregate(results: List[Dict[str, Any]]) -> Dict[str, Any]:
273
+ """Compute aggregate statistics over all evaluated items."""
274
+ scored = [r for r in results if not r.get("skipped", False)]
275
+ n = len(scored)
276
+ total = len(results)
277
+ skipped = total - n
278
+
279
+ if n == 0:
280
+ return {
281
+ "total_items": total,
282
+ "scored_items": 0,
283
+ "skipped_items": skipped,
284
+ "avg_factuality_score": None,
285
+ "avg_hallucination_score": None,
286
+ "avg_grounding_score": None,
287
+ "avg_classifier_score": None,
288
+ }
289
+
290
+ def safe_avg(key):
291
+ vals = [r[key] for r in scored if r[key] is not None]
292
+ return round(sum(vals) / len(vals), 4) if vals else None
293
+
294
+ return {
295
+ "total_items": total,
296
+ "scored_items": n,
297
+ "skipped_items": skipped,
298
+ "avg_factuality_score": safe_avg("factuality_score"),
299
+ "avg_hallucination_score": safe_avg("hallucination_score"),
300
+ "avg_grounding_score": safe_avg("grounding_score"),
301
+ "avg_classifier_score": safe_avg("classifier_score"),
302
+ }
303
+
304
+
305
+ def main():
306
+ parser = argparse.ArgumentParser(
307
+ description="Evaluate factuality, hallucination, and classifier scores on a JSON file."
308
+ )
309
+ parser.add_argument(
310
+ "--input", "-i", required=True,
311
+ help="Path to input JSON file (list of objects).",
312
+ )
313
+ parser.add_argument(
314
+ "--output", "-o", default=None,
315
+ help="Path to output JSON file with per-item scores. "
316
+ "Defaults to <input_stem>_eval_results.json.",
317
+ )
318
+ parser.add_argument(
319
+ "--output-dir", default=None,
320
+ help="Directory to save output files. If set, output filename is derived "
321
+ "from input filename and placed in this directory.",
322
+ )
323
+ parser.add_argument(
324
+ "--subclaims", "-s", default=None,
325
+ help="Path to subclaims JSON file (for BON format). Provides fulltext, "
326
+ "summary, summary_subclaims, and fulltext_subclaims keyed by doc_id.",
327
+ )
328
+ parser.add_argument(
329
+ "--model-key", default="qwen3_base",
330
+ help="Key in the BON data containing candidates/best_summary (default: qwen3_base).",
331
+ )
332
+ parser.add_argument(
333
+ "--target-level", "-t", default=None,
334
+ help="Override target literacy level for all items "
335
+ "(e.g. low_health_literacy). If not set, uses each item's 'label' field.",
336
+ )
337
+ parser.add_argument(
338
+ "--support-check-url", default=None,
339
+ help="Override VLLM_SUPPORT_CHECK_BN_API_BASE.",
340
+ )
341
+ parser.add_argument(
342
+ "--classifier-url", default=None,
343
+ help="Override VLLM_CLASSIFIER_BN_API_BASE.",
344
+ )
345
+ parser.add_argument(
346
+ "--subclaim-extractor-url", default=None,
347
+ help="Override VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE.",
348
+ )
349
+ args = parser.parse_args()
350
+
351
+ if args.support_check_url:
352
+ os.environ["VLLM_SUPPORT_CHECK_BN_API_BASE"] = args.support_check_url
353
+ if args.classifier_url:
354
+ os.environ["VLLM_CLASSIFIER_BN_API_BASE"] = args.classifier_url
355
+ if args.subclaim_extractor_url:
356
+ os.environ["VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE"] = args.subclaim_extractor_url
357
+
358
+ # Load input (supports both JSON array and JSONL)
359
+ with open(args.input, "r", encoding="utf-8") as f:
360
+ content = f.read().strip()
361
+ if content.startswith("["):
362
+ raw_data = json.loads(content)
363
+ else:
364
+ raw_data = [json.loads(line) for line in content.splitlines() if line.strip()]
365
+
366
+ if not isinstance(raw_data, list):
367
+ print(f"Error: Expected a JSON list, got {type(raw_data).__name__}.", file=sys.stderr)
368
+ sys.exit(1)
369
+
370
+ # Normalise field names from RL-inference JSONL format
371
+ for item in raw_data:
372
+ if "label" not in item and "gold_label" in item:
373
+ item["label"] = item["gold_label"]
374
+ if "fulltext" not in item and "input_text" in item:
375
+ item["fulltext"] = item["input_text"]
376
+ if "summary_subclaims" not in item and "subclaims" in item:
377
+ item["summary_subclaims"] = item["subclaims"]
378
+
379
+ # Detect BON format: items have a model key (e.g. qwen3_base) with best_summary
380
+ is_bon = (
381
+ len(raw_data) > 0
382
+ and args.model_key in raw_data[0]
383
+ and "best_summary" in raw_data[0].get(args.model_key, {})
384
+ )
385
+
386
+ # Detect inference format: top-level doc_id, label, predicted_gen_text; no fulltext/summary_subclaims
387
+ is_inference = (
388
+ len(raw_data) > 0
389
+ and "doc_id" in raw_data[0]
390
+ and "label" in raw_data[0]
391
+ and "predicted_gen_text" in raw_data[0]
392
+ and raw_data[0].get("fulltext") is None
393
+ and raw_data[0].get("summary_subclaims") is None
394
+ )
395
+
396
+ # Detect self-refine format: doc_id, label, final_summary as gen text; no fulltext/summary_subclaims
397
+ is_self_refine = (
398
+ len(raw_data) > 0
399
+ and "doc_id" in raw_data[0]
400
+ and "label" in raw_data[0]
401
+ and "final_summary" in raw_data[0]
402
+ and raw_data[0].get("fulltext") is None
403
+ and raw_data[0].get("summary_subclaims") is None
404
+ )
405
+
406
+ if is_bon:
407
+ if not args.subclaims:
408
+ print("Error: BON format detected but --subclaims file not provided.", file=sys.stderr)
409
+ sys.exit(1)
410
+ with open(args.subclaims, "r", encoding="utf-8") as f:
411
+ subclaims_data = json.load(f)
412
+ print(f"BON format detected (model_key={args.model_key})")
413
+ print(f"Loaded {len(raw_data)} BON items from {args.input}")
414
+ print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}")
415
+ data = prepare_bon_items(raw_data, subclaims_data, model_key=args.model_key)
416
+ print(f"Prepared {len(data)} items for evaluation")
417
+ elif is_inference:
418
+ if not args.subclaims:
419
+ print("Error: Inference format detected (predicted_gen_text) but --subclaims file not provided.", file=sys.stderr)
420
+ sys.exit(1)
421
+ with open(args.subclaims, "r", encoding="utf-8") as f:
422
+ subclaims_data = json.load(f)
423
+ print("Inference format detected (predicted_gen_text as evaluated summary)")
424
+ print(f"Loaded {len(raw_data)} inference items from {args.input}")
425
+ print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}")
426
+ data = prepare_inference_items(raw_data, subclaims_data)
427
+ print(f"Prepared {len(data)} items for evaluation")
428
+ elif is_self_refine:
429
+ if not args.subclaims:
430
+ print("Error: Self-refine format detected (final_summary) but --subclaims file not provided.", file=sys.stderr)
431
+ sys.exit(1)
432
+ with open(args.subclaims, "r", encoding="utf-8") as f:
433
+ subclaims_data = json.load(f)
434
+ print("Self-refine format detected (final_summary as evaluated summary)")
435
+ print(f"Loaded {len(raw_data)} self-refine items from {args.input}")
436
+ print(f"Loaded {len(subclaims_data)} subclaims entries from {args.subclaims}")
437
+ data = prepare_self_refine_items(raw_data, subclaims_data)
438
+ print(f"Prepared {len(data)} items for evaluation")
439
+ else:
440
+ data = raw_data
441
+ print(f"Loaded {len(data)} items from {args.input}")
442
+
443
+ print(f" Support check API : {os.getenv('VLLM_SUPPORT_CHECK_BN_API_BASE', 'http://localhost:8090/v1')}")
444
+ print(f" Classifier API : {os.getenv('VLLM_CLASSIFIER_BN_API_BASE', 'http://localhost:8040/v1')}")
445
+ print(f" Subclaim extractor: {os.getenv('VLLM_SUBCLAIM_EXTRACTOR_BN_API_BASE', 'http://localhost:8050/v1')}")
446
+ if args.target_level:
447
+ print(f" Target level override: {args.target_level}")
448
+ print("-" * 60)
449
+
450
+ # Evaluate each item
451
+ results = []
452
+ start_time = time.time()
453
+ for idx, item in enumerate(tqdm(data, desc="Evaluating")):
454
+ r = evaluate_single(item, target_level_override=args.target_level)
455
+ r["index"] = idx
456
+ r["doc_id"] = item.get("doc_id", "")
457
+ results.append(r)
458
+
459
+ if (idx + 1) % 10 == 0 or idx == 0:
460
+ partial_agg = compute_aggregate(results)
461
+ tqdm.write(
462
+ f" [{idx+1}/{len(data)}] "
463
+ f"fact={partial_agg['avg_factuality_score']} "
464
+ f"hallu={partial_agg['avg_hallucination_score']} "
465
+ f"cls={partial_agg['avg_classifier_score']}"
466
+ )
467
+
468
+ elapsed = time.time() - start_time
469
+
470
+ # --- Validation: all items must be evaluated with non-null scores ---
471
+ expected_count = len(data)
472
+ skipped_items = [r for r in results if r.get("skipped", False)]
473
+ null_score_items = []
474
+ for r in results:
475
+ if r.get("skipped", False):
476
+ continue
477
+ for key in ("factuality_score", "hallucination_score", "classifier_score", "grounding_score"):
478
+ if r.get(key) is None:
479
+ null_score_items.append((r.get("index"), r.get("doc_id"), key))
480
+
481
+ has_errors = False
482
+ if skipped_items:
483
+ has_errors = True
484
+ print(f"\nERROR: {len(skipped_items)} out of {expected_count} items were skipped:", file=sys.stderr)
485
+ for r in skipped_items:
486
+ print(f" index={r.get('index')} doc_id={r.get('doc_id')} reason={r.get('skip_reason')}", file=sys.stderr)
487
+
488
+ if null_score_items:
489
+ has_errors = True
490
+ print(f"\nERROR: {len(null_score_items)} null score(s) found:", file=sys.stderr)
491
+ for idx, doc_id, key in null_score_items:
492
+ print(f" index={idx} doc_id={doc_id} null_field={key}", file=sys.stderr)
493
+
494
+ if len(results) != expected_count:
495
+ has_errors = True
496
+ print(f"\nERROR: Expected {expected_count} results but got {len(results)}.", file=sys.stderr)
497
+
498
+ if has_errors:
499
+ print(f"\nAborting: will NOT save results. All {expected_count} items must be fully evaluated with non-null scores.", file=sys.stderr)
500
+ sys.exit(1)
501
+
502
+ # Aggregate
503
+ agg = compute_aggregate(results)
504
+
505
+ # Per-label aggregates
506
+ label_groups: Dict[str, List[Dict[str, Any]]] = {}
507
+ for r in results:
508
+ lbl = r.get("target_level", "unknown")
509
+ label_groups.setdefault(lbl, []).append(r)
510
+ per_label_agg = {lbl: compute_aggregate(items) for lbl, items in sorted(label_groups.items())}
511
+
512
+ # Output path
513
+ if args.output:
514
+ out_path = args.output
515
+ elif args.output_dir:
516
+ os.makedirs(args.output_dir, exist_ok=True)
517
+ stem = os.path.splitext(os.path.basename(args.input))[0]
518
+ out_path = os.path.join(args.output_dir, f"{stem}_eval_results.json")
519
+ else:
520
+ stem = os.path.splitext(os.path.basename(args.input))[0]
521
+ out_dir = os.path.dirname(args.input) or "."
522
+ out_path = os.path.join(out_dir, f"{stem}_eval_results.json")
523
+
524
+ os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
525
+
526
+ output = {
527
+ "input_file": os.path.abspath(args.input),
528
+ "subclaims_file": os.path.abspath(args.subclaims) if args.subclaims else None,
529
+ "model_key": args.model_key if is_bon else None,
530
+ "inference_format": is_inference if not is_bon else False,
531
+ "self_refine_format": is_self_refine if not is_bon and not is_inference else False,
532
+ "target_level_override": args.target_level,
533
+ "elapsed_seconds": round(elapsed, 2),
534
+ "aggregate": agg,
535
+ "per_label_aggregate": per_label_agg,
536
+ "per_item": results,
537
+ }
538
+
539
+ with open(out_path, "w", encoding="utf-8") as f:
540
+ json.dump(output, f, indent=2, ensure_ascii=False)
541
+
542
+ # Print summary
543
+ print("\n" + "=" * 60)
544
+ print("EVALUATION SUMMARY")
545
+ print("=" * 60)
546
+ print(f" Total items : {agg['total_items']}")
547
+ print(f" Scored items : {agg['scored_items']}")
548
+ print(f" Skipped items : {agg['skipped_items']}")
549
+ print(f" Elapsed time : {round(elapsed, 1)}s")
550
+ print("-" * 60)
551
+ print(f" Avg Factuality Score : {agg['avg_factuality_score']}")
552
+ print(f" Avg Hallucination Score: {agg['avg_hallucination_score']}")
553
+ print(f" Avg Grounding Score : {agg['avg_grounding_score']}")
554
+ print(f" Avg Classifier Score : {agg['avg_classifier_score']}")
555
+ print("-" * 60)
556
+ for lbl, la in per_label_agg.items():
557
+ print(f" [{lbl}] items={la['scored_items']}"
558
+ f" fact={la['avg_factuality_score']}"
559
+ f" hallu={la['avg_hallucination_score']}"
560
+ f" cls={la['avg_classifier_score']}")
561
+ print("-" * 60)
562
+ print(f" Results saved to: {out_path}")
563
+ print("=" * 60)
564
+
565
+
566
+ if __name__ == "__main__":
567
+ main()
code/fine_tune_sft_dpo/evaluation/bn/bn_200_eval_results.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0cb20710f7c4b519b77457d541d9a132ded57b6a9252bc5552788cf358e9e436
3
+ size 96048
code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_mini_eval_results_20260316_071029.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18c09e4d51cd941941cbc3c585699e3f7e2e98b051a5a26396bc242effa5438a
3
+ size 97818
code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_mini_prepared_20260316_071029.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97ac7c77d004ac04c4b06a848b28c43b9cdcd4edc8d1c97f757ec11769445aae
3
+ size 6569807
code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_nano_eval_results_20260316_071029.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4cf921a153e1d69f2671c93cfc4a1d368cbcc3d3db0c5b30b706bd354402cb44
3
+ size 97656
code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_nano_prepared_20260316_071029.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6638c5e9c0feb85cf471fac04021473211fd645cc250b5e42a39483e3a6e1fd
3
+ size 6114304
code/fine_tune_sft_dpo/evaluation/bn/eval_gpt5_models/gpt_5_prepared_20260316_071029.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:afefe278ca421f9eb698e638b2fcca1c0814525ac5a9a4bf576c0d3294290033
3
+ size 6349574
code/fine_tune_sft_dpo/evaluation/bn/test_best_of_n_qwen3-4B_base_eval_results.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85d23851eb74b15da0c466c3a11ffd43ae371db75523ee9e85a0cec6f1cf6b5e
3
+ size 97215
code/fine_tune_sft_dpo/evaluation/bn/test_best_of_n_qwen3-4B_sft_eval_results.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f89d77f48683f42239d9bcdcbd469810f3f11d3d06d8e06d3442e69e34729535
3
+ size 96710
code/fine_tune_sft_dpo/evaluation/bn/test_inference_vllm_qwen3-4B_base_eval_results.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2eb135210f5dc0f92170e415936ad34927979be1efdfb0bb6f629374c42b79e4
3
+ size 97550
code/fine_tune_sft_dpo/evaluation/bn/test_inference_vllm_qwen3-4B_sft_eval_results.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b54519f530b0894294b4ec86b0af317f521bc9baff38e230dc46dccfb46e7d19
3
+ size 97002
code/fine_tune_sft_dpo/evaluation/bn/test_self_refine_vllm_qwen3_4B_base_eval_results.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:346286565d0a6ec98ea218bd463a9322e6b1738bf5afb8235cddc2cfcb22d488
3
+ size 97347
code/fine_tune_sft_dpo/evaluation/bn/test_self_refine_vllm_qwen3_4B_sft_eval_results.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9630e15bd5d4b3a74de3c0d97a5dc5fc40ddffb96eb16ea55b3804cf5db5a419
3
+ size 97016
code/fine_tune_sft_dpo/evaluation/bn/v1/test_best_of_n_qwen3-4B_base_eval_results.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c121bcd6eae3ab09fcc37eaa3a6b49edf437b4700573f62febc5db58948d26e8
3
+ size 97136
code/fine_tune_sft_dpo/evaluation/bn/v1/test_best_of_n_qwen3-4B_sft_eval_results.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccf84db8e83c982492ad99641f4a01b882ff7e9a259159f7fbcb04105745776b
3
+ size 96671
code/fine_tune_sft_dpo/evaluation/bn/v1/test_inference_vllm_qwen3-4B_base_eval_results.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de2460d80063dd684264a474babc483a510e293e42d94e9392665b85e67f351c
3
+ size 97496
code/fine_tune_sft_dpo/evaluation/bn/v1/test_inference_vllm_qwen3-4B_sft_eval_results.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32bc196d3288b6355e7b16f9dc1bd31a7dbce1c2c837a0059c29e255cca698e1
3
+ size 97006
code/fine_tune_sft_dpo/evaluation/bn/v1/test_self_refine_vllm_qwen3_4B_base_eval_results.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b163e821453592fbcdb922291a803431114d0dfa828df32366cd58340da949ef
3
+ size 97349
code/fine_tune_sft_dpo/evaluation/bn/v1/test_self_refine_vllm_qwen3_4B_sft_eval_results.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e822e39c927b8f75d411d42b5e556696481987b5589cd9f6453bee8e65d068d7
3
+ size 97013
code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **সিস্টেম ভূমিকা:**
2
+
3
+ আপনি একজন বিশেষজ্ঞ চিকিৎসা সম্পাদক এবং স্বাস্থ্য-সাক্ষরতা বিশেষজ্ঞ। আপনার কাজ হলো জটিল চিকিৎসাবিষয়ক লেখাকে পাঠকের স্বাস্থ্য-সাক্ষরতার স্তর অনুযায়ী তিনটি ভিন্ন সংস্করণে রূপান্তর করা। আপনাকে ইনপুটের মূল ভাষা অবশ্যই অপরিবর্তিত রাখতে হবে, তবে ভাষার জটিলতা স্তর অনুযায়ী সমন্বয় করতে হবে। সরলীকৃত সংস্করণগুলো যেন সঠিক ও গুরুত্বপূর্ণ তথ্য বজায় রাখে, সে জন্য আপনাকে মূল তথ্য ও বার্তাকে ভিত্তি হিসেবে ব্যবহার করতে হবে।
4
+
5
+ **ব্যবহারকারী নির্দেশনা:**
6
+
7
+ দয়া করে নিচের চিকিৎসাবিষয়ক সোর্স টেক্সট ব্যবহার করে স্বাস্থ্য‑সাক্ষরতার তিনটি ভিন্ন স্তরের জন্য আলাদা আলাদা সংস্করণ তৈরি করুন।
8
+
9
+ ### প্রতিটি স্তরের জন্য নির্দেশনা:
10
+
11
+ 1. স্তর: কম স্বাস্থ্য‑সাক্ষরতা (উচ্চ পাঠযোগ্যতা)
12
+
13
+ লক্ষ্য পাঠক: যারা খুব সহজ, দৈনন্দিন ভাষায় দ্রুত বোঝার মতো ব্যাখ্যা চান।
14
+
15
+ ভাষাগত লক্ষ্য: একদম ঘরোয়া/দৈনন্দিন কথাবার্তার ভাষা ব্যবহার করুন। সব ধরনের চিকিৎসাবিষয়ক জারগনকে সহজ ব্যাখ্যামূলক ভাষায় রূপান্তর করুন (যেমন, "renal" এর পরিবর্তে "কিডনির সমস্যা" বা "কিডনি" লিখুন)।
16
+
17
+ তথ্যের ঘনত্ব: কেবলমাত্র "যা অবশ্যই জানা দরকার" ধরনের মূল তথ্যগুলো রাখুন।
18
+
19
+ কৌশল: বেশি মাত্রায় পুনর্লিখন ও উদাহরণ/উপমা ব্যবহার করুন। প্রতি বাক্যে একটি করে মূল ধারণা রাখুন।
20
+
21
+ বিশ্বস্ততা: লেখাটি অবশ্যই গোল্ড সামারি‑র সঙ্গে সম্পূর্ণ সামঞ্জস্যপূর্ণ হতে হবে।
22
+
23
+ 2. স্তর: মাঝারি স্বাস্থ্য‑সাক্ষরতা (মাঝারি পাঠযোগ্যতা)
24
+
25
+ লক্ষ্য পাঠক: সাধারণ মানুষ, যারা সাধারণ সংবাদ বা তথ্যভিত্তিক লেখা পড়ে বুঝতে পারেন।
26
+
27
+ ভাষাগত লক্ষ্য: মানিকৃত/সাধারণ শব্দভাণ্ডার ব্যবহার করুন। সাধারণভাবে পরিচিত চিকিৎসাবিষয়ক শব্দ ব্যবহার করা যেতে পারে, তবে অতিরিক্ত টেকনিক্যাল "ডাক্তারি ভাষা" এড়িয়ে চলুন বা সহজভাবে ব্যাখ্যা করুন।
28
+
29
+ তথ্যের ঘনত্ব: ভারসাম্যপূর্ণ রাখুন। মূল বার্তাকে কেন্দ্র করে কাঠামো তৈরি করুন এবং প্রয়োজন অনুযায়ী সোর্স টেক্সট থেকে প্রাসঙ্গিক অতিরিক্ত প্রেক্ষাপট যোগ করুন।
30
+
31
+ কৌশল: মাঝারি মাত্রার পুনর্লিখন করুন। অপ্রয়োজনীয় টেকনিক্যাল খুঁটিনাটি বাদ দিন, যাতে পাঠক অতিরিক্ত তথ্যের চাপে না পড়েন।
32
+
33
+ বিশ্বস্ততা: লেখাটি যেন মূল বার্তা ও ধারাবাহিকতা বজায় রাখে।
34
+
35
+ 3. স্তর: উচ্চ স্বাস্থ্য‑সাক্ষরতা / প্র���িসিয়েন্ট (কম পাঠযোগ্যতা, উচ্চ জটিলতা)
36
+
37
+ লক্ষ্য পাঠক: গবেষক, ক্লিনিশিয়ান বা অত্যন্ত সচেতন/অভিজ্ঞ রোগী।
38
+
39
+ ভাষাগত লক্ষ্য: প্রয়োজনে টেকনিক্যাল ও একাডেমিক ভাষা ব্যবহার করুন। ক্লিনিক্যাল নির্ভুলতা ও চিকিৎসাবিজ্ঞানভিত্তিক সূক্ষ্ম দিকগুলোকে অগ্রাধিকার দিন।
40
+
41
+ তথ্যের ঘনত্ব: বেশি রাখুন। পুরো সোর্স টেক্সট ব্যবহার করে ডেটা, শারীরবৃত্তীয় প্রক্রিয়া, পরিসংখ্যান ইত্যাদি প্রাসঙ্গিক তথ্য অন্তর্ভুক্ত করুন।
42
+
43
+ কৌশল: যতটা সম্ভব কম পুনর্লিখন করুন। মূল টেকনিক্যাল পরিভাষা ও বাক্য গঠন অধিকাংশই অক্ষুণ্ণ রাখুন।
44
+
45
+ বিশ্বস্ততা: সোর্স টেক্সটের সাথে ঘনিষ্ঠভাবে অনুগত থাকুন; প্রয়োজন হলে বৈজ্ঞানিক প্রেক্ষাপট বাড়াতে সম্পর্কিত উপ‑দাবি বা ব্যাখ্যা যোগ করতে পারেন।
46
+
47
+
48
+ আমি আপনাকে নিম্নোক্ত তথ্যগুলো প্রদান করব:
49
+
50
+ - ইনপুট ভাষা: <<<SOURCE_LANGUAGE>>>
51
+ - সোর্স টেক্সট (বিস্তারিত মূল লেখা): <<<FULL_TEXT>>>
52
+
53
+ **আউটপুট ফরম্যাট (শুধু JSON):**
54
+ {{
55
+ "low_health_literacy": "...",
56
+ "intermediate_health_literacy": "...",
57
+ "proficient_health_literacy": "..."
58
+ }}
code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_intermediate ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **সিস্টেম ভূমিকা:**
2
+
3
+ আপনি একজন বিশেষজ্ঞ চিকিৎসা সম্পাদক এবং স্বাস্থ্য‑সাক্ষরতা বিশেষজ্ঞ। আপনার কাজ হলো জটিল চিকিৎসাবিষয়ক লেখাকে এমন একটি সংস্করণে রূপান্তর করা, যা মাঝারি স্বাস্থ্য‑সাক্ষরতা সম্পন্ন পাঠকদের জন্য উপযোগী। আপনাকে ইনপুটের মূল ভাষা অপরিবর্তিত রেখে ভাষার জটিলতা ও তথ্যের ঘনত্বকে ভারসাম্যপূর্ণ করতে হবে। সরলীকৃত লেখাটি যেন সঠিক ও গুরুত্বপূর্ণ তথ্য বজায় রাখে, সে জন্য আপনাকে মূল তথ্য ও বার্তাকে ভিত্তি হিসেবে ব্যবহার করতে হবে।
4
+
5
+ **ব্যবহারকারী নির্দেশনা:**
6
+
7
+ দয়া করে নিচের চিকিৎসাবিষয়ক সোর্স টেক্সট ব্যবহার করে **মাঝারি স্বাস্থ্য‑সাক্ষরতা (মাঝারি পাঠযোগ্যতা)** স্তরের জন্য একটি সংস্করণ তৈরি করুন।
8
+
9
+ ### নির্দেশনা:
10
+
11
+ স্তর: মাঝারি স্বাস্থ্য‑সাক্ষরতা (মাঝারি পাঠযোগ্যতা)
12
+
13
+ লক্ষ্য পাঠক: সাধারণ জনগণ, যারা সাধারণ সংবাদ বা তথ্যভিত্তিক লেখা পড়ে বুঝতে পারেন।
14
+
15
+ ভাষাগত লক্ষ্য: মানিকৃত ও সহজবোধ্য শব্দভাণ্ডার ব্যবহার করুন। পরিচিত চিকিৎসাবিষয়ক শব্দ ব্যবহার করা যেতে পারে, তবে অতিরিক্ত টেকনিক্যাল "ডাক্তারি ভাষা" এলে তা সহজ ব্যাখ্যায় রূপান্তর করুন।
16
+
17
+ তথ্যের ঘনত্ব: ভারসাম্যপূর্ণ রাখুন। মূল বার্তাকে সামনে রেখে মূল কাঠামো তৈরি করুন এবং প্রয়োজন হলে সোর্স টেক্সট থেকে প্রাসঙ্গিক অতিরিক্ত তথ্য বা প্রেক্ষাপট যোগ করুন।
18
+
19
+ কৌশল: মাঝারি মাত্রার পুনর্লিখন করুন। অতি খুঁটিনাটি টেকনিক্যাল ডিটেইল বাদ দিন, যাতে পাঠক তথ্যের চাপে না পড়ে কিন্তু মূল বিষয়টি স্পষ্টভাবে বুঝতে পারে।
20
+
21
+ বিশ্বস্ততা: লেখাটি যেন মূল বার্তা, ক্রম এবং যুক্তি বজায় রাখে।
22
+
23
+ আমি আপনাকে নিম্নোক্ত তথ্যগুলো প্রদান করব:
24
+
25
+ - ইনপুট ভাষা: {source_lang}
26
+ - সোর্স টেক্সট (বিস্তারিত মূল লেখা): {full_text}
27
+
28
+ **আউটপুট ফরম্যাট (শুধু JSON):**
29
+ {{
30
+ "intermediate_health_literacy": "..."
31
+ }}
code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_low ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **সিস্টেম ভূমিকা:**
2
+
3
+ আপনি একজন বিশেষজ্ঞ চিকিৎসা সম্পাদক এবং স্বাস্থ্য‑সাক্ষরতা বিশেষজ্ঞ। আপনার কাজ হলো জটিল চিকিৎসাবিষয়ক লেখাকে এমনভাবে রূপান্তর করা, যা কম স্বাস্থ্য‑সাক্ষরতা সম্পন্ন পাঠকদের জন্য সহজে বোঝা যায়। আপনাকে ইনপুটের মূল ভাষা অপরিবর্তিত রাখতে হবে, তবে ভাষার জটিলতা কমিয়ে আনতে হবে। সরলীকৃত লেখাটি যেন সঠিক ও প্রয়োজনীয় থাকে, সে জন্য আপনাকে মূল তথ্য ও বার্তাকে ভিত্তি হিসেবে ব্যবহার করতে হবে।
4
+
5
+ **ব্যবহারকারী নির্দেশনা:**
6
+
7
+ দয়া করে নিচের চিকিৎসাবিষয়ক সোর্স টেক্সট ব্যবহার করে **কম স্বাস্থ্য‑সাক্ষরতা (উচ্চ পাঠযোগ্যতা)** স্তরের জন্য একটি সংস্করণ তৈরি করুন।
8
+
9
+ ### নির্দেশনা:
10
+
11
+ স্তর: কম স্বাস্থ্য‑সাক্ষরতা (উচ্চ পাঠযোগ্যতা)
12
+
13
+ লক্ষ্য পাঠক: এমন ব্যক্তি, যাঁরা খুব সহজ, সরাসরি ভাষায় তথ্য পেতে চান এবং তা থেকে দ্রুত পদক্ষেপ নিতে চান।
14
+
15
+ ভাষাগত লক্ষ্য: একদম ঘরোয়া/দৈনন্দিন কথাবার্তার ভাষা ব্যবহার করুন। সব ধরনের চিকিৎসাবিষয়ক জারগনকে সহজ বর্ণনামূলক শব্দে রূপান্তর করুন (যেমন, "renal" এর পরিবর্তে "কিডনির সমস্যা" বা "কিডনি" লিখুন)।
16
+
17
+ তথ্যের ঘনত্ব: কেবলমাত্র "অবশ্যই জানা দরকার" ধরনের মূল তথ্যগুলো রাখুন। অপ্রয়োজনীয় ব্যাখ্যা বা অতিরিক্ত ডেটা এড়িয়ে চলুন।
18
+
19
+ কৌশল: উচ্চ মাত্রার পুনর্লিখন করুন এবং প্রয়োজন হলে সহজ উপমা বা উদাহরণ ব্যবহার করুন। প্রতিটি বাক্যে একটি করে স্পষ্ট ধারণা রাখুন।
20
+
21
+ বিশ্বস্ততা: লেখাটি অবশ্যই গোল্ড সামারি‑র সাথে সম্পূর্ণ সামঞ্জস্যপূর্ণ হতে হবে; নতুন তথ্য যোগ করা যাবে না।
22
+
23
+ আমি আপনাকে নিম্নোক্ত তথ্যগুলো প্রদান করব:
24
+
25
+ - ইনপুট ভাষা: {source_lang}
26
+ - সোর্স টেক্সট (বিস্তারিত মূল লেখা): {full_text}
27
+
28
+ **আউটপুট ফরম্যাট (শুধু JSON):**
29
+ {{
30
+ "low_health_literacy": "..."
31
+ }}
code/fine_tune_sft_dpo/prompt_bn_wo_gs/prompt_proficient ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **সিস্টেম ভূমিকা:**
2
+
3
+ আপনি একজন বিশেষজ্ঞ চিকিৎসা সম্পাদক এবং স্বাস্থ্য‑সাক্ষরতা বিশেষজ্ঞ। আপনার কাজ হলো জটিল চিকিৎসাবিষয়ক লেখাকে এমন একটি সংস্করণে রূপান্তর করা, যা উচ্চ/প্রফিসিয়েন্ট স্বাস্থ্য‑সাক্ষরতা সম্পন্ন পাঠকদের জন্য উপযোগী। আপনাকে ইনপুটের মূল ভাষা বজায় রেখে টেকনিক্যাল ও একাডেমিক ভাষার যথাযথ ব্যবহার করতে হবে। আপনি মূল তথ্যকে রেফারেন্স হিসেবে ব্যবহার করবেন, তবে প্রয়োজনে সোর্স টেক্সট থেকে গভীরতর বৈজ্ঞানিক প্রেক্ষাপটও যোগ করতে পারবেন।
4
+
5
+ **ব্যবহারকারী নির্দেশনা:**
6
+
7
+ দয়া করে নিচের চিকিৎসাবিষয়ক সোর্স টেক্সট ব্যবহার করে **উচ্চ/প্রফিসিয়েন্ট স্বাস্থ্য‑সাক্ষরতা (কম পাঠযোগ্যতা, উচ্চ জটিলতা)** স্তরের জন্য একটি সংস্করণ তৈরি করুন।
8
+
9
+ ### নির্দেশনা:
10
+
11
+ স্তর: প্রফিসিয়েন্ট স্বাস্থ্য‑সাক্ষরতা (কম পাঠযোগ্যতা)
12
+
13
+ লক্ষ্য পাঠক: গবেষক, ক্লিনিশিয়ান, বা অত্যন্ত সচেতন/অভিজ্ঞ রোগী।
14
+
15
+ ভাষাগত লক্ষ্য: প্রয়োজন অনুযায়ী টেকনিক্যাল ও একাডেমিক ভাষা ব্যবহার করুন। ক্লিনিক্যাল সূক্ষ্মতা, প্যাথোফিজিওলজি, ডায়াগনস্টিক মানদণ্ড ইত্যাদির নির্ভুল উপস্থাপনাকে অগ্রাধিকার দিন।
16
+
17
+ তথ্যের ঘনত্ব: উচ্চ রাখুন। সোর্স টেক্সট থেকে ডেটা, পরিসংখ্যান, শারীরবৃত্তীয় প্রক্রিয়া, চিকিৎসাপদ্ধতি এবং গবেষণালব্ধ তথ্য উপযুক্তভাবে অন্তর্ভুক্ত করুন।
18
+
19
+ কৌশল: কম মাত্রার পুনর্লিখন করুন। মূল টেকনিক্যাল পরিভাষা, গঠন এবং গুরুত্বপূর্ণ বাক্যগুলো যতটা সম্ভব অক্ষুণ্ণ রাখুন; প্রয়োজনে কেবল ব্যাকরণগত বা শৈলগত সামঞ্জস্যের জন্য পরিবর্তন করুন।
20
+
21
+ বিশ্বস্ততা: সোর্স টেক্সটের প্রতি ঘনিষ্ঠভাবে অনুগত থাকুন; প্রয়োজন হলে বৈজ্ঞানিক প্রেক্ষাপট ও ব্যাখ্যা সম্প্রসারণ করতে সম্পর্কিত উপ‑দাবি বা তথ্য যোগ করতে পারেন, তবে ভিত্তিহীন নতুন দাবি যোগ করবেন না।
22
+
23
+ আমি আপনাকে নিম্নোক্ত তথ্যগুলো প্রদান করব:
24
+
25
+ - ইনপুট ভাষা: {source_lang}
26
+ - সোর্স টেক্সট (বিস্তারিত মূল লেখা): {full_text}
27
+
28
+ **আউটপুট ফরম্যাট (শুধু JSON):**
29
+ {{
30
+ "proficient_health_literacy": "..."
31
+ }}
code/fine_tune_sft_dpo/qwen3-inference-vllm_bn.py CHANGED
@@ -7,16 +7,23 @@ merged model was saved to `/home/mshahidul/readctrl/code/fine_tune_sft_dpo/model
7
 
8
  import os
9
  os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
10
- os.environ["CUDA_VISIBLE_DEVICES"] = "6"
11
 
12
  import argparse
13
  import json
 
14
  from datetime import datetime
15
 
16
  from vllm import LLM, SamplingParams
17
  from transformers import AutoTokenizer
18
 
19
 
 
 
 
 
 
 
20
  # Paths
21
  BASE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo"
22
  MODEL_DIR = os.path.join(BASE_DIR, "model", "bn")
@@ -73,7 +80,7 @@ def parse_args():
73
  p.add_argument(
74
  "--temperature",
75
  type=float,
76
- default=0.0,
77
  help="Sampling temperature for generation.",
78
  )
79
  p.add_argument(
@@ -147,7 +154,10 @@ def main():
147
  user_prompt = build_user_message(prompts[label], fulltext, summary)
148
  chat = [{"role": "user", "content": user_prompt}]
149
  formatted = tokenizer.apply_chat_template(
150
- chat, tokenize=False, add_generation_prompt=True
 
 
 
151
  )
152
 
153
  batched_prompts.append(formatted)
@@ -189,7 +199,7 @@ def main():
189
  # Map generation results for this batch back to global indices
190
  for idx_in_batch, output in enumerate(outputs):
191
  original_idx = batch_indices[idx_in_batch]
192
- text = output.outputs[0].text.strip()
193
  generated_texts[original_idx] = text
194
 
195
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
 
7
 
8
  import os
9
  os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
10
+ os.environ["CUDA_VISIBLE_DEVICES"] = "5"
11
 
12
  import argparse
13
  import json
14
+ import re
15
  from datetime import datetime
16
 
17
  from vllm import LLM, SamplingParams
18
  from transformers import AutoTokenizer
19
 
20
 
21
+ def strip_think_blocks(text: str) -> str:
22
+ """Remove <think>...</think> reasoning blocks from model output."""
23
+ cleaned = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
24
+ return cleaned if cleaned else text
25
+
26
+
27
  # Paths
28
  BASE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo"
29
  MODEL_DIR = os.path.join(BASE_DIR, "model", "bn")
 
80
  p.add_argument(
81
  "--temperature",
82
  type=float,
83
+ default=0.1,
84
  help="Sampling temperature for generation.",
85
  )
86
  p.add_argument(
 
154
  user_prompt = build_user_message(prompts[label], fulltext, summary)
155
  chat = [{"role": "user", "content": user_prompt}]
156
  formatted = tokenizer.apply_chat_template(
157
+ chat,
158
+ tokenize=False,
159
+ add_generation_prompt=True,
160
+ enable_thinking=False,
161
  )
162
 
163
  batched_prompts.append(formatted)
 
199
  # Map generation results for this batch back to global indices
200
  for idx_in_batch, output in enumerate(outputs):
201
  original_idx = batch_indices[idx_in_batch]
202
+ text = strip_think_blocks(output.outputs[0].text)
203
  generated_texts[original_idx] = text
204
 
205
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
code/fine_tune_sft_dpo/qwen3_best_of_n.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Run inference with the finetuned Bangla Qwen3 model on test_bn.json
3
+ and save the generation results under results/bn.
4
+ """
5
+ import os
6
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
7
+ os.environ["CUDA_VISIBLE_DEVICES"] = "5"
8
+ import argparse
9
+ import json
10
+ import os
11
+ from datetime import datetime
12
+ from typing import Any, Dict, List
13
+
14
+ import torch
15
+ from transformers import AutoModelForCausalLM, AutoTokenizer
16
+
17
+
18
+ # Paths (keep in sync with qwen3-finetune_bn.py)
19
+ BASE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo"
20
+ MODEL_SAVE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo/model/bn"
21
+ PROMPT_DIR = os.path.join(BASE_DIR, "prompt_bn")
22
+ TEST_JSON = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo/dataset/bn/test_bn.json"
23
+ RESULTS_DIR = os.path.join(BASE_DIR, "results", "bn")
24
+ SOURCE_LANG = "Bangla"
25
+
26
+ LABEL_TO_PROMPT_FILE = {
27
+ "low_health_literacy": "prompt_low",
28
+ "intermediate_health_literacy": "prompt_intermediate",
29
+ "proficient_health_literacy": "prompt_proficient",
30
+ }
31
+
32
+
33
+ def load_prompts() -> Dict[str, str]:
34
+ """Load prompt templates from prompt_bn directory."""
35
+ prompts: Dict[str, str] = {}
36
+ for label, filename in LABEL_TO_PROMPT_FILE.items():
37
+ path = os.path.join(PROMPT_DIR, filename)
38
+ if os.path.isfile(path):
39
+ with open(path, "r", encoding="utf-8") as f:
40
+ prompts[label] = f.read()
41
+ else:
42
+ raise FileNotFoundError(f"Prompt file not found: {path}")
43
+ return prompts
44
+
45
+
46
+ def build_user_message(
47
+ prompt_template: str,
48
+ full_text: str,
49
+ gold_summary: str,
50
+ source_lang: str = SOURCE_LANG,
51
+ ) -> str:
52
+ """Fill prompt template with full_text, gold_summary, source_lang."""
53
+ return (
54
+ prompt_template.replace("{full_text}", full_text)
55
+ .replace("{gold_summary}", gold_summary)
56
+ .replace("{source_lang}", source_lang)
57
+ )
58
+
59
+
60
+ def parse_args() -> argparse.Namespace:
61
+ p = argparse.ArgumentParser(
62
+ description="Run inference with finetuned Qwen3-4B Bangla model on test_bn.json."
63
+ )
64
+ p.add_argument(
65
+ "--max-new-tokens",
66
+ type=int,
67
+ default=512,
68
+ help="Maximum number of new tokens to generate per sample.",
69
+ )
70
+ p.add_argument(
71
+ "--temperature",
72
+ type=float,
73
+ default=0.7,
74
+ help="Sampling temperature.",
75
+ )
76
+ p.add_argument(
77
+ "--top-p",
78
+ type=float,
79
+ default=0.9,
80
+ help="Top-p (nucleus) sampling value.",
81
+ )
82
+ p.add_argument(
83
+ "--output-json",
84
+ type=str,
85
+ default="test_bn_qwen3-4B_sft_inference.json",
86
+ help=(
87
+ "Output JSON filename (saved under results/bn). "
88
+ "If it already exists, it will be overwritten."
89
+ ),
90
+ )
91
+ return p.parse_args()
92
+
93
+
94
+ def load_model_and_tokenizer(model_dir: str):
95
+ """Load the merged finetuned model and tokenizer for inference."""
96
+ if not os.path.isdir(model_dir):
97
+ raise FileNotFoundError(
98
+ f"Finetuned model directory not found: {model_dir}. "
99
+ "Make sure qwen3-finetune_bn.py was run with model saving enabled."
100
+ )
101
+
102
+ print(f"Loading tokenizer from {model_dir}")
103
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
104
+ if tokenizer.pad_token is None:
105
+ tokenizer.pad_token = tokenizer.eos_token
106
+
107
+ print(f"Loading model from {model_dir}")
108
+ if torch.cuda.is_available():
109
+ model = AutoModelForCausalLM.from_pretrained(
110
+ model_dir,
111
+ torch_dtype=torch.bfloat16,
112
+ device_map="auto",
113
+ )
114
+ else:
115
+ model = AutoModelForCausalLM.from_pretrained(model_dir)
116
+
117
+ model.eval()
118
+ return model, tokenizer
119
+
120
+
121
+ def run_inference(
122
+ model,
123
+ tokenizer,
124
+ test_items: List[Dict[str, Any]],
125
+ prompts: Dict[str, str],
126
+ max_new_tokens: int,
127
+ temperature: float,
128
+ top_p: float,
129
+ ) -> List[Dict[str, Any]]:
130
+ """Generate adapted texts for each test item."""
131
+ results: List[Dict[str, Any]] = []
132
+
133
+ device = next(model.parameters()).device
134
+
135
+ for idx, item in enumerate(test_items):
136
+ label = item.get("label")
137
+ fulltext = item.get("fulltext", "")
138
+ summary = item.get("summary", "")
139
+
140
+ if not fulltext or label not in prompts:
141
+ # Keep the original item, but note that generation was skipped.
142
+ out_item = dict(item)
143
+ out_item["model_gen_text"] = ""
144
+ out_item["model_gen_skipped"] = True
145
+ results.append(out_item)
146
+ continue
147
+
148
+ user_msg = build_user_message(prompts[label], fulltext, summary)
149
+ messages = [{"role": "user", "content": user_msg}]
150
+
151
+ text = tokenizer.apply_chat_template(
152
+ messages,
153
+ add_generation_prompt=True,
154
+ tokenize=False,
155
+ )
156
+ inputs = tokenizer(text, return_tensors="pt").to(device)
157
+ input_len = inputs["input_ids"].shape[-1]
158
+
159
+ with torch.no_grad():
160
+ generated_ids = model.generate(
161
+ **inputs,
162
+ max_new_tokens=max_new_tokens,
163
+ do_sample=True,
164
+ temperature=temperature,
165
+ top_p=top_p,
166
+ pad_token_id=tokenizer.eos_token_id,
167
+ )
168
+
169
+ gen_ids = generated_ids[0, input_len:]
170
+ gen_text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
171
+
172
+ out_item = dict(item)
173
+ out_item["model_gen_text"] = gen_text
174
+ out_item["model_name"] = "qwen3-4B_sft_bn"
175
+ out_item["model_max_new_tokens"] = max_new_tokens
176
+ out_item["model_temperature"] = temperature
177
+ out_item["model_top_p"] = top_p
178
+ results.append(out_item)
179
+
180
+ if (idx + 1) % 10 == 0:
181
+ print(f"Processed {idx + 1} / {len(test_items)} samples")
182
+
183
+ return results
184
+
185
+
186
+ def main():
187
+ args = parse_args()
188
+
189
+ os.makedirs(RESULTS_DIR, exist_ok=True)
190
+
191
+ print("Loading prompts from", PROMPT_DIR)
192
+ prompts = load_prompts()
193
+
194
+ print("Loading test data from", TEST_JSON)
195
+ with open(TEST_JSON, "r", encoding="utf-8") as f:
196
+ test_items = json.load(f)
197
+
198
+ print(f"Test samples: {len(test_items)}")
199
+
200
+ model, tokenizer = load_model_and_tokenizer(MODEL_SAVE_DIR)
201
+
202
+ print(
203
+ f"Running inference with max_new_tokens={args.max_new_tokens}, "
204
+ f"temperature={args.temperature}, top_p={args.top_p}"
205
+ )
206
+ results = run_inference(
207
+ model=model,
208
+ tokenizer=tokenizer,
209
+ test_items=test_items,
210
+ prompts=prompts,
211
+ max_new_tokens=args.max_new_tokens,
212
+ temperature=args.temperature,
213
+ top_p=args.top_p,
214
+ )
215
+
216
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
217
+ output_filename = args.output_json
218
+ if not output_filename.endswith(".json"):
219
+ output_filename += ".json"
220
+
221
+ output_path = os.path.join(RESULTS_DIR, output_filename)
222
+
223
+ # If the filename already exists, append a timestamp to avoid silent overwrite.
224
+ if os.path.exists(output_path):
225
+ name, ext = os.path.splitext(output_filename)
226
+ output_filename = f"{name}_{timestamp}{ext}"
227
+ output_path = os.path.join(RESULTS_DIR, output_filename)
228
+
229
+ print("Saving results to", output_path)
230
+ with open(output_path, "w", encoding="utf-8") as f:
231
+ json.dump(results, f, ensure_ascii=False, indent=2)
232
+
233
+ print("Done.")
234
+
235
+
236
+ if __name__ == "__main__":
237
+ main()
238
+
code/fine_tune_sft_dpo/qwen3_infer_bn.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Run inference with the finetuned Bangla Qwen3 model on test_bn.json
3
+ and save the generation results under results/bn.
4
+ """
5
+ import os
6
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
7
+ os.environ["CUDA_VISIBLE_DEVICES"] = "5"
8
+ import argparse
9
+ import json
10
+ import os
11
+ import re
12
+ from datetime import datetime
13
+ from typing import Any, Dict, List
14
+
15
+ import torch
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer
17
+
18
+
19
+ def strip_think_blocks(text: str) -> str:
20
+ """Remove <think>...</think> reasoning blocks from model output."""
21
+ cleaned = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
22
+ return cleaned if cleaned else text
23
+
24
+
25
+ # Paths (keep in sync with qwen3-finetune_bn.py)
26
+ BASE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo"
27
+ MODEL_SAVE_DIR = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo/model/bn"
28
+ PROMPT_DIR = os.path.join(BASE_DIR, "prompt_bn")
29
+ TEST_JSON = "/home/mshahidul/readctrl/code/fine_tune_sft_dpo/dataset/bn/test_bn.json"
30
+ RESULTS_DIR = os.path.join(BASE_DIR, "results", "bn")
31
+ SOURCE_LANG = "Bangla"
32
+
33
+ LABEL_TO_PROMPT_FILE = {
34
+ "low_health_literacy": "prompt_low",
35
+ "intermediate_health_literacy": "prompt_intermediate",
36
+ "proficient_health_literacy": "prompt_proficient",
37
+ }
38
+
39
+
40
+ def load_prompts() -> Dict[str, str]:
41
+ """Load prompt templates from prompt_bn directory."""
42
+ prompts: Dict[str, str] = {}
43
+ for label, filename in LABEL_TO_PROMPT_FILE.items():
44
+ path = os.path.join(PROMPT_DIR, filename)
45
+ if os.path.isfile(path):
46
+ with open(path, "r", encoding="utf-8") as f:
47
+ prompts[label] = f.read()
48
+ else:
49
+ raise FileNotFoundError(f"Prompt file not found: {path}")
50
+ return prompts
51
+
52
+
53
+ def build_user_message(
54
+ prompt_template: str,
55
+ full_text: str,
56
+ gold_summary: str,
57
+ source_lang: str = SOURCE_LANG,
58
+ ) -> str:
59
+ """Fill prompt template with full_text, gold_summary, source_lang."""
60
+ return (
61
+ prompt_template.replace("{full_text}", full_text)
62
+ .replace("{gold_summary}", gold_summary)
63
+ .replace("{source_lang}", source_lang)
64
+ )
65
+
66
+
67
+ def parse_args() -> argparse.Namespace:
68
+ p = argparse.ArgumentParser(
69
+ description="Run inference with finetuned Qwen3-4B Bangla model on test_bn.json."
70
+ )
71
+ p.add_argument(
72
+ "--max-new-tokens",
73
+ type=int,
74
+ default=512,
75
+ help="Maximum number of new tokens to generate per sample.",
76
+ )
77
+ p.add_argument(
78
+ "--temperature",
79
+ type=float,
80
+ default=0.7,
81
+ help="Sampling temperature.",
82
+ )
83
+ p.add_argument(
84
+ "--top-p",
85
+ type=float,
86
+ default=0.9,
87
+ help="Top-p (nucleus) sampling value.",
88
+ )
89
+ p.add_argument(
90
+ "--output-json",
91
+ type=str,
92
+ default="test_bn_qwen3-4B_sft_inference.json",
93
+ help=(
94
+ "Output JSON filename (saved under results/bn). "
95
+ "If it already exists, it will be overwritten."
96
+ ),
97
+ )
98
+ return p.parse_args()
99
+
100
+
101
+ def load_model_and_tokenizer(model_dir: str):
102
+ """Load the merged finetuned model and tokenizer for inference."""
103
+ if not os.path.isdir(model_dir):
104
+ raise FileNotFoundError(
105
+ f"Finetuned model directory not found: {model_dir}. "
106
+ "Make sure qwen3-finetune_bn.py was run with model saving enabled."
107
+ )
108
+
109
+ print(f"Loading tokenizer from {model_dir}")
110
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
111
+ if tokenizer.pad_token is None:
112
+ tokenizer.pad_token = tokenizer.eos_token
113
+
114
+ print(f"Loading model from {model_dir}")
115
+ if torch.cuda.is_available():
116
+ model = AutoModelForCausalLM.from_pretrained(
117
+ model_dir,
118
+ torch_dtype=torch.bfloat16,
119
+ device_map="auto",
120
+ )
121
+ else:
122
+ model = AutoModelForCausalLM.from_pretrained(model_dir)
123
+
124
+ model.eval()
125
+ return model, tokenizer
126
+
127
+
128
+ def run_inference(
129
+ model,
130
+ tokenizer,
131
+ test_items: List[Dict[str, Any]],
132
+ prompts: Dict[str, str],
133
+ max_new_tokens: int,
134
+ temperature: float,
135
+ top_p: float,
136
+ ) -> List[Dict[str, Any]]:
137
+ """Generate adapted texts for each test item."""
138
+ results: List[Dict[str, Any]] = []
139
+
140
+ device = next(model.parameters()).device
141
+
142
+ for idx, item in enumerate(test_items):
143
+ label = item.get("label")
144
+ fulltext = item.get("fulltext", "")
145
+ summary = item.get("summary", "")
146
+
147
+ if not fulltext or label not in prompts:
148
+ # Keep the original item, but note that generation was skipped.
149
+ out_item = dict(item)
150
+ out_item["model_gen_text"] = ""
151
+ out_item["model_gen_skipped"] = True
152
+ results.append(out_item)
153
+ continue
154
+
155
+ user_msg = build_user_message(prompts[label], fulltext, summary)
156
+ messages = [{"role": "user", "content": user_msg}]
157
+
158
+ text = tokenizer.apply_chat_template(
159
+ messages,
160
+ add_generation_prompt=True,
161
+ tokenize=False,
162
+ enable_thinking=False,
163
+ )
164
+ inputs = tokenizer(text, return_tensors="pt").to(device)
165
+ input_len = inputs["input_ids"].shape[-1]
166
+
167
+ with torch.no_grad():
168
+ generated_ids = model.generate(
169
+ **inputs,
170
+ max_new_tokens=max_new_tokens,
171
+ do_sample=True,
172
+ temperature=temperature,
173
+ top_p=top_p,
174
+ pad_token_id=tokenizer.eos_token_id,
175
+ )
176
+
177
+ gen_ids = generated_ids[0, input_len:]
178
+ gen_text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
179
+ gen_text = strip_think_blocks(gen_text)
180
+
181
+ out_item = dict(item)
182
+ out_item["model_gen_text"] = gen_text
183
+ out_item["model_name"] = "qwen3-4B_sft_bn"
184
+ out_item["model_max_new_tokens"] = max_new_tokens
185
+ out_item["model_temperature"] = temperature
186
+ out_item["model_top_p"] = top_p
187
+ results.append(out_item)
188
+
189
+ if (idx + 1) % 10 == 0:
190
+ print(f"Processed {idx + 1} / {len(test_items)} samples")
191
+
192
+ return results
193
+
194
+
195
+ def main():
196
+ args = parse_args()
197
+
198
+ os.makedirs(RESULTS_DIR, exist_ok=True)
199
+
200
+ print("Loading prompts from", PROMPT_DIR)
201
+ prompts = load_prompts()
202
+
203
+ print("Loading test data from", TEST_JSON)
204
+ with open(TEST_JSON, "r", encoding="utf-8") as f:
205
+ test_items = json.load(f)
206
+
207
+ print(f"Test samples: {len(test_items)}")
208
+
209
+ model, tokenizer = load_model_and_tokenizer(MODEL_SAVE_DIR)
210
+
211
+ print(
212
+ f"Running inference with max_new_tokens={args.max_new_tokens}, "
213
+ f"temperature={args.temperature}, top_p={args.top_p}"
214
+ )
215
+ results = run_inference(
216
+ model=model,
217
+ tokenizer=tokenizer,
218
+ test_items=test_items,
219
+ prompts=prompts,
220
+ max_new_tokens=args.max_new_tokens,
221
+ temperature=args.temperature,
222
+ top_p=args.top_p,
223
+ )
224
+
225
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
226
+ output_filename = args.output_json
227
+ if not output_filename.endswith(".json"):
228
+ output_filename += ".json"
229
+
230
+ output_path = os.path.join(RESULTS_DIR, output_filename)
231
+
232
+ # If the filename already exists, append a timestamp to avoid silent overwrite.
233
+ if os.path.exists(output_path):
234
+ name, ext = os.path.splitext(output_filename)
235
+ output_filename = f"{name}_{timestamp}{ext}"
236
+ output_path = os.path.join(RESULTS_DIR, output_filename)
237
+
238
+ print("Saving results to", output_path)
239
+ with open(output_path, "w", encoding="utf-8") as f:
240
+ json.dump(results, f, ensure_ascii=False, indent=2)
241
+
242
+ print("Done.")
243
+
244
+
245
+ if __name__ == "__main__":
246
+ main()
247
+
code/fine_tune_sft_dpo/results/bn/gpt5_inference_all_20260314_110445.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72cc6e0f72af25f7b56ac4c623d23be4d57d35e09473e4a012eff27680942e85
3
+ size 16659336
code/fine_tune_sft_dpo/results/bn/gpt5_inference_all_wo_gs_20260314_173736.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:195bb9580d98d1a43f642f06f2c982397336c42cb0cf68bc641138997b699f9e
3
+ size 17385349
code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-mini_20260314_110445.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20bf9c12039c4bc33421452806e48ddc1ef7dca13403de06b95642cdf27e4334
3
+ size 5984175
code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-mini_20260314_173736.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49eef179000c958ac8986ee7a3fe0a79d44714b35e9a2a09ca68402eda9cd37f
3
+ size 6234645
code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-nano_20260314_110445.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a2a2acef3f83a916faee240a310e9595f6b6cb043c7e1081f014fce80c5d836
3
+ size 5237182
code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5-nano_20260314_173736.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5439fb64537f7f159f61b97f46f4f400238fee1c65fd08f7b514ffc46f9160a
3
+ size 5355465
code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5_20260314_110445.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbd39a27ea21e26c4d1ba7d4ef85c41964713b655ff03574c5d260ec7b94a5e4
3
+ size 5437979
code/fine_tune_sft_dpo/results/bn/gpt5_inference_gpt-5_20260314_173736.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f16ff37f4c633d22a245b087dfb14583f14c3e27def8d896811e31a61af45e5
3
+ size 5795239
code/fine_tune_sft_dpo/results/bn/gpt5_inference_summary_20260314_110445.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a0b61606edceb252580fdb18a15ac237d5d69bd896bea1d430cad5a6113a152
3
+ size 1684
code/fine_tune_sft_dpo/results/bn/gpt5_inference_summary_wo_gs_20260314_173736.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb177630eb8e44f478897e4e96af2ead55760791e59605ce091ff49c8bb7f403
3
+ size 1690
code/fine_tune_sft_dpo/results/bn/inference_summary_vllm_20260314_101627.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5d5cf9a06af93f982c38c5901fbd8478e4bce6ed6bdb3b5480ad1ff586187c2
3
+ size 471
code/fine_tune_sft_dpo/results/bn/{inference_summary_vllm_20260311_044629.json → misc/inference_summary_vllm_20260311_044629.json} RENAMED
File without changes