readctrl / code /classifier /classifier.py
shahidul034's picture
Add files using upload-large-folder tool
c7a6fe6 verified
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import torch
from unsloth import FastLanguageModel
import json
import tqdm
import re
# -----------------------------
# MODEL CACHE
# -----------------------------
_model_cache = {"model": None, "tokenizer": None}
def load_finetuned_model(model_path: str):
if _model_cache["model"] is not None:
return _model_cache["model"], _model_cache["tokenizer"]
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_path,
max_seq_length=8192,
load_in_4bit=False, # Set to True if you want 4bit inference for speed/memory
load_in_8bit=False,
full_finetuning=False,
)
# Enable native 2x faster inference
FastLanguageModel.for_inference(model)
_model_cache["model"], _model_cache["tokenizer"] = model, tokenizer
return model, tokenizer
# -----------------------------
# READABILITY CLASSIFICATION PROMPT
# -----------------------------
def classification_prompt(full_text: str, summary: str) -> str:
"""
Constructs the prompt to classify readability of the summary
based on the context of the full text.
"""
prompt = f"""You are a medical readability evaluator.
### Task
Compare the "GENERATED TEXT" against the "FULL TEXT" to determine its readability for a general, non-medical audience.
### Input Data
- **FULL TEXT:** {full_text}
- **GENERATED TEXT (Evaluate this):** {summary}
### Readability Scale
1: Very Easy - Minimal medical language, uses simple terms.
2: Easy - Accessible to most, minor jargon explained.
3: Medium - Some technical terms, moderate complexity.
4: Hard - Clinical tone, assumes some prior knowledge.
5: Very Hard - Extremely technical, requires medical expertise.
### Constraints
- Evaluate ONLY the "GENERATED TEXT".
- Use "FULL TEXT" only for context of the subject matter.
- Do NOT assess factual accuracy.
### Output Format
Return ONLY a valid JSON object:
{{
"readability_score": <integer_1_to_5>
}}"""
return prompt
# -----------------------------
# INFERENCE FUNCTION
# -----------------------------
def infer_readability(full_text: str,
summary: str,
model_path: str) -> dict:
model, tokenizer = load_finetuned_model(model_path)
prompt = classification_prompt(full_text, summary)
messages = [{"role": "user", "content": prompt}]
chat_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(chat_text, return_tensors="pt").to("cuda")
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=50, # Classification only needs a few tokens
temperature=0.1, # Low temperature for classification consistency
do_sample=False,
)
output_text = tokenizer.decode(output_ids[0][len(inputs.input_ids[0]):], skip_special_tokens=True).strip()
# Clean up output (remove thinking or markdown)
if "</think>" in output_text:
output_text = output_text.split("</think>")[-1].strip()
# Simple regex to extract JSON if the model adds conversational filler
try:
match = re.search(r"\{.*\}", output_text, re.DOTALL)
if match:
return json.loads(match.group())
return {"readability_score": "error", "raw": output_text}
except Exception:
return {"readability_score": "error", "raw": output_text}
# -----------------------------
# MAIN EXECUTION
# -----------------------------
if __name__ == "__main__":
# Settings based on your paths
INPUT_FILE = "/home/mshahidul/readctrl/data/processed_raw_data/multiclinsum_test_en.json"
SAVE_FOLDER = "/home/mshahidul/readctrl/data/classified_readability"
# Note: Ensure this path points to your CLASSIFIER model, not the subclaim extractor
MODEL_PATH = "/home/mshahidul/readctrl_model/qwen3-32B_classifier_en"
os.makedirs(SAVE_FOLDER, exist_ok=True)
file_name = os.path.basename(INPUT_FILE).split(".json")[0]
OUTPUT_FILE = os.path.join(SAVE_FOLDER, f"classified_{file_name}.json")
# Load input dataset
with open(INPUT_FILE, "r") as f:
data = json.load(f)
# Resume mode
result = []
if os.path.exists(OUTPUT_FILE):
with open(OUTPUT_FILE, "r") as f:
result = json.load(f)
existing_ids = {item["id"] for item in result}
print(f"Starting classification. Saving to: {OUTPUT_FILE}")
for item in tqdm.tqdm(data):
if item["id"] in existing_ids:
continue
full_text = item.get("fulltext", "")
summary = item.get("summary", "")
classification_res = infer_readability(
full_text=full_text,
summary=summary,
model_path=MODEL_PATH
)
result.append({
"id": item["id"],
"readability_score": classification_res.get("readability_score"),
"fulltext": full_text,
"summary": summary
})
# Checkpoint every 50 items
if len(result) % 50 == 0:
with open(OUTPUT_FILE, "w") as f:
json.dump(result, f, indent=4, ensure_ascii=False)
# Final save
with open(OUTPUT_FILE, "w") as f:
json.dump(result, f, indent=4, ensure_ascii=False)
print(f"Classification completed. {len(result)} items processed.")