Miracle12345 commited on
Commit
c228698
·
1 Parent(s): 3d5e034
Files changed (2) hide show
  1. README.md +3 -43
  2. inference_test.ipynb +4 -42
README.md CHANGED
@@ -10,12 +10,8 @@ tags:
10
  - math-reasoning
11
  language:
12
  - en
13
- base_model: google/gemma-3-1b-it
14
- datasets:
15
- - custom-instruction-dataset
16
- metrics:
17
- - accuracy
18
- - f1
19
  ---
20
 
21
  # Gemma-3 1B IT LoRA Fine-tuned with GRPO
@@ -62,7 +58,7 @@ from peft import PeftModel
62
 
63
  # Model identifiers
64
  base_model_name = "google/gemma-3-1b-it"
65
- adapter_repo_id = "your-username/gemma-3-GRPO" # Replace with your Hugging Face repo
66
 
67
  # Load base model and tokenizer
68
  model = AutoModelForCausalLM.from_pretrained(
@@ -122,28 +118,6 @@ The model was fine-tuned using the Unsloth framework with the following approach
122
  - Reward Function: Accuracy-based with format compliance bonus
123
  - KL Divergence Penalty: 0.01
124
 
125
- ### Dataset
126
-
127
- The model was trained on a curated dataset of mathematical problems and reasoning tasks, including:
128
- - Arithmetic problems
129
- - Word problems
130
- - Algebraic equations
131
- - Geometric calculations
132
-
133
- *Note: Replace placeholders with actual training details if available.*
134
-
135
- ## Evaluation
136
-
137
- The model's performance was evaluated on a held-out test set of mathematical problems. Key metrics include:
138
-
139
- - **Accuracy**: Percentage of correct final answers
140
- - **Format Compliance**: Adherence to specified output format
141
- - **Reasoning Quality**: Coherence and correctness of intermediate steps
142
-
143
- Example evaluation results:
144
- - Simple arithmetic: 95% accuracy
145
- - Complex word problems: 78% accuracy
146
- - Overall improvement over base model: +15-20% on reasoning tasks
147
 
148
  ## Limitations
149
 
@@ -154,26 +128,12 @@ Example evaluation results:
154
  - **Hallucinations**: Like all language models, can generate incorrect information
155
  - **Bias**: May reflect biases present in the training data
156
 
157
- ## Ethical Considerations
158
-
159
- - Use outputs as a tool, not as definitive answers
160
- - Verify critical information independently
161
- - Be aware of potential biases in generated content
162
- - Consider the environmental impact of large language model usage
163
 
164
  ## Citation
165
 
166
  If you use this model in your research or applications, please cite:
167
 
168
  ```bibtex
169
- @misc{gemma3-grpo-lora,
170
- title={Gemma-3 1B IT LoRA Fine-tuned with GRPO},
171
- author={Your Name},
172
- year={2025},
173
- publisher={Hugging Face},
174
- url={https://huggingface.co/your-username/gemma-3-GRPO}
175
- }
176
-
177
  @article{shao2024deepseekmath,
178
  title={DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open-source Large Language Models},
179
  author={Shao, Zhihong and Wang, Peiyi and Zhu, Qihao and Xu, Runxin and Song, Junxiao and Bi, Xiao and Zhang, Haowei and Zhang, Mingchuan and Li, Y. K. and Wu, Y. K. and Guo, Daya},
 
10
  - math-reasoning
11
  language:
12
  - en
13
+ base_model: google/gemma-3-1b-it
14
+
 
 
 
 
15
  ---
16
 
17
  # Gemma-3 1B IT LoRA Fine-tuned with GRPO
 
58
 
59
  # Model identifiers
60
  base_model_name = "google/gemma-3-1b-it"
61
+ adapter_repo_id = "Miracle12345/gemma-3-GRPO"
62
 
63
  # Load base model and tokenizer
64
  model = AutoModelForCausalLM.from_pretrained(
 
118
  - Reward Function: Accuracy-based with format compliance bonus
119
  - KL Divergence Penalty: 0.01
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  ## Limitations
123
 
 
128
  - **Hallucinations**: Like all language models, can generate incorrect information
129
  - **Bias**: May reflect biases present in the training data
130
 
 
 
 
 
 
 
131
 
132
  ## Citation
133
 
134
  If you use this model in your research or applications, please cite:
135
 
136
  ```bibtex
 
 
 
 
 
 
 
 
137
  @article{shao2024deepseekmath,
138
  title={DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open-source Large Language Models},
139
  author={Shao, Zhihong and Wang, Peiyi and Zhu, Qihao and Xu, Runxin and Song, Junxiao and Bi, Xiao and Zhang, Haowei and Zhang, Mingchuan and Li, Y. K. and Wu, Y. K. and Guo, Daya},
inference_test.ipynb CHANGED
@@ -55,74 +55,48 @@
55
  "from transformers import AutoTokenizer\n",
56
  "from peft import PeftModel\n",
57
  "\n",
58
- "# ===============================\n",
59
- "# Tags for GRPO training format\n",
60
- "# ===============================\n",
61
  "reasoning_start = \"<start_working_out>\"\n",
62
  "reasoning_end = \"<end_working_out>\"\n",
63
  "solution_start = \"<SOLUTION>\"\n",
64
  "solution_end = \"</SOLUTION>\"\n",
65
  "\n",
66
  "def _normalize_numeric(s: str) -> str:\n",
67
- " \"\"\"Normalize numeric-like strings:\n",
68
- " - remove commas and stray whitespace\n",
69
- " - convert '60.0' -> '60'\n",
70
- " - keep decimals if they are not integer-valued\n",
71
- " - otherwise return the cleaned string as-is\n",
72
- " \"\"\"\n",
73
  " s = s.strip().replace(\",\", \"\")\n",
74
- " # remove trailing punctuation commonly found in outputs\n",
75
  " s = s.rstrip(\".;)\")\n",
76
- " # try float conversion\n",
77
  " try:\n",
78
  " f = float(s)\n",
79
  " except Exception:\n",
80
- " return s # not a pure number, return raw cleaned string\n",
81
- " # if it's actually an integer value, return integer form\n",
82
  " if f.is_integer():\n",
83
  " return str(int(f))\n",
84
- " # else return float without unnecessary trailing zeros\n",
85
  " s_float = repr(f)\n",
86
- " # strip trailing zeros like '2.500000' -> '2.5'\n",
87
  " if \".\" in s_float:\n",
88
  " s_float = s_float.rstrip(\"0\").rstrip(\".\")\n",
89
  " return s_float\n",
90
  "\n",
91
  "def extract_solution(text: str) -> str | None:\n",
92
- " \"\"\"\n",
93
- " Extract the final solution from `text`.\n",
94
- " 1) Look for <SOLUTION>...</SOLUTION>\n",
95
- " 2) Otherwise take the last numeric token in the text\n",
96
- " Returns the cleaned numeric answer as a string, or None.\n",
97
- " \"\"\"\n",
98
- " # 1) Strict tag-based extraction (safe escaping)\n",
99
  " try:\n",
100
  " tag_pattern = re.escape(solution_start) + r\"(.*?)\" + re.escape(solution_end)\n",
101
  " m = re.search(tag_pattern, text, flags=re.DOTALL)\n",
102
  " except NameError:\n",
103
- " # If solution_start/solution_end not defined for some reason\n",
104
  " m = None\n",
105
  "\n",
106
  " if m:\n",
107
  " ans = m.group(1).strip()\n",
108
  " return _normalize_numeric(ans)\n",
109
  "\n",
110
- " # 2) Fallback: find all numeric tokens and return the last one\n",
111
  " nums = re.findall(r\"-?\\d+(?:\\.\\d+)?\", text)\n",
112
  " if not nums:\n",
113
  " return None\n",
114
  " return _normalize_numeric(nums[-1])\n",
115
  "\n",
116
  "\n",
117
- "\n",
118
- "# ===============================\n",
119
- "# Model setup\n",
120
- "# ===============================\n",
121
  "model_name = \"unsloth/gemma-3-1b-it\"\n",
122
  "lora_repo_id = \"Miracle12345/gemma-3-GRPO\"\n",
123
  "max_seq_len = 4096\n",
124
  "\n",
125
- "# Load base model + tokenizer\n",
126
  "base_model, tokenizer = FastLanguageModel.from_pretrained(\n",
127
  " model_name=model_name,\n",
128
  " max_seq_length=max_seq_len,\n",
@@ -137,18 +111,12 @@
137
  " is_trainable=False,\n",
138
  ")\n",
139
  "\n",
140
- "# ===============================\n",
141
- "# Prompt setup\n",
142
- "# ===============================\n",
143
  "system_prompt = f\"\"\"You are given a problem.\n",
144
  "Think about it and provide your working out.\n",
145
  "Put your reasoning between {reasoning_start} and {reasoning_end}.\n",
146
  "Then, provide ONLY the final numerical solution between {solution_start}{solution_end}.\n",
147
  "Do not output anything else.\"\"\"\n",
148
  "\n",
149
- "# ===============================\n",
150
- "# Inference function\n",
151
- "# ===============================\n",
152
  "def run_inference(model, tokenizer, question, label=\"\"):\n",
153
  " messages = [\n",
154
  " {\"role\": \"system\", \"content\": system_prompt},\n",
@@ -177,9 +145,8 @@
177
  "\n",
178
  " return extract_solution(generated_text)\n",
179
  "\n",
180
- "# ===============================\n",
181
  "# Test set (easy → hard)\n",
182
- "# ===============================\n",
183
  "test_problems = [\n",
184
  " (\"What is 12 + 8 - 4 ?\", \"16\"),\n",
185
  " (\"If you buy 5 pens at $12 each and 3 notebooks at $20 each, what is the total cost?\", \"120\"),\n",
@@ -191,9 +158,7 @@
191
  " (\"Solve: A boat goes 30 km downstream in 2 hours and the same distance upstream in 3 hours. Find the speed of the boat in still water.\", \"12\"),\n",
192
  "]\n",
193
  "\n",
194
- "# ===============================\n",
195
  "# Run diagnostic test\n",
196
- "# ===============================\n",
197
  "for q, correct in test_problems:\n",
198
  " ans_lora = run_inference(lora_model, tokenizer, q, label=\"(LoRA)\")\n",
199
  " ans_base = run_inference(base_model, tokenizer, q, label=\"(Base)\")\n",
@@ -207,9 +172,6 @@
207
  " print(\"\\n✅ Found case where LoRA is correct and Base is wrong!\")\n",
208
  " break\n",
209
  "\n",
210
- "# ===============================\n",
211
- "# Debug memory usage\n",
212
- "# ===============================\n",
213
  "print(\"\\nGPU memory allocated:\", torch.cuda.memory_allocated() / 1024**3, \"GB\")"
214
  ]
215
  }
 
