| | import os |
| | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| | os.environ["CUDA_VISIBLE_DEVICES"] = "3" |
| | from unsloth import FastLanguageModel |
| | import torch |
| | from health_classifier import classifier |
| | max_seq_length = 8192 |
| |
|
| | model, tokenizer = FastLanguageModel.from_pretrained( |
| | model_name = "/home/mshahidul/readctrl_model/RL_model/readability_sft_lora_model", |
| | max_seq_length = max_seq_length, |
| | load_in_4bit = False, |
| | fast_inference = False, |
| | ) |
| |
|
| | |
| | model = FastLanguageModel.for_training(model) |
| |
|
| | |
| | with open("/home/mshahidul/readctrl/data/extracting_subclaim/extracted_subclaims_multiclinsum_test_en_full.json", "r") as f: |
| | import json |
| | data = json.load(f) |
| | from datasets import Dataset |
| | dataset = Dataset.from_list(data) |
| | with open('/home/mshahidul/readctrl/code/RL_model/prompt', 'r') as f: |
| | prompt_template = f.read() |
| | dataset = dataset.map(lambda x: { |
| | "prompt" : [ |
| | {"role": "system", "content": prompt_template}, |
| | {"role": "user", "content": f''' |
| | - Input Language: English |
| | - Gold Summary (the anchor reference summary): {x['summary']} |
| | - Source Text (detailed content): {x['fulltext']} |
| | '''}, |
| | ], |
| | "answer": { |
| | "fulltext_subclaims": x['fulltext_subclaims'], |
| | "summary_subclaims": x['summary_subclaims'], |
| | }, |
| | }) |
| | import requests |
| | import json |
| | import re |
| |
|
| | from claim_verifier import MedicalClaimVerifier |
| |
|
| | verifier = MedicalClaimVerifier() |
| |
|
| | def claim_reward_func(prompts, completions, answer, **kwargs): |
| | |
| | """ |
| | GRPO reward function. |
| | Expects 'summary_subclaims' and 'fulltext_subclaims' to be in the dataset. |
| | """ |
| | rewards = [] |
| | |
| | for i in range(len(completions)): |
| | reward = verifier.get_reward_score( |
| | completions[i], |
| | answer[i]["summary_subclaims"], |
| | answer[i]["fulltext_subclaims"] |
| | ) |
| | rewards.append(reward) |
| | return rewards |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import json |
| |
|
| | def literacy_classifier_reward_func(completions, **kwargs): |
| | scores = [] |
| | for completion in completions: |
| | try: |
| | |
| | cleaned_content = completion[0]['content'].strip() |
| | if cleaned_content.startswith("```"): |
| | |
| | cleaned_content = cleaned_content.split("```")[1] |
| | if cleaned_content.startswith("json"): |
| | cleaned_content = cleaned_content[4:] |
| | |
| | |
| | data = json.loads(cleaned_content.strip()) |
| | |
| | alignment_score = 0.0 |
| | target_labels = ["low", "intermediate", "proficient"] |
| | |
| | for label in target_labels: |
| | key = f"{label}_health_literacy" |
| | text_to_test = data.get(key, "") |
| | |
| | |
| | if text_to_test: |
| | |
| | result = classifier(summary_text=text_to_test) |
| | predicted = result.label |
| | |
| | |
| | if predicted == key: |
| | alignment_score += 1.0 |
| | else: |
| | |
| | alignment_score -= 0.5 |
| | else: |
| | |
| | alignment_score -= 0.3 |
| | |
| | scores.append(alignment_score) |
| | |
| | except (json.JSONDecodeError, Exception): |
| | |
| | scores.append(-1.0) |
| | |
| | return scores |
| |
|
| |
|
| | from trl import GRPOConfig, GRPOTrainer |
| |
|
| | training_args = GRPOConfig( |
| | learning_rate = 5e-6, |
| | lr_scheduler_type = "cosine", |
| | weight_decay = 0.1, |
| | max_prompt_length = 8192, |
| | max_completion_length = 4096, |
| | |
| | num_generations = 4, |
| | per_device_train_batch_size = 4, |
| | gradient_accumulation_steps = 4, |
| | max_steps = 500, |
| | bf16 = True, |
| | output_dir = "medical_grpo_outputs", |
| | ) |
| |
|
| | trainer = GRPOTrainer( |
| | model = model, |
| | reward_funcs = [ |
| | claim_reward_func, |
| | |
| | literacy_classifier_reward_func |
| | ], |
| | args = training_args, |
| | train_dataset = dataset, |
| | tokenizer = tokenizer, |
| | ) |
| |
|
| | trainer.train() |
| |
|
| | model.save_pretrained("/home/mshahidul/readctrl_model/readability_GRPO_model_v1") |
| | tokenizer.save_pretrained("/home/mshahidul/readctrl_model/readability_GRPO_model_v1") |