bebechien commited on
Commit
fc2675b
Β·
verified Β·
1 Parent(s): 5fd743d

Fix issue with stop if queued

Browse files
Files changed (1) hide show
  1. engine.py +119 -112
engine.py CHANGED
@@ -150,50 +150,53 @@ class FunctionGemmaEngine:
150
  yield "❌ Error: No model loaded.\n"
151
  return False
152
  return True
153
-
154
  # --- Evaluation Pipeline ---
155
 
156
  def run_evaluation(self, test_size: float, shuffle_data: bool) -> Generator[str, None, None]:
157
  self.stop_event.clear()
158
  output_buffer = ""
159
 
160
- # 1. Check Model
161
- gen = self._ensure_model_consistency()
162
  try:
163
- while True:
164
- msg = next(gen)
165
- output_buffer += msg
166
- yield output_buffer
167
- except StopIteration as e:
168
- if not e.value: return # Failed to load
169
-
170
- # 2. Prepare Data
171
- output_buffer += f"⏳ Preparing Dataset for Eval (Test Split: {test_size})...\n"
172
- yield output_buffer
 
 
 
173
 
174
- dataset, log = self._prepare_dataset()
175
- output_buffer += log
176
- yield output_buffer
177
-
178
- if not dataset:
179
- output_buffer += "❌ Dataset creation failed.\n"
180
  yield output_buffer
181
- return
 
 
 
 
182
 
183
- if len(dataset) > 1:
184
- dataset = dataset.train_test_split(test_size=test_size, shuffle=shuffle_data)
185
- else:
186
- dataset = {"train": dataset, "test": dataset}
187
-
188
- # 3. Run Inference
189
- output_buffer += "\nπŸ“Š Evaluating Model Success Rate on Test Split...\n"
190
- yield output_buffer
191
 
192
- for update in self._evaluate_model(dataset["test"]):
193
- yield f"{output_buffer}{update}"
194
- if self.stop_event.is_set():
195
- yield f"{output_buffer}{update}\n\nπŸ›‘ Evaluation interrupted by user."
196
- break
 
 
197
 
198
  # --- Training Pipeline ---
199
 
@@ -202,98 +205,102 @@ class FunctionGemmaEngine:
202
  output_buffer = ""
203
  last_plot = None
204
 
205
- # 1. Check Model
206
- gen = self._ensure_model_consistency()
207
  try:
208
- while True:
209
- msg = next(gen)
210
- output_buffer += f"{msg}"
211
- yield output_buffer, None
212
- except StopIteration as e:
213
- if not e.value: return
214
-
215
- output_buffer += f"⏳ Preparing Dataset (Test Split: {test_size}, Shuffle: {shuffle_data})...\n"
216
- yield output_buffer, None
217
-
218
- dataset, log = self._prepare_dataset()
219
- if not dataset:
220
- yield "Dataset creation failed.", None
221
- return
 
 
 
 
 
 
 
 
 
 
 
222
 
223
- output_buffer += log
224
- yield output_buffer, None
 
225
 
226
- if len(dataset) > 1:
227
- dataset = dataset.train_test_split(test_size=test_size, shuffle=shuffle_data)
228
- else:
229
- dataset = {"train": dataset, "test": dataset}
230
-
231
- # --- Training (Threaded) ---
232
- output_buffer += f"\nπŸš€ Starting Fine-tuning (Epochs: {epochs}, LR: {learning_rate})...\n"
233
- yield output_buffer, None
234
-
235
- log_queue = queue.Queue()
236
- training_error = None
237
- running_history = []
238
-
239
- def train_wrapper():
240
- nonlocal training_error
241
- try:
242
- self._execute_trainer(dataset, log_queue, epochs, learning_rate)
243
- except Exception as e:
244
- training_error = e
 
 
 
 
 
 
 
 
 
 
245
 
246
- train_thread = threading.Thread(target=train_wrapper)
247
- train_thread.start()
248
-
249
- while train_thread.is_alive():
 
 
 
 
 
250
  while not log_queue.empty():