55
  "from transformers import AutoTokenizer\n",
56
  "from peft import PeftModel\n",
57
  "\n",
 
 
 
58
  "reasoning_start = \"<start_working_out>\"\n",
59
  "reasoning_end = \"<end_working_out>\"\n",
60
  "solution_start = \"<SOLUTION>\"\n",
61
  "solution_end = \"</SOLUTION>\"\n",
62
  "\n",
63
  "def _normalize_numeric(s: str) -> str:\n",
64
+ "\n",
 
 
 
 
 
65
  " s = s.strip().replace(\",\", \"\")\n",
 
66
  " s = s.rstrip(\".;)\")\n",
 
67
  " try:\n",
68
  " f = float(s)\n",
69
  " except Exception:\n",
70
+ " return s \n",
 
71
  " if f.is_integer():\n",
72
  " return str(int(f))\n",
 
73
  " s_float = repr(f)\n",
 
74
  " if \".\" in s_float:\n",
75
  " s_float = s_float.rstrip(\"0\").rstrip(\".\")\n",
76
  " return s_float\n",
77
  "\n",
78
  "def extract_solution(text: str) -> str | None:\n",
79
+ "\n",
 
 
 
 
 
 
80
  " try:\n",
81
  " tag_pattern = re.escape(solution_start) + r\"(.*?)\" + re.escape(solution_end)\n",
