import os os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "3" from unsloth import FastLanguageModel import torch from health_classifier import classifier max_seq_length = 8192 model, tokenizer = FastLanguageModel.from_pretrained( model_name = "/home/mshahidul/readctrl_model/RL_model/readability_sft_lora_model", max_seq_length = max_seq_length, load_in_4bit = False, # Set to False if you have enough VRAM fast_inference = False, ) # Simply enable gradient checkpointing and prepare for training model = FastLanguageModel.for_training(model) # /home/mshahidul/readctrl/data/extracting_subclaim/extracted_subclaims_multiclinsum_test_en_full.json with open("/home/mshahidul/readctrl/data/extracting_subclaim/extracted_subclaims_multiclinsum_test_en_full.json", "r") as f: import json data = json.load(f) from datasets import Dataset dataset = Dataset.from_list(data) with open('/home/mshahidul/readctrl/code/RL_model/prompt', 'r') as f: prompt_template = f.read() dataset = dataset.map(lambda x: { "prompt" : [ {"role": "system", "content": prompt_template}, {"role": "user", "content": f''' - Input Language: English - Gold Summary (the anchor reference summary): {x['summary']} - Source Text (detailed content): {x['fulltext']} '''}, ], "answer": { "fulltext_subclaims": x['fulltext_subclaims'], "summary_subclaims": x['summary_subclaims'], }, }) import requests import json import re from claim_verifier import MedicalClaimVerifier verifier = MedicalClaimVerifier() def claim_reward_func(prompts, completions, answer, **kwargs): # import ipdb; ipdb.set_trace() """ GRPO reward function. Expects 'summary_subclaims' and 'fulltext_subclaims' to be in the dataset. """ rewards = [] # We loop through the group of completions for i in range(len(completions)): reward = verifier.get_reward_score( completions[i], answer[i]["summary_subclaims"], answer[i]["fulltext_subclaims"] ) rewards.append(reward) return rewards # def format_reward_func(completions, **kwargs): # required_keys = ["low_health_literacy", "intermediate_health_literacy", "proficient_health_literacy"] # scores = [] # for completion in completions: # try: # match = re.search(r"(.*?)", completion, re.DOTALL) # content = match.group(1) if match else completion # data = json.loads(content) # if all(k in data for k in required_keys): # scores.append(2.0) # else: # scores.append(-1.0) # except: # scores.append(-2.0) # return scores import json def literacy_classifier_reward_func(completions, **kwargs): scores = [] for completion in completions: try: # 1. Clean up potential Markdown formatting cleaned_content = completion[0]['content'].strip() if cleaned_content.startswith("```"): # Removes leading ```json or ``` and trailing ``` cleaned_content = cleaned_content.split("```")[1] if cleaned_content.startswith("json"): cleaned_content = cleaned_content[4:] # 2. Parse the JSON data = json.loads(cleaned_content.strip()) alignment_score = 0.0 target_labels = ["low", "intermediate", "proficient"] for label in target_labels: key = f"{label}_health_literacy" text_to_test = data.get(key, "") if text_to_test: # Run the DSPy classifier result = classifier(summary_text=text_to_test) predicted = result.label # Expected format: "low_health_literacy" # import ipdb; ipdb.set_trace() if predicted == key: alignment_score += 1.0 else: # Soft penalty for misclassification alignment_score -= 0.5 else: # Penalty if a specific literacy level is missing from the JSON alignment_score -= 0.3 scores.append(alignment_score) except (json.JSONDecodeError, Exception): # Significant penalty for malformed JSON or failed processing scores.append(-1.0) return scores from trl import GRPOConfig, GRPOTrainer training_args = GRPOConfig( learning_rate = 5e-6, lr_scheduler_type = "cosine", weight_decay = 0.1, max_prompt_length = 8192, max_completion_length = 4096, # num_of_epochs = 10, num_generations = 4, # GRPO group size per_device_train_batch_size = 4, gradient_accumulation_steps = 4, max_steps = 500, bf16 = True, output_dir = "medical_grpo_outputs", ) trainer = GRPOTrainer( model = model, reward_funcs = [ claim_reward_func, # format_reward_func, literacy_classifier_reward_func ], args = training_args, train_dataset = dataset, # Use the same dataset from your SFT prep tokenizer = tokenizer, ) trainer.train() model.save_pretrained("/home/mshahidul/readctrl_model/readability_GRPO_model_v1") tokenizer.save_pretrained("/home/mshahidul/readctrl_model/readability_GRPO_model_v1")