251
  payload = log_queue.get()
252
  if isinstance(payload, tuple):
253
  msg, log_data = payload
254
  output_buffer += f"{msg}\n"
255
  running_history.append(log_data)
256
- try:
257
- last_plot = self._generate_loss_plot(running_history)
258
- yield output_buffer, last_plot
259
- except Exception:
260
- yield output_buffer, last_plot
261
  else:
262
  output_buffer += f"{payload}\n"
263
- yield output_buffer, last_plot
264
-
265
- if self.stop_event.is_set():
266
- yield f"{output_buffer}πŸ›‘ Stop signal sent. Waiting for trainer to wrap up...\n", last_plot
 
 
 
 
 
 
 
267
 
268
- time.sleep(0.1)
269
-
270
- train_thread.join()
271
-
272
- self.has_model_tuned = True
273
-
274
- while not log_queue.empty():
275
- payload = log_queue.get()
276
- if isinstance(payload, tuple):
277
- msg, log_data = payload
278
- output_buffer += f"{msg}\n"
279
- running_history.append(log_data)
280
- last_plot = self._generate_loss_plot(running_history)
281
- else:
282
- output_buffer += f"{payload}\n"
283
- yield output_buffer, last_plot
284
-
285
- if training_error:
286
- output_buffer += f"❌ Error during training: {training_error}\n"
287
  yield output_buffer, last_plot
288
- return
289
-
290
- if self.stop_event.is_set():
291
- output_buffer += "πŸ›‘ Training manually stopped.\n"
292
- yield output_buffer, last_plot
293
- return
294
-
295
- output_buffer += "βœ… Training finished.\n"
296
- yield output_buffer, last_plot
297
 
298
  def _prepare_dataset(self):
299
  formatting_fn = partial(create_conversation_format, tools_list=self.current_tools)
@@ -433,4 +440,4 @@ class FunctionGemmaEngine:
433
 
434
  return f"βœ… Success! Model uploaded to: {repo_url}"
435
  except Exception as e:
436
- return f"❌ Upload failed: {str(e)}"
 
150
  yield "❌ Error: No model loaded.\n"
151
  return False
152
  return True
153
+
154
  # --- Evaluation Pipeline ---
155
 
156
  def run_evaluation(self, test_size: float, shuffle_data: bool) -> Generator[str, None, None]:
157
  self.stop_event.clear()
158
  output_buffer = ""
159
 
 
 
160
  try:
161
+ # 1. Check Model
162
+ gen = self._ensure_model_consistency()
163
+ try:
164
+ while True:
165
+ msg = next(gen)
166
+ output_buffer += msg
167
+ yield output_buffer
168
+ except StopIteration as e:
169
+ if not e.value: return # Failed to load
170
+
171
+ # 2. Prepare Data
172
+ output_buffer += f"⏳ Preparing Dataset for Eval (Test Split: {test_size})...\n"
173
+ yield output_buffer
174
 
175
+ dataset, log = self._prepare_dataset()
176
+ output_buffer += log
 
 
 
 
177
  yield output_buffer
178
+
179
+ if not dataset:
180
+ output_buffer += "❌ Dataset creation failed.\n"
181
+ yield output_buffer
182
+ return
183
 
184
+ if len(dataset) > 1:
185
+ dataset = dataset.train_test_split(test_size=test_size, shuffle=shuffle_data)
186
+ else:
187
+ dataset = {"train": dataset, "test": dataset}
188
+
189
+ # 3. Run Inference
190
+ output_buffer += "\nπŸ“Š Evaluating Model Success Rate on Test Split...\n"
191
+ yield output_buffer
192
 
193
+ for update in self._evaluate_model(dataset["test"]):
194
+ yield f"{output_buffer}{update}"
195
+ if self.stop_event.is_set():
196
+ yield f"{output_buffer}{update}\n\nπŸ›‘ Evaluation interrupted by user."
197
+ break
198
+ finally:
199
+ self.stop_event.set() # Ensure loop breaks if generator cancelled
200
 
