bebechien commited on
Commit
c07c868
Β·
verified Β·
1 Parent(s): f2bf1da

Split testing and tuning

Browse files
Files changed (2) hide show
  1. engine.py +70 -45
  2. ui.py +12 -1
engine.py CHANGED
@@ -135,6 +135,66 @@ class FunctionGemmaEngine:
135
  def trigger_stop(self):
136
  self.stop_event.set()
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  # --- Training Pipeline ---
139
 
140
  def run_training_pipeline(self, epochs: int, learning_rate: float, test_size: float, shuffle_data: bool) -> Generator[Tuple[str, Any], None, None]:
@@ -142,21 +202,15 @@ class FunctionGemmaEngine:
142
  output_buffer = ""
143
  last_plot = None
144
 
145
- if self.config.MODEL_NAME != self.loaded_model_name:
146
- output_buffer += f"πŸ”„ Model changed. Switching from '{self.loaded_model_name}' to '{self.config.MODEL_NAME}'...\n"
147
- yield output_buffer, None
148
- try:
149
- self._load_model_weights()
150
- output_buffer += "βœ… Model reloaded successfully.\n"
151
- yield output_buffer, None
152
- except Exception as e:
153
- output_buffer += f"❌ Failed to load model '{self.config.MODEL_NAME}': {e}\n"
154
  yield output_buffer, None
155
- return
156
-
157
- if self.model is None:
158
- yield "Training failed: No model loaded.", None
159
- return
160
 
161
  output_buffer += f"⏳ Preparing Dataset (Test Split: {test_size}, Shuffle: {shuffle_data})...\n"
162
  yield output_buffer, None
@@ -174,24 +228,8 @@ class FunctionGemmaEngine:
174
  else:
175
  dataset = {"train": dataset, "test": dataset}
176
 
177
- # --- Phase 1: Pre-Training Eval ---
178
- output_buffer += "\nπŸ“Š Evaluating Pre-Training Success Rate...\n"
179
- yield output_buffer, None
180
-
181
- pre_training_report = ""
182
- for update in self._evaluate_model(dataset["test"]):
183
- pre_training_report = update
184
- if self.stop_event.is_set():
185
- pre_training_report += "\n\nπŸ›‘ Manual Eval interrupted by user.\n"
186
- yield f"{output_buffer}{pre_training_report}", None
187
- break
188
- yield f"{output_buffer}{pre_training_report}", None
189
-
190
- if self.stop_event.is_set(): return
191
- output_buffer += pre_training_report
192
-
193
- # --- Phase 2: Training (Threaded) ---
194
- output_buffer += f"\n\nπŸš€ Starting Fine-tuning (Epochs: {epochs}, LR: {learning_rate})...\n"
195
  yield output_buffer, None
196
 
197
  log_queue = queue.Queue()
@@ -257,19 +295,6 @@ class FunctionGemmaEngine:
257
  output_buffer += "βœ… Training finished.\n"
258
  yield output_buffer, last_plot
259
 
260
- # --- Phase 3: Post-Training Eval ---
261
- output_buffer += "\nπŸ“Š Evaluating Post-Training Success Rate...\n"
262
- yield output_buffer, last_plot
263
-
264
- post_training_report = ""
265
- for update in self._evaluate_model(dataset["test"]):
266
- post_training_report = update
267
- if self.stop_event.is_set():
268
- post_training_report += "\n\nπŸ›‘ Manual Eval interrupted by user.\n"
269
- yield f"{output_buffer}{post_training_report}", last_plot
270
- break
271
- yield f"{output_buffer}{post_training_report}", last_plot
272
-
273
  def _prepare_dataset(self):
274
  formatting_fn = partial(create_conversation_format, tools_list=self.current_tools)
275
 
 
135
  def trigger_stop(self):
136
  self.stop_event.set()
137
 
