Spaces:
Running
Running
Fix issue with stop if queued
Browse files
engine.py
CHANGED
|
@@ -150,50 +150,53 @@ class FunctionGemmaEngine:
|
|
| 150 |
yield "β Error: No model loaded.\n"
|
| 151 |
return False
|
| 152 |
return True
|
| 153 |
-
|
| 154 |
# --- Evaluation Pipeline ---
|
| 155 |
|
| 156 |
def run_evaluation(self, test_size: float, shuffle_data: bool) -> Generator[str, None, None]:
|
| 157 |
self.stop_event.clear()
|
| 158 |
output_buffer = ""
|
| 159 |
|
| 160 |
-
# 1. Check Model
|
| 161 |
-
gen = self._ensure_model_consistency()
|
| 162 |
try:
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
yield output_buffer
|
| 177 |
-
|
| 178 |
-
if not dataset:
|
| 179 |
-
output_buffer += "β Dataset creation failed.\n"
|
| 180 |
yield output_buffer
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
|
|
|
|
|
|
| 197 |
|
| 198 |
# --- Training Pipeline ---
|
| 199 |
|
|
@@ -202,98 +205,102 @@ class FunctionGemmaEngine:
|
|
| 202 |
output_buffer = ""
|
| 203 |
last_plot = None
|
| 204 |
|
| 205 |
-
# 1. Check Model
|
| 206 |
-
gen = self._ensure_model_consistency()
|
| 207 |
try:
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
|
| 223 |
-
|
| 224 |
-
|
|
|
|
| 225 |
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
while not log_queue.empty():
|
| 251 |
payload = log_queue.get()
|
| 252 |
if isinstance(payload, tuple):
|
| 253 |
msg, log_data = payload
|
| 254 |
output_buffer += f"{msg}\n"
|
| 255 |
running_history.append(log_data)
|
| 256 |
-
|
| 257 |
-
last_plot = self._generate_loss_plot(running_history)
|
| 258 |
-
yield output_buffer, last_plot
|
| 259 |
-
except Exception:
|
| 260 |
-
yield output_buffer, last_plot
|
| 261 |
else:
|
| 262 |
output_buffer += f"{payload}\n"
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
if
|
| 266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
train_thread.join()
|
| 271 |
-
|
| 272 |
-
self.has_model_tuned = True
|
| 273 |
-
|
| 274 |
-
while not log_queue.empty():
|
| 275 |
-
payload = log_queue.get()
|
| 276 |
-
if isinstance(payload, tuple):
|
| 277 |
-
msg, log_data = payload
|
| 278 |
-
output_buffer += f"{msg}\n"
|
| 279 |
-
running_history.append(log_data)
|
| 280 |
-
last_plot = self._generate_loss_plot(running_history)
|
| 281 |
-
else:
|
| 282 |
-
output_buffer += f"{payload}\n"
|
| 283 |
-
yield output_buffer, last_plot
|
| 284 |
-
|
| 285 |
-
if training_error:
|
| 286 |
-
output_buffer += f"β Error during training: {training_error}\n"
|
| 287 |
yield output_buffer, last_plot
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
output_buffer += "π Training manually stopped.\n"
|
| 292 |
-
yield output_buffer, last_plot
|
| 293 |
-
return
|
| 294 |
-
|
| 295 |
-
output_buffer += "β
Training finished.\n"
|
| 296 |
-
yield output_buffer, last_plot
|
| 297 |
|
| 298 |
def _prepare_dataset(self):
|
| 299 |
formatting_fn = partial(create_conversation_format, tools_list=self.current_tools)
|
|
@@ -433,4 +440,4 @@ class FunctionGemmaEngine:
|
|
| 433 |
|
| 434 |
return f"β
Success! Model uploaded to: {repo_url}"
|
| 435 |
except Exception as e:
|
| 436 |
-
return f"β Upload failed: {str(e)}"
|
|
|
|
| 150 |
yield "β Error: No model loaded.\n"
|
| 151 |
return False
|
| 152 |
return True
|
| 153 |
+
|
| 154 |
# --- Evaluation Pipeline ---
|
| 155 |
|
| 156 |
def run_evaluation(self, test_size: float, shuffle_data: bool) -> Generator[str, None, None]:
|
| 157 |
self.stop_event.clear()
|
| 158 |
output_buffer = ""
|
| 159 |
|
|
|
|
|
|
|
| 160 |
try:
|
| 161 |
+
# 1. Check Model
|
| 162 |
+
gen = self._ensure_model_consistency()
|
| 163 |
+
try:
|
| 164 |
+
while True:
|
| 165 |
+
msg = next(gen)
|
| 166 |
+
output_buffer += msg
|
| 167 |
+
yield output_buffer
|
| 168 |
+
except StopIteration as e:
|
| 169 |
+
if not e.value: return # Failed to load
|
| 170 |
+
|
| 171 |
+
# 2. Prepare Data
|
| 172 |
+
output_buffer += f"β³ Preparing Dataset for Eval (Test Split: {test_size})...\n"
|
| 173 |
+
yield output_buffer
|
| 174 |
|
| 175 |
+
dataset, log = self._prepare_dataset()
|
| 176 |
+
output_buffer += log
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
yield output_buffer
|
| 178 |
+
|
| 179 |
+
if not dataset:
|
| 180 |
+
output_buffer += "β Dataset creation failed.\n"
|
| 181 |
+
yield output_buffer
|
| 182 |
+
return
|
| 183 |
|
| 184 |
+
if len(dataset) > 1:
|
| 185 |
+
dataset = dataset.train_test_split(test_size=test_size, shuffle=shuffle_data)
|
| 186 |
+
else:
|
| 187 |
+
dataset = {"train": dataset, "test": dataset}
|
| 188 |
+
|
| 189 |
+
# 3. Run Inference
|
| 190 |
+
output_buffer += "\nπ Evaluating Model Success Rate on Test Split...\n"
|
| 191 |
+
yield output_buffer
|
| 192 |
|
| 193 |
+
for update in self._evaluate_model(dataset["test"]):
|
| 194 |
+
yield f"{output_buffer}{update}"
|
| 195 |
+
if self.stop_event.is_set():
|
| 196 |
+
yield f"{output_buffer}{update}\n\nπ Evaluation interrupted by user."
|
| 197 |
+
break
|
| 198 |
+
finally:
|
| 199 |
+
self.stop_event.set() # Ensure loop breaks if generator cancelled
|
| 200 |
|
| 201 |
# --- Training Pipeline ---
|
| 202 |
|
|
|
|
| 205 |
output_buffer = ""
|
| 206 |
last_plot = None
|
| 207 |
|
|
|
|
|
|
|
| 208 |
try:
|
| 209 |
+
# 1. Check Model
|
| 210 |
+
gen = self._ensure_model_consistency()
|
| 211 |
+
try:
|
| 212 |
+
while True:
|
| 213 |
+
msg = next(gen)
|
| 214 |
+
output_buffer += f"{msg}"
|
| 215 |
+
yield output_buffer, None
|
| 216 |
+
except StopIteration as e:
|
| 217 |
+
if not e.value: return
|
| 218 |
+
|
| 219 |
+
output_buffer += f"β³ Preparing Dataset (Test Split: {test_size}, Shuffle: {shuffle_data})...\n"
|
| 220 |
+
yield output_buffer, None
|
| 221 |
+
|
| 222 |
+
dataset, log = self._prepare_dataset()
|
| 223 |
+
if not dataset:
|
| 224 |
+
yield "Dataset creation failed.", None
|
| 225 |
+
return
|
| 226 |
+
|
| 227 |
+
output_buffer += log
|
| 228 |
+
yield output_buffer, None
|
| 229 |
+
|
| 230 |
+
if len(dataset) > 1:
|
| 231 |
+
dataset = dataset.train_test_split(test_size=test_size, shuffle=shuffle_data)
|
| 232 |
+
else:
|
| 233 |
+
dataset = {"train": dataset, "test": dataset}
|
| 234 |
|
| 235 |
+
# --- Training (Threaded) ---
|
| 236 |
+
output_buffer += f"\nπ Starting Fine-tuning (Epochs: {epochs}, LR: {learning_rate})...\n"
|
| 237 |
+
yield output_buffer, None
|
| 238 |
|
| 239 |
+
log_queue = queue.Queue()
|
| 240 |
+
training_error = None
|
| 241 |
+
running_history = []
|
| 242 |
+
|
| 243 |
+
def train_wrapper():
|
| 244 |
+
nonlocal training_error
|
| 245 |
+
try:
|
| 246 |
+
self._execute_trainer(dataset, log_queue, epochs, learning_rate)
|
| 247 |
+
except Exception as e:
|
| 248 |
+
training_error = e
|
| 249 |
+
|
| 250 |
+
train_thread = threading.Thread(target=train_wrapper)
|
| 251 |
+
train_thread.start()
|
| 252 |
+
|
| 253 |
+
while train_thread.is_alive():
|
| 254 |
+
while not log_queue.empty():
|
| 255 |
+
payload = log_queue.get()
|
| 256 |
+
if isinstance(payload, tuple):
|
| 257 |
+
msg, log_data = payload
|
| 258 |
+
output_buffer += f"{msg}\n"
|
| 259 |
+
running_history.append(log_data)
|
| 260 |
+
try:
|
| 261 |
+
last_plot = self._generate_loss_plot(running_history)
|
| 262 |
+
yield output_buffer, last_plot
|
| 263 |
+
except Exception:
|
| 264 |
+
yield output_buffer, last_plot
|
| 265 |
+
else:
|
| 266 |
+
output_buffer += f"{payload}\n"
|
| 267 |
+
yield output_buffer, last_plot
|
| 268 |
|
| 269 |
+
if self.stop_event.is_set():
|
| 270 |
+
yield f"{output_buffer}π Stop signal sent. Waiting for trainer to wrap up...\n", last_plot
|
| 271 |
+
|
| 272 |
+
time.sleep(0.1)
|
| 273 |
+
|
| 274 |
+
train_thread.join()
|
| 275 |
+
|
| 276 |
+
self.has_model_tuned = True
|
| 277 |
+
|
| 278 |
while not log_queue.empty():
|
| 279 |
payload = log_queue.get()
|
| 280 |
if isinstance(payload, tuple):
|
| 281 |
msg, log_data = payload
|
| 282 |
output_buffer += f"{msg}\n"
|
| 283 |
running_history.append(log_data)
|
| 284 |
+
last_plot = self._generate_loss_plot(running_history)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
else:
|
| 286 |
output_buffer += f"{payload}\n"
|
| 287 |
+
yield output_buffer, last_plot
|
| 288 |
+
|
| 289 |
+
if training_error:
|
| 290 |
+
output_buffer += f"β Error during training: {training_error}\n"
|
| 291 |
+
yield output_buffer, last_plot
|
| 292 |
+
return
|
| 293 |
+
|
| 294 |
+
if self.stop_event.is_set():
|
| 295 |
+
output_buffer += "π Training manually stopped.\n"
|
| 296 |
+
yield output_buffer, last_plot
|
| 297 |
+
return
|
| 298 |
|
| 299 |
+
output_buffer += "β
Training finished.\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
yield output_buffer, last_plot
|
| 301 |
+
|
| 302 |
+
finally:
|
| 303 |
+
self.stop_event.set() # Ensure background thread stops if generator cancelled
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
|
| 305 |
def _prepare_dataset(self):
|
| 306 |
formatting_fn = partial(create_conversation_format, tools_list=self.current_tools)
|
|
|
|
| 440 |
|
| 441 |
return f"β
Success! Model uploaded to: {repo_url}"
|
| 442 |
except Exception as e:
|
| 443 |
+
return f"β Upload failed: {str(e)}"
|