| import os |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| os.environ["CUDA_VISIBLE_DEVICES"] = "3" |
| from unsloth import FastLanguageModel |
| import torch |
| from health_classifier import classifier |
| max_seq_length = 8192 |
|
|
| model, tokenizer = FastLanguageModel.from_pretrained( |
| model_name = "/home/mshahidul/readctrl_model/RL_model/readability_sft_lora_model", |
| max_seq_length = max_seq_length, |
| load_in_4bit = False, |
| fast_inference = False, |
| ) |
|
|
| |
| model = FastLanguageModel.for_training(model) |
|
|
| |
| with open("/home/mshahidul/readctrl/data/extracting_subclaim/extracted_subclaims_multiclinsum_test_en_full.json", "r") as f: |
| import json |
| data = json.load(f) |
| from datasets import Dataset |
| dataset = Dataset.from_list(data) |
| with open('/home/mshahidul/readctrl/code/RL_model/prompt', 'r') as f: |
| prompt_template = f.read() |
| dataset = dataset.map(lambda x: { |
| "prompt" : [ |
| {"role": "system", "content": prompt_template}, |
| {"role": "user", "content": f''' |
| - Input Language: English |
| - Gold Summary (the anchor reference summary): {x['summary']} |
| - Source Text (detailed content): {x['fulltext']} |
| '''}, |
| ], |
| "answer": { |
| "fulltext_subclaims": x['fulltext_subclaims'], |
| "summary_subclaims": x['summary_subclaims'], |
| }, |
| }) |
| import requests |
| import json |
| import re |
|
|
| from claim_verifier import MedicalClaimVerifier |
|
|
| verifier = MedicalClaimVerifier() |
|
|
| def claim_reward_func(prompts, completions, answer, **kwargs): |
| |
| """ |
| GRPO reward function. |
| Expects 'summary_subclaims' and 'fulltext_subclaims' to be in the dataset. |
| """ |
| rewards = [] |
| |
| for i in range(len(completions)): |
| reward = verifier.get_reward_score( |
| completions[i], |
| answer[i]["summary_subclaims"], |
| answer[i]["fulltext_subclaims"] |
| ) |
| rewards.append(reward) |
| return rewards |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import json |
|
|
| def literacy_classifier_reward_func(completions, **kwargs): |
| scores = [] |
| for completion in completions: |
| try: |
| |
| cleaned_content = completion[0]['content'].strip() |
| if cleaned_content.startswith("```"): |
| |
| cleaned_content = cleaned_content.split("```")[1] |
| if cleaned_content.startswith("json"): |
| cleaned_content = cleaned_content[4:] |
| |
| |
| data = json.loads(cleaned_content.strip()) |
| |
| alignment_score = 0.0 |
| target_labels = ["low", "intermediate", "proficient"] |
| |
| for label in target_labels: |
| key = f"{label}_health_literacy" |
| text_to_test = data.get(key, "") |
| |
| |
| if text_to_test: |
| |
| result = classifier(summary_text=text_to_test) |
| predicted = result.label |
| |
| |
| if predicted == key: |
| alignment_score += 1.0 |
| else: |
| |
| alignment_score -= 0.5 |
| else: |
| |
| alignment_score -= 0.3 |
| |
| scores.append(alignment_score) |
| |
| except (json.JSONDecodeError, Exception): |
| |
| scores.append(-1.0) |
| |
| return scores |
|
|
|
|
| from trl import GRPOConfig, GRPOTrainer |
|
|
| training_args = GRPOConfig( |
| learning_rate = 5e-6, |
| lr_scheduler_type = "cosine", |
| weight_decay = 0.1, |
| max_prompt_length = 8192, |
| max_completion_length = 4096, |
| |
| num_generations = 4, |
| per_device_train_batch_size = 4, |
| gradient_accumulation_steps = 4, |
| max_steps = 500, |
| bf16 = True, |
| output_dir = "medical_grpo_outputs", |
| ) |
|
|
| trainer = GRPOTrainer( |
| model = model, |
| reward_funcs = [ |
| claim_reward_func, |
| |
| literacy_classifier_reward_func |
| ], |
| args = training_args, |
| train_dataset = dataset, |
| tokenizer = tokenizer, |
| ) |
|
|
| trainer.train() |
|
|
| model.save_pretrained("/home/mshahidul/readctrl_model/readability_GRPO_model_v1") |
| tokenizer.save_pretrained("/home/mshahidul/readctrl_model/readability_GRPO_model_v1") |