bebechien commited on
Commit
c055e6e
Β·
verified Β·
1 Parent(s): 6be610b

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +455 -204
app.py CHANGED
@@ -2,36 +2,40 @@ import gradio as gr
2
  import os
3
  import json
4
  import torch
 
 
 
 
5
 
6
- from typing import Final, Optional, List
7
  from pathlib import Path
 
 
8
  from huggingface_hub import login
9
  from trl import SFTConfig, SFTTrainer
10
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
 
 
11
  from datasets import Dataset, load_dataset
12
- from transformers.utils import get_json_schema
13
-
14
 
15
- ARTIFACTS_DIR: Final[Path] = Path("artifacts")
16
-
17
- def authenticate_hf(token: Optional[str]) -> None:
18
- """Logs into the Hugging Face Hub."""
19
- if token:
20
- print("Logging into Hugging Face Hub...")
21
- login(token=token)
22
- else:
23
- print("Skipping Hugging Face login: HF_TOKEN not set.")
 
 
 
24
 
25
- def load_model(model_name: str):
26
- print(f"Loading Transformer model: {model_name}")
27
- try:
28
- tokenizer = AutoTokenizer.from_pretrained(model_name)
29
- model = AutoModelForCausalLM.from_pretrained(model_name)
30
- print("Model loaded successfully.")
31
- return (model, tokenizer)
32
- except Exception as e:
33
- print(f"Error loading Transformer model {model_name}: {e}")
34
- raise
35
 
36
  # --- Tool Definitions ---
37
  def search_knowledge_base(query: str) -> str:
