import threading import torch import time import json import queue import matplotlib.pyplot as plt from functools import partial from typing import Generator, Optional, List, Dict from datasets import Dataset, load_dataset from trl import SFTConfig, SFTTrainer from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl from config import AppConfig from tools import DEFAULT_TOOLS from utils import ( authenticate_hf, load_model_and_tokenizer, create_conversation_format, parse_csv_dataset, zip_directory ) class AbortCallback(TrainerCallback): def __init__(self, stop_event: threading.Event): self.stop_event = stop_event def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): if self.stop_event.is_set(): control.should_training_stop = True class LogStreamingCallback(TrainerCallback): """ NEW: Intercepts training logs and pushes them to a queue so the main thread can display them in the UI. """ def __init__(self, log_queue: queue.Queue): self.log_queue = log_queue def _get_string(self, value): if isinstance(value, float): return f"{value:.4f}" return str(value) def on_log(self, args, state, control, logs=None, **kwargs): if not logs: return metrics_map = { "loss": "Loss", "eval_loss": "Eval Loss", "learning_rate": "LR", "epoch": "Epoch" } log_parts = [f"šŸ“ [Step {state.global_step}]"] for key, label in metrics_map.items(): if key in logs: val = logs[key] # Format floats: use scientific notation for very small numbers (like LR) if isinstance(val, (float, int)): val_str = f"{val:.4f}" if val > 1e-4 else f"{val:.2e}" else: val_str = str(val) log_parts.append(f"{label}: {val_str}") self.log_queue.put(" | ".join(log_parts)) class FunctionGemmaEngine: def __init__(self, config: AppConfig): self.config = config self.model = None self.tokenizer = None self.imported_dataset = [] self.stop_event = threading.Event() # NEW: State for tools self.current_tools = DEFAULT_TOOLS authenticate_hf(self.config.HF_TOKEN) try: self.refresh_data_and_model() except Exception as e: print(f"Initial load warning: {e}") # NEW: Methods to handle Tool Schema updates def get_tools_json(self) -> str: return json.dumps(self.current_tools, indent=2) def update_tools(self, json_str: str) -> str: try: new_tools = json.loads(json_str) if not isinstance(new_tools, list): return "Error: Schema must be a list of tool definitions." self.current_tools = new_tools return "āœ… Tool Schema Updated successfully." except json.JSONDecodeError as e: return f"āŒ JSON Error: {e}" except Exception as e: return f"āŒ Error: {e}" def refresh_data_and_model(self) -> str: self.imported_dataset = [] try: self.model, self.tokenizer = load_model_and_tokenizer(self.config.MODEL_NAME) return "Model and data reloaded. Ready." except Exception as e: self.model = None self.tokenizer = None return f"CRITICAL ERROR: Model failed to load. {e}" def load_csv(self, file_path: str) -> str: try: new_data = parse_csv_dataset(file_path) if not new_data: return "Error: File empty or format invalid." self.imported_dataset = new_data return f"Successfully imported {len(new_data)} samples." except Exception as e: return f"Import failed: {e}" def trigger_stop(self): self.stop_event.set() def run_training_pipeline(self, epochs: int, learning_rate: float, test_size: float, shuffle_data: bool) -> Generator[str, None, None]: if self.model is None: yield "Training failed: Model is not loaded.", None return self.stop_event.clear() output_buffer = f"ā³ Preparing Dataset (Test Split: {test_size}, Shuffle: {shuffle_data})...\n" yield output_buffer, None dataset, log = self._prepare_dataset() if not dataset: yield "Dataset creation failed.", None return output_buffer += log yield output_buffer, None if len(dataset) > 1: dataset = dataset.train_test_split(test_size=test_size, shuffle=shuffle_data) else: dataset = {"train": dataset, "test": dataset} # --- Phase 1: Pre-Training Eval --- output_buffer += "\nšŸ“Š Evaluating Pre-Training Success Rate...\n" yield output_buffer, None pre_training_report = "" for update in self._evaluate_model(dataset["test"]): pre_training_report = update if self.stop_event.is_set(): pre_training_report += "\n\nšŸ›‘ Manual Eval interrupted by user.\n" yield f"{output_buffer}{pre_training_report}", None break yield f"{output_buffer}{pre_training_report}", None if self.stop_event.is_set(): return output_buffer += pre_training_report # --- Phase 2: Training (Threaded) --- output_buffer += "\n\nšŸš€ Starting Fine-tuning (Epochs: {epochs}, LR: {learning_rate})...\n" yield output_buffer, None log_queue = queue.Queue() training_error = None training_history = [] # Function to run in the thread def train_wrapper(): nonlocal training_error, training_history try: training_history = self._execute_trainer(dataset, log_queue, epochs, learning_rate) except Exception as e: training_error = e # Start training thread train_thread = threading.Thread(target=train_wrapper) train_thread.start() # Monitor loop: Yields logs while training runs while train_thread.is_alive(): # Drain the queue while not log_queue.empty(): log_msg = log_queue.get() output_buffer += f"{log_msg}\n" yield output_buffer, None # Check for stop signal if self.stop_event.is_set(): yield f"{output_buffer}šŸ›‘ Stop signal sent. Waiting for trainer to wrap up...\n", None # We don't break here, we wait for thread to finish cleanly time.sleep(0.1) # Prevent CPU spinning train_thread.join() # Ensure thread is completely done # Flush any remaining logs while not log_queue.empty(): log_msg = log_queue.get() output_buffer += f"{log_msg}\n" yield output_buffer, None if training_error: output_buffer += f"āŒ Error during training: {training_error}\n" yield output_buffer, None return if self.stop_event.is_set(): output_buffer += "šŸ›‘ Training manually stopped.\n" yield output_buffer, None return output_buffer += "āœ… Training finished.\n" yield output_buffer, None output_buffer += "\nšŸ“ˆ Generating Loss Plot...\n" yield output_buffer, None try: final_plot = self._generate_loss_plot(training_history) yield output_buffer, final_plot except Exception as e: output_buffer += f"āš ļø Could not generate plot: {e}\n" yield output_buffer, None # --- Phase 3: Post-Training Eval --- output_buffer += "\nšŸ“Š Evaluating Post-Training Success Rate...\n" yield output_buffer, final_plot post_training_report = "" for update in self._evaluate_model(dataset["test"]): post_training_report = update if self.stop_event.is_set(): post_training_report += "\n\nšŸ›‘ Manual Eval interrupted by user.\n" yield f"{output_buffer}{post_training_report}", final_plot break yield f"{output_buffer}{post_training_report}", final_plot def _prepare_dataset(self): # NEW: Use partial to inject self.current_tools into the formatting function formatting_fn = partial(create_conversation_format, tools_list=self.current_tools) if not self.imported_dataset: ds = load_dataset(self.config.DEFAULT_DATASET, split="train").map(formatting_fn) log = f" `-> using default dataset (size:{len(ds)})\n" else: dataset_as_dicts = [{ "user_content": row[0], "tool_name": row[1], "tool_arguments": row[2]} for row in self.imported_dataset ] ds = Dataset.from_list(dataset_as_dicts).map(formatting_fn) log = f" `-> using custom dataset (size:{len(ds)})\n" return ds, log def _execute_trainer(self, dataset, log_queue: queue.Queue, epochs: int, learning_rate: float) -> List[Dict]: torch_dtype = self.model.dtype args = SFTConfig( output_dir=str(self.config.OUTPUT_DIR), max_length=512, packing=False, num_train_epochs=epochs, per_device_train_batch_size=4, logging_steps=1, save_strategy="no", eval_strategy="epoch", learning_rate=learning_rate, fp16=(torch_dtype == torch.float16), bf16=(torch_dtype == torch.bfloat16), report_to="none", dataset_kwargs={"add_special_tokens": False, "append_concat_token": True} ) trainer = SFTTrainer( model=self.model, args=args, train_dataset=dataset['train'], eval_dataset=dataset['test'], processing_class=self.tokenizer, callbacks=[ AbortCallback(self.stop_event), LogStreamingCallback(log_queue) ] ) trainer.train() trainer.save_model() return trainer.state.log_history def _generate_loss_plot(self, history: list): if not history: return None # Extract Training Loss # log_history format: [{'loss': 0.5, 'step': 1}, {'eval_loss': 0.4, 'step': 1}, ...] train_steps = [x['step'] for x in history if 'loss' in x] train_loss = [x['loss'] for x in history if 'loss' in x] # Extract Validation Loss eval_steps = [x['step'] for x in history if 'eval_loss' in x] eval_loss = [x['eval_loss'] for x in history if 'eval_loss' in x] fig, ax = plt.subplots(figsize=(10, 5)) if train_steps: ax.plot(train_steps, train_loss, label='Training Loss', linestyle='-', marker=None) if eval_steps: ax.plot(eval_steps, eval_loss, label='Validation Loss', linestyle='--', marker='o') ax.set_xlabel("Steps") ax.set_ylabel("Loss") ax.set_title("Training & Validation Loss") ax.legend() ax.grid(True, linestyle=':', alpha=0.6) plt.tight_layout() return fig def _evaluate_model(self, test_dataset) -> Generator[str, None, None]: results = [] success_count = 0 for idx, item in enumerate(test_dataset): messages = item["messages"][:2] try: # NEW: Pass self.current_tools to the template inputs = self.tokenizer.apply_chat_template( messages, tools=self.current_tools, add_generation_prompt=True, return_dict=True, return_tensors="pt" ) device = self.model.device inputs = {k: v.to(device) for k, v in inputs.items()} out = self.model.generate( **inputs, pad_token_id=self.tokenizer.eos_token_id, max_new_tokens=128 ) output = self.tokenizer.decode(out[0][len(inputs["input_ids"][0]):], skip_special_tokens=True) log_entry = f"{idx+1}. Prompt: {messages[1]['content']}\n Output: {output[:100]}..." # Check tool correctness expected_tool = item['messages'][2]['tool_calls'][0]['function']['name'] if expected_tool in output: log_entry += "\n -> āœ… Correct Tool" success_count += 1 else: log_entry += f"\n -> āŒ Wrong Tool (Expected: {expected_tool})" results.append(log_entry) yield "\n".join(results) + f"\n\nRunning Success Rate: {success_count}/{idx+1}" except Exception as e: yield f"Error during inference: {e}" def get_zip_path(self) -> Optional[str]: if not self.config.OUTPUT_DIR.exists(): return None timestamp = int(time.time()) base_name = str(self.config.ARTIFACTS_DIR.joinpath(f"functiongemma_finetuned_{timestamp}")) return zip_directory(str(self.config.OUTPUT_DIR), base_name)