File size: 5,491 Bytes
c7a6fe6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import torch
from unsloth import FastLanguageModel
import json
import tqdm
import re
# -----------------------------
# MODEL CACHE
# -----------------------------
_model_cache = {"model": None, "tokenizer": None}
def load_finetuned_model(model_path: str):
if _model_cache["model"] is not None:
return _model_cache["model"], _model_cache["tokenizer"]
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_path,
max_seq_length=8192,
load_in_4bit=False, # Set to True if you want 4bit inference for speed/memory
load_in_8bit=False,
full_finetuning=False,
)
# Enable native 2x faster inference
FastLanguageModel.for_inference(model)
_model_cache["model"], _model_cache["tokenizer"] = model, tokenizer
return model, tokenizer
# -----------------------------
# READABILITY CLASSIFICATION PROMPT
# -----------------------------
def classification_prompt(full_text: str, summary: str) -> str:
"""
Constructs the prompt to classify readability of the summary
based on the context of the full text.
"""
prompt = f"""You are a medical readability evaluator.
### Task
Compare the "GENERATED TEXT" against the "FULL TEXT" to determine its readability for a general, non-medical audience.
### Input Data
- **FULL TEXT:** {full_text}
- **GENERATED TEXT (Evaluate this):** {summary}
### Readability Scale
1: Very Easy - Minimal medical language, uses simple terms.
2: Easy - Accessible to most, minor jargon explained.
3: Medium - Some technical terms, moderate complexity.
4: Hard - Clinical tone, assumes some prior knowledge.
5: Very Hard - Extremely technical, requires medical expertise.
### Constraints
- Evaluate ONLY the "GENERATED TEXT".
- Use "FULL TEXT" only for context of the subject matter.
- Do NOT assess factual accuracy.
### Output Format
Return ONLY a valid JSON object:
{{
"readability_score": <integer_1_to_5>
}}"""
return prompt
# -----------------------------
# INFERENCE FUNCTION
# -----------------------------
def infer_readability(full_text: str,
summary: str,
model_path: str) -> dict:
model, tokenizer = load_finetuned_model(model_path)
prompt = classification_prompt(full_text, summary)
messages = [{"role": "user", "content": prompt}]
chat_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(chat_text, return_tensors="pt").to("cuda")
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=50, # Classification only needs a few tokens
temperature=0.1, # Low temperature for classification consistency
do_sample=False,
)
output_text = tokenizer.decode(output_ids[0][len(inputs.input_ids[0]):], skip_special_tokens=True).strip()
# Clean up output (remove thinking or markdown)
if "</think>" in output_text:
output_text = output_text.split("</think>")[-1].strip()
# Simple regex to extract JSON if the model adds conversational filler
try:
match = re.search(r"\{.*\}", output_text, re.DOTALL)
if match:
return json.loads(match.group())
return {"readability_score": "error", "raw": output_text}
except Exception:
return {"readability_score": "error", "raw": output_text}
# -----------------------------
# MAIN EXECUTION
# -----------------------------
if __name__ == "__main__":
# Settings based on your paths
INPUT_FILE = "/home/mshahidul/readctrl/data/processed_raw_data/multiclinsum_test_en.json"
SAVE_FOLDER = "/home/mshahidul/readctrl/data/classified_readability"
# Note: Ensure this path points to your CLASSIFIER model, not the subclaim extractor
MODEL_PATH = "/home/mshahidul/readctrl_model/qwen3-32B_classifier_en"
os.makedirs(SAVE_FOLDER, exist_ok=True)
file_name = os.path.basename(INPUT_FILE).split(".json")[0]
OUTPUT_FILE = os.path.join(SAVE_FOLDER, f"classified_{file_name}.json")
# Load input dataset
with open(INPUT_FILE, "r") as f:
data = json.load(f)
# Resume mode
result = []
if os.path.exists(OUTPUT_FILE):
with open(OUTPUT_FILE, "r") as f:
result = json.load(f)
existing_ids = {item["id"] for item in result}
print(f"Starting classification. Saving to: {OUTPUT_FILE}")
for item in tqdm.tqdm(data):
if item["id"] in existing_ids:
continue
full_text = item.get("fulltext", "")
summary = item.get("summary", "")
classification_res = infer_readability(
full_text=full_text,
summary=summary,
model_path=MODEL_PATH
)
result.append({
"id": item["id"],
"readability_score": classification_res.get("readability_score"),
"fulltext": full_text,
"summary": summary
})
# Checkpoint every 50 items
if len(result) % 50 == 0:
with open(OUTPUT_FILE, "w") as f:
json.dump(result, f, indent=4, ensure_ascii=False)
# Final save
with open(OUTPUT_FILE, "w") as f:
json.dump(result, f, indent=4, ensure_ascii=False)
print(f"Classification completed. {len(result)} items processed.") |