@@ -52,178 +56,213 @@ def search_google(query: str) -> str:
52
  """
53
  return "Public Result"
54
 
55
-
56
- TOOLS = [get_json_schema(search_knowledge_base), get_json_schema(search_google)]
57
-
58
- DEFAULT_SYSTEM_MSG = "You are a model that can do function calling with the following functions"
59
-
60
-
61
- def create_conversation(sample):
62
- return {
63
- "messages": [
64
- {"role": "developer", "content": DEFAULT_SYSTEM_MSG},
65
- {"role": "user", "content": sample["user_content"]},
66
- {"role": "assistant", "tool_calls": [{"type": "function", "function": {"name": sample["tool_name"], "arguments": json.loads(sample["tool_arguments"])}}]},
67
- ],
68
- "tools": TOOLS
 
 
 
 
 
 
69
  }
70
-
71
-
72
- def train_with_dataset(
73
- model: AutoModelForCausalLM,
74
- tokenizer: AutoTokenizer,
75
- dataset: Dataset,
76
- output_dir: Path,
77
- learning_rate: float = 5e-5
78
- ) -> None:
79
-
80
- torch_dtype = model.dtype
81
-
82
- args = SFTConfig(
83
- output_dir=output_dir, # directory to save and repository id
84
- max_length=512, # max sequence length for model and packing of the dataset
85
- packing=False, # Groups multiple samples in the dataset into a single sequence
86
- num_train_epochs=5, # number of training epochs
87
- per_device_train_batch_size=4, # batch size per device during training
88
- gradient_checkpointing=False, # Caching is incompatible with gradient checkpointing
89
- optim="adamw_torch_fused", # use fused adamw optimizer
90
- logging_steps=1, # log every step
91
- #save_strategy="epoch", # save checkpoint every epoch
92
- eval_strategy="epoch", # evaluate checkpoint every epoch
93
- learning_rate=learning_rate, # learning rate
94
- fp16=True if torch_dtype == torch.float16 else False, # use float16 precision
95
- bf16=True if torch_dtype == torch.bfloat16 else False, # use bfloat16 precision
96
- lr_scheduler_type="constant", # use constant learning rate scheduler
97
- push_to_hub=False, # push model to hub
98
- report_to="none", # report metrics to tensorboard
99
- dataset_kwargs={
100
- "add_special_tokens": False, # Template with special tokens
101
- "append_concat_token": True, # Add EOS token as separator token between examples
102
  }
103
- )
 
 
 
 
 
 
 
 
 
104
 
105
- # Create Trainer object
106
- trainer = SFTTrainer(
107
- model=model,
108
- args=args,
109
- train_dataset=dataset['train'],
110
- eval_dataset=dataset['test'],
111
- processing_class=tokenizer,
112
- )
113
 
114
- trainer.train()
 
 
 
 
 
 
115
 
116
- print("Training finished. Model weights are updated in memory.")
 
 
 
117
 
118
- # Save the final fine-tuned model
119
- trainer.save_model()
120
 
121
- print(f"Model saved locally to: {output_dir}")
 
 
 
 
 
 
 
122
 
123
- class AppConfig:
124
- """
125
- Central configuration class for the Fine-Tuner application.
126
- """
127
- ARTIFACTS_DIR: Final[Path] = ARTIFACTS_DIR
128
- HF_TOKEN: Final[str | None] = os.getenv('HF_TOKEN')
129
- MODEL_NAME: Final[str] = '../hf/270m'
130
- DEFAULT_DATASET: Final[str] = 'bebechien/SimpleToolCalling'
131
- OUTPUT_DIR: Final[Path] = ARTIFACTS_DIR.joinpath("functiongemma-270m-it-modkit-demo")
132
 
 
 
 
 
 
 
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  class FunctionGemmaTuner:
135
  def __init__(self, config: AppConfig = AppConfig):
136
  self.config = config
137
-
138
- os.makedirs(self.config.ARTIFACTS_DIR, exist_ok=True)
139
- print(f"Created artifact directory: {self.config.ARTIFACTS_DIR}")
 
 
 
140
 
141
  authenticate_hf(self.config.HF_TOKEN)
142
-
143
- self._initial_load()
144
-
145
- def _initial_load(self):
146
- """Helper to run the refresh function once at startup."""
147
  print("--- Running Initial Data Load ---")
148
- self.refresh_data_and_model()
149
- print("--- Initial Load Complete ---")
 
 
 
150
 
151
  def refresh_data_and_model(self):
 
152
  print("\n" + "=" * 50)
153
  print("RELOADING MODEL and RE-FETCHING DATA")
154
 
155
- # Reset dataset state
156
  self.imported_dataset = []
157
 
158
- # 1. Reload the base model
159
  try:
160
- self.model, self.tokenizer = load_model(self.config.MODEL_NAME)
 
161
  except Exception as e:
162
- gr.Error(f"Model load failed: {e}")
163
  self.model = None
164
  self.tokenizer = None
165
- return gr.update(value=f"CRITICAL ERROR: Model failed to load. {e}")
166
-
167
- status_value: str = f"Model and data reloaded. Click 'Run Fine-Tuning' to begin."
168
-
169
- # Return Gradio updates for CheckboxGroup and Textbox
170
- return gr.update(value=status_value)
171
 
172
- # --- Import Dataset/Export ---
173
  def import_additional_dataset(self, file_path: str) -> str:
 
174
  if not file_path:
175
  return "Please upload a CSV file."
176
- new_dataset, num_imported = [], 0
 
 
 
177
  try:
 
178
  with open(file_path, 'r', newline='', encoding='utf-8') as f:
179
  reader = csv.reader(f)
 
 
180
  try:
181
  header = next(reader)
182
- if not (header and header[0].lower().strip() == 'anchor'):
 
183
  f.seek(0)
184
  except StopIteration:
185
  return "Error: Uploaded file is empty."
186
 
187
  for row in reader:
188
- if len(row) == 3:
189
- new_dataset.append([s.strip() for s in row])
 
190
  num_imported += 1
 
191
  if num_imported == 0:
192
- raise ValueError("No valid [Anchor, Positive, Negative] rows found in the CSV.")
 
193
  self.imported_dataset = new_dataset
194
- return f"Successfully imported {num_imported} additional training triplets."
 
195
  except Exception as e:
196
- gr.Error(f"Import failed. Ensure the CSV format is: [Anchor, Positive, Negative]. Error: {e}")
197
- return "Import failed. Check console for details."
198
 
199
- def download_model(self) -> Optional[str]:
200
- if not os.path.exists(self.config.OUTPUT_DIR):
201
- gr.Warning(f"The model directory '{self.config.OUTPUT_DIR}' does not exist. Please run training first.")
202
- return None
203
- timestamp = int(time.time())
204
- try:
205
- base_name = os.path.join(self.config.ARTIFACTS_DIR, f"embedding_gemma_finetuned_{timestamp}")
206
- archive_path = shutil.make_archive(
207
- base_name=base_name,
208
- format='zip',
209
- root_dir=self.config.OUTPUT_DIR,
210
- )
211
- gr.Info(f"Model files successfully zipped to: {archive_path}")
212
- return archive_path
213
- except Exception as e:
214
- gr.Error(f"Failed to create the model ZIP file. Error: {e}")
215
- return None
216
 
217
- def training(self, test_size: float = 0.5) -> str:
218
  """