138
+ def _ensure_model_consistency(self) -> Generator[str, None, bool]:
139
+ """Checks if the requested model matches the loaded one. Reloads if necessary."""
140
+ if self.config.MODEL_NAME != self.loaded_model_name:
141
+ yield f"πŸ”„ Model changed. Switching from '{self.loaded_model_name}' to '{self.config.MODEL_NAME}'...\n"
142
+ try:
143
+ self._load_model_weights()
144
+ yield "βœ… Model reloaded successfully.\n"
145
+ return True
146
+ except Exception as e:
147
+ yield f"❌ Failed to load model '{self.config.MODEL_NAME}': {e}\n"
148
+ return False
149
+ if self.model is None:
150
+ yield "❌ Error: No model loaded.\n"
151
+ return False
152
+ return True
153
+
154
+ # --- Evaluation Pipeline ---
155
+
156
+ def run_evaluation(self, test_size: float, shuffle_data: bool) -> Generator[str, None, None]:
157
+ self.stop_event.clear()
158
+ output_buffer = ""
159
+
160
+ # 1. Check Model
161
+ gen = self._ensure_model_consistency()
162
+ try:
163
+ while True:
164
+ msg = next(gen)
165
+ output_buffer += msg
166
+ yield output_buffer
167
+ except StopIteration as e:
168
+ if not e.value: return # Failed to load
169
+
170
+ # 2. Prepare Data
171
+ output_buffer += f"⏳ Preparing Dataset for Eval (Test Split: {test_size})...\n"
172
+ yield output_buffer
173
+
174
+ dataset, log = self._prepare_dataset()
175
+ output_buffer += log
176
+ yield output_buffer
177
+
178
+ if not dataset:
179
+ output_buffer += "❌ Dataset creation failed.\n"
180
+ yield output_buffer
181
+ return
182
+
183
+ if len(dataset) > 1:
184
+ dataset = dataset.train_test_split(test_size=test_size, shuffle=shuffle_data)
185
+ else:
186
+ dataset = {"train": dataset, "test": dataset}
187
+
188
+ # 3. Run Inference
189
+ output_buffer += "\nπŸ“Š Evaluating Model Success Rate on Test Split...\n"
190
+ yield output_buffer
191
+
192
+ for update in self._evaluate_model(dataset["test"]):
193
+ yield f"{output_buffer}{update}"
194
+ if self.stop_event.is_set():
195
+ yield f"{output_buffer}{update}\n\nπŸ›‘ Evaluation interrupted by user."
196
+ break
197
+
198
  # --- Training Pipeline ---
199
 
200
  def run_training_pipeline(self, epochs: int, learning_rate: float, test_size: float, shuffle_data: bool) -> Generator[Tuple[str, Any], None, None]:
 
202
  output_buffer = ""
203
  last_plot = None
204
 
205
+ # 1. Check Model
206
+ gen = self._ensure_model_consistency()
207
+ try:
208
+ while True:
209
+ msg = next(gen)
210
+ output_buffer += f"{msg}"
 
 
 
211
  yield output_buffer, None
212
+ except StopIteration as e:
213
+ if not e.value: return
 
 
 
214
 
215
  output_buffer += f"⏳ Preparing Dataset (Test Split: {test_size}, Shuffle: {shuffle_data})...\n"
216
  yield output_buffer, None
 
228
  else:
229
  dataset = {"train": dataset, "test": dataset}
230
 
231
+ # --- Training (Threaded) ---
232
+ output_buffer += f"\nπŸš€ Starting Fine-tuning (Epochs: {epochs}, LR: {learning_rate})...\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  yield output_buffer, None
234
 
235
  log_queue = queue.Queue()
 
295
  output_buffer += "βœ… Training finished.\n"
296
  yield output_buffer, last_plot
297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  def _prepare_dataset(self):
299
  formatting_fn = partial(create_conversation_format, tools_list=self.current_tools)
300
 
ui.py CHANGED
@@ -41,6 +41,15 @@ class UIController:
41
  engine.config.MODEL_NAME = model_name.strip()
