Spaces:
Sleeping
Sleeping
| from datasets import load_dataset | |
| dataset = load_dataset("openai/gsm8k", "main", split = "train") | |
| dataset | |
| dataset[0]["question"] | |
| def extract_hash_answer(text): | |
| if "####" not in text: return None | |
| return text.split("####")[1].strip() | |
| extract_hash_answer(dataset[0]["answer"]) | |
| reasoning_start = "<start_working_out>" | |
| reasoning_end = "<end_working_out>" | |
| solution_start = "<SOLUTION>" | |
| solution_end = "</SOLUTION>" | |
| system_prompt = \ | |
| f"""You are given a problem. | |
| Think about the problem and provide your working out. | |
| Place it between {reasoning_start} and {reasoning_end}. | |
| Then, provide your solution between {solution_start}{solution_end}""" | |
| system_prompt | |
| dataset = dataset.map(lambda x: { | |
| "prompt" : [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": x["question"]}, | |
| ], | |
| "answer": extract_hash_answer(x["answer"]), | |
| }) | |
| dataset[0] | |
| import re | |
| match_format = re.compile( | |
| rf"^[\s]{{0,}}"\ | |
| rf"{reasoning_start}.+?{reasoning_end}.*?"\ | |
| rf"{solution_start}(.+?){solution_end}"\ | |
| rf"[\s]{{0,}}$", | |
| flags = re.MULTILINE | re.DOTALL | |
| ) | |
| match_format.search( | |
| "<start_working_out>Let me think!<end_working_out>"\ | |
| "<SOLUTION>2</SOLUTION>", | |
| ) | |
| def match_format_exactly(completions, **kwargs): | |
| scores = [] | |
| for completion in completions: | |
| score = 0 | |
| response = completion[0]["content"] | |
| # Match if format is seen exactly! | |
| if match_format.search(response) is not None: score += 3.0 | |
| scores.append(score) | |
| return scores | |
| def match_format_approximately(completions, **kwargs): | |
| scores = [] | |
| for completion in completions: | |
| score = 0 | |
| response = completion[0]["content"] | |
| # Count how many keywords are seen - we penalize if too many! | |
| # If we see 1, then plus some points! | |
| score += 0.5 if response.count(reasoning_start) == 1 else -1.0 | |
| score += 0.5 if response.count(reasoning_end) == 1 else -1.0 | |
| score += 0.5 if response.count(solution_start) == 1 else -1.0 | |
| score += 0.5 if response.count(solution_end) == 1 else -1.0 | |
| scores.append(score) | |
| return scores | |
| def check_answer(prompts, completions, answer, **kwargs): | |
| question = prompts[0][-1]["content"] | |
| responses = [completion[0]["content"] for completion in completions] | |
| extracted_responses = [ | |
| guess.group(1) | |
| if (guess := match_format.search(r)) is not None else None \ | |
| for r in responses | |
| ] | |
| scores = [] | |
| for guess, true_answer in zip(extracted_responses, answer): | |
| score = 0 | |
| if guess is None: | |
| scores.append(0) | |
| continue | |
| # Correct answer gets 3 points! | |
| if guess == true_answer: | |
| score += 3.0 | |
| # Match if spaces are seen, but less reward | |
| elif guess.strip() == true_answer.strip(): | |
| score += 1.5 | |
| else: | |
| # We also reward it if the answer is close via ratios! | |
| # Ie if the answer is within some range, reward it! | |
| try: | |
| ratio = float(guess) / float(true_answer) | |
| if ratio >= 0.9 and ratio <= 1.1: score += 1.0 | |
| elif ratio >= 0.8 and ratio <= 1.2: score += 0.5 | |
| else: score -= 1.5 # Penalize wrong answers | |
| except: | |
| score -= 1.5 # Penalize | |
| scores.append(score) | |
| return scores | |
| match_numbers = re.compile( | |
| solution_start + r".*?([\d\.\,]{1,})", | |
| flags = re.MULTILINE | re.DOTALL | |
| ) | |
| print(match_numbers.findall("<SOLUTION> 0.34 </SOLUTION>")) | |
| print(match_numbers.findall("<SOLUTION> 123,456 </SOLUTION>")) | |
| global PRINTED_TIMES | |
| PRINTED_TIMES = 0 | |
| global PRINT_EVERY_STEPS | |
| PRINT_EVERY_STEPS = 5 | |
| def check_numbers(prompts, completions, answer, **kwargs): | |
| question = prompts[0][-1]["content"] | |
| responses = [completion[0]["content"] for completion in completions] | |
| extracted_responses = [ | |
| guess.group(1) | |
| if (guess := match_numbers.search(r)) is not None else None \ | |
| for r in responses | |
| ] | |
| scores = [] | |
| # Print only every few steps | |
| global PRINTED_TIMES | |
| global PRINT_EVERY_STEPS | |
| if PRINTED_TIMES % PRINT_EVERY_STEPS == 0: | |
| print('*'*20, f"Question:\n{question}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}") | |
| PRINTED_TIMES += 1 | |
| for guess, true_answer in zip(extracted_responses, answer): | |
| if guess is None: | |
| scores.append(0) | |
| continue | |
| # Convert to numbers | |
| try: | |
| true_answer = float(true_answer.strip()) | |
| # Remove commas like in 123,456 | |
| guess = float(guess.strip().replace(",", "")) | |
| scores.append(1.5 if guess == true_answer else -0.5) | |
| except: | |
| scores.append(0) | |
| continue | |
| return scores | |
| max(dataset.map( | |
| lambda x: {"tokens" : tokenizer.apply_chat_template(x["prompt"], add_generation_prompt = True, tokenize = True)}, | |
| batched = True, | |
| ).map(lambda x: {"length" : len(x["tokens"])})["length"]) | |
| max_prompt_length = 287 + 1 # + 1 just in case! | |
| from trl import GRPOConfig, GRPOTrainer | |
| training_args = GRPOConfig( | |
| learning_rate = 5e-6, | |
| weight_decay = 0.1, | |
| warmup_ratio = 0.1, | |
| lr_scheduler_type = "cosine", | |
| optim = "adamw_8bit", | |
| logging_steps = 1, | |
| per_device_train_batch_size = 1, | |
| gradient_accumulation_steps = 4, # Increase to 4 for smoother training | |
| num_generations = 4, # Decrease if out of memory | |
| max_prompt_length = max_prompt_length, | |
| max_completion_length = max_seq_length - max_prompt_length, | |
| # num_train_epochs = 1, # Set to 1 for a full training run | |
| max_steps = 500, | |
| save_steps = 250, | |
| max_grad_norm = 1.0, | |
| report_to = "none", # Can use Weights & Biases | |
| output_dir = "outputs", | |
| ) | |
| trainer = GRPOTrainer( | |
| model = model, | |
| processing_class = tokenizer, | |
| reward_funcs = [ | |
| match_format_exactly, | |
| match_format_approximately, | |
| check_answer, | |
| check_numbers, | |
| ], | |
| args = training_args, | |
| train_dataset = dataset, | |
| ) | |
| trainer.train() | |
| from datasets import load_dataset | |
| dataset = load_dataset("openai/gsm8k", "main", split = "train") | |
| dataset | |
| dataset[0]["question"] | |
| def extract_hash_answer(text): | |
| if "####" not in text: return None | |
| return text.split("####")[1].strip() | |
| extract_hash_answer(dataset[0]["answer"]) | |
| reasoning_start = "<start_working_out>" | |
| reasoning_end = "<end_working_out>" | |
| solution_start = "<SOLUTION>" | |
| solution_end = "</SOLUTION>" | |
| system_prompt = \ | |
| f"""You are given a problem. | |
| Think about the problem and provide your working out. | |
| Place it between {reasoning_start} and {reasoning_end}. | |
| Then, provide your solution between {solution_start}{solution_end}""" | |
| system_prompt | |
| dataset = dataset.map(lambda x: { | |
| "prompt" : [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": x["question"]}, | |
| ], | |
| "answer": extract_hash_answer(x["answer"]), | |
| }) | |
| dataset[0] | |
| import re | |
| match_format = re.compile( | |
| rf"^[\s]{{0,}}"\ | |
| rf"{reasoning_start}.+?{reasoning_end}.*?"\ | |
| rf"{solution_start}(.+?){solution_end}"\ | |
| rf"[\s]{{0,}}$", | |
| flags = re.MULTILINE | re.DOTALL | |
| ) | |
| match_format.search( | |
| "<start_working_out>Let me think!<end_working_out>"\ | |
| "<SOLUTION>2</SOLUTION>", | |
| ) | |
| def match_format_exactly(completions, **kwargs): | |
| scores = [] | |
| for completion in completions: | |
| score = 0 | |
| response = completion[0]["content"] | |
| # Match if format is seen exactly! | |
| if match_format.search(response) is not None: score += 3.0 | |
| scores.append(score) | |
| return scores | |
| def match_format_approximately(completions, **kwargs): | |
| scores = [] | |
| for completion in completions: | |
| score = 0 | |
| response = completion[0]["content"] | |
| # Count how many keywords are seen - we penalize if too many! | |
| # If we see 1, then plus some points! | |
| score += 0.5 if response.count(reasoning_start) == 1 else -0.5 | |
| score += 0.5 if response.count(reasoning_end) == 1 else -0.5 | |
| score += 0.5 if response.count(solution_start) == 1 else -0.5 | |
| score += 0.5 if response.count(solution_end) == 1 else -0.5 | |
| scores.append(score) | |
| return scores | |
| def check_answer(prompts, completions, answer, **kwargs): | |
| question = prompts[0][-1]["content"] | |
| responses = [completion[0]["content"] for completion in completions] | |
| extracted_responses = [ | |
| guess.group(1) | |
| if (guess := match_format.search(r)) is not None else None \ | |
| for r in responses | |
| ] | |
| scores = [] | |
| for guess, true_answer in zip(extracted_responses, answer): | |
| score = 0 | |
| if guess is None: | |
| scores.append(0) | |
| continue | |
| # Correct answer gets 3 points! | |
| if guess == true_answer: | |
| score += 3.0 | |
| # Match if spaces are seen | |
| elif guess.strip() == true_answer.strip(): | |
| score += 1.5 | |
| else: | |
| # We also reward it if the answer is close via ratios! | |
| # Ie if the answer is within some range, reward it! | |
| try: | |
| ratio = float(guess) / float(true_answer) | |
| if ratio >= 0.9 and ratio <= 1.1: score += 0.5 | |
| elif ratio >= 0.8 and ratio <= 1.2: score += 0.25 | |
| else: score -= 1.0 # Penalize wrong answers | |
| except: | |
| score -= 0.5 # Penalize | |
| scores.append(score) | |
| return scores | |
| match_numbers = re.compile( | |
| rf"{solution_start}.*?([\d\.]{{1,}})", | |
| flags = re.MULTILINE | re.DOTALL | |
| ) | |
| match_numbers.findall("<SOLUTION> 0.34 </SOLUTION>") | |
| def check_numbers(prompts, completions, answer, **kwargs): | |
| question = prompts[0][-1]["content"] | |
| responses = [completion[0]["content"] for completion in completions] | |
| extracted_responses = [ | |
| guess.group(1) | |
| if (guess := match_numbers.search(r)) is not None else None \ | |
| for r in responses | |
| ] | |
| scores = [] | |
| print('*'*20, f"Question:\n{question}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}") | |
| for guess, true_answer in zip(extracted_responses, answer): | |
| if guess is None: | |
| scores.append(0) | |
| continue | |
| # Convert to numbers | |
| try: | |
| true_answer = float(true_answer.strip()) | |
| guess = float(guess.strip()) | |
| scores.append(1.5 if guess == true_answer else 0.0) | |
| except: | |
| scores.append(0) | |
| continue | |
| return scores | |
| max_prompt_length = 256 | |
| from trl import GRPOConfig, GRPOTrainer | |
| training_args = GRPOConfig( | |
| learning_rate = 5e-6, | |
| adam_beta1 = 0.9, | |
| adam_beta2 = 0.99, | |
| weight_decay = 0.1, | |
| warmup_ratio = 0.1, | |
| lr_scheduler_type = "cosine", | |
| optim = "adamw_torch_fused", | |
| logging_steps = 1, | |
| per_device_train_batch_size = 1, | |
| gradient_accumulation_steps = 1, # Increase to 4 for smoother training | |
| num_generations = 4, # Decrease if out of memory | |
| max_prompt_length = max_prompt_length, | |
| max_completion_length = max_seq_length - max_prompt_length, | |
| # num_train_epochs = 1, # Set to 1 for a full training run | |
| max_steps = 50, | |
| save_steps = 50, | |
| max_grad_norm = 0.1, | |
| report_to = "none", # Can use Weights & Biases | |
| output_dir = "outputs", | |
| ) | |
| from unsloth import FastLanguageModel | |
| import torch | |
| max_seq_length = 1024 # Can increase for longer reasoning traces | |
| lora_rank = 32 # Larger rank = smarter, but slower | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name = "unsloth/meta-Llama-3.1-8B-Instruct", | |
| max_seq_length = max_seq_length, | |
| load_in_4bit = True, # False for LoRA 16bit | |
| fast_inference = True, # Enable vLLM fast inference | |
| max_lora_rank = lora_rank, | |
| gpu_memory_utilization = 0.9, # Reduce if out of memory | |
| ) | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 | |
| target_modules = [ | |
| "q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj", | |
| ], # Remove QKVO if out of memory | |
| lora_alpha = lora_rank, | |
| use_gradient_checkpointing = "unsloth", # Enable long context finetuning | |
| random_state = 3407, | |
| ) | |
| import re | |
| from datasets import load_dataset, Dataset | |
| # Load and prep dataset | |
| SYSTEM_PROMPT = """ | |
| Respond in the following format: | |
| <reasoning> | |
| ... | |
| </reasoning> | |
| <answer> | |
| ... | |
| </answer> | |
| """ | |
| XML_COT_FORMAT = """\ | |
| <reasoning> | |
| {reasoning} | |
| </reasoning> | |
| <answer> | |
| {answer} | |
| </answer> | |
| """ | |
| def extract_xml_answer(text: str) -> str: | |
| answer = text.split("<answer>")[-1] | |
| answer = answer.split("</answer>")[0] | |
| return answer.strip() | |
| def extract_hash_answer(text: str) -> str | None: | |
| if "####" not in text: | |
| return None | |
| return text.split("####")[1].strip() | |
| # uncomment middle messages for 1-shot prompting | |
| def get_gsm8k_questions(split = "train") -> Dataset: | |
| data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore | |
| data = data.map(lambda x: { # type: ignore | |
| 'prompt': [ | |
| {'role': 'system', 'content': SYSTEM_PROMPT}, | |
| {'role': 'user', 'content': x['question']} | |
| ], | |
| 'answer': extract_hash_answer(x['answer']) | |
| }) # type: ignore | |
| return data # type: ignore | |
| dataset = get_gsm8k_questions() | |
| # Reward functions | |
| def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: | |
| responses = [completion[0]['content'] for completion in completions] | |
| q = prompts[0][-1]['content'] | |
| extracted_responses = [extract_xml_answer(r) for r in responses] | |
| print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}") | |
| return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] | |
| def int_reward_func(completions, **kwargs) -> list[float]: | |
| responses = [completion[0]['content'] for completion in completions] | |
| extracted_responses = [extract_xml_answer(r) for r in responses] | |
| return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] | |
| def strict_format_reward_func(completions, **kwargs) -> list[float]: | |
| """Reward function that checks if the completion has a specific format.""" | |
| pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$" | |
| responses = [completion[0]["content"] for completion in completions] | |
| matches = [re.match(pattern, r) for r in responses] | |
| return [0.5 if match else 0.0 for match in matches] | |
| def soft_format_reward_func(completions, **kwargs) -> list[float]: | |
| """Reward function that checks if the completion has a specific format.""" | |
| pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>" | |
| responses = [completion[0]["content"] for completion in completions] | |
| matches = [re.match(pattern, r) for r in responses] | |
| return [0.5 if match else 0.0 for match in matches] | |
| def count_xml(text) -> float: | |
| count = 0.0 | |
| if text.count("<reasoning>\n") == 1: | |
| count += 0.125 | |
| if text.count("\n</reasoning>\n") == 1: | |
| count += 0.125 | |
| if text.count("\n<answer>\n") == 1: | |
| count += 0.125 | |
| count -= len(text.split("\n</answer>\n")[-1])*0.001 | |
| if text.count("\n</answer>") == 1: | |
| count += 0.125 | |
| count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001 | |
| return count | |
| def xmlcount_reward_func(completions, **kwargs) -> list[float]: | |
| contents = [completion[0]["content"] for completion in completions] | |
| return [count_xml(c) for c in contents] | |
| max_prompt_length = 256 | |
| from trl import GRPOConfig, GRPOTrainer | |
| training_args = GRPOConfig( | |
| learning_rate = 5e-6, | |
| adam_beta1 = 0.9, | |
| adam_beta2 = 0.99, | |
| weight_decay = 0.1, | |
| warmup_ratio = 0.1, | |
| lr_scheduler_type = "cosine", | |
| optim = "paged_adamw_8bit", | |
| logging_steps = 1, | |
| per_device_train_batch_size = 1, | |
| gradient_accumulation_steps = 1, # Increase to 4 for smoother training | |
| num_generations = 6, # Decrease if out of memory | |
| max_prompt_length = max_prompt_length, | |
| max_completion_length = max_seq_length - max_prompt_length, | |
| # num_train_epochs = 1, # Set to 1 for a full training run | |
| max_steps = 250, | |
| save_steps = 250, | |
| max_grad_norm = 0.1, | |
| report_to = "none", # Can use Weights & Biases | |
| output_dir = "outputs", | |
| ) | |
| trainer = GRPOTrainer( | |
| model = model, | |
| processing_class = tokenizer, | |
| reward_funcs = [ | |
| xmlcount_reward_func, | |
| soft_format_reward_func, | |
| strict_format_reward_func, | |
| int_reward_func, | |
| correctness_reward_func, | |
| ], | |
| args = training_args, | |
| train_dataset = dataset, | |
| ) | |
| trainer.train() |