219
- Generates a training dataset from user selection and runs the fine-tuning process.
220
  """
 
221
  if self.model is None:
222
- raise gr.Error("Training failed: Model is not loaded.")
 
223
 
 
 
 
 
224
  if not self.imported_dataset:
225
- print("No imported dataset, use the default")
226
- dataset = load_dataset(self.config.DEFAULT_DATASET, split="train")
 
 
 
 
227
  else:
228
  dataset_as_dicts = [{
229
  "user_content": row[0], "tool_name": row[1], "tool_arguments": row[2]}
@@ -231,81 +270,293 @@ class FunctionGemmaTuner:
231
  ]
232
  dataset = Dataset.from_list(dataset_as_dicts)
233
 
234
- dataset = dataset.map(create_conversation, batched=False)
235
- dataset = dataset.train_test_split(test_size=test_size, shuffle=False)
236
- print(dataset)
237
- print("--- dataset input ---")
238
- print(json.dumps(dataset["train"][0], indent=2))
239
- debug_msg = self.tokenizer.apply_chat_template(dataset["train"][0]["messages"], tools=dataset["train"][0]["tools"], add_generation_prompt=False, tokenize=False)
240
- print("--- Formatted prompt ---")
241
- print(debug_msg)
242
-
243
- result = "### Success Rate (Before Training):\n" + f"{self.check_success_rate(dataset["test"])}\n\n"
244
- print("-" * 50 + "\nStarting Fine-tuning...")
245
- train_with_dataset(model=self.model, tokenizer=self.tokenizer, dataset=dataset, output_dir=self.config.OUTPUT_DIR)
246
- print("Fine-tuning Complete.\n" + "-" * 50)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
- result += "### Success Rate (After Training):\n" + f"{self.check_success_rate(dataset["test"])}\n\n"
249
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
  def check_success_rate(self, test_dataset):
252
- result = []
 
253
  success_count = 0
 
 
254
  for idx, item in enumerate(test_dataset):
255
- messages = [
256
- item["messages"][0],
257
- item["messages"][1],
258
- ]
259
-
260
- inputs = self.tokenizer.apply_chat_template(messages, tools=TOOLS, add_generation_prompt=True, return_dict=True, return_tensors="pt")
261
-
262
- out = self.model.generate(**inputs.to(self.model.device), pad_token_id=self.tokenizer.eos_token_id, max_new_tokens=128)
263
- output = self.tokenizer.decode(out[0][len(inputs["input_ids"][0]) :], skip_special_tokens=False)
264
-
265
- result.append(f"{idx+1} Prompt: {item['messages'][1]['content']}")
266
- result.append(f" Output: {output}")
267
- if item['messages'][2]['tool_calls'][0]['function']['name'] in output:
268
- result.append(" `-> βœ… correct!")
269
- success_count += 1
270
- else:
271
- result.append(" `-> ❌ wrong tool")
272
-
273
- result.append(f"Success : {success_count} / {len(test_dataset)}")
274
-
275
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
 
277
  def build_interface(self) -> gr.Blocks:
278
  with gr.Blocks(title="FunctionGemma Modkit") as demo:
279
  gr.Markdown("# πŸ€– FunctionGemma Modkit: Fine-Tuning")
280
- gr.Markdown("This project provides a set of tools to fine-tune FunctionGemma to understand your personal needs.<br>See [README](https://huggingface.co/spaces/google/functiongemma-modkit/blob/main/README.md) for more details.")
281
- self._build_training_interface()
282
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
- def _build_training_interface(self):
285
- with gr.Column():
286
- gr.Markdown("## Fine-Tuning")
287
- with gr.Row():
288
- output = gr.Textbox(lines=14, label="Training and Search Results", value="Click 'Run Fine-Tuning' to begin.")
289
- with gr.Row():
290
- clear_reload_btn = gr.Button("Clear & Reload Model/Data")
291
- run_training_btn = gr.Button("πŸš€ Run Fine-Tuning", variant="primary")
292
- gr.Markdown("--- \n ## Dataset & Model Management")
293
- import_file = gr.File(label="Upload Additional Dataset (.csv)", file_types=[".csv"], height=50)
294
- with gr.Row():
295
- download_model_btn = gr.Button("⬇️ Download Fine-Tuned Model")
296
- download_status = gr.Markdown("Ready.")
297
- with gr.Row():
298
- model_output = gr.File(label="Download Model ZIP", height=50, visible=False, interactive=False)
299
-
300
- run_training_btn.click(fn=self.training, outputs=output)
301
- clear_reload_btn.click(fn=self.refresh_data_and_model, inputs=None, outputs=[output], queue=False)
302
- import_file.change(fn=self.import_additional_dataset, inputs=[import_file], outputs=download_status)
303
- download_model_btn.click(lambda: [gr.update(value=None, visible=False), "Zipping..."], None, [model_output, download_status], queue=False).then(self.download_model, None, model_output).then(lambda p: [gr.update(visible=p is not None, value=p), "ZIP ready." if p else "Zipping failed."], [model_output], [model_output, download_status])
 
 
 
 
 
 
304
 
 
305
 
306
  if __name__ == "__main__":
307
  app = FunctionGemmaTuner(AppConfig)
308
  demo = app.build_interface()
309
  print("Starting Gradio App...")
310
  demo.launch()
311
-
 
2
  import os
3
  import json
4
  import torch
5
+ import csv
6
+ import shutil
7
+ import time
8
+ import threading
9
 
10
+ from typing import Final, Optional, List, Any, Generator
11
  from pathlib import Path
12
+ from dataclasses import dataclass
13
+
14
  from huggingface_hub import login
15
  from trl import SFTConfig, SFTTrainer
16
+ from transformers import (
17
+ AutoTokenizer,
18
+ AutoModelForCausalLM,
19
+ TrainerCallback,
20
+ TrainingArguments,
21
+ TrainerControl,
22
+ TrainerState
23
+ )
24
  from datasets import Dataset, load_dataset
 
 
25
 
26
+ # --- Configuration ---
27
+ class AppConfig:
28
+ """
29
+ Central configuration class.
30
+ """
31
+ ARTIFACTS_DIR: Final[Path] = Path("artifacts")
32
+ ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
33
+
34
+ HF_TOKEN: Final[Optional[str]] = os.getenv('HF_TOKEN')
35
+ MODEL_NAME: Final[str] = '../hf/270m'
36
+ DEFAULT_DATASET: Final[str] = 'bebechien/SimpleToolCalling'
37
+ OUTPUT_DIR: Final[Path] = ARTIFACTS_DIR.joinpath("functiongemma-modkit-demo")
38
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  # --- Tool Definitions ---
41
  def search_knowledge_base(query: str) -> str:
 
56
  """