42
  yield from engine.run_training_pipeline(epochs, lr, test_size, shuffle)
43
 
 
 
 
 
 
 
 
 
 
44
  @staticmethod
45
  def handle_reset(engine: FunctionGemmaEngine, model_name: str) -> str:
46
  engine.config.MODEL_NAME = model_name.strip()
@@ -151,6 +160,7 @@ def _render_training_tab(engine_state):
151
  param_shuffle = gr.Checkbox(value=True, label="Shuffle Data", info="Randomize before split")
152
 
153
  with gr.Row():
 
154
  run_training_btn = gr.Button("πŸš€ Run Fine-Tuning", variant="primary", scale=1)
155
  stop_training_btn = gr.Button("πŸ›‘ Stop", variant="stop", visible=False, scale=1)
156
  clear_reload_btn = gr.Button("πŸ”„ Reload Model & Reset Data", variant="secondary", scale=1)
@@ -161,7 +171,7 @@ def _render_training_tab(engine_state):
161
 
162
  return {
163
  "params": [param_epochs, param_lr, param_test_size, param_shuffle, param_model],
164
- "buttons": [run_training_btn, stop_training_btn, clear_reload_btn],
165
  "outputs": [output_display, loss_plot],
166
  "model_input": param_model # specifically needed for initialization
167
  }
@@ -175,6 +185,7 @@ def _render_export_tab(engine_state, username_state):
175
  gr.Markdown("Download the model weights locally.")
176
  zip_btn = gr.Button("⬇️ Prepare Model ZIP", variant="secondary", interactive=False)
177
  download_file = gr.File(label="Download Archive", interactive=False)
 
178
 
179
  with gr.Column():
180
  gr.Markdown("#### Option B: Save to Hugging Face Hub")
 
41
  engine.config.MODEL_NAME = model_name.strip()
42
  yield from engine.run_training_pipeline(epochs, lr, test_size, shuffle)
43
 
44
+ @staticmethod
45
+ def run_evaluation(engine: FunctionGemmaEngine, test_size: float, shuffle: bool, model_name: str) -> Generator:
46
+ if not engine:
47
+ yield "⚠️ Engine not initialized."
48
+ return
49
+
50
+ engine.config.MODEL_NAME = model_name.strip()
51
+ yield from engine.run_evaluation(test_size, shuffle)
52
+
53
  @staticmethod
54
  def handle_reset(engine: FunctionGemmaEngine, model_name: str) -> str:
55
  engine.config.MODEL_NAME = model_name.strip()
 
160
  param_shuffle = gr.Checkbox(value=True, label="Shuffle Data", info="Randomize before split")
161
 
162
  with gr.Row():
163
+ run_eval_btn = gr.Button("πŸ§ͺ Run Evaluation", variant="secondary", scale=1)
164
  run_training_btn = gr.Button("πŸš€ Run Fine-Tuning", variant="primary", scale=1)
165
  stop_training_btn = gr.Button("πŸ›‘ Stop", variant="stop", visible=False, scale=1)
166
  clear_reload_btn = gr.Button("πŸ”„ Reload Model & Reset Data", variant="secondary", scale=1)
 
171
 
172
  return {
173
  "params": [param_epochs, param_lr, param_test_size, param_shuffle, param_model],
174
+ "buttons": [run_training_btn, stop_training_btn, clear_reload_btn, run_eval_btn],
175
  "outputs": [output_display, loss_plot],
176
  "model_input": param_model # specifically needed for initialization
177
  }
 
185
  gr.Markdown("Download the model weights locally.")
186
  zip_btn = gr.Button("⬇️ Prepare Model ZIP", variant="secondary", interactive=False)
187
  download_file = gr.File(label="Download Archive", interactive=False)
188
+ gr.Markdown("NOTE: Zipping usually takes 1~2 min.")
189
 
190
  with gr.Column():
191
  gr.Markdown("#### Option B: Save to Hugging Face Hub")