File size: 7,163 Bytes
9c6961c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
import torch
from unsloth import FastLanguageModel
import json

# ===========================
# GPU SETTINGS
# ===========================


# ===========================
# MODEL LOADING (CACHED)
# ===========================
_model_cache = {"model": None, "tokenizer": None}

def load_finetuned_model(model_path: str):
    """Load and cache the fine-tuned model + tokenizer."""
    if _model_cache["model"] is not None:
        return _model_cache["model"], _model_cache["tokenizer"]

    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=model_path,
        max_seq_length=4096,
        load_in_4bit=False,
        load_in_8bit=False,
        full_finetuning=False,
    )
    _model_cache["model"], _model_cache["tokenizer"] = model, tokenizer
    return model, tokenizer


# ===========================
# INFERENCE FUNCTION
# ===========================
def infer_reasonableness(
    reference_summary: str,
    generated_summary: str,
    readability_level: str,
    subclaim_text: str,
    result: int,
    model_path: str = "/home/mshahidul/readctrl_model/qwen3-32B_subclaims-completeness_resonability_check_8kCtx_v3",
):
    """
    Given the reference summary, generated summary, readability level, subclaim, and its result (0/1),
    predict reasonableness: reasonable / partially_reasonable / unreasonable, plus justification.
    """
    model, tokenizer = load_finetuned_model(model_path)

    # ---- Build inference prompt (same structure as training) ----
    prompt = f"""
You are an impartial medical summarization evaluator.

Goal:
Decide whether the inclusion or omission of ONE specific subclaim from the reference summary is *reasonable*, given the readability level of the generated summary.

Readability Criteria:
- Easy: for non-medical readers; emphasize main story and outcomes; omit numerical data, anatomy, and test details.
- Intermediate: for general educated readers; keep main findings but simplify phrasing.
- Hard: for clinical or technical readers; maintain diagnostic accuracy and essential quantitative or anatomic content.

Judging rules:
* Base your decision strictly on what appears in the generated summary.
* If result = 0 (subclaim omitted) and the omitted detail is clearly technical or numerical for the given level, choose "reasonable".
* If result = 0 and the subclaim is essential to the main story, choose "unreasonable".
* Stay consistent between `result`, justification, and readability level.

### Inputs
Readability Level: {readability_level}
Reference Summary: {reference_summary}
Generated Summary: {generated_summary}
Subclaim: "{subclaim_text}"
Result: {result}   # 1 = supported (included), 0 = omitted

### Task
Respond **only** with the following JSON object:

{{
  "reasonableness": "<reasonable | partially_reasonable | unreasonable>",
  "justification": "<short clear explanation>"
}}
""".strip()

    messages = [{"role": "user", "content": prompt + "\n"}]

    chat_text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False,  # important for Unsloth chat template
    )

    inputs = tokenizer(chat_text, return_tensors="pt").to("cuda")

    # ---- Generate output ----
    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=150,
            temperature=0.2,
            top_p=0.8,
            top_k=5,
            do_sample=False,
        )

    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
    output_text = output_text.split("</think>")[1].strip().replace("```json", "").replace("```", "")
    # ---- Extract model JSON output ----
    try:
        parsed = json.loads(output_text)
    except Exception:
        # print("Failed to parse JSON from model output. Returning raw text.\n\n")
        parsed = output_text
    return parsed


# ===========================
# EXAMPLE USAGE
# ===========================
if __name__ == "__main__":
    # reference_summary = "Una ni帽a nacida a las 34 semanas de gestaci贸n precis贸 intubaci贸n..."
    # generated_summary = "Esta es la historia de una ni帽a que naci贸 antes de tiempo, a las 34 semanas..."
    # subclaim_text = "La paciente presentaba hiperinsulinismo en el per铆odo neonatal."
    # readability_level = "easy"
    # result = 0  # omitted
    import json
    with open('/home/mshahidul/readctrl/data/testing_data_gs/multiclinsum_gs_train_es.json', 'r') as f:
        multiclinsum_gs_train_es_data = json.load(f)
    ref_summaries={}
    fulltexts={}
    for item in multiclinsum_gs_train_es_data:
        ref_summaries[item['id']]=item['summary']
        fulltexts[item['id']]=item['fulltext']
    
    generated_summaries = {}
    with open('/home/mshahidul/readctrl/data/hand_create_gpt5_other_model/synthetic_data_es_raw_592.json', 'r') as f:
        synthetic_data_es_raw_592 = json.load(f)
    for item in synthetic_data_es_raw_592:
        for version in ['easy', 'intermediate', 'hard']:
            generated_summaries[(item['id'], version)] = item['readability_versions'][version]['text']
    # /home/mshahidul/readctrl/results/dataset_quality_check/subclaim_verifier_results_100_qwen3-32B.json
    with open("/home/mshahidul/readctrl/results/dataset_quality_check/subclaim_verifier_results_100_qwen3-32B.json", 'r') as f:
        qwen3_32B_results = json.load(f)
    full_res = []
    save_path = "/home/mshahidul/readctrl/results/dataset_quality_check/completeness_resonability_check_100_qwen3-32B_v4.json"
    import tqdm
    for idx, item in tqdm.tqdm(enumerate(qwen3_32B_results)):
        print(f"Processing item {idx + 1}/{len(qwen3_32B_results)}")
        reference_summary = ref_summaries[item['id']]
        fulltext = fulltexts[item['id']]
        generated_summary = generated_summaries[(item['id'], item['version'])]
        temp_res = []
        for item2 in item['completeness']['results']:
            subclaim_text = item2['subclaim']['subclaim']
            result = item2['result']
            if result =="1":
                continue
            response = infer_reasonableness(
                reference_summary,
                generated_summary,
                item['version'],
                subclaim_text,
                result,
                model_path="/home/mshahidul/readctrl_model/qwen3-32B_subclaims-completeness_resonability_check_8kCtx_v3",
            )
            temp_res.append({
                'id':item2['subclaim']['id'],
                "subclaim": subclaim_text,
                "result": result,
                "reasonableness": response
            })
        full_res.append({
            "id": item['id'],
            "version": item['version'],
            "completeness": {
                "results": temp_res
            }
        })
        if len(full_res)%10==0:
            with open(save_path, 'w') as f:
                json.dump(full_res, f, indent=2, ensure_ascii=False)

with open(save_path, 'w') as f:
   json.dump(full_res, f, indent=2, ensure_ascii=False)