57
  return "Public Result"
58
 
59
+ search_knowledge_base_schema = {
60
+ "type": "function",
61
+ "function": {
62
+ "name": "search_knowledge_base",
63
+ "description": "Search internal company documents, policies and project data.",
64
+ "parameters": {
65
+ "type": "object",
66
+ "properties": {
67
+ "query": {
68
+ "type": "string",
69
+ "description": "query string"
70
+ }
71
+ },
72
+ "required": [
73
+ "query"
74
+ ]
75
+ },
76
+ "return": {
77
+ "type": "string"
78
+ }
79
  }
80
+ }
81
+
82
+ search_google_schema = {
83
+ "type": "function",
84
+ "function": {
85
+ "name": "search_google",
86
+ "description": "Search public information.",
87
+ "parameters": {
88
+ "type": "object",
89
+ "properties": {
90
+ "query": {
91
+ "type": "string",
92
+ "description": "query string"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  }
94
+ },
95
+ "required": [
96
+ "query"
97
+ ]
98
+ },
99
+ "return": {
100
+ "type": "string"
101
+ }
102
+ }
103
+ }
104
 
105
+ TOOLS = [search_knowledge_base_schema, search_google_schema]
106
+ DEFAULT_SYSTEM_MSG = "You are a model that can do function calling with the following functions"
 
 
 
 
 
 
107
 
108
+ # --- Callbacks ---
109
+ class AbortCallback(TrainerCallback):
110
+ """
111
+ A custom callback to check a threading Event to stop training on user request.
112
+ """
113
+ def __init__(self, stop_event: threading.Event):
114
+ self.stop_event = stop_event
115
 
