Commit ·
c228698
1
Parent(s): 3d5e034
finalize
Browse files- README.md +3 -43
- inference_test.ipynb +4 -42
README.md
CHANGED
|
@@ -10,12 +10,8 @@ tags:
|
|
| 10 |
- math-reasoning
|
| 11 |
language:
|
| 12 |
- en
|
| 13 |
-
base_model: google/gemma-3-1b-it
|
| 14 |
-
|
| 15 |
-
- custom-instruction-dataset
|
| 16 |
-
metrics:
|
| 17 |
-
- accuracy
|
| 18 |
-
- f1
|
| 19 |
---
|
| 20 |
|
| 21 |
# Gemma-3 1B IT LoRA Fine-tuned with GRPO
|
|
@@ -62,7 +58,7 @@ from peft import PeftModel
|
|
| 62 |
|
| 63 |
# Model identifiers
|
| 64 |
base_model_name = "google/gemma-3-1b-it"
|
| 65 |
-
adapter_repo_id = "
|
| 66 |
|
| 67 |
# Load base model and tokenizer
|
| 68 |
model = AutoModelForCausalLM.from_pretrained(
|
|
@@ -122,28 +118,6 @@ The model was fine-tuned using the Unsloth framework with the following approach
|
|
| 122 |
- Reward Function: Accuracy-based with format compliance bonus
|
| 123 |
- KL Divergence Penalty: 0.01
|
| 124 |
|
| 125 |
-
### Dataset
|
| 126 |
-
|
| 127 |
-
The model was trained on a curated dataset of mathematical problems and reasoning tasks, including:
|
| 128 |
-
- Arithmetic problems
|
| 129 |
-
- Word problems
|
| 130 |
-
- Algebraic equations
|
| 131 |
-
- Geometric calculations
|
| 132 |
-
|
| 133 |
-
*Note: Replace placeholders with actual training details if available.*
|
| 134 |
-
|
| 135 |
-
## Evaluation
|
| 136 |
-
|
| 137 |
-
The model's performance was evaluated on a held-out test set of mathematical problems. Key metrics include:
|
| 138 |
-
|
| 139 |
-
- **Accuracy**: Percentage of correct final answers
|
| 140 |
-
- **Format Compliance**: Adherence to specified output format
|
| 141 |
-
- **Reasoning Quality**: Coherence and correctness of intermediate steps
|
| 142 |
-
|
| 143 |
-
Example evaluation results:
|
| 144 |
-
- Simple arithmetic: 95% accuracy
|
| 145 |
-
- Complex word problems: 78% accuracy
|
| 146 |
-
- Overall improvement over base model: +15-20% on reasoning tasks
|
| 147 |
|
| 148 |
## Limitations
|
| 149 |
|
|
@@ -154,26 +128,12 @@ Example evaluation results:
|
|
| 154 |
- **Hallucinations**: Like all language models, can generate incorrect information
|
| 155 |
- **Bias**: May reflect biases present in the training data
|
| 156 |
|
| 157 |
-
## Ethical Considerations
|
| 158 |
-
|
| 159 |
-
- Use outputs as a tool, not as definitive answers
|
| 160 |
-
- Verify critical information independently
|
| 161 |
-
- Be aware of potential biases in generated content
|
| 162 |
-
- Consider the environmental impact of large language model usage
|
| 163 |
|
| 164 |
## Citation
|
| 165 |
|
| 166 |
If you use this model in your research or applications, please cite:
|
| 167 |
|
| 168 |
```bibtex
|
| 169 |
-
@misc{gemma3-grpo-lora,
|
| 170 |
-
title={Gemma-3 1B IT LoRA Fine-tuned with GRPO},
|
| 171 |
-
author={Your Name},
|
| 172 |
-
year={2025},
|
| 173 |
-
publisher={Hugging Face},
|
| 174 |
-
url={https://huggingface.co/your-username/gemma-3-GRPO}
|
| 175 |
-
}
|
| 176 |
-
|
| 177 |
@article{shao2024deepseekmath,
|
| 178 |
title={DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open-source Large Language Models},
|
| 179 |
author={Shao, Zhihong and Wang, Peiyi and Zhu, Qihao and Xu, Runxin and Song, Junxiao and Bi, Xiao and Zhang, Haowei and Zhang, Mingchuan and Li, Y. K. and Wu, Y. K. and Guo, Daya},
|
|
|
|
| 10 |
- math-reasoning
|
| 11 |
language:
|
| 12 |
- en
|
| 13 |
+
base_model: google/gemma-3-1b-it
|
| 14 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
---
|
| 16 |
|
| 17 |
# Gemma-3 1B IT LoRA Fine-tuned with GRPO
|
|
|
|
| 58 |
|
| 59 |
# Model identifiers
|
| 60 |
base_model_name = "google/gemma-3-1b-it"
|
| 61 |
+
adapter_repo_id = "Miracle12345/gemma-3-GRPO"
|
| 62 |
|
| 63 |
# Load base model and tokenizer
|
| 64 |
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
| 118 |
- Reward Function: Accuracy-based with format compliance bonus
|
| 119 |
- KL Divergence Penalty: 0.01
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
## Limitations
|
| 123 |
|
|
|
|
| 128 |
- **Hallucinations**: Like all language models, can generate incorrect information
|
| 129 |
- **Bias**: May reflect biases present in the training data
|
| 130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
## Citation
|
| 133 |
|
| 134 |
If you use this model in your research or applications, please cite:
|
| 135 |
|
| 136 |
```bibtex
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
@article{shao2024deepseekmath,
|
| 138 |
title={DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open-source Large Language Models},
|
| 139 |
author={Shao, Zhihong and Wang, Peiyi and Zhu, Qihao and Xu, Runxin and Song, Junxiao and Bi, Xiao and Zhang, Haowei and Zhang, Mingchuan and Li, Y. K. and Wu, Y. K. and Guo, Daya},
|
inference_test.ipynb
CHANGED
|
@@ -55,74 +55,48 @@
|
|
| 55 |
"from transformers import AutoTokenizer\n",
|
| 56 |
"from peft import PeftModel\n",
|
| 57 |
"\n",
|
| 58 |
-
"# ===============================\n",
|
| 59 |
-
"# Tags for GRPO training format\n",
|
| 60 |
-
"# ===============================\n",
|
| 61 |
"reasoning_start = \"<start_working_out>\"\n",
|
| 62 |
"reasoning_end = \"<end_working_out>\"\n",
|
| 63 |
"solution_start = \"<SOLUTION>\"\n",
|
| 64 |
"solution_end = \"</SOLUTION>\"\n",
|
| 65 |
"\n",
|
| 66 |
"def _normalize_numeric(s: str) -> str:\n",
|
| 67 |
-
"
|
| 68 |
-
" - remove commas and stray whitespace\n",
|
| 69 |
-
" - convert '60.0' -> '60'\n",
|
| 70 |
-
" - keep decimals if they are not integer-valued\n",
|
| 71 |
-
" - otherwise return the cleaned string as-is\n",
|
| 72 |
-
" \"\"\"\n",
|
| 73 |
" s = s.strip().replace(\",\", \"\")\n",
|
| 74 |
-
" # remove trailing punctuation commonly found in outputs\n",
|
| 75 |
" s = s.rstrip(\".;)\")\n",
|
| 76 |
-
" # try float conversion\n",
|
| 77 |
" try:\n",
|
| 78 |
" f = float(s)\n",
|
| 79 |
" except Exception:\n",
|
| 80 |
-
" return s
|
| 81 |
-
" # if it's actually an integer value, return integer form\n",
|
| 82 |
" if f.is_integer():\n",
|
| 83 |
" return str(int(f))\n",
|
| 84 |
-
" # else return float without unnecessary trailing zeros\n",
|
| 85 |
" s_float = repr(f)\n",
|
| 86 |
-
" # strip trailing zeros like '2.500000' -> '2.5'\n",
|
| 87 |
" if \".\" in s_float:\n",
|
| 88 |
" s_float = s_float.rstrip(\"0\").rstrip(\".\")\n",
|
| 89 |
" return s_float\n",
|
| 90 |
"\n",
|
| 91 |
"def extract_solution(text: str) -> str | None:\n",
|
| 92 |
-
"
|
| 93 |
-
" Extract the final solution from `text`.\n",
|
| 94 |
-
" 1) Look for <SOLUTION>...</SOLUTION>\n",
|
| 95 |
-
" 2) Otherwise take the last numeric token in the text\n",
|
| 96 |
-
" Returns the cleaned numeric answer as a string, or None.\n",
|
| 97 |
-
" \"\"\"\n",
|
| 98 |
-
" # 1) Strict tag-based extraction (safe escaping)\n",
|
| 99 |
" try:\n",
|
| 100 |
" tag_pattern = re.escape(solution_start) + r\"(.*?)\" + re.escape(solution_end)\n",
|
| 101 |
" m = re.search(tag_pattern, text, flags=re.DOTALL)\n",
|
| 102 |
" except NameError:\n",
|
| 103 |
-
" # If solution_start/solution_end not defined for some reason\n",
|
| 104 |
" m = None\n",
|
| 105 |
"\n",
|
| 106 |
" if m:\n",
|
| 107 |
" ans = m.group(1).strip()\n",
|
| 108 |
" return _normalize_numeric(ans)\n",
|
| 109 |
"\n",
|
| 110 |
-
" # 2) Fallback: find all numeric tokens and return the last one\n",
|
| 111 |
" nums = re.findall(r\"-?\\d+(?:\\.\\d+)?\", text)\n",
|
| 112 |
" if not nums:\n",
|
| 113 |
" return None\n",
|
| 114 |
" return _normalize_numeric(nums[-1])\n",
|
| 115 |
"\n",
|
| 116 |
"\n",
|
| 117 |
-
"\n",
|
| 118 |
-
"# ===============================\n",
|
| 119 |
-
"# Model setup\n",
|
| 120 |
-
"# ===============================\n",
|
| 121 |
"model_name = \"unsloth/gemma-3-1b-it\"\n",
|
| 122 |
"lora_repo_id = \"Miracle12345/gemma-3-GRPO\"\n",
|
| 123 |
"max_seq_len = 4096\n",
|
| 124 |
"\n",
|
| 125 |
-
"# Load base model + tokenizer\n",
|
| 126 |
"base_model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
| 127 |
" model_name=model_name,\n",
|
| 128 |
" max_seq_length=max_seq_len,\n",
|
|
@@ -137,18 +111,12 @@
|
|
| 137 |
" is_trainable=False,\n",
|
| 138 |
")\n",
|
| 139 |
"\n",
|
| 140 |
-
"# ===============================\n",
|
| 141 |
-
"# Prompt setup\n",
|
| 142 |
-
"# ===============================\n",
|
| 143 |
"system_prompt = f\"\"\"You are given a problem.\n",
|
| 144 |
"Think about it and provide your working out.\n",
|
| 145 |
"Put your reasoning between {reasoning_start} and {reasoning_end}.\n",
|
| 146 |
"Then, provide ONLY the final numerical solution between {solution_start}{solution_end}.\n",
|
| 147 |
"Do not output anything else.\"\"\"\n",
|
| 148 |
"\n",
|
| 149 |
-
"# ===============================\n",
|
| 150 |
-
"# Inference function\n",
|
| 151 |
-
"# ===============================\n",
|
| 152 |
"def run_inference(model, tokenizer, question, label=\"\"):\n",
|
| 153 |
" messages = [\n",
|
| 154 |
" {\"role\": \"system\", \"content\": system_prompt},\n",
|
|
@@ -177,9 +145,8 @@
|
|
| 177 |
"\n",
|
| 178 |
" return extract_solution(generated_text)\n",
|
| 179 |
"\n",
|
| 180 |
-
"# ===============================\n",
|
| 181 |
"# Test set (easy → hard)\n",
|
| 182 |
-
"
|
| 183 |
"test_problems = [\n",
|
| 184 |
" (\"What is 12 + 8 - 4 ?\", \"16\"),\n",
|
| 185 |
" (\"If you buy 5 pens at $12 each and 3 notebooks at $20 each, what is the total cost?\", \"120\"),\n",
|
|
@@ -191,9 +158,7 @@
|
|
| 191 |
" (\"Solve: A boat goes 30 km downstream in 2 hours and the same distance upstream in 3 hours. Find the speed of the boat in still water.\", \"12\"),\n",
|
| 192 |
"]\n",
|
| 193 |
"\n",
|
| 194 |
-
"# ===============================\n",
|
| 195 |
"# Run diagnostic test\n",
|
| 196 |
-
"# ===============================\n",
|
| 197 |
"for q, correct in test_problems:\n",
|
| 198 |
" ans_lora = run_inference(lora_model, tokenizer, q, label=\"(LoRA)\")\n",
|
| 199 |
" ans_base = run_inference(base_model, tokenizer, q, label=\"(Base)\")\n",
|
|
@@ -207,9 +172,6 @@
|
|
| 207 |
" print(\"\\n✅ Found case where LoRA is correct and Base is wrong!\")\n",
|
| 208 |
" break\n",
|
| 209 |
"\n",
|
| 210 |
-
"# ===============================\n",
|
| 211 |
-
"# Debug memory usage\n",
|
| 212 |
-
"# ===============================\n",
|
| 213 |
"print(\"\\nGPU memory allocated:\", torch.cuda.memory_allocated() / 1024**3, \"GB\")"
|
| 214 |
]
|
| 215 |
}
|
|
|
|
| 55 |
"from transformers import AutoTokenizer\n",
|
| 56 |
"from peft import PeftModel\n",
|
| 57 |
"\n",
|
|
|
|
|
|
|
|
|
|
| 58 |
"reasoning_start = \"<start_working_out>\"\n",
|
| 59 |
"reasoning_end = \"<end_working_out>\"\n",
|
| 60 |
"solution_start = \"<SOLUTION>\"\n",
|
| 61 |
"solution_end = \"</SOLUTION>\"\n",
|
| 62 |
"\n",
|
| 63 |
"def _normalize_numeric(s: str) -> str:\n",
|
| 64 |
+
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
" s = s.strip().replace(\",\", \"\")\n",
|
|
|
|
| 66 |
" s = s.rstrip(\".;)\")\n",
|
|
|
|
| 67 |
" try:\n",
|
| 68 |
" f = float(s)\n",
|
| 69 |
" except Exception:\n",
|
| 70 |
+
" return s \n",
|
|
|
|
| 71 |
" if f.is_integer():\n",
|
| 72 |
" return str(int(f))\n",
|
|
|
|
| 73 |
" s_float = repr(f)\n",
|
|
|
|
| 74 |
" if \".\" in s_float:\n",
|
| 75 |
" s_float = s_float.rstrip(\"0\").rstrip(\".\")\n",
|
| 76 |
" return s_float\n",
|
| 77 |
"\n",
|
| 78 |
"def extract_solution(text: str) -> str | None:\n",
|
| 79 |
+
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
" try:\n",
|
| 81 |
" tag_pattern = re.escape(solution_start) + r\"(.*?)\" + re.escape(solution_end)\n",
|
| 82 |
" m = re.search(tag_pattern, text, flags=re.DOTALL)\n",
|
| 83 |
" except NameError:\n",
|
|
|
|
| 84 |
" m = None\n",
|
| 85 |
"\n",
|
| 86 |
" if m:\n",
|
| 87 |
" ans = m.group(1).strip()\n",
|
| 88 |
" return _normalize_numeric(ans)\n",
|
| 89 |
"\n",
|
|
|
|
| 90 |
" nums = re.findall(r\"-?\\d+(?:\\.\\d+)?\", text)\n",
|
| 91 |
" if not nums:\n",
|
| 92 |
" return None\n",
|
| 93 |
" return _normalize_numeric(nums[-1])\n",
|
| 94 |
"\n",
|
| 95 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
"model_name = \"unsloth/gemma-3-1b-it\"\n",
|
| 97 |
"lora_repo_id = \"Miracle12345/gemma-3-GRPO\"\n",
|
| 98 |
"max_seq_len = 4096\n",
|
| 99 |
"\n",
|
|
|
|
| 100 |
"base_model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
| 101 |
" model_name=model_name,\n",
|
| 102 |
" max_seq_length=max_seq_len,\n",
|
|
|
|
| 111 |
" is_trainable=False,\n",
|
| 112 |
")\n",
|
| 113 |
"\n",
|
|
|
|
|
|
|
|
|
|
| 114 |
"system_prompt = f\"\"\"You are given a problem.\n",
|
| 115 |
"Think about it and provide your working out.\n",
|
| 116 |
"Put your reasoning between {reasoning_start} and {reasoning_end}.\n",
|
| 117 |
"Then, provide ONLY the final numerical solution between {solution_start}{solution_end}.\n",
|
| 118 |
"Do not output anything else.\"\"\"\n",
|
| 119 |
"\n",
|
|
|
|
|
|
|
|
|
|
| 120 |
"def run_inference(model, tokenizer, question, label=\"\"):\n",
|
| 121 |
" messages = [\n",
|
| 122 |
" {\"role\": \"system\", \"content\": system_prompt},\n",
|
|
|
|
| 145 |
"\n",
|
| 146 |
" return extract_solution(generated_text)\n",
|
| 147 |
"\n",
|
|
|
|
| 148 |
"# Test set (easy → hard)\n",
|
| 149 |
+
"\n",
|
| 150 |
"test_problems = [\n",
|
| 151 |
" (\"What is 12 + 8 - 4 ?\", \"16\"),\n",
|
| 152 |
" (\"If you buy 5 pens at $12 each and 3 notebooks at $20 each, what is the total cost?\", \"120\"),\n",
|
|
|
|
| 158 |
" (\"Solve: A boat goes 30 km downstream in 2 hours and the same distance upstream in 3 hours. Find the speed of the boat in still water.\", \"12\"),\n",
|
| 159 |
"]\n",
|
| 160 |
"\n",
|
|
|
|
| 161 |
"# Run diagnostic test\n",
|
|
|
|
| 162 |
"for q, correct in test_problems:\n",
|
| 163 |
" ans_lora = run_inference(lora_model, tokenizer, q, label=\"(LoRA)\")\n",
|
| 164 |
" ans_base = run_inference(base_model, tokenizer, q, label=\"(Base)\")\n",
|
|
|
|
| 172 |
" print(\"\\n✅ Found case where LoRA is correct and Base is wrong!\")\n",
|
| 173 |
" break\n",
|
| 174 |
"\n",
|
|
|
|
|
|
|
|
|
|
| 175 |
"print(\"\\nGPU memory allocated:\", torch.cuda.memory_allocated() / 1024**3, \"GB\")"
|
| 176 |
]
|
| 177 |
}
|