from datasets import load_dataset dataset = load_dataset("openai/gsm8k", "main", split = "train") dataset dataset[0]["question"] def extract_hash_answer(text): if "####" not in text: return None return text.split("####")[1].strip() extract_hash_answer(dataset[0]["answer"]) reasoning_start = "" reasoning_end = "" solution_start = "" solution_end = "" system_prompt = \ f"""You are given a problem. Think about the problem and provide your working out. Place it between {reasoning_start} and {reasoning_end}. Then, provide your solution between {solution_start}{solution_end}""" system_prompt dataset = dataset.map(lambda x: { "prompt" : [ {"role": "system", "content": system_prompt}, {"role": "user", "content": x["question"]}, ], "answer": extract_hash_answer(x["answer"]), }) dataset[0] import re match_format = re.compile( rf"^[\s]{{0,}}"\ rf"{reasoning_start}.+?{reasoning_end}.*?"\ rf"{solution_start}(.+?){solution_end}"\ rf"[\s]{{0,}}$", flags = re.MULTILINE | re.DOTALL ) match_format.search( "Let me think!"\ "2", ) def match_format_exactly(completions, **kwargs): scores = [] for completion in completions: score = 0 response = completion[0]["content"] # Match if format is seen exactly! if match_format.search(response) is not None: score += 3.0 scores.append(score) return scores def match_format_approximately(completions, **kwargs): scores = [] for completion in completions: score = 0 response = completion[0]["content"] # Count how many keywords are seen - we penalize if too many! # If we see 1, then plus some points! score += 0.5 if response.count(reasoning_start) == 1 else -1.0 score += 0.5 if response.count(reasoning_end) == 1 else -1.0 score += 0.5 if response.count(solution_start) == 1 else -1.0 score += 0.5 if response.count(solution_end) == 1 else -1.0 scores.append(score) return scores def check_answer(prompts, completions, answer, **kwargs): question = prompts[0][-1]["content"] responses = [completion[0]["content"] for completion in completions] extracted_responses = [ guess.group(1) if (guess := match_format.search(r)) is not None else None \ for r in responses ] scores = [] for guess, true_answer in zip(extracted_responses, answer): score = 0 if guess is None: scores.append(0) continue # Correct answer gets 3 points! if guess == true_answer: score += 3.0 # Match if spaces are seen, but less reward elif guess.strip() == true_answer.strip(): score += 1.5 else: # We also reward it if the answer is close via ratios! # Ie if the answer is within some range, reward it! try: ratio = float(guess) / float(true_answer) if ratio >= 0.9 and ratio <= 1.1: score += 1.0 elif ratio >= 0.8 and ratio <= 1.2: score += 0.5 else: score -= 1.5 # Penalize wrong answers except: score -= 1.5 # Penalize scores.append(score) return scores match_numbers = re.compile( solution_start + r".*?([\d\.\,]{1,})", flags = re.MULTILINE | re.DOTALL ) print(match_numbers.findall(" 0.34 ")) print(match_numbers.findall(" 123,456 ")) global PRINTED_TIMES PRINTED_TIMES = 0 global PRINT_EVERY_STEPS PRINT_EVERY_STEPS = 5 def check_numbers(prompts, completions, answer, **kwargs): question = prompts[0][-1]["content"] responses = [completion[0]["content"] for completion in completions] extracted_responses = [ guess.group(1) if (guess := match_numbers.search(r)) is not None else None \ for r in responses ] scores = [] # Print only every few steps global PRINTED_TIMES global PRINT_EVERY_STEPS if PRINTED_TIMES % PRINT_EVERY_STEPS == 0: print('*'*20, f"Question:\n{question}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}") PRINTED_TIMES += 1 for guess, true_answer in zip(extracted_responses, answer): if guess is None: scores.append(0) continue # Convert to numbers try: true_answer = float(true_answer.strip()) # Remove commas like in 123,456 guess = float(guess.strip().replace(",", "")) scores.append(1.5 if guess == true_answer else -0.5) except: scores.append(0) continue return scores max(dataset.map( lambda x: {"tokens" : tokenizer.apply_chat_template(x["prompt"], add_generation_prompt = True, tokenize = True)}, batched = True, ).map(lambda x: {"length" : len(x["tokens"])})["length"]) max_prompt_length = 287 + 1 # + 1 just in case! from trl import GRPOConfig, GRPOTrainer training_args = GRPOConfig( learning_rate = 5e-6, weight_decay = 0.1, warmup_ratio = 0.1, lr_scheduler_type = "cosine", optim = "adamw_8bit", logging_steps = 1, per_device_train_batch_size = 1, gradient_accumulation_steps = 4, # Increase to 4 for smoother training num_generations = 4, # Decrease if out of memory max_prompt_length = max_prompt_length, max_completion_length = max_seq_length - max_prompt_length, # num_train_epochs = 1, # Set to 1 for a full training run max_steps = 500, save_steps = 250, max_grad_norm = 1.0, report_to = "none", # Can use Weights & Biases output_dir = "outputs", ) trainer = GRPOTrainer( model = model, processing_class = tokenizer, reward_funcs = [ match_format_exactly, match_format_approximately, check_answer, check_numbers, ], args = training_args, train_dataset = dataset, ) trainer.train() from datasets import load_dataset dataset = load_dataset("openai/gsm8k", "main", split = "train") dataset dataset[0]["question"] def extract_hash_answer(text): if "####" not in text: return None return text.split("####")[1].strip() extract_hash_answer(dataset[0]["answer"]) reasoning_start = "" reasoning_end = "" solution_start = "" solution_end = "" system_prompt = \ f"""You are given a problem. Think about the problem and provide your working out. Place it between {reasoning_start} and {reasoning_end}. Then, provide your solution between {solution_start}{solution_end}""" system_prompt dataset = dataset.map(lambda x: { "prompt" : [ {"role": "system", "content": system_prompt}, {"role": "user", "content": x["question"]}, ], "answer": extract_hash_answer(x["answer"]), }) dataset[0] import re match_format = re.compile( rf"^[\s]{{0,}}"\ rf"{reasoning_start}.+?{reasoning_end}.*?"\ rf"{solution_start}(.+?){solution_end}"\ rf"[\s]{{0,}}$", flags = re.MULTILINE | re.DOTALL ) match_format.search( "Let me think!"\ "2", ) def match_format_exactly(completions, **kwargs): scores = [] for completion in completions: score = 0 response = completion[0]["content"] # Match if format is seen exactly! if match_format.search(response) is not None: score += 3.0 scores.append(score) return scores def match_format_approximately(completions, **kwargs): scores = [] for completion in completions: score = 0 response = completion[0]["content"] # Count how many keywords are seen - we penalize if too many! # If we see 1, then plus some points! score += 0.5 if response.count(reasoning_start) == 1 else -0.5 score += 0.5 if response.count(reasoning_end) == 1 else -0.5 score += 0.5 if response.count(solution_start) == 1 else -0.5 score += 0.5 if response.count(solution_end) == 1 else -0.5 scores.append(score) return scores def check_answer(prompts, completions, answer, **kwargs): question = prompts[0][-1]["content"] responses = [completion[0]["content"] for completion in completions] extracted_responses = [ guess.group(1) if (guess := match_format.search(r)) is not None else None \ for r in responses ] scores = [] for guess, true_answer in zip(extracted_responses, answer): score = 0 if guess is None: scores.append(0) continue # Correct answer gets 3 points! if guess == true_answer: score += 3.0 # Match if spaces are seen elif guess.strip() == true_answer.strip(): score += 1.5 else: # We also reward it if the answer is close via ratios! # Ie if the answer is within some range, reward it! try: ratio = float(guess) / float(true_answer) if ratio >= 0.9 and ratio <= 1.1: score += 0.5 elif ratio >= 0.8 and ratio <= 1.2: score += 0.25 else: score -= 1.0 # Penalize wrong answers except: score -= 0.5 # Penalize scores.append(score) return scores match_numbers = re.compile( rf"{solution_start}.*?([\d\.]{{1,}})", flags = re.MULTILINE | re.DOTALL ) match_numbers.findall(" 0.34 ") def check_numbers(prompts, completions, answer, **kwargs): question = prompts[0][-1]["content"] responses = [completion[0]["content"] for completion in completions] extracted_responses = [ guess.group(1) if (guess := match_numbers.search(r)) is not None else None \ for r in responses ] scores = [] print('*'*20, f"Question:\n{question}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}") for guess, true_answer in zip(extracted_responses, answer): if guess is None: scores.append(0) continue # Convert to numbers try: true_answer = float(true_answer.strip()) guess = float(guess.strip()) scores.append(1.5 if guess == true_answer else 0.0) except: scores.append(0) continue return scores max_prompt_length = 256 from trl import GRPOConfig, GRPOTrainer training_args = GRPOConfig( learning_rate = 5e-6, adam_beta1 = 0.9, adam_beta2 = 0.99, weight_decay = 0.1, warmup_ratio = 0.1, lr_scheduler_type = "cosine", optim = "adamw_torch_fused", logging_steps = 1, per_device_train_batch_size = 1, gradient_accumulation_steps = 1, # Increase to 4 for smoother training num_generations = 4, # Decrease if out of memory max_prompt_length = max_prompt_length, max_completion_length = max_seq_length - max_prompt_length, # num_train_epochs = 1, # Set to 1 for a full training run max_steps = 50, save_steps = 50, max_grad_norm = 0.1, report_to = "none", # Can use Weights & Biases output_dir = "outputs", ) from unsloth import FastLanguageModel import torch max_seq_length = 1024 # Can increase for longer reasoning traces lora_rank = 32 # Larger rank = smarter, but slower model, tokenizer = FastLanguageModel.from_pretrained( model_name = "unsloth/meta-Llama-3.1-8B-Instruct", max_seq_length = max_seq_length, load_in_4bit = True, # False for LoRA 16bit fast_inference = True, # Enable vLLM fast inference max_lora_rank = lora_rank, gpu_memory_utilization = 0.9, # Reduce if out of memory ) model = FastLanguageModel.get_peft_model( model, r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 target_modules = [ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], # Remove QKVO if out of memory lora_alpha = lora_rank, use_gradient_checkpointing = "unsloth", # Enable long context finetuning random_state = 3407, ) import re from datasets import load_dataset, Dataset # Load and prep dataset SYSTEM_PROMPT = """ Respond in the following format: ... ... """ XML_COT_FORMAT = """\ {reasoning} {answer} """ def extract_xml_answer(text: str) -> str: answer = text.split("")[-1] answer = answer.split("")[0] return answer.strip() def extract_hash_answer(text: str) -> str | None: if "####" not in text: return None return text.split("####")[1].strip() # uncomment middle messages for 1-shot prompting def get_gsm8k_questions(split = "train") -> Dataset: data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore data = data.map(lambda x: { # type: ignore 'prompt': [ {'role': 'system', 'content': SYSTEM_PROMPT}, {'role': 'user', 'content': x['question']} ], 'answer': extract_hash_answer(x['answer']) }) # type: ignore return data # type: ignore dataset = get_gsm8k_questions() # Reward functions def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: responses = [completion[0]['content'] for completion in completions] q = prompts[0][-1]['content'] extracted_responses = [extract_xml_answer(r) for r in responses] print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}") return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] def int_reward_func(completions, **kwargs) -> list[float]: responses = [completion[0]['content'] for completion in completions] extracted_responses = [extract_xml_answer(r) for r in responses] return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] def strict_format_reward_func(completions, **kwargs) -> list[float]: """Reward function that checks if the completion has a specific format.""" pattern = r"^\n.*?\n\n\n.*?\n\n$" responses = [completion[0]["content"] for completion in completions] matches = [re.match(pattern, r) for r in responses] return [0.5 if match else 0.0 for match in matches] def soft_format_reward_func(completions, **kwargs) -> list[float]: """Reward function that checks if the completion has a specific format.""" pattern = r".*?\s*.*?" responses = [completion[0]["content"] for completion in completions] matches = [re.match(pattern, r) for r in responses] return [0.5 if match else 0.0 for match in matches] def count_xml(text) -> float: count = 0.0 if text.count("\n") == 1: count += 0.125 if text.count("\n\n") == 1: count += 0.125 if text.count("\n\n") == 1: count += 0.125 count -= len(text.split("\n\n")[-1])*0.001 if text.count("\n") == 1: count += 0.125 count -= (len(text.split("\n")[-1]) - 1)*0.001 return count def xmlcount_reward_func(completions, **kwargs) -> list[float]: contents = [completion[0]["content"] for completion in completions] return [count_xml(c) for c in contents] max_prompt_length = 256 from trl import GRPOConfig, GRPOTrainer training_args = GRPOConfig( learning_rate = 5e-6, adam_beta1 = 0.9, adam_beta2 = 0.99, weight_decay = 0.1, warmup_ratio = 0.1, lr_scheduler_type = "cosine", optim = "paged_adamw_8bit", logging_steps = 1, per_device_train_batch_size = 1, gradient_accumulation_steps = 1, # Increase to 4 for smoother training num_generations = 6, # Decrease if out of memory max_prompt_length = max_prompt_length, max_completion_length = max_seq_length - max_prompt_length, # num_train_epochs = 1, # Set to 1 for a full training run max_steps = 250, save_steps = 250, max_grad_norm = 0.1, report_to = "none", # Can use Weights & Biases output_dir = "outputs", ) trainer = GRPOTrainer( model = model, processing_class = tokenizer, reward_funcs = [ xmlcount_reward_func, soft_format_reward_func, strict_format_reward_func, int_reward_func, correctness_reward_func, ], args = training_args, train_dataset = dataset, ) trainer.train()