201
  # --- Training Pipeline ---
202
 
 
205
  output_buffer = ""
206
  last_plot = None
207
 
 
 
208
  try:
209
+ # 1. Check Model
210
+ gen = self._ensure_model_consistency()
211
+ try:
212
+ while True:
213
+ msg = next(gen)
214
+ output_buffer += f"{msg}"
215
+ yield output_buffer, None
216
+ except StopIteration as e:
217
+ if not e.value: return
218
+
219
+ output_buffer += f"⏳ Preparing Dataset (Test Split: {test_size}, Shuffle: {shuffle_data})...\n"
220
+ yield output_buffer, None
221
+
222
+ dataset, log = self._prepare_dataset()
223
+ if not dataset:
224
+ yield "Dataset creation failed.", None
225
+ return
226
+
227
+ output_buffer += log
228
+ yield output_buffer, None
229
+
230
+ if len(dataset) > 1:
231
+ dataset = dataset.train_test_split(test_size=test_size, shuffle=shuffle_data)
232
+ else:
233
+ dataset = {"train": dataset, "test": dataset}
234
 
235
+ # --- Training (Threaded) ---
236
+ output_buffer += f"\nπŸš€ Starting Fine-tuning (Epochs: {epochs}, LR: {learning_rate})...\n"
237
+ yield output_buffer, None
238
 
239
+ log_queue = queue.Queue()
240
+ training_error = None
241
+ running_history = []
242
+
243
+ def train_wrapper():
244
+ nonlocal training_error
245
+ try:
246
+ self._execute_trainer(dataset, log_queue, epochs, learning_rate)
247
+ except Exception as e:
248
+ training_error = e
249
+
250
+ train_thread = threading.Thread(target=train_wrapper)
251
+ train_thread.start()
252
+
253
+ while train_thread.is_alive():
254
+ while not log_queue.empty():
255
+ payload = log_queue.get()
256
+ if isinstance(payload, tuple):
257
+ msg, log_data = payload
258
+ output_buffer += f"{msg}\n"
259
+ running_history.append(log_data)
260
+ try:
261
+ last_plot = self._generate_loss_plot(running_history)
262
+ yield output_buffer, last_plot
263
+ except Exception:
264
+ yield output_buffer, last_plot
265
+ else:
266
+ output_buffer += f"{payload}\n"
267
+ yield output_buffer, last_plot
268
 
269
+ if self.stop_event.is_set():
270
+ yield f"{output_buffer}πŸ›‘ Stop signal sent. Waiting for trainer to wrap up...\n", last_plot
271
+
272
+ time.sleep(0.1)
273
+
274
+ train_thread.join()
275
+
276
+ self.has_model_tuned = True
277
+
278
  while not log_queue.empty():
279
  payload = log_queue.get()
280
  if isinstance(payload, tuple):
281
  msg, log_data = payload
282
  output_buffer += f"{msg}\n"
283
  running_history.append(log_data)
284
+ last_plot = self._generate_loss_plot(running_history)
 
 
 
 
285
  else:
286
  output_buffer += f"{payload}\n"
287
+ yield output_buffer, last_plot
288
+
289
+ if training_error:
290
+ output_buffer += f"❌ Error during training: {training_error}\n"
291
+ yield output_buffer, last_plot
292
+ return
293
+
294
+ if self.stop_event.is_set():
295
+ output_buffer += "πŸ›‘ Training manually stopped.\n"
296
+ yield output_buffer, last_plot
297
+ return
298
 
299
+ output_buffer += "βœ… Training finished.\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  yield output_buffer, last_plot
301
+
302
+ finally:
303
+ self.stop_event.set() # Ensure background thread stops if generator cancelled
 
 
 
 
 
 
304
 
305
  def _prepare_dataset(self):
306
  formatting_fn = partial(create_conversation_format, tools_list=self.current_tools)
 
440
 
441
  return f"βœ… Success! Model uploaded to: {repo_url}"
442
  except Exception as e:
443
+ return f"❌ Upload failed: {str(e)}"