| import gradio as gr |
| import os |
| import json |
| import torch |
| import csv |
| import shutil |
| import time |
| import threading |
|
|
| from typing import Final, Optional, List, Any, Generator |
| from pathlib import Path |
| from dataclasses import dataclass |
|
|
| from huggingface_hub import login |
| from trl import SFTConfig, SFTTrainer |
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForCausalLM, |
| TrainerCallback, |
| TrainingArguments, |
| TrainerControl, |
| TrainerState |
| ) |
| from datasets import Dataset, load_dataset |
|
|
| |
| class AppConfig: |
| """ |
| Central configuration class. |
| """ |
| ARTIFACTS_DIR: Final[Path] = Path("artifacts") |
| ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True) |
| |
| HF_TOKEN: Final[Optional[str]] = os.getenv('HF_TOKEN') |
| MODEL_NAME: Final[str] = '../hf/270m' |
| DEFAULT_DATASET: Final[str] = 'bebechien/SimpleToolCalling' |
| OUTPUT_DIR: Final[Path] = ARTIFACTS_DIR.joinpath("functiongemma-modkit-demo") |
|
|
|
|
| |
| def search_knowledge_base(query: str) -> str: |
| """ |
| Search internal company documents, policies and project data. |
| |
| Args: |
| query: query string |
| """ |
| return "Interal Result" |
|
|
| def search_google(query: str) -> str: |
| """ |
| Search public information. |
| |
| Args: |
| query: query string |
| """ |
| return "Public Result" |
|
|
| search_knowledge_base_schema = { |
| "type": "function", |
| "function": { |
| "name": "search_knowledge_base", |
| "description": "Search internal company documents, policies and project data.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "query": { |
| "type": "string", |
| "description": "query string" |
| } |
| }, |
| "required": [ |
| "query" |
| ] |
| }, |
| "return": { |
| "type": "string" |
| } |
| } |
| } |
|
|
| search_google_schema = { |
| "type": "function", |
| "function": { |
| "name": "search_google", |
| "description": "Search public information.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "query": { |
| "type": "string", |
| "description": "query string" |
| } |
| }, |
| "required": [ |
| "query" |
| ] |
| }, |
| "return": { |
| "type": "string" |
| } |
| } |
| } |
|
|
| TOOLS = [search_knowledge_base_schema, search_google_schema] |
| DEFAULT_SYSTEM_MSG = "You are a model that can do function calling with the following functions" |
|
|
| |
| class AbortCallback(TrainerCallback): |
| """ |
| A custom callback to check a threading Event to stop training on user request. |
| """ |
| def __init__(self, stop_event: threading.Event): |
| self.stop_event = stop_event |
|
|
| def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): |
| if self.stop_event.is_set(): |
| print("π Stop signal received. Stopping training...") |
| control.should_training_stop = True |
|
|
|
|
| |
| def authenticate_hf(token: Optional[str]) -> None: |
| """Logs into the Hugging Face Hub.""" |
| if token: |
| print("Logging into Hugging Face Hub...") |
| login(token=token) |
| else: |
| print("Skipping Hugging Face login: HF_TOKEN not set.") |
|
|
| def load_model_and_tokenizer(model_name: str): |
| print(f"Loading Transformer model: {model_name}") |
| try: |
| |
| if model_name.startswith("..") and not os.path.exists(model_name): |
| print(f"Warning: Local path {model_name} not found. Falling back to default hub model.") |
| model_name = "google/gemma-2b-it" |
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModelForCausalLM.from_pretrained(model_name) |
| print("Model loaded successfully.") |
| return model, tokenizer |
| except Exception as e: |
| print(f"Error loading Transformer model {model_name}: {e}") |
| raise e |
|
|
| def create_conversation_format(sample): |
| """Formats a dataset row into the conversational format required for SFT.""" |
| try: |
| tool_args = json.loads(sample["tool_arguments"]) |
| except (json.JSONDecodeError, TypeError): |
| tool_args = {} |
| |
| return { |
| "messages": [ |
| {"role": "developer", "content": DEFAULT_SYSTEM_MSG}, |
| {"role": "user", "content": sample["user_content"]}, |
| {"role": "assistant", "tool_calls": [{"type": "function", "function": {"name": sample["tool_name"], "arguments": tool_args}}]}, |
| ], |
| "tools": TOOLS |
| } |
|
|
|
|
| |
| class FunctionGemmaTuner: |
| def __init__(self, config: AppConfig = AppConfig): |
| self.config = config |
| self.model = None |
| self.tokenizer = None |
| self.imported_dataset = [] |
| |
| |
| self.stop_event = threading.Event() |
|
|
| authenticate_hf(self.config.HF_TOKEN) |
| |
| |
| print("--- Running Initial Data Load ---") |
| try: |
| self.refresh_data_and_model() |
| print("--- Initial Load Complete ---") |
| except Exception as e: |
| print(f"Initial load failed (this is common if model path is invalid): {e}") |
|
|
| def refresh_data_and_model(self): |
| """Reloads the model and clears imported data.""" |
| print("\n" + "=" * 50) |
| print("RELOADING MODEL and RE-FETCHING DATA") |
|
|
| self.imported_dataset = [] |
|
|
| try: |
| self.model, self.tokenizer = load_model_and_tokenizer(self.config.MODEL_NAME) |
| status_value = "Model and data reloaded. Ready." |
| except Exception as e: |
| self.model = None |
| self.tokenizer = None |
| status_value = f"CRITICAL ERROR: Model failed to load. {e}" |
| |
| |
| return status_value |
|
|
| def import_additional_dataset(self, file_path: str) -> str: |
| """Parses an uploaded CSV file.""" |
| if not file_path: |
| return "Please upload a CSV file." |
| |
| new_dataset = [] |
| num_imported = 0 |
| |
| try: |
| |
| with open(file_path, 'r', newline='', encoding='utf-8') as f: |
| reader = csv.reader(f) |
| |
| |
| try: |
| header = next(reader) |
| |
| if not (header and "anchor" in header[0].lower()): |
| f.seek(0) |
| except StopIteration: |
| return "Error: Uploaded file is empty." |
|
|
| for row in reader: |
| |
| if len(row) >= 3: |
| new_dataset.append([s.strip() for s in row[:3]]) |
| num_imported += 1 |
| |
| if num_imported == 0: |
| return "No valid rows found. CSV format: [Anchor, Positive, Negative]" |
| |
| self.imported_dataset = new_dataset |
| return f"Successfully imported {num_imported} additional training samples." |
| |
| except Exception as e: |
| return f"Import failed. Error: {e}" |
|
|
| def stop_training(self): |
| """Signal the training loop to stop.""" |
| print("Set stop event") |
| self.stop_event.set() |
| return "Stopping initiated... please wait for the current step to finish." |
|
|
| def run_training(self, test_size: float = 0.5) -> Generator[str, None, None]: |
| """ |
| Main training logic. Yields status strings to the UI. |
| """ |
| |
| if self.model is None: |
| yield "Training failed: Model is not loaded." |
| return |
|
|
| self.stop_event.clear() |
| yield "β³ Preparing Dataset..." |
|
|
| |
| if not self.imported_dataset: |
| print("No imported dataset, using default HF dataset") |
| try: |
| dataset = load_dataset(self.config.DEFAULT_DATASET, split="train") |
| except Exception as e: |
| yield f"Error loading default dataset: {e}" |
| return |
| else: |
| dataset_as_dicts = [{ |
| "user_content": row[0], "tool_name": row[1], "tool_arguments": row[2]} |
| for row in self.imported_dataset |
| ] |
| dataset = Dataset.from_list(dataset_as_dicts) |
|
|
| |
| dataset = dataset.map(create_conversation_format, batched=False) |
| |
| |
| if len(dataset) > 1: |
| dataset = dataset.train_test_split(test_size=test_size, shuffle=False) |
| else: |
| |
| dataset = {"train": dataset, "test": dataset} |
|
|
| output_buffer = "π Evaluating Pre-Training Success Rate...\n### Success Rate (Before Training):\n" |
| yield output_buffer |
| pre_training_report = "" |
| gen = self.check_success_rate(dataset["test"]) |
| while not self.stop_event.is_set(): |
| try: |
| pre_training_report += f"{next(gen)}\n" |
| yield f"{output_buffer}{pre_training_report}" |
| except StopIteration as e: |
| pre_training_report = e.value |
| break |
| |
| if self.stop_event.is_set(): |
| output_buffer += f"{pre_training_report}\n\nπ Manual Eval interrupted by user.\n" |
| yield output_buffer |
| return |
| |
| output_buffer += f"{pre_training_report}\n\n" |
| output_buffer += "-" * 30 + "\nStarting Fine-tuning...\n" |
| yield output_buffer |
|
|
| |
| torch_dtype = self.model.dtype |
| |
| args = SFTConfig( |
| output_dir=str(self.config.OUTPUT_DIR), |
| max_length=512, |
| packing=False, |
| num_train_epochs=5, |
| per_device_train_batch_size=4, |
| gradient_checkpointing=False, |
| optim="adamw_torch_fused", |
| logging_steps=1, |
| save_strategy="no", |
| eval_strategy="epoch", |
| learning_rate=5e-5, |
| fp16=True if torch_dtype == torch.float16 else False, |
| bf16=True if torch_dtype == torch.bfloat16 else False, |
| lr_scheduler_type="constant", |
| push_to_hub=False, |
| report_to="none", |
| dataset_kwargs={ |
| "add_special_tokens": False, |
| "append_concat_token": True, |
| } |
| ) |
|
|
| trainer = SFTTrainer( |
| model=self.model, |
| args=args, |
| train_dataset=dataset['train'], |
| eval_dataset=dataset['test'], |
| processing_class=self.tokenizer, |
| callbacks=[AbortCallback(self.stop_event)] |
| ) |
|
|
| |
| try: |
| output_buffer += "π Training in progress... (Click Stop to interrupt)\n" |
| yield output_buffer |
| trainer.train() |
| |
| if self.stop_event.is_set(): |
| output_buffer += "\nπ Training interrupted by user.\n" |
| else: |
| output_buffer += "\nβ
Training finished. Model weights updated in memory.\n" |
| yield output_buffer |
| |
| |
| trainer.save_model() |
| output_buffer += f"Model saved locally to: {self.config.OUTPUT_DIR}\n" |
| yield output_buffer |
|
|
| except Exception as e: |
| output_buffer += f"\nβ Error during training: {e}\n" |
| yield output_buffer |
| return |
|
|
| if self.stop_event.is_set(): |
| return |
|
|
| |
| output_buffer += "π Evaluating Post-Training Success Rate...\n" |
| post_report = "" |
| yield output_buffer |
| gen = self.check_success_rate(dataset["test"]) |
| while not self.stop_event.is_set(): |
| try: |
| post_report += f"{next(gen)}\n" |
| yield f"{output_buffer}{post_report}" |
| except StopIteration as e: |
| post_report = e.value |
| break |
| |
| if self.stop_event.is_set(): |
| output_buffer += f"{post_report}\n\nπ Manual Eval interrupted by user.\n" |
| yield output_buffer |
| return |
|
|
| output_buffer += f"{post_report}\n\n" |
| yield output_buffer |
|
|
| def check_success_rate(self, test_dataset): |
| """Runs inference on test set to calculate accuracy.""" |
| results = [] |
| success_count = 0 |
| total = len(test_dataset) |
| |
| for idx, item in enumerate(test_dataset): |
| if idx >= 5: |
| break |
| if self.stop_event.is_set(): |
| break |
|
|
| messages = [item["messages"][0], item["messages"][1]] |
|
|
| try: |
| inputs = self.tokenizer.apply_chat_template( |
| messages, |
| tools=TOOLS, |
| add_generation_prompt=True, |
| return_dict=True, |
| return_tensors="pt" |
| ) |
|
|
| out = self.model.generate( |
| **inputs.to(self.model.device), |
| pad_token_id=self.tokenizer.eos_token_id, |
| max_new_tokens=128 |
| ) |
| |
| |
| output = self.tokenizer.decode(out[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True) |
|
|
| results.append(f"{idx+1}. Prompt: {item['messages'][1]['content']}") |
| yield results[-1] |
| results.append(f" Output: {output[:100]}...") |
| yield results[-1] |
| |
| |
| expected_tool = item['messages'][2]['tool_calls'][0]['function']['name'] |
| if expected_tool in output: |
| results.append(" -> β
Correct Tool") |
| yield results[-1] |
| success_count += 1 |
| else: |
| results.append(f" -> β Wrong Tool (Expected: {expected_tool})") |
| yield results[-1] |
|
|
| except Exception as e: |
| results.append(f" -> Error: {e}") |
| yield results[-1] |
|
|
| summary = "\n".join(results) |
| summary += f"\n\nTotal Success : {success_count} / {len(test_dataset)}" |
| return summary |
|
|
| def download_model_zip(self) -> Optional[str]: |
| """Zips the output directory for download.""" |
| if not os.path.exists(self.config.OUTPUT_DIR): |
| return None |
| |
| timestamp = int(time.time()) |
| try: |
| base_name = self.config.ARTIFACTS_DIR.joinpath(f"functiongemma_finetuned_{timestamp}") |
| archive_path = shutil.make_archive( |
| base_name=str(base_name), |
| format='zip', |
| root_dir=str(self.config.OUTPUT_DIR), |
| ) |
| return archive_path |
| except Exception as e: |
| print(f"Zip failed: {e}") |
| return None |
|
|
| |
| def build_interface(self) -> gr.Blocks: |
| with gr.Blocks(title="FunctionGemma Modkit") as demo: |
| gr.Markdown("# π€ FunctionGemma Modkit: Fine-Tuning") |
| gr.Markdown("Fine-tune FunctionGemma to understand your custom functions.") |
| |
| with gr.Column(): |
| gr.Markdown("## 1. Training Controls") |
| |
| with gr.Row(): |
| run_training_btn = gr.Button("π Run Fine-Tuning", variant="primary") |
| stop_training_btn = gr.Button("π Stop Training", variant="stop", visible=False) |
| |
| output_display = gr.Textbox( |
| lines=14, |
| label="Training Logs & Search Results", |
| value="Ready. Click 'Run' to begin.", |
| interactive=False |
| ) |
| |
| clear_reload_btn = gr.Button("π Reset Model & Data") |
|
|
| gr.Markdown("--- \n ## 2. Data Management") |
| import_file = gr.File(label="Upload Additional Dataset (.csv)", file_types=[".csv"], height=80) |
| import_status = gr.Markdown("") |
| |
| gr.Markdown("--- \n ## 3. Export") |
| with gr.Row(): |
| zip_btn = gr.Button("β¬οΈ Prepare Model ZIP") |
| download_file = gr.File(label="Download ZIP", height=80, visible=True, interactive=False) |
|
|
| |
| |
| |
| run_training_btn.click( |
| fn=lambda: ( |
| gr.update(visible=False), |
| gr.update(interactive=False), |
| gr.update(visible=True) |
| ), |
| inputs=None, |
| outputs=[run_training_btn, clear_reload_btn, stop_training_btn] |
| ).then( |
| fn=self.run_training, |
| inputs=[], |
| outputs=[output_display], |
| ).then( |
| fn=lambda: ( |
| gr.update(visible=True), |
| gr.update(interactive=True), |
| gr.update(visible=False) |
| ), |
| inputs=None, |
| outputs=[run_training_btn, clear_reload_btn, stop_training_btn] |
| ) |
| |
| |
| stop_training_btn.click( |
| fn=self.stop_training, |
| inputs=None, |
| outputs=None |
| ).then( |
| fn=lambda: ( |
| gr.update(visible=True), |
| gr.update(interactive=True), |
| gr.update(visible=False) |
| ), |
| inputs=None, |
| outputs=[run_training_btn, clear_reload_btn, stop_training_btn] |
| ) |
|
|
| |
| clear_reload_btn.click( |
| fn=self.refresh_data_and_model, |
| inputs=None, |
| outputs=[output_display] |
| ) |
| |
| |
| import_file.upload( |
| fn=self.import_additional_dataset, |
| inputs=[import_file], |
| outputs=[import_status] |
| ) |
| |
| |
| def handle_zip(): |
| path = self.download_model_zip() |
| if path: |
| return gr.update(value=path, visible=True) |
| return gr.update(value=None, visible=False) |
|
|
| zip_btn.click( |
| fn=handle_zip, |
| inputs=None, |
| outputs=[download_file] |
| ) |
|
|
| return demo |
|
|
| if __name__ == "__main__": |
| app = FunctionGemmaTuner(AppConfig) |
| demo = app.build_interface() |
| print("Starting Gradio App...") |
| demo.launch() |
|
|