import os os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "2" import torch from unsloth import FastLanguageModel import json # Optional: wrap model/tokenizer in a singleton pattern for repeated use _model_cache = {"model": None, "tokenizer": None} def load_finetuned_model(model_path: str): """Load and cache your fine‑tuned model + tokenizer.""" if _model_cache["model"] is not None: return _model_cache["model"], _model_cache["tokenizer"] model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_path, max_seq_length=4092, load_in_4bit=False, load_in_8bit=False, full_finetuning=False, ) _model_cache["model"], _model_cache["tokenizer"] = model, tokenizer return model, tokenizer def infer_subclaim(text: str, subclaim: str, model_path: str = "/home/mshahidul/readctrl_model/qwen3-32B_subclaims-verifier_lora_nonreasoning", cuda_device: str = "0") -> str: """ Given a medical text and a subclaim, returns '1' if the text supports the subclaim, otherwise '0'. """ model, tokenizer = load_finetuned_model(model_path) # Build prompt (the same structure you trained on) prompt = f""" Given the following medical text and subclaim, decide if the text supports the subclaim. Text: {text} Subclaim: {subclaim} Respond only with 1 if the text supports the subclaim, otherwise 0. """.strip() messages = [{"role": "user", "content": prompt + "\n"}] chat_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=False, ) inputs = tokenizer(chat_text, return_tensors="pt").to("cuda") with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=10, temperature=0.1, top_p=0.8, top_k=5, ) output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() return output_text.split("")[1].strip() if __name__ == "__main__": # example_text = ( # "Una niña nacida a las 34 semanas de gestación precisó intubación y ventilación al nacer..." # ) # example_subclaim = "La paciente es una recién nacida prematura." def process_completeness(example,version): example_text = example["readability_versions"][version]['text'] example_subclaims = example['ref_summary']["subclaims"] # print("Input text:", example_text) res=[] total=0 correct=0 for example_subclaim in example_subclaims: result = infer_subclaim(example_text, example_subclaim) if "1" in result: correct+=1 total+=1 elif "0" in result: total+=1 res.append({ "subclaim": example_subclaim, "result": result }) return {"metric": "completeness", "version": version, "input_text": example_text, "results": res, "total": total, "correct": correct, "accuracy": (correct/total)*100 if total>0 else 0} def process_conciseness(example, version): example_text = example["ref_summary"]['text'] example_subclaims = example["readability_versions"][version]["subclaims"] # print("Input text:", example_text) res=[] total=0 correct=0 for example_subclaim in example_subclaims: result = infer_subclaim(example_text, example_subclaim) if "1" in result: correct+=1 total+=1 elif "0" in result: total+=1 res.append({ "subclaim": example_subclaim, "result": result }) return {"metric": "conciseness", "version": version, "input_text": example_text, "results": res, "total": total, "correct": correct, "accuracy": (correct/total)*100 if total>0 else 0} def process_attribution(example, version): example_text = example['full_text'] example_subclaims = example["readability_versions"][version]["subclaims"] # print("Input text:", example_text) res=[] total=0 correct=0 for example_subclaim in example_subclaims: result = infer_subclaim(example_text, example_subclaim) if "1" in result: correct+=1 total+=1 elif "0" in result: total+=1 res.append({ "subclaim": example_subclaim, "result": result }) return {"metric": "attribution", "version": version, "input_text": example_text, "results": res, "total": total, "correct": correct, "accuracy": (correct/total)*100 if total>0 else 0} with open("/home/mshahidul/readctrl/data/training_data_subclaim_verifier/synthetic_data_es_subclaims_100.json", "r", encoding="utf-8") as f: data = json.load(f) import tqdm full_data_results = [] save_path = "/home/mshahidul/readctrl/results/dataset_quality_check/subclaim_verifier_results_100_qwen3-32B.json" for item in tqdm.tqdm(data): print(f"Processing item ID: {item['id']}") for version in ["easy", "intermediate", "hard"]: completeness=process_completeness(item,version) conciseness=process_conciseness(item,version) attribution=process_attribution(item,version) full_data_results.append({ "id": item["id"], "version": version, "completeness": completeness, "conciseness": conciseness, "attribution": attribution }) if len(full_data_results)%5==0: with open(save_path, "w", encoding="utf-8") as f: json.dump(full_data_results, f, indent=4, ensure_ascii=False) with open(save_path, "w", encoding="utf-8") as f: json.dump(full_data_results, f, indent=4, ensure_ascii=False)