82
  " m = re.search(tag_pattern, text, flags=re.DOTALL)\n",
83
  " except NameError:\n",
 
84
  " m = None\n",
85
  "\n",
86
  " if m:\n",
87
  " ans = m.group(1).strip()\n",
88
  " return _normalize_numeric(ans)\n",
89
  "\n",
 
90
  " nums = re.findall(r\"-?\\d+(?:\\.\\d+)?\", text)\n",
91
  " if not nums:\n",
92
  " return None\n",
93
  " return _normalize_numeric(nums[-1])\n",
94
  "\n",
95
  "\n",
 
 
 
 
96
  "model_name = \"unsloth/gemma-3-1b-it\"\n",
97
  "lora_repo_id = \"Miracle12345/gemma-3-GRPO\"\n",
98
  "max_seq_len = 4096\n",
99
  "\n",
 
100
  "base_model, tokenizer = FastLanguageModel.from_pretrained(\n",
101
  " model_name=model_name,\n",
102
  " max_seq_length=max_seq_len,\n",
 
111
  " is_trainable=False,\n",
112
  ")\n",
113
  "\n",
 
 
 
114
  "system_prompt = f\"\"\"You are given a problem.\n",
115
  "Think about it and provide your working out.\n",
116
  "Put your reasoning between {reasoning_start} and {reasoning_end}.\n",
117
  "Then, provide ONLY the final numerical solution between {solution_start}{solution_end}.\n",
118
  "Do not output anything else.\"\"\"\n",
119
  "\n",
 
 
 
120
  "def run_inference(model, tokenizer, question, label=\"\"):\n",
121
  " messages = [\n",
122
  " {\"role\": \"system\", \"content\": system_prompt},\n",
 
145
  "\n",
146
  " return extract_solution(generated_text)\n",
147
  "\n",
 
148
  "# Test set (easy → hard)\n",
149
+ "\n",
150
  "test_problems = [\n",
151
  " (\"What is 12 + 8 - 4 ?\", \"16\"),\n",
152
  " (\"If you buy 5 pens at $12 each and 3 notebooks at $20 each, what is the total cost?\", \"120\"),\n",
 
158
  " (\"Solve: A boat goes 30 km downstream in 2 hours and the same distance upstream in 3 hours. Find the speed of the boat in still water.\", \"12\"),\n",
159
  "]\n",
160
  "\n",
 
161
  "# Run diagnostic test\n",
 
162
  "for q, correct in test_problems:\n",
163
  " ans_lora = run_inference(lora_model, tokenizer, q, label=\"(LoRA)\")\n",
164
  " ans_base = run_inference(base_model, tokenizer, q, label=\"(Base)\")\n",
 
172
  " print(\"\\n✅ Found case where LoRA is correct and Base is wrong!\")\n",
173
  " break\n",
174
  "\n",
 
 
 
175
  "print(\"\\nGPU memory allocated:\", torch.cuda.memory_allocated() / 1024**3, \"GB\")"
176
  ]
177
  }