116
+ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
117
+ if self.stop_event.is_set():
118
+ print("πŸ›‘ Stop signal received. Stopping training...")
119
+ control.should_training_stop = True
120
 
 
 
121
 
122
+ # --- Helper Functions ---
123
+ def authenticate_hf(token: Optional[str]) -> None:
124
+ """Logs into the Hugging Face Hub."""
125
+ if token:
126
+ print("Logging into Hugging Face Hub...")
127
+ login(token=token)
128
+ else:
129
+ print("Skipping Hugging Face login: HF_TOKEN not set.")
130
 
131
+ def load_model_and_tokenizer(model_name: str):
132
+ print(f"Loading Transformer model: {model_name}")
133
+ try:
134
+ # Check if local path exists, otherwise treat as HF Hub ID
135
+ if model_name.startswith("..") and not os.path.exists(model_name):
136
+ print(f"Warning: Local path {model_name} not found. Falling back to default hub model.")
137
+ model_name = "google/gemma-2b-it" # Fallback example
 
 
138
 
139
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
140
+ model = AutoModelForCausalLM.from_pretrained(model_name)
141
+ print("Model loaded successfully.")
142
+ return model, tokenizer
143
+ except Exception as e:
144
+ print(f"Error loading Transformer model {model_name}: {e}")
145
+ raise e
146
 
147
+ def create_conversation_format(sample):
148
+ """Formats a dataset row into the conversational format required for SFT."""
149
+ try:
150
+ tool_args = json.loads(sample["tool_arguments"])
151
+ except (json.JSONDecodeError, TypeError):
152
+ tool_args = {}
153
+
154
+ return {
155
+ "messages": [
156
+ {"role": "developer", "content": DEFAULT_SYSTEM_MSG},
157
+ {"role": "user", "content": sample["user_content"]},
158
+ {"role": "assistant", "tool_calls": [{"type": "function", "function": {"name": sample["tool_name"], "arguments": tool_args}}]},
159
+ ],
160
+ "tools": TOOLS
161
+ }
162
+
163
+
164
+ # --- Main Application Logic ---
165
  class FunctionGemmaTuner:
166
  def __init__(self, config: AppConfig = AppConfig):
167
  self.config = config
168
+ self.model = None
169
+ self.tokenizer = None
170
+ self.imported_dataset = []
171
+
172
+ # Threading event to control stopping
173
+ self.stop_event = threading.Event()
174
 
175
  authenticate_hf(self.config.HF_TOKEN)
176
+
177
+ # Initial load attempt
 
 
 
178
  print("--- Running Initial Data Load ---")
179
+ try:
180
+ self.refresh_data_and_model()
181
+ print("--- Initial Load Complete ---")
182
+ except Exception as e:
183
+ print(f"Initial load failed (this is common if model path is invalid): {e}")
184
 
185
  def refresh_data_and_model(self):
186
+ """Reloads the model and clears imported data."""
187
  print("\n" + "=" * 50)
188
  print("RELOADING MODEL and RE-FETCHING DATA")
189
 
 
190
  self.imported_dataset = []
191
 
 
192
  try:
193
+ self.model, self.tokenizer = load_model_and_tokenizer(self.config.MODEL_NAME)
194
+ status_value = "Model and data reloaded. Ready."
195
  except Exception as e:
 
196
  self.model = None
197
  self.tokenizer = None
198
+ status_value = f"CRITICAL ERROR: Model failed to load. {e}"
199
+ # We don't raise here to allow the UI to render the error message
200
+
201
+ return status_value
 
 
202
 
 
203
  def import_additional_dataset(self, file_path: str) -> str:
204
+ """Parses an uploaded CSV file."""
205
  if not file_path:
206
  return "Please upload a CSV file."
207
+
208
+ new_dataset = []
209
+ num_imported = 0
210
+
211
  try:
212
+ # Open file handle properly
213
  with open(file_path, 'r', newline='', encoding='utf-8') as f:
214
  reader = csv.reader(f)
215
+
216
+ # Basic header validation
217
  try:
218
  header = next(reader)
219
+ # Simple heuristic check, allows skipping header or rewinding
220
+ if not (header and "anchor" in header[0].lower()):
221
  f.seek(0)
222
  except StopIteration:
223
  return "Error: Uploaded file is empty."
224
 
225
  for row in reader:
226
+ # Expecting: [User Prompt, Tool Name, Tool Args JSON/String]
227
+ if len(row) >= 3:
228
+ new_dataset.append([s.strip() for s in row[:3]])
229
  num_imported += 1
230
+
231
  if num_imported == 0:
232
+ return "No valid rows found. CSV format: [Anchor, Positive, Negative]"
233
+
234
  self.imported_dataset = new_dataset
235
+ return f"Successfully imported {num_imported} additional training samples."
236
+
237
  except Exception as e:
238
+ return f"Import failed. Error: {e}"
 
239
 
240
+ def stop_training(self):
241
+ """Signal the training loop to stop."""
242
+ print("Set stop event")
243
+ self.stop_event.set()
244
+ return "Stopping initiated... please wait for the current step to finish."
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
+ def run_training(self, test_size: float = 0.5) -> Generator[str, None, None]:
247
  """
248
+ Main training logic. Yields status strings to the UI.
249
  """
250
+ # 1. Validation
251
  if self.model is None:
252
+ yield "Training failed: Model is not loaded."
253
+ return
254
 
255
+ self.stop_event.clear() # Reset stop flag
256
+ yield "⏳ Preparing Dataset..."
257
+
258
+ # 2. Dataset Preparation
259
  if not self.imported_dataset:
260
+ print("No imported dataset, using default HF dataset")
261
+ try:
262
+ dataset = load_dataset(self.config.DEFAULT_DATASET, split="train")
263
+ except Exception as e:
264
+ yield f"Error loading default dataset: {e}"
265
+ return
266
  else:
267
  dataset_as_dicts = [{
268
  "user_content": row[0], "tool_name": row[1], "tool_arguments": row[2]}
 
270
  ]
271
  dataset = Dataset.from_list(dataset_as_dicts)
272
 
273
+ # Apply formatting
274
+ dataset = dataset.map(create_conversation_format, batched=False)
275
+
276
+ # Split
277
+ if len(dataset) > 1:
278
+ dataset = dataset.train_test_split(test_size=test_size, shuffle=False)
279
+ else:
280
+ # Fallback for very small datasets (mostly for debugging)
281
+ dataset = {"train": dataset, "test": dataset}
282
+
283
+ output_buffer = "πŸ“Š Evaluating Pre-Training Success Rate...\n### Success Rate (Before Training):\n"
284
+ yield output_buffer
285
+ pre_training_report = ""
286
+ gen = self.check_success_rate(dataset["test"])
287
+ while not self.stop_event.is_set():
288
+ try:
289
+ pre_training_report += f"{next(gen)}\n"
290
+ yield f"{output_buffer}{pre_training_report}"
291
+ except StopIteration as e:
292
+ pre_training_report = e.value
293
+ break
294
+
295
+ if self.stop_event.is_set():
296
+ output_buffer += f"{pre_training_report}\n\nπŸ›‘ Manual Eval interrupted by user.\n"
297
+ yield output_buffer
298
+ return
299
+
300
+ output_buffer += f"{pre_training_report}\n\n"
301
+ output_buffer += "-" * 30 + "\nStarting Fine-tuning...\n"
302
+ yield output_buffer
303
+
304
+ # 3. Training Setup
305
+ torch_dtype = self.model.dtype
306
+
307
+ args = SFTConfig(
308
+ output_dir=str(self.config.OUTPUT_DIR),
309
+ max_length=512,
310
+ packing=False,
311
+ num_train_epochs=5,
312
+ per_device_train_batch_size=4,
313
+ gradient_checkpointing=False,
314
+ optim="adamw_torch_fused",
315
+ logging_steps=1,
316
+ save_strategy="no", # Speed up demo
317
+ eval_strategy="epoch",
318
+ learning_rate=5e-5,
319
+ fp16=True if torch_dtype == torch.float16 else False,
320
+ bf16=True if torch_dtype == torch.bfloat16 else False,
321
+ lr_scheduler_type="constant",
322
+ push_to_hub=False,
323
+ report_to="none",
324
+ dataset_kwargs={
325
+ "add_special_tokens": False,
326
+ "append_concat_token": True,
327
+ }
328
+ )
329
+
330
+ trainer = SFTTrainer(
331
+ model=self.model,
332
+ args=args,
333
+ train_dataset=dataset['train'],
334
+ eval_dataset=dataset['test'],
335
+ processing_class=self.tokenizer,
336
+ callbacks=[AbortCallback(self.stop_event)] # Inject our stopper
337
+ )
338
+
339
+ # 4. Run Training
340
+ try:
341
+ output_buffer += "πŸš€ Training in progress... (Click Stop to interrupt)\n"
342
+ yield output_buffer
343
+ trainer.train()
344
+
345
+ if self.stop_event.is_set():
346
+ output_buffer += "\nπŸ›‘ Training interrupted by user.\n"
347
+ else:
348
+ output_buffer += "\nβœ… Training finished. Model weights updated in memory.\n"
349
+ yield output_buffer
350
+
351
+ # Save locally
352
+ trainer.save_model()
353
+ output_buffer += f"Model saved locally to: {self.config.OUTPUT_DIR}\n"
354
+ yield output_buffer
355
 
