bebechien commited on
Commit
7e41311
Β·
verified Β·
1 Parent(s): 2cfd82c

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. config.py +13 -4
  2. engine.py +86 -73
  3. ui.py +31 -9
config.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  from pathlib import Path
3
- from typing import Final, Optional
4
- from dataclasses import dataclass
5
 
6
  @dataclass
7
  class AppConfig:
@@ -14,8 +14,17 @@ class AppConfig:
14
 
15
  # Model & Data
16
  HF_TOKEN: Final[Optional[str]] = os.getenv('HF_TOKEN')
17
- # Defaulting to a real model ID for safety, original was local path '../hf/270m'
18
- MODEL_NAME: Final[str] = '../hf/270m'
 
 
 
 
 
 
 
 
 
19
  DEFAULT_DATASET: Final[str] = 'bebechien/SimpleToolCalling'
20
 
21
  def __post_init__(self):
 
1
  import os
2
  from pathlib import Path
3
+ from typing import Final, Optional, List
4
+ from dataclasses import dataclass, field
5
 
6
  @dataclass
7
  class AppConfig:
 
14
 
15
  # Model & Data
16
  HF_TOKEN: Final[Optional[str]] = os.getenv('HF_TOKEN')
17
+
18
+ # Model Configuration
19
+ # Mutable: User can change this in the UI
20
+ MODEL_NAME: str = '../hf/270m'
21
+
22
+ AVAILABLE_MODELS: List[str] = field(default_factory=lambda: [
23
+ '../hf/270m',
24
+ '../hf/gemma-3-270m-it',
25
+ 'google/gemma-3-270m-it'
26
+ ])
27
+
28
  DEFAULT_DATASET: Final[str] = 'bebechien/SimpleToolCalling'
29
 
30
  def __post_init__(self):
engine.py CHANGED
@@ -5,7 +5,7 @@ import json
5
  import queue
6
  import matplotlib.pyplot as plt
7
  from functools import partial
8
- from typing import Generator, Optional, List, Dict
9
  from datasets import Dataset, load_dataset
10
  from trl import SFTConfig, SFTTrainer
11
  from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
@@ -30,8 +30,8 @@ class AbortCallback(TrainerCallback):
30
 
31
  class LogStreamingCallback(TrainerCallback):
32
  """
33
- NEW: Intercepts training logs and pushes them to a queue
34
- so the main thread can display them in the UI.
35
  """
36
  def __init__(self, log_queue: queue.Queue):
37
  self.log_queue = log_queue
@@ -56,7 +56,6 @@ class LogStreamingCallback(TrainerCallback):
56
  for key, label in metrics_map.items():
57
  if key in logs:
58
  val = logs[key]
59
- # Format floats: use scientific notation for very small numbers (like LR)
60
  if isinstance(val, (float, int)):
61
  val_str = f"{val:.4f}" if val > 1e-4 else f"{val:.2e}"
62
  else:
@@ -64,17 +63,20 @@ class LogStreamingCallback(TrainerCallback):
64
 
65
  log_parts.append(f"{label}: {val_str}")
66
 
67
- self.log_queue.put(" | ".join(log_parts))
 
 
 
 
68
 
69
  class FunctionGemmaEngine:
70
  def __init__(self, config: AppConfig):
71
  self.config = config
72
  self.model = None
73
  self.tokenizer = None
 
74
  self.imported_dataset = []
75
  self.stop_event = threading.Event()
76
-
77
- # NEW: State for tools
78
  self.current_tools = DEFAULT_TOOLS
79
 
80
  authenticate_hf(self.config.HF_TOKEN)
@@ -83,7 +85,7 @@ class FunctionGemmaEngine:
83
  except Exception as e:
84
  print(f"Initial load warning: {e}")
85
 
86
- # NEW: Methods to handle Tool Schema updates
87
  def get_tools_json(self) -> str:
88
  return json.dumps(self.current_tools, indent=2)
89
 
@@ -99,14 +101,24 @@ class FunctionGemmaEngine:
99
  except Exception as e:
100
  return f"❌ Error: {e}"
101
 
 
 
 
 
 
 
 
 
102
  def refresh_data_and_model(self) -> str:
 
103
  self.imported_dataset = []
104
  try:
105
- self.model, self.tokenizer = load_model_and_tokenizer(self.config.MODEL_NAME)
106
- return "Model and data reloaded. Ready."
107
  except Exception as e:
108
  self.model = None
109
  self.tokenizer = None
 
110
  return f"CRITICAL ERROR: Model failed to load. {e}"
111
 
112
  def load_csv(self, file_path: str) -> str:
@@ -122,13 +134,31 @@ class FunctionGemmaEngine:
122
  def trigger_stop(self):
123
  self.stop_event.set()
124
 
125
- def run_training_pipeline(self, epochs: int, learning_rate: float, test_size: float, shuffle_data: bool) -> Generator[str, None, None]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  if self.model is None:
127
- yield "Training failed: Model is not loaded.", None
128
  return
129
 
130
- self.stop_event.clear()
131
- output_buffer = f"⏳ Preparing Dataset (Test Split: {test_size}, Shuffle: {shuffle_data})...\n"
132
  yield output_buffer, None
133
 
134
  dataset, log = self._prepare_dataset()
@@ -161,86 +191,85 @@ class FunctionGemmaEngine:
161
  output_buffer += pre_training_report
162
 
163
  # --- Phase 2: Training (Threaded) ---
164
- output_buffer += "\n\nπŸš€ Starting Fine-tuning (Epochs: {epochs}, LR: {learning_rate})...\n"
165
  yield output_buffer, None
166
 
167
  log_queue = queue.Queue()
168
  training_error = None
169
- training_history = []
170
 
171
- # Function to run in the thread
172
  def train_wrapper():
173
- nonlocal training_error, training_history
174
  try:
175
- training_history = self._execute_trainer(dataset, log_queue, epochs, learning_rate)
176
  except Exception as e:
177
  training_error = e
178
 
179
- # Start training thread
180
  train_thread = threading.Thread(target=train_wrapper)
181
  train_thread.start()
182
 
183
- # Monitor loop: Yields logs while training runs
184
  while train_thread.is_alive():
185
- # Drain the queue
186
  while not log_queue.empty():
187
- log_msg = log_queue.get()
188
- output_buffer += f"{log_msg}\n"
189
- yield output_buffer, None
 
 
 
 
 
 
 
 
 
 
190
 
191
- # Check for stop signal
192
  if self.stop_event.is_set():
193
- yield f"{output_buffer}πŸ›‘ Stop signal sent. Waiting for trainer to wrap up...\n", None
194
- # We don't break here, we wait for thread to finish cleanly
195
 
196
- time.sleep(0.1) # Prevent CPU spinning
197
 
198
- train_thread.join() # Ensure thread is completely done
199
 
200
- # Flush any remaining logs
201
  while not log_queue.empty():
202
- log_msg = log_queue.get()
203
- output_buffer += f"{log_msg}\n"
204
- yield output_buffer, None
 
 
 
 
 
 
205
 
206
  if training_error:
207
  output_buffer += f"❌ Error during training: {training_error}\n"
208
- yield output_buffer, None
209
  return
210
 
211
  if self.stop_event.is_set():
212
  output_buffer += "πŸ›‘ Training manually stopped.\n"
213
- yield output_buffer, None
214
  return
215
 
216
  output_buffer += "βœ… Training finished.\n"
217
- yield output_buffer, None
218
-
219
- output_buffer += "\nπŸ“ˆ Generating Loss Plot...\n"
220
- yield output_buffer, None
221
-
222
- try:
223
- final_plot = self._generate_loss_plot(training_history)
224
- yield output_buffer, final_plot
225
- except Exception as e:
226
- output_buffer += f"⚠️ Could not generate plot: {e}\n"
227
- yield output_buffer, None
228
 
229
  # --- Phase 3: Post-Training Eval ---
230
  output_buffer += "\nπŸ“Š Evaluating Post-Training Success Rate...\n"
231
- yield output_buffer, final_plot
232
 
233
  post_training_report = ""
234
  for update in self._evaluate_model(dataset["test"]):
235
  post_training_report = update
236
  if self.stop_event.is_set():
237
  post_training_report += "\n\nπŸ›‘ Manual Eval interrupted by user.\n"
238
- yield f"{output_buffer}{post_training_report}", final_plot
239
  break
240
- yield f"{output_buffer}{post_training_report}", final_plot
241
 
242
  def _prepare_dataset(self):
243
- # NEW: Use partial to inject self.current_tools into the formatting function
244
  formatting_fn = partial(create_conversation_format, tools_list=self.current_tools)
245
 
246
  if not self.imported_dataset:
@@ -286,27 +315,22 @@ class FunctionGemmaEngine:
286
  )
287
  trainer.train()
288
  trainer.save_model()
289
-
290
  return trainer.state.log_history
291
 
292
  def _generate_loss_plot(self, history: list):
293
- if not history:
294
- return None
295
-
296
- # Extract Training Loss
297
- # log_history format: [{'loss': 0.5, 'step': 1}, {'eval_loss': 0.4, 'step': 1}, ...]
298
  train_steps = [x['step'] for x in history if 'loss' in x]
299
  train_loss = [x['loss'] for x in history if 'loss' in x]
300
-
301
- # Extract Validation Loss
302
  eval_steps = [x['step'] for x in history if 'eval_loss' in x]
303
  eval_loss = [x['eval_loss'] for x in history if 'eval_loss' in x]
304
 
305
  fig, ax = plt.subplots(figsize=(10, 5))
306
-
307
  if train_steps:
308
  ax.plot(train_steps, train_loss, label='Training Loss', linestyle='-', marker=None)
309
-
310
  if eval_steps:
311
  ax.plot(eval_steps, eval_loss, label='Validation Loss', linestyle='--', marker='o')
312
 
@@ -315,51 +339,40 @@ class FunctionGemmaEngine:
315
  ax.set_title("Training & Validation Loss")
316
  ax.legend()
317
  ax.grid(True, linestyle=':', alpha=0.6)
318
-
319
  plt.tight_layout()
320
  return fig
321
 
322
  def _evaluate_model(self, test_dataset) -> Generator[str, None, None]:
323
  results = []
324
  success_count = 0
325
-
326
  for idx, item in enumerate(test_dataset):
327
  messages = item["messages"][:2]
328
  try:
329
- # NEW: Pass self.current_tools to the template
330
  inputs = self.tokenizer.apply_chat_template(
331
  messages, tools=self.current_tools, add_generation_prompt=True, return_dict=True, return_tensors="pt"
332
  )
333
-
334
  device = self.model.device
335
  inputs = {k: v.to(device) for k, v in inputs.items()}
336
-
337
  out = self.model.generate(
338
  **inputs,
339
  pad_token_id=self.tokenizer.eos_token_id,
340
  max_new_tokens=128
341
  )
