File size: 5,491 Bytes
c7a6fe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import torch
from unsloth import FastLanguageModel
import json
import tqdm
import re

# -----------------------------
#  MODEL CACHE
# -----------------------------
_model_cache = {"model": None, "tokenizer": None}

def load_finetuned_model(model_path: str):
    if _model_cache["model"] is not None:
        return _model_cache["model"], _model_cache["tokenizer"]

    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=model_path,
        max_seq_length=8192,
        load_in_4bit=False, # Set to True if you want 4bit inference for speed/memory
        load_in_8bit=False,
        full_finetuning=False,
    )
    # Enable native 2x faster inference
    FastLanguageModel.for_inference(model)
    
    _model_cache["model"], _model_cache["tokenizer"] = model, tokenizer
    return model, tokenizer

# -----------------------------
#  READABILITY CLASSIFICATION PROMPT
# -----------------------------
def classification_prompt(full_text: str, summary: str) -> str:
    """
    Constructs the prompt to classify readability of the summary
    based on the context of the full text.
    """
    prompt = f"""You are a medical readability evaluator.

### Task
Compare the "GENERATED TEXT" against the "FULL TEXT" to determine its readability for a general, non-medical audience.

### Input Data
- **FULL TEXT:** {full_text}
- **GENERATED TEXT (Evaluate this):** {summary}

### Readability Scale
1: Very Easy - Minimal medical language, uses simple terms.
2: Easy - Accessible to most, minor jargon explained.
3: Medium - Some technical terms, moderate complexity.
4: Hard - Clinical tone, assumes some prior knowledge.
5: Very Hard - Extremely technical, requires medical expertise.

### Constraints
- Evaluate ONLY the "GENERATED TEXT".
- Use "FULL TEXT" only for context of the subject matter.
- Do NOT assess factual accuracy.

### Output Format
Return ONLY a valid JSON object:
{{
  "readability_score": <integer_1_to_5>
}}"""
    return prompt

# -----------------------------
#  INFERENCE FUNCTION
# -----------------------------
def infer_readability(full_text: str, 
                      summary: str,
                      model_path: str) -> dict:
    
    model, tokenizer = load_finetuned_model(model_path)
    prompt = classification_prompt(full_text, summary)

    messages = [{"role": "user", "content": prompt}]

    chat_text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )

    inputs = tokenizer(chat_text, return_tensors="pt").to("cuda")

    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=50, # Classification only needs a few tokens
            temperature=0.1,   # Low temperature for classification consistency
            do_sample=False,
        )

    output_text = tokenizer.decode(output_ids[0][len(inputs.input_ids[0]):], skip_special_tokens=True).strip()

    # Clean up output (remove thinking or markdown)
    if "</think>" in output_text:
        output_text = output_text.split("</think>")[-1].strip()
    
    # Simple regex to extract JSON if the model adds conversational filler
    try:
        match = re.search(r"\{.*\}", output_text, re.DOTALL)
        if match:
            return json.loads(match.group())
        return {"readability_score": "error", "raw": output_text}
    except Exception:
        return {"readability_score": "error", "raw": output_text}

# -----------------------------
#  MAIN EXECUTION
# -----------------------------
if __name__ == "__main__":
    # Settings based on your paths
    INPUT_FILE = "/home/mshahidul/readctrl/data/processed_raw_data/multiclinsum_test_en.json"
    SAVE_FOLDER = "/home/mshahidul/readctrl/data/classified_readability"
    # Note: Ensure this path points to your CLASSIFIER model, not the subclaim extractor
    MODEL_PATH = "/home/mshahidul/readctrl_model/qwen3-32B_classifier_en" 

    os.makedirs(SAVE_FOLDER, exist_ok=True)
    file_name = os.path.basename(INPUT_FILE).split(".json")[0]
    OUTPUT_FILE = os.path.join(SAVE_FOLDER, f"classified_{file_name}.json")

    # Load input dataset
    with open(INPUT_FILE, "r") as f:
        data = json.load(f)

    # Resume mode
    result = []
    if os.path.exists(OUTPUT_FILE):
        with open(OUTPUT_FILE, "r") as f:
            result = json.load(f)
    
    existing_ids = {item["id"] for item in result}

    print(f"Starting classification. Saving to: {OUTPUT_FILE}")

    for item in tqdm.tqdm(data):
        if item["id"] in existing_ids:
            continue

        full_text = item.get("fulltext", "")
        summary = item.get("summary", "")

        classification_res = infer_readability(
            full_text=full_text,
            summary=summary,
            model_path=MODEL_PATH
        )

        result.append({
            "id": item["id"],
            "readability_score": classification_res.get("readability_score"),
            "fulltext": full_text,
            "summary": summary
        })

        # Checkpoint every 50 items
        if len(result) % 50 == 0:
            with open(OUTPUT_FILE, "w") as f:
                json.dump(result, f, indent=4, ensure_ascii=False)

    # Final save
    with open(OUTPUT_FILE, "w") as f:
        json.dump(result, f, indent=4, ensure_ascii=False)

    print(f"Classification completed. {len(result)} items processed.")