356
+ except Exception as e:
357
+ output_buffer += f"\n❌ Error during training: {e}\n"
358
+ yield output_buffer
359
+ return
360
+
361
+ if self.stop_event.is_set():
362
+ return
363
+
364
+ # 5. Post-Evaluation
365
+ output_buffer += "πŸ“Š Evaluating Post-Training Success Rate...\n"
366
+ post_report = ""
367
+ yield output_buffer
368
+ gen = self.check_success_rate(dataset["test"])
369
+ while not self.stop_event.is_set():
370
+ try:
371
+ post_report += f"{next(gen)}\n"
372
+ yield f"{output_buffer}{post_report}"
373
+ except StopIteration as e:
374
+ post_report = e.value
375
+ break
376
+
377
+ if self.stop_event.is_set():
378
+ output_buffer += f"{post_report}\n\nπŸ›‘ Manual Eval interrupted by user.\n"
379
+ yield output_buffer
380
+ return
381
+
382
+ output_buffer += f"{post_report}\n\n"
383
+ yield output_buffer
384
 
385
  def check_success_rate(self, test_dataset):
386
+ """Runs inference on test set to calculate accuracy."""
387
+ results = []
388
  success_count = 0
389
+ total = len(test_dataset)
390
+
391
  for idx, item in enumerate(test_dataset):
392
+ if idx >= 5:
393
+ break
394
+ if self.stop_event.is_set():
395
+ break
396
+
397
+ messages = [item["messages"][0], item["messages"][1]] # System + User
398
+
399
+ try:
400
+ inputs = self.tokenizer.apply_chat_template(
401
+ messages,
402
+ tools=TOOLS,
403
+ add_generation_prompt=True,
404
+ return_dict=True,
405
+ return_tensors="pt"
406
+ )
407
+
408
+ out = self.model.generate(
409
+ **inputs.to(self.model.device),
410
+ pad_token_id=self.tokenizer.eos_token_id,
411
+ max_new_tokens=128
412
+ )
413
+
414
+ # Decode only the new tokens
415
+ output = self.tokenizer.decode(out[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True)
416
+
417
+ results.append(f"{idx+1}. Prompt: {item['messages'][1]['content']}")
418
+ yield results[-1]
419
+ results.append(f" Output: {output[:100]}...")
420
+ yield results[-1]
421
+
422
+ # Check for correct tool name usage
423
+ expected_tool = item['messages'][2]['tool_calls'][0]['function']['name']
424
+ if expected_tool in output:
425
+ results.append(" -> βœ… Correct Tool")
426
+ yield results[-1]
427
+ success_count += 1
428
+ else:
429
+ results.append(f" -> ❌ Wrong Tool (Expected: {expected_tool})")
430
+ yield results[-1]
431
+
432
+ except Exception as e:
433
+ results.append(f" -> Error: {e}")
434
+ yield results[-1]
435
+
436
+ summary = "\n".join(results)
437
+ summary += f"\n\nTotal Success : {success_count} / {len(test_dataset)}"
438
+ return summary
439
+
440
+ def download_model_zip(self) -> Optional[str]:
441
+ """Zips the output directory for download."""
442
+ if not os.path.exists(self.config.OUTPUT_DIR):
443
+ return None
444
+
445
+ timestamp = int(time.time())
446
+ try:
447
+ base_name = self.config.ARTIFACTS_DIR.joinpath(f"functiongemma_finetuned_{timestamp}")
448
+ archive_path = shutil.make_archive(
449
+ base_name=str(base_name),
450
+ format='zip',
451
+ root_dir=str(self.config.OUTPUT_DIR),
452
+ )
453
+ return archive_path
454
+ except Exception as e:
455
+ print(f"Zip failed: {e}")
456
+ return None
457
 
458
+ # --- UI Builder ---
459
  def build_interface(self) -> gr.Blocks:
460
  with gr.Blocks(title="FunctionGemma Modkit") as demo:
461
  gr.Markdown("# πŸ€– FunctionGemma Modkit: Fine-Tuning")
462
+ gr.Markdown("Fine-tune FunctionGemma to understand your custom functions.")
463
+
464
+ with gr.Column():
465
+ gr.Markdown("## 1. Training Controls")
466
+
467
+ with gr.Row():
468
+ run_training_btn = gr.Button("πŸš€ Run Fine-Tuning", variant="primary")
469
+ stop_training_btn = gr.Button("πŸ›‘ Stop Training", variant="stop", visible=False)
470
+
471
+ output_display = gr.Textbox(
472
+ lines=14,
473
+ label="Training Logs & Search Results",
474
+ value="Ready. Click 'Run' to begin.",
475
+ interactive=False
476
+ )
477
+
478
+ clear_reload_btn = gr.Button("πŸ”„ Reset Model & Data")
479
+
480
+ gr.Markdown("--- \n ## 2. Data Management")
481
+ import_file = gr.File(label="Upload Additional Dataset (.csv)", file_types=[".csv"], height=80)
482
+ import_status = gr.Markdown("")
483
+
484
+ gr.Markdown("--- \n ## 3. Export")
485
+ with gr.Row():
486
+ zip_btn = gr.Button("⬇️ Prepare Model ZIP")
487
+ download_file = gr.File(label="Download ZIP", height=80, visible=True, interactive=False)
488
+
489
+ # --- Event Wiring ---
490
+
491
+ # Start Training (Generator updates output_display)
492
+ run_training_btn.click(
493
+ fn=lambda: (
494
+ gr.update(visible=False),
495
+ gr.update(interactive=False),
496
+ gr.update(visible=True)
497
+ ),
498
+ inputs=None,
499
+ outputs=[run_training_btn, clear_reload_btn, stop_training_btn]
500
+ ).then(
501
+ fn=self.run_training,
502
+ inputs=[],
503
+ outputs=[output_display],
504
+ ).then(
505
+ fn=lambda: (
506
+ gr.update(visible=True),
507
+ gr.update(interactive=True),
508
+ gr.update(visible=False)
509
+ ),
510
+ inputs=None,
511
+ outputs=[run_training_btn, clear_reload_btn, stop_training_btn]
512
+ )
513
+
514
+ # Stop Training
515
+ stop_training_btn.click(
516
+ fn=self.stop_training,
517
+ inputs=None,
518
+ outputs=None # We don't need to return anything, status updates via the training generator
519
+ ).then(
520
+ fn=lambda: (
521
+ gr.update(visible=True),
522
+ gr.update(interactive=True),
523
+ gr.update(visible=False)
524
+ ),
525
+ inputs=None,
526
+ outputs=[run_training_btn, clear_reload_btn, stop_training_btn]
527
+ )
528
 
529
+ # Reload
530
+ clear_reload_btn.click(
531
+ fn=self.refresh_data_and_model,
532
+ inputs=None,
533
+ outputs=[output_display]
534
+ )
535
+
536
+ # File Import
537
+ import_file.upload(
538
+ fn=self.import_additional_dataset,
539
+ inputs=[import_file],
540
+ outputs=[import_status]
541
+ )
542
+
543
+ # Download Logic
544
+ def handle_zip():
545
+ path = self.download_model_zip()
546
+ if path:
547
+ return gr.update(value=path, visible=True)
548
+ return gr.update(value=None, visible=False)
549
+
550
+ zip_btn.click(
551
+ fn=handle_zip,
552
+ inputs=None,
553
+ outputs=[download_file]
554
+ )
555
 
556
+ return demo
557
 
558
  if __name__ == "__main__":
559
  app = FunctionGemmaTuner(AppConfig)
560
  demo = app.build_interface()
561
  print("Starting Gradio App...")
562
  demo.launch()