342
  output = self.tokenizer.decode(out[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
343
-
344
  log_entry = f"{idx+1}. Prompt: {messages[1]['content']}\n Output: {output[:100]}..."
345
-
346
- # Check tool correctness
347
  expected_tool = item['messages'][2]['tool_calls'][0]['function']['name']
348
  if expected_tool in output:
349
  log_entry += "\n -> βœ… Correct Tool"
350
  success_count += 1
351
  else:
352
  log_entry += f"\n -> ❌ Wrong Tool (Expected: {expected_tool})"
353
-
354
  results.append(log_entry)
355
  yield "\n".join(results) + f"\n\nRunning Success Rate: {success_count}/{idx+1}"
356
-
357
  except Exception as e:
358
  yield f"Error during inference: {e}"
359
 
360
  def get_zip_path(self) -> Optional[str]:
361
- if not self.config.OUTPUT_DIR.exists():
362
- return None
363
  timestamp = int(time.time())
364
  base_name = str(self.config.ARTIFACTS_DIR.joinpath(f"functiongemma_finetuned_{timestamp}"))
365
  return zip_directory(str(self.config.OUTPUT_DIR), base_name)
 
5
  import queue
6
  import matplotlib.pyplot as plt
7
  from functools import partial
8
+ from typing import Generator, Optional, List, Dict, Any, Tuple
9
  from datasets import Dataset, load_dataset
10
  from trl import SFTConfig, SFTTrainer
11
  from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
 
30
 
31
  class LogStreamingCallback(TrainerCallback):
32
  """
33
+ Intercepts training logs and pushes them to a queue.
34
+ Sends tuple: (formatted_string, raw_dict)
35
  """
36
  def __init__(self, log_queue: queue.Queue):
37
  self.log_queue = log_queue
 
56
  for key, label in metrics_map.items():
57
  if key in logs:
58
  val = logs[key]
 
59
  if isinstance(val, (float, int)):
60
  val_str = f"{val:.4f}" if val > 1e-4 else f"{val:.2e}"
61
  else:
 
63
 
64
  log_parts.append(f"{label}: {val_str}")
65
 
66
+ # Structure for plotting
67
+ log_payload = logs.copy()
68
+ log_payload['step'] = state.global_step
69
+
70
+ self.log_queue.put((" | ".join(log_parts), log_payload))
71
 
72
  class FunctionGemmaEngine:
73
  def __init__(self, config: AppConfig):
74
  self.config = config
75
  self.model = None
76
  self.tokenizer = None
77
+ self.loaded_model_name = None
78
  self.imported_dataset = []
79
  self.stop_event = threading.Event()
 
 
80
  self.current_tools = DEFAULT_TOOLS
81
 
82
  authenticate_hf(self.config.HF_TOKEN)
 
85
  except Exception as e:
86
  print(f"Initial load warning: {e}")
87
 
88
+ # --- Tool Schema Methods ---
89
  def get_tools_json(self) -> str:
90
  return json.dumps(self.current_tools, indent=2)
91
 
 
101
  except Exception as e:
102
  return f"❌ Error: {e}"
103
 
104
+ # --- Model & Data Management ---
105
+
106
+ def _load_model_weights(self):
107
+ """Internal helper to load model based on current config."""
108
+ print(f"Loading model: {self.config.MODEL_NAME}...")
109
+ self.model, self.tokenizer = load_model_and_tokenizer(self.config.MODEL_NAME)
110
+ self.loaded_model_name = self.config.MODEL_NAME
111
+
112
  def refresh_data_and_model(self) -> str:
113
+ """Full reset: Reloads model and clears dataset."""
114
  self.imported_dataset = []
115
  try:
116
+ self._load_model_weights()
117
+ return f"Model loaded: {self.loaded_model_name}\nData cleared.\nReady."
118
  except Exception as e:
119
  self.model = None
120
  self.tokenizer = None
121
+ self.loaded_model_name = None
122
  return f"CRITICAL ERROR: Model failed to load. {e}"
123
 
124
  def load_csv(self, file_path: str) -> str:
 
134
  def trigger_stop(self):
135
  self.stop_event.set()
136
 
137
+ # --- Training Pipeline ---
138
+
139
+ def run_training_pipeline(self, epochs: int, learning_rate: float, test_size: float, shuffle_data: bool) -> Generator[Tuple[str, Any], None, None]:
140
+ self.stop_event.clear()
141
+ output_buffer = ""
142
+ last_plot = None
143
+
144
+ # 1. Check if model name changed since last load
145
+ if self.config.MODEL_NAME != self.loaded_model_name:
146
+ output_buffer += f"πŸ”„ Model changed. Switching from '{self.loaded_model_name}' to '{self.config.MODEL_NAME}'...\n"
147
+ yield output_buffer, None
148
+ try:
149
+ self._load_model_weights()
150
+ output_buffer += "βœ… Model reloaded successfully.\n"
151
+ yield output_buffer, None
152
+ except Exception as e:
153
+ output_buffer += f"❌ Failed to load model '{self.config.MODEL_NAME}': {e}\n"
154
+ yield output_buffer, None
155
+ return
156
+
157
  if self.model is None:
158
+ yield "Training failed: No model loaded.", None
159
  return
160
 
161
+ output_buffer += f"⏳ Preparing Dataset (Test Split: {test_size}, Shuffle: {shuffle_data})...\n"
 
162
  yield output_buffer, None
163
 
164
  dataset, log = self._prepare_dataset()
 
191
  output_buffer += pre_training_report
192
 
193
  # --- Phase 2: Training (Threaded) ---
194
+ output_buffer += f"\n\nπŸš€ Starting Fine-tuning (Epochs: {epochs}, LR: {learning_rate})...\n"
195
  yield output_buffer, None
196
 
197
  log_queue = queue.Queue()
198
  training_error = None
199
+ running_history = []
200
 
 
201
  def train_wrapper():
202
+ nonlocal training_error
203
  try:
204
+ self._execute_trainer(dataset, log_queue, epochs, learning_rate)
205
  except Exception as e:
206
  training_error = e
207
 
 
208
  train_thread = threading.Thread(target=train_wrapper)
209
  train_thread.start()
210
 
 
211
  while train_thread.is_alive():
 
212
  while not log_queue.empty():
213
+ payload = log_queue.get()
214
+ if isinstance(payload, tuple):
215
+ msg, log_data = payload
216
+ output_buffer += f"{msg}\n"
217
+ running_history.append(log_data)
218
+ try:
219
+ last_plot = self._generate_loss_plot(running_history)
220
+ yield output_buffer, last_plot
221
+ except Exception:
222
+ yield output_buffer, last_plot
223
+ else:
224
+ output_buffer += f"{payload}\n"
225
+ yield output_buffer, last_plot
226
 
 
227
  if self.stop_event.is_set():
228
+ yield f"{output_buffer}πŸ›‘ Stop signal sent. Waiting for trainer to wrap up...\n", last_plot
 
229
 
230
+ time.sleep(0.1)
231
 
232
+ train_thread.join()
233
 
234
+ # Flush logs
235
  while not log_queue.empty():
236
+ payload = log_queue.get()
237
+ if isinstance(payload, tuple):
238
+ msg, log_data = payload
239
+ output_buffer += f"{msg}\n"
240
+ running_history.append(log_data)
241
+ last_plot = self._generate_loss_plot(running_history)
242
+ else:
243
+ output_buffer += f"{payload}\n"
244
+ yield output_buffer, last_plot
245
 
246
  if training_error:
247
  output_buffer += f"❌ Error during training: {training_error}\n"
248
+ yield output_buffer, last_plot
249
  return
250
 
251
  if self.stop_event.is_set():
252
  output_buffer += "πŸ›‘ Training manually stopped.\n"
253
+ yield output_buffer, last_plot
254
  return
255
 
256
  output_buffer += "βœ… Training finished.\n"
257
+ yield output_buffer, last_plot
 
 
 
 
 
 
 
 
 
 
258
 
259
  # --- Phase 3: Post-Training Eval ---
260
  output_buffer += "\nπŸ“Š Evaluating Post-Training Success Rate...\n"
261
+ yield output_buffer, last_plot
262
 
263
  post_training_report = ""
264
  for update in self._evaluate_model(dataset["test"]):
265
  post_training_report = update
266
  if self.stop_event.is_set():
267
  post_training_report += "\n\nπŸ›‘ Manual Eval interrupted by user.\n"
268
+ yield f"{output_buffer}{post_training_report}", last_plot
269
  break
270
+ yield f"{output_buffer}{post_training_report}", last_plot
271
 
272
  def _prepare_dataset(self):
 
273
  formatting_fn = partial(create_conversation_format, tools_list=self.current_tools)
274
 
275
  if not self.imported_dataset:
 
315
  )
316
  trainer.train()
317
  trainer.save_model()
 
318
  return trainer.state.log_history
319
 
320
  def _generate_loss_plot(self, history: list):
321
+ if not history: return None
322
+
323
+ # CHANGED: Close previous figures to prevent memory warning
324
+ plt.close('all')
325
+
326
  train_steps = [x['step'] for x in history if 'loss' in x]
327
  train_loss = [x['loss'] for x in history if 'loss' in x]
 
 
328
  eval_steps = [x['step'] for x in history if 'eval_loss' in x]
329
  eval_loss = [x['eval_loss'] for x in history if 'eval_loss' in x]
330
 
331
  fig, ax = plt.subplots(figsize=(10, 5))
 
332
  if train_steps:
333
  ax.plot(train_steps, train_loss, label='Training Loss', linestyle='-', marker=None)
 
334
  if eval_steps:
335
  ax.plot(eval_steps, eval_loss, label='Validation Loss', linestyle='--', marker='o')
336
 
 
339
  ax.set_title("Training & Validation Loss")
340
  ax.legend()
341
  ax.grid(True, linestyle=':', alpha=0.6)
 
342
  plt.tight_layout()
343
  return fig
344
 
345
  def _evaluate_model(self, test_dataset) -> Generator[str, None, None]:
346
  results = []
347
  success_count = 0
 
348
  for idx, item in enumerate(test_dataset):
349
  messages = item["messages"][:2]
350
  try:
 
351
  inputs = self.tokenizer.apply_chat_template(
352
  messages, tools=self.current_tools, add_generation_prompt=True, return_dict=True, return_tensors="pt"
353
  )
 
354
  device = self.model.device
355
  inputs = {k: v.to(device) for k, v in inputs.items()}
 
356
  out = self.model.generate(
357
  **inputs,
358
  pad_token_id=self.tokenizer.eos_token_id,
359
  max_new_tokens=128
360
  )
361
  output = self.tokenizer.decode(out[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
 
362
  log_entry = f"{idx+1}. Prompt: {messages[1]['content']}\n Output: {output[:100]}..."
 
 
363
  expected_tool = item['messages'][2]['tool_calls'][0]['function']['name']
364
  if expected_tool in output:
365
  log_entry += "\n -> βœ… Correct Tool"
366
  success_count += 1
367
  else:
368
  log_entry += f"\n -> ❌ Wrong Tool (Expected: {expected_tool})"
 
369
  results.append(log_entry)
370
  yield "\n".join(results) + f"\n\nRunning Success Rate: {success_count}/{idx+1}"
 
371
  except Exception as e:
372
  yield f"Error during inference: {e}"
373
 
374
  def get_zip_path(self) -> Optional[str]:
375
+ if not self.config.OUTPUT_DIR.exists(): return None
 
376
  timestamp = int(time.time())
377
  base_name = str(self.config.ARTIFACTS_DIR.joinpath(f"functiongemma_finetuned_{timestamp}"))
378
  return zip_directory(str(self.config.OUTPUT_DIR), base_name)
ui.py CHANGED
@@ -2,6 +2,17 @@ import gradio as gr
2
  from engine import FunctionGemmaEngine
3
 
4
  def build_interface(engine: FunctionGemmaEngine) -> gr.Blocks:
 
 
 
 
 
 
 
 
 
 
 
5
  with gr.Blocks(title="FunctionGemma Modkit") as demo:
6
  gr.Markdown("# πŸ€– FunctionGemma Modkit: Fine-Tuning")
7
  gr.Markdown("Fine-tune FunctionGemma to understand your custom functions.<br>See [README](https://huggingface.co/spaces/google/functiongemma-modkit/blob/main/README.md) for more details.")
@@ -41,10 +52,20 @@ def build_interface(engine: FunctionGemmaEngine) -> gr.Blocks:
41
  with gr.Group():
42
  gr.Markdown("**Hyperparameters**")
43
  with gr.Row():
 
 
 
 
 
 
 
 
 
44
  param_epochs = gr.Slider(
45
  minimum=1, maximum=20, value=5, step=1,
46
  label="Epochs", info="Total training passes"
47
  )
 
48
  param_lr = gr.Number(
49
  value=5e-5,
50
  label="Learning Rate",
@@ -52,18 +73,18 @@ def build_interface(engine: FunctionGemmaEngine) -> gr.Blocks:
52
  )
53
  param_test_size = gr.Slider(
54
  minimum=0.1, maximum=0.9, value=0.2, step=0.05,
55
- label="Test Split", info="Validation data ratio. Typical value is 0.2 (80% for training, 20% for testing)"
56
  )
57
  param_shuffle = gr.Checkbox(
58
  value=True,
59
  label="Shuffle Data",
60
- info="Randomize before split\nWhen unchecking this for your own custom datasets, ensure your source data is pre-mixed. If the distribution is unknown or sorted, you should check this to ensure the model learns a balanced representation of all tools during training."
61
  )
62
 
63
  with gr.Row():
64
  run_training_btn = gr.Button("πŸš€ Run Fine-Tuning", variant="primary", scale=2)
65
  stop_training_btn = gr.Button("πŸ›‘ Stop", variant="stop", visible=False, scale=1)
66
- clear_reload_btn = gr.Button("πŸ”„ Reset", variant="secondary", scale=1)
67
 
68
  with gr.Row():
69
  # Left column: Text Logs
@@ -74,7 +95,7 @@ def build_interface(engine: FunctionGemmaEngine) -> gr.Blocks:
74
  interactive=False,
75
  autoscroll=True
76
  )
77
- # Right column: Plot (NEW)
78
  loss_plot = gr.Plot(label="Training Metrics")
79
 
80
  # --- TAB 3: EXPORT ---
@@ -102,7 +123,7 @@ def build_interface(engine: FunctionGemmaEngine) -> gr.Blocks:
102
  outputs=[import_status]
103
  )
104
 
105
- # Tab 2: Training
106
  run_training_btn.click(
107
  fn=lambda: (
108
  gr.update(visible=False), # Hide Run
@@ -111,8 +132,8 @@ def build_interface(engine: FunctionGemmaEngine) -> gr.Blocks:
111
  ),
112
  outputs=[run_training_btn, clear_reload_btn, stop_training_btn]
113
  ).then(
114
- fn=engine.run_training_pipeline,
115
- inputs=[param_epochs, param_lr, param_test_size, param_shuffle],
116
  outputs=[output_display, loss_plot],
117
  ).then(
118
  fn=lambda: (
@@ -129,9 +150,10 @@ def build_interface(engine: FunctionGemmaEngine) -> gr.Blocks:
129
  outputs=None
130
  )
131
 
132
- # Tab 2: Reset
133
  clear_reload_btn.click(
134
- fn=engine.refresh_data_and_model,
 
135
  outputs=[output_display]
136
  )
137
 
 
2
  from engine import FunctionGemmaEngine
3
 
4
  def build_interface(engine: FunctionGemmaEngine) -> gr.Blocks:
5
+
6
+ # Wrapper: Update config with selected model, then Run Training
7
+ def run_training_wrapper(epochs, lr, test_size, shuffle, model_name):
8
+ engine.config.MODEL_NAME = model_name.strip()
9
+ yield from engine.run_training_pipeline(epochs, lr, test_size, shuffle)
10
+
11
+ # Wrapper: Update config with selected model, then Reset/Reload
12
+ def handle_reset(model_name):
13
+ engine.config.MODEL_NAME = model_name.strip()
14
+ return engine.refresh_data_and_model()
15
+
16
  with gr.Blocks(title="FunctionGemma Modkit") as demo:
17
  gr.Markdown("# πŸ€– FunctionGemma Modkit: Fine-Tuning")
18
  gr.Markdown("Fine-tune FunctionGemma to understand your custom functions.<br>See [README](https://huggingface.co/spaces/google/functiongemma-modkit/blob/main/README.md) for more details.")
 
52
  with gr.Group():
53
  gr.Markdown("**Hyperparameters**")
54
  with gr.Row():
55
+ # Dropdown that allows custom typing
56
+ param_model = gr.Dropdown(
57
+ choices=engine.config.AVAILABLE_MODELS,
58
+ value=engine.config.MODEL_NAME,
59
+ allow_custom_value=True,
60
+ label="Base Model",
61
+ info="Select a preset OR type a custom Hugging Face model ID (e.g. 'google/gemma-3-1b-it')",
62
+ interactive=True
63
+ )
64
  param_epochs = gr.Slider(
65
  minimum=1, maximum=20, value=5, step=1,
66
  label="Epochs", info="Total training passes"
67
  )
68
+ with gr.Row():
69
  param_lr = gr.Number(
70
  value=5e-5,
71
  label="Learning Rate",
 
73
  )
74
  param_test_size = gr.Slider(
75
  minimum=0.1, maximum=0.9, value=0.2, step=0.05,
76
+ label="Test Split", info="Validation ratio (0.2 = 20%)"
77
  )
78
  param_shuffle = gr.Checkbox(
79
  value=True,
80
  label="Shuffle Data",
81
+ info="Randomize before split"
82
  )
83
 
84
  with gr.Row():
85
  run_training_btn = gr.Button("πŸš€ Run Fine-Tuning", variant="primary", scale=2)
86
  stop_training_btn = gr.Button("πŸ›‘ Stop", variant="stop", visible=False, scale=1)
87
+ clear_reload_btn = gr.Button("πŸ”„ Reload Model & Reset Data", variant="secondary", scale=1)
88
 
89
  with gr.Row():
90
  # Left column: Text Logs
 
95
  interactive=False,
96
  autoscroll=True
97
  )
98
+ # Right column: Plot
99
  loss_plot = gr.Plot(label="Training Metrics")
100
 
101
  # --- TAB 3: EXPORT ---
 
123
  outputs=[import_status]
124
  )
125
 
126
+ # Tab 2: Training (Uses Wrapper)
127
  run_training_btn.click(
128
  fn=lambda: (
129
  gr.update(visible=False), # Hide Run
 
132
  ),
133
  outputs=[run_training_btn, clear_reload_btn, stop_training_btn]
134
  ).then(
135
+ fn=run_training_wrapper,
136
+ inputs=[param_epochs, param_lr, param_test_size, param_shuffle, param_model],
137
  outputs=[output_display, loss_plot],
138
  ).then(
139
  fn=lambda: (
 
150
  outputs=None
151
  )
152
 
153
+ # Tab 2: Reset (Uses Wrapper to capture model name)
154
  clear_reload_btn.click(
155
+ fn=handle_reset,
156
+ inputs=[param_model],
157
  outputs=[output_display]
158
  )
159