bebechien commited on
Commit
fdf7bd6
·
verified ·
1 Parent(s): c055e6e

Upload folder using huggingface_hub

Browse files
__pycache__/config.cpython-312.pyc ADDED
Binary file (1.33 kB). View file
 
__pycache__/engine.cpython-312.pyc ADDED
Binary file (17.8 kB). View file
 
__pycache__/tools.cpython-312.pyc ADDED
Binary file (851 Bytes). View file
 
__pycache__/ui.cpython-312.pyc ADDED
Binary file (7.69 kB). View file
 
__pycache__/utils.cpython-312.pyc ADDED
Binary file (4.03 kB). View file
 
app.py CHANGED
@@ -1,562 +1,15 @@
1
- import gradio as gr
2
- import os
3
- import json
4
- import torch
5
- import csv
6
- import shutil
7
- import time
8
- import threading
9
-
10
- from typing import Final, Optional, List, Any, Generator
11
- from pathlib import Path
12
- from dataclasses import dataclass
13
-
14
- from huggingface_hub import login
15
- from trl import SFTConfig, SFTTrainer
16
- from transformers import (
17
- AutoTokenizer,
18
- AutoModelForCausalLM,
19
- TrainerCallback,
20
- TrainingArguments,
21
- TrainerControl,
22
- TrainerState
23
- )
24
- from datasets import Dataset, load_dataset
25
-
26
- # --- Configuration ---
27
- class AppConfig:
28
- """
29
- Central configuration class.
30
- """
31
- ARTIFACTS_DIR: Final[Path] = Path("artifacts")
32
- ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
33
-
34
- HF_TOKEN: Final[Optional[str]] = os.getenv('HF_TOKEN')
35
- MODEL_NAME: Final[str] = '../hf/270m'
36
- DEFAULT_DATASET: Final[str] = 'bebechien/SimpleToolCalling'
37
- OUTPUT_DIR: Final[Path] = ARTIFACTS_DIR.joinpath("functiongemma-modkit-demo")
38
-
39
-
40
- # --- Tool Definitions ---
41
- def search_knowledge_base(query: str) -> str:
42
- """
43
- Search internal company documents, policies and project data.
44
-
45
- Args:
46
- query: query string
47
- """
48
- return "Interal Result"
49
-
50
- def search_google(query: str) -> str:
51
- """
52
- Search public information.
53
-
54
- Args:
55
- query: query string
56
- """
57
- return "Public Result"
58
-
59
- search_knowledge_base_schema = {
60
- "type": "function",
61
- "function": {
62
- "name": "search_knowledge_base",
63
- "description": "Search internal company documents, policies and project data.",
64
- "parameters": {
65
- "type": "object",
66
- "properties": {
67
- "query": {
68
- "type": "string",
69
- "description": "query string"
70
- }
71
- },
72
- "required": [
73
- "query"
74
- ]
75
- },
76
- "return": {
77
- "type": "string"
78
- }
79
- }
80
- }
81
-
82
- search_google_schema = {
83
- "type": "function",
84
- "function": {
85
- "name": "search_google",
86
- "description": "Search public information.",
87
- "parameters": {
88
- "type": "object",
89
- "properties": {
90
- "query": {
91
- "type": "string",
92
- "description": "query string"
93
- }
94
- },
95
- "required": [
96
- "query"
97
- ]
98
- },
99
- "return": {
100
- "type": "string"
101
- }
102
- }
103
- }
104
-
105
- TOOLS = [search_knowledge_base_schema, search_google_schema]
106
- DEFAULT_SYSTEM_MSG = "You are a model that can do function calling with the following functions"
107
-
108
- # --- Callbacks ---
109
- class AbortCallback(TrainerCallback):
110
- """
111
- A custom callback to check a threading Event to stop training on user request.
112
- """
113
- def __init__(self, stop_event: threading.Event):
114
- self.stop_event = stop_event
115
-
116
- def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
117
- if self.stop_event.is_set():
118
- print("🛑 Stop signal received. Stopping training...")
119
- control.should_training_stop = True
120
-
121
-
122
- # --- Helper Functions ---
123
- def authenticate_hf(token: Optional[str]) -> None:
124
- """Logs into the Hugging Face Hub."""
125
- if token:
126
- print("Logging into Hugging Face Hub...")
127
- login(token=token)
128
- else:
129
- print("Skipping Hugging Face login: HF_TOKEN not set.")
130
-
131
- def load_model_and_tokenizer(model_name: str):
132
- print(f"Loading Transformer model: {model_name}")
133
- try:
134
- # Check if local path exists, otherwise treat as HF Hub ID
135
- if model_name.startswith("..") and not os.path.exists(model_name):
136
- print(f"Warning: Local path {model_name} not found. Falling back to default hub model.")
137
- model_name = "google/gemma-2b-it" # Fallback example
138
-
139
- tokenizer = AutoTokenizer.from_pretrained(model_name)
140
- model = AutoModelForCausalLM.from_pretrained(model_name)
141
- print("Model loaded successfully.")
142
- return model, tokenizer
143
- except Exception as e:
144
- print(f"Error loading Transformer model {model_name}: {e}")
145
- raise e
146
-
147
- def create_conversation_format(sample):
148
- """Formats a dataset row into the conversational format required for SFT."""
149
- try:
150
- tool_args = json.loads(sample["tool_arguments"])
151
- except (json.JSONDecodeError, TypeError):
152
- tool_args = {}
153
-
154
- return {
155
- "messages": [
156
- {"role": "developer", "content": DEFAULT_SYSTEM_MSG},
157
- {"role": "user", "content": sample["user_content"]},
158
- {"role": "assistant", "tool_calls": [{"type": "function", "function": {"name": sample["tool_name"], "arguments": tool_args}}]},
159
- ],
160
- "tools": TOOLS
161
- }
162
-
163
-
164
- # --- Main Application Logic ---
165
- class FunctionGemmaTuner:
166
- def __init__(self, config: AppConfig = AppConfig):
167
- self.config = config
168
- self.model = None
169
- self.tokenizer = None
170
- self.imported_dataset = []
171
-
172
- # Threading event to control stopping
173
- self.stop_event = threading.Event()
174
-
175
- authenticate_hf(self.config.HF_TOKEN)
176
-
177
- # Initial load attempt
178
- print("--- Running Initial Data Load ---")
179
- try:
180
- self.refresh_data_and_model()
181
- print("--- Initial Load Complete ---")
182
- except Exception as e:
183
- print(f"Initial load failed (this is common if model path is invalid): {e}")
184
-
185
- def refresh_data_and_model(self):
186
- """Reloads the model and clears imported data."""
187
- print("\n" + "=" * 50)
188
- print("RELOADING MODEL and RE-FETCHING DATA")
189
-
190
- self.imported_dataset = []
191
-
192
- try:
193
- self.model, self.tokenizer = load_model_and_tokenizer(self.config.MODEL_NAME)
194
- status_value = "Model and data reloaded. Ready."
195
- except Exception as e:
196
- self.model = None
197
- self.tokenizer = None
198
- status_value = f"CRITICAL ERROR: Model failed to load. {e}"
199
- # We don't raise here to allow the UI to render the error message
200
-
201
- return status_value
202
-
203
- def import_additional_dataset(self, file_path: str) -> str:
204
- """Parses an uploaded CSV file."""
205
- if not file_path:
206
- return "Please upload a CSV file."
207
-
208
- new_dataset = []
209
- num_imported = 0
210
-
211
- try:
212
- # Open file handle properly
213
- with open(file_path, 'r', newline='', encoding='utf-8') as f:
214
- reader = csv.reader(f)
215
-
216
- # Basic header validation
217
- try:
218
- header = next(reader)
219
- # Simple heuristic check, allows skipping header or rewinding
220
- if not (header and "anchor" in header[0].lower()):
221
- f.seek(0)
222
- except StopIteration:
223
- return "Error: Uploaded file is empty."
224
-
225
- for row in reader:
226
- # Expecting: [User Prompt, Tool Name, Tool Args JSON/String]
227
- if len(row) >= 3:
228
- new_dataset.append([s.strip() for s in row[:3]])
229
- num_imported += 1
230
-
231
- if num_imported == 0:
232
- return "No valid rows found. CSV format: [Anchor, Positive, Negative]"
233
-
234
- self.imported_dataset = new_dataset
235
- return f"Successfully imported {num_imported} additional training samples."
236
-
237
- except Exception as e:
238
- return f"Import failed. Error: {e}"
239
-
240
- def stop_training(self):
241
- """Signal the training loop to stop."""
242
- print("Set stop event")
243
- self.stop_event.set()
244
- return "Stopping initiated... please wait for the current step to finish."
245
-
246
- def run_training(self, test_size: float = 0.5) -> Generator[str, None, None]:
247
- """
248
- Main training logic. Yields status strings to the UI.
249
- """
250
- # 1. Validation
251
- if self.model is None:
252
- yield "Training failed: Model is not loaded."
253
- return
254
-
255
- self.stop_event.clear() # Reset stop flag
256
- yield "⏳ Preparing Dataset..."
257
-
258
- # 2. Dataset Preparation
259
- if not self.imported_dataset:
260
- print("No imported dataset, using default HF dataset")
261
- try:
262
- dataset = load_dataset(self.config.DEFAULT_DATASET, split="train")
263
- except Exception as e:
264
- yield f"Error loading default dataset: {e}"
265
- return
266
- else:
267
- dataset_as_dicts = [{
268
- "user_content": row[0], "tool_name": row[1], "tool_arguments": row[2]}
269
- for row in self.imported_dataset
270
- ]
271
- dataset = Dataset.from_list(dataset_as_dicts)
272
-
273
- # Apply formatting
274
- dataset = dataset.map(create_conversation_format, batched=False)
275
-
276
- # Split
277
- if len(dataset) > 1:
278
- dataset = dataset.train_test_split(test_size=test_size, shuffle=False)
279
- else:
280
- # Fallback for very small datasets (mostly for debugging)
281
- dataset = {"train": dataset, "test": dataset}
282
-
283
- output_buffer = "📊 Evaluating Pre-Training Success Rate...\n### Success Rate (Before Training):\n"
284
- yield output_buffer
285
- pre_training_report = ""
286
- gen = self.check_success_rate(dataset["test"])
287
- while not self.stop_event.is_set():
288
- try:
289
- pre_training_report += f"{next(gen)}\n"
290
- yield f"{output_buffer}{pre_training_report}"
291
- except StopIteration as e:
292
- pre_training_report = e.value
293
- break
294
-
295
- if self.stop_event.is_set():
296
- output_buffer += f"{pre_training_report}\n\n🛑 Manual Eval interrupted by user.\n"
297
- yield output_buffer
298
- return
299
-
300
- output_buffer += f"{pre_training_report}\n\n"
301
- output_buffer += "-" * 30 + "\nStarting Fine-tuning...\n"
302
- yield output_buffer
303
-
304
- # 3. Training Setup
305
- torch_dtype = self.model.dtype
306
-
307
- args = SFTConfig(
308
- output_dir=str(self.config.OUTPUT_DIR),
309
- max_length=512,
310
- packing=False,
311
- num_train_epochs=5,
312
- per_device_train_batch_size=4,
313
- gradient_checkpointing=False,
314
- optim="adamw_torch_fused",
315
- logging_steps=1,
316
- save_strategy="no", # Speed up demo
317
- eval_strategy="epoch",
318
- learning_rate=5e-5,
319
- fp16=True if torch_dtype == torch.float16 else False,
320
- bf16=True if torch_dtype == torch.bfloat16 else False,
321
- lr_scheduler_type="constant",
322
- push_to_hub=False,
323
- report_to="none",
324
- dataset_kwargs={
325
- "add_special_tokens": False,
326
- "append_concat_token": True,
327
- }
328
- )
329
-
330
- trainer = SFTTrainer(
331
- model=self.model,
332
- args=args,
333
- train_dataset=dataset['train'],
334
- eval_dataset=dataset['test'],
335
- processing_class=self.tokenizer,
336
- callbacks=[AbortCallback(self.stop_event)] # Inject our stopper
337
- )
338
-
339
- # 4. Run Training
340
- try:
341
- output_buffer += "🚀 Training in progress... (Click Stop to interrupt)\n"
342
- yield output_buffer
343
- trainer.train()
344
-
345
- if self.stop_event.is_set():
346
- output_buffer += "\n🛑 Training interrupted by user.\n"
347
- else:
348
- output_buffer += "\n✅ Training finished. Model weights updated in memory.\n"
349
- yield output_buffer
350
-
351
- # Save locally
352
- trainer.save_model()
353
- output_buffer += f"Model saved locally to: {self.config.OUTPUT_DIR}\n"
354
- yield output_buffer
355
-
356
- except Exception as e:
357
- output_buffer += f"\n❌ Error during training: {e}\n"
358
- yield output_buffer
359
- return
360
-
361
- if self.stop_event.is_set():
362
- return
363
-
364
- # 5. Post-Evaluation
365
- output_buffer += "📊 Evaluating Post-Training Success Rate...\n"
366
- post_report = ""
367
- yield output_buffer
368
- gen = self.check_success_rate(dataset["test"])
369
- while not self.stop_event.is_set():
370
- try:
371
- post_report += f"{next(gen)}\n"
372
- yield f"{output_buffer}{post_report}"
373
- except StopIteration as e:
374
- post_report = e.value
375
- break
376
-
377
- if self.stop_event.is_set():
378
- output_buffer += f"{post_report}\n\n🛑 Manual Eval interrupted by user.\n"
379
- yield output_buffer
380
- return
381
-
382
- output_buffer += f"{post_report}\n\n"
383
- yield output_buffer
384
-
385
- def check_success_rate(self, test_dataset):
386
- """Runs inference on test set to calculate accuracy."""
387
- results = []
388
- success_count = 0
389
- total = len(test_dataset)
390
-
391
- for idx, item in enumerate(test_dataset):
392
- if idx >= 5:
393
- break
394
- if self.stop_event.is_set():
395
- break
396
-
397
- messages = [item["messages"][0], item["messages"][1]] # System + User
398
-
399
- try:
400
- inputs = self.tokenizer.apply_chat_template(
401
- messages,
402
- tools=TOOLS,
403
- add_generation_prompt=True,
404
- return_dict=True,
405
- return_tensors="pt"
406
- )
407
-
408
- out = self.model.generate(
409
- **inputs.to(self.model.device),
410
- pad_token_id=self.tokenizer.eos_token_id,
411
- max_new_tokens=128
412
- )
413
-
414
- # Decode only the new tokens
415
- output = self.tokenizer.decode(out[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True)
416
-
417
- results.append(f"{idx+1}. Prompt: {item['messages'][1]['content']}")
418
- yield results[-1]
419
- results.append(f" Output: {output[:100]}...")
420
- yield results[-1]
421
-
422
- # Check for correct tool name usage
423
- expected_tool = item['messages'][2]['tool_calls'][0]['function']['name']
424
- if expected_tool in output:
425
- results.append(" -> ✅ Correct Tool")
426
- yield results[-1]
427
- success_count += 1
428
- else:
429
- results.append(f" -> ❌ Wrong Tool (Expected: {expected_tool})")
430
- yield results[-1]
431
-
432
- except Exception as e:
433
- results.append(f" -> Error: {e}")
434
- yield results[-1]
435
-
436
- summary = "\n".join(results)
437
- summary += f"\n\nTotal Success : {success_count} / {len(test_dataset)}"
438
- return summary
439
-
440
- def download_model_zip(self) -> Optional[str]:
441
- """Zips the output directory for download."""
442
- if not os.path.exists(self.config.OUTPUT_DIR):
443
- return None
444
-
445
- timestamp = int(time.time())
446
- try:
447
- base_name = self.config.ARTIFACTS_DIR.joinpath(f"functiongemma_finetuned_{timestamp}")
448
- archive_path = shutil.make_archive(
449
- base_name=str(base_name),
450
- format='zip',
451
- root_dir=str(self.config.OUTPUT_DIR),
452
- )
453
- return archive_path
454
- except Exception as e:
455
- print(f"Zip failed: {e}")
456
- return None
457
-
458
- # --- UI Builder ---
459
- def build_interface(self) -> gr.Blocks:
460
- with gr.Blocks(title="FunctionGemma Modkit") as demo:
461
- gr.Markdown("# 🤖 FunctionGemma Modkit: Fine-Tuning")
462
- gr.Markdown("Fine-tune FunctionGemma to understand your custom functions.")
463
-
464
- with gr.Column():
465
- gr.Markdown("## 1. Training Controls")
466
-
467
- with gr.Row():
468
- run_training_btn = gr.Button("🚀 Run Fine-Tuning", variant="primary")
469
- stop_training_btn = gr.Button("🛑 Stop Training", variant="stop", visible=False)
470
-
471
- output_display = gr.Textbox(
472
- lines=14,
473
- label="Training Logs & Search Results",
474
- value="Ready. Click 'Run' to begin.",
475
- interactive=False
476
- )
477
-
478
- clear_reload_btn = gr.Button("🔄 Reset Model & Data")
479
-
480
- gr.Markdown("--- \n ## 2. Data Management")
481
- import_file = gr.File(label="Upload Additional Dataset (.csv)", file_types=[".csv"], height=80)
482
- import_status = gr.Markdown("")
483
-
484
- gr.Markdown("--- \n ## 3. Export")
485
- with gr.Row():
486
- zip_btn = gr.Button("⬇️ Prepare Model ZIP")
487
- download_file = gr.File(label="Download ZIP", height=80, visible=True, interactive=False)
488
-
489
- # --- Event Wiring ---
490
-
491
- # Start Training (Generator updates output_display)
492
- run_training_btn.click(
493
- fn=lambda: (
494
- gr.update(visible=False),
495
- gr.update(interactive=False),
496
- gr.update(visible=True)
497
- ),
498
- inputs=None,
499
- outputs=[run_training_btn, clear_reload_btn, stop_training_btn]
500
- ).then(
501
- fn=self.run_training,
502
- inputs=[],
503
- outputs=[output_display],
504
- ).then(
505
- fn=lambda: (
506
- gr.update(visible=True),
507
- gr.update(interactive=True),
508
- gr.update(visible=False)
509
- ),
510
- inputs=None,
511
- outputs=[run_training_btn, clear_reload_btn, stop_training_btn]
512
- )
513
-
514
- # Stop Training
515
- stop_training_btn.click(
516
- fn=self.stop_training,
517
- inputs=None,
518
- outputs=None # We don't need to return anything, status updates via the training generator
519
- ).then(
520
- fn=lambda: (
521
- gr.update(visible=True),
522
- gr.update(interactive=True),
523
- gr.update(visible=False)
524
- ),
525
- inputs=None,
526
- outputs=[run_training_btn, clear_reload_btn, stop_training_btn]
527
- )
528
-
529
- # Reload
530
- clear_reload_btn.click(
531
- fn=self.refresh_data_and_model,
532
- inputs=None,
533
- outputs=[output_display]
534
- )
535
-
536
- # File Import
537
- import_file.upload(
538
- fn=self.import_additional_dataset,
539
- inputs=[import_file],
540
- outputs=[import_status]
541
- )
542
-
543
- # Download Logic
544
- def handle_zip():
545
- path = self.download_model_zip()
546
- if path:
547
- return gr.update(value=path, visible=True)
548
- return gr.update(value=None, visible=False)
549
-
550
- zip_btn.click(
551
- fn=handle_zip,
552
- inputs=None,
553
- outputs=[download_file]
554
- )
555
-
556
- return demo
557
 
558
  if __name__ == "__main__":
559
- app = FunctionGemmaTuner(AppConfig)
560
- demo = app.build_interface()
 
 
 
 
 
 
561
  print("Starting Gradio App...")
562
  demo.launch()
 
1
+ from config import AppConfig
2
+ from engine import FunctionGemmaEngine
3
+ from ui import build_interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  if __name__ == "__main__":
6
+ # Initialize Config
7
+ config = AppConfig()
8
+
9
+ # Initialize Logic Engine
10
+ app_engine = FunctionGemmaEngine(config)
11
+
12
+ # Build and Launch UI
13
+ demo = build_interface(app_engine)
14
  print("Starting Gradio App...")
15
  demo.launch()
config.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Final, Optional
4
+ from dataclasses import dataclass
5
+
6
+ @dataclass
7
+ class AppConfig:
8
+ """
9
+ Central configuration class.
10
+ """
11
+ # Directory Setup
12
+ ARTIFACTS_DIR: Final[Path] = Path("artifacts")
13
+ OUTPUT_DIR: Final[Path] = ARTIFACTS_DIR.joinpath("functiongemma-modkit-demo")
14
+
15
+ # Model & Data
16
+ HF_TOKEN: Final[Optional[str]] = os.getenv('HF_TOKEN')
17
+ # Defaulting to a real model ID for safety, original was local path '../hf/270m'
18
+ MODEL_NAME: Final[str] = '../hf/270m'
19
+ DEFAULT_DATASET: Final[str] = 'bebechien/SimpleToolCalling'
20
+
21
+ def __post_init__(self):
22
+ self.ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
engine.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import torch
3
+ import time
4
+ import json
5
+ import queue
6
+ import matplotlib.pyplot as plt
7
+ from functools import partial
8
+ from typing import Generator, Optional, List, Dict
9
+ from datasets import Dataset, load_dataset
10
+ from trl import SFTConfig, SFTTrainer
11
+ from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
12
+
13
+ from config import AppConfig
14
+ from tools import DEFAULT_TOOLS
15
+ from utils import (
16
+ authenticate_hf,
17
+ load_model_and_tokenizer,
18
+ create_conversation_format,
19
+ parse_csv_dataset,
20
+ zip_directory
21
+ )
22
+
23
+ class AbortCallback(TrainerCallback):
24
+ def __init__(self, stop_event: threading.Event):
25
+ self.stop_event = stop_event
26
+
27
+ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
28
+ if self.stop_event.is_set():
29
+ control.should_training_stop = True
30
+
31
+ class LogStreamingCallback(TrainerCallback):
32
+ """
33
+ NEW: Intercepts training logs and pushes them to a queue
34
+ so the main thread can display them in the UI.
35
+ """
36
+ def __init__(self, log_queue: queue.Queue):
37
+ self.log_queue = log_queue
38
+
39
+ def _get_string(self, value):
40
+ if isinstance(value, float):
41
+ return f"{value:.4f}"
42
+ return str(value)
43
+
44
+ def on_log(self, args, state, control, logs=None, **kwargs):
45
+ if not logs:
46
+ return
47
+
48
+ metrics_map = {
49
+ "loss": "Loss",
50
+ "eval_loss": "Eval Loss",
51
+ "learning_rate": "LR",
52
+ "epoch": "Epoch"
53
+ }
54
+ log_parts = [f"📝 [Step {state.global_step}]"]
55
+
56
+ for key, label in metrics_map.items():
57
+ if key in logs:
58
+ val = logs[key]
59
+ # Format floats: use scientific notation for very small numbers (like LR)
60
+ if isinstance(val, (float, int)):
61
+ val_str = f"{val:.4f}" if val > 1e-4 else f"{val:.2e}"
62
+ else:
63
+ val_str = str(val)
64
+
65
+ log_parts.append(f"{label}: {val_str}")
66
+
67
+ self.log_queue.put(" | ".join(log_parts))
68
+
69
+ class FunctionGemmaEngine:
70
+ def __init__(self, config: AppConfig):
71
+ self.config = config
72
+ self.model = None
73
+ self.tokenizer = None
74
+ self.imported_dataset = []
75
+ self.stop_event = threading.Event()
76
+
77
+ # NEW: State for tools
78
+ self.current_tools = DEFAULT_TOOLS
79
+
80
+ authenticate_hf(self.config.HF_TOKEN)
81
+ try:
82
+ self.refresh_data_and_model()
83
+ except Exception as e:
84
+ print(f"Initial load warning: {e}")
85
+
86
+ # NEW: Methods to handle Tool Schema updates
87
+ def get_tools_json(self) -> str:
88
+ return json.dumps(self.current_tools, indent=2)
89
+
90
+ def update_tools(self, json_str: str) -> str:
91
+ try:
92
+ new_tools = json.loads(json_str)
93
+ if not isinstance(new_tools, list):
94
+ return "Error: Schema must be a list of tool definitions."
95
+ self.current_tools = new_tools
96
+ return "✅ Tool Schema Updated successfully."
97
+ except json.JSONDecodeError as e:
98
+ return f"❌ JSON Error: {e}"
99
+ except Exception as e:
100
+ return f"❌ Error: {e}"
101
+
102
+ def refresh_data_and_model(self) -> str:
103
+ self.imported_dataset = []
104
+ try:
105
+ self.model, self.tokenizer = load_model_and_tokenizer(self.config.MODEL_NAME)
106
+ return "Model and data reloaded. Ready."
107
+ except Exception as e:
108
+ self.model = None
109
+ self.tokenizer = None
110
+ return f"CRITICAL ERROR: Model failed to load. {e}"
111
+
112
+ def load_csv(self, file_path: str) -> str:
113
+ try:
114
+ new_data = parse_csv_dataset(file_path)
115
+ if not new_data:
116
+ return "Error: File empty or format invalid."
117
+ self.imported_dataset = new_data
118
+ return f"Successfully imported {len(new_data)} samples."
119
+ except Exception as e:
120
+ return f"Import failed: {e}"
121
+
122
+ def trigger_stop(self):
123
+ self.stop_event.set()
124
+
125
+ def run_training_pipeline(self, epochs: int, learning_rate: float, test_size: float, shuffle_data: bool) -> Generator[str, None, None]:
126
+ if self.model is None:
127
+ yield "Training failed: Model is not loaded.", None
128
+ return
129
+
130
+ self.stop_event.clear()
131
+ output_buffer = f"⏳ Preparing Dataset (Test Split: {test_size}, Shuffle: {shuffle_data})...\n"
132
+ yield output_buffer, None
133
+
134
+ dataset, log = self._prepare_dataset()
135
+ if not dataset:
136
+ yield "Dataset creation failed.", None
137
+ return
138
+
139
+ output_buffer += log
140
+ yield output_buffer, None
141
+
142
+ if len(dataset) > 1:
143
+ dataset = dataset.train_test_split(test_size=test_size, shuffle=shuffle_data)
144
+ else:
145
+ dataset = {"train": dataset, "test": dataset}
146
+
147
+ # --- Phase 1: Pre-Training Eval ---
148
+ output_buffer += "\n📊 Evaluating Pre-Training Success Rate...\n"
149
+ yield output_buffer, None
150
+
151
+ pre_training_report = ""
152
+ for update in self._evaluate_model(dataset["test"]):
153
+ pre_training_report = update
154
+ if self.stop_event.is_set():
155
+ pre_training_report += "\n\n🛑 Manual Eval interrupted by user.\n"
156
+ yield f"{output_buffer}{pre_training_report}", None
157
+ break
158
+ yield f"{output_buffer}{pre_training_report}", None
159
+
160
+ if self.stop_event.is_set(): return
161
+ output_buffer += pre_training_report
162
+
163
+ # --- Phase 2: Training (Threaded) ---
164
+ output_buffer += "\n\n🚀 Starting Fine-tuning (Epochs: {epochs}, LR: {learning_rate})...\n"
165
+ yield output_buffer, None
166
+
167
+ log_queue = queue.Queue()
168
+ training_error = None
169
+ training_history = []
170
+
171
+ # Function to run in the thread
172
+ def train_wrapper():
173
+ nonlocal training_error, training_history
174
+ try:
175
+ training_history = self._execute_trainer(dataset, log_queue, epochs, learning_rate)
176
+ except Exception as e:
177
+ training_error = e
178
+
179
+ # Start training thread
180
+ train_thread = threading.Thread(target=train_wrapper)
181
+ train_thread.start()
182
+
183
+ # Monitor loop: Yields logs while training runs
184
+ while train_thread.is_alive():
185
+ # Drain the queue
186
+ while not log_queue.empty():
187
+ log_msg = log_queue.get()
188
+ output_buffer += f"{log_msg}\n"
189
+ yield output_buffer, None
190
+
191
+ # Check for stop signal
192
+ if self.stop_event.is_set():
193
+ yield f"{output_buffer}🛑 Stop signal sent. Waiting for trainer to wrap up...\n", None
194
+ # We don't break here, we wait for thread to finish cleanly
195
+
196
+ time.sleep(0.1) # Prevent CPU spinning
197
+
198
+ train_thread.join() # Ensure thread is completely done
199
+
200
+ # Flush any remaining logs
201
+ while not log_queue.empty():
202
+ log_msg = log_queue.get()
203
+ output_buffer += f"{log_msg}\n"
204
+ yield output_buffer, None
205
+
206
+ if training_error:
207
+ output_buffer += f"❌ Error during training: {training_error}\n"
208
+ yield output_buffer, None
209
+ return
210
+
211
+ if self.stop_event.is_set():
212
+ output_buffer += "🛑 Training manually stopped.\n"
213
+ yield output_buffer, None
214
+ return
215
+
216
+ output_buffer += "✅ Training finished.\n"
217
+ yield output_buffer, None
218
+
219
+ output_buffer += "\n📈 Generating Loss Plot...\n"
220
+ yield output_buffer, None
221
+
222
+ try:
223
+ final_plot = self._generate_loss_plot(training_history)
224
+ yield output_buffer, final_plot
225
+ except Exception as e:
226
+ output_buffer += f"⚠️ Could not generate plot: {e}\n"
227
+ yield output_buffer, None
228
+
229
+ # --- Phase 3: Post-Training Eval ---
230
+ output_buffer += "\n📊 Evaluating Post-Training Success Rate...\n"
231
+ yield output_buffer, final_plot
232
+
233
+ post_training_report = ""
234
+ for update in self._evaluate_model(dataset["test"]):
235
+ post_training_report = update
236
+ if self.stop_event.is_set():
237
+ post_training_report += "\n\n🛑 Manual Eval interrupted by user.\n"
238
+ yield f"{output_buffer}{post_training_report}", final_plot
239
+ break
240
+ yield f"{output_buffer}{post_training_report}", final_plot
241
+
242
+ def _prepare_dataset(self):
243
+ # NEW: Use partial to inject self.current_tools into the formatting function
244
+ formatting_fn = partial(create_conversation_format, tools_list=self.current_tools)
245
+
246
+ if not self.imported_dataset:
247
+ ds = load_dataset(self.config.DEFAULT_DATASET, split="train").map(formatting_fn)
248
+ log = f" `-> using default dataset (size:{len(ds)})\n"
249
+ else:
250
+ dataset_as_dicts = [{
251
+ "user_content": row[0], "tool_name": row[1], "tool_arguments": row[2]}
252
+ for row in self.imported_dataset
253
+ ]
254
+ ds = Dataset.from_list(dataset_as_dicts).map(formatting_fn)
255
+ log = f" `-> using custom dataset (size:{len(ds)})\n"
256
+ return ds, log
257
+
258
+ def _execute_trainer(self, dataset, log_queue: queue.Queue, epochs: int, learning_rate: float) -> List[Dict]:
259
+ torch_dtype = self.model.dtype
260
+ args = SFTConfig(
261
+ output_dir=str(self.config.OUTPUT_DIR),
262
+ max_length=512,
263
+ packing=False,
264
+ num_train_epochs=epochs,
265
+ per_device_train_batch_size=4,
266
+ logging_steps=1,
267
+ save_strategy="no",
268
+ eval_strategy="epoch",
269
+ learning_rate=learning_rate,
270
+ fp16=(torch_dtype == torch.float16),
271
+ bf16=(torch_dtype == torch.bfloat16),
272
+ report_to="none",
273
+ dataset_kwargs={"add_special_tokens": False, "append_concat_token": True}
274
+ )
275
+
276
+ trainer = SFTTrainer(
277
+ model=self.model,
278
+ args=args,
279
+ train_dataset=dataset['train'],
280
+ eval_dataset=dataset['test'],
281
+ processing_class=self.tokenizer,
282
+ callbacks=[
283
+ AbortCallback(self.stop_event),
284
+ LogStreamingCallback(log_queue)
285
+ ]
286
+ )
287
+ trainer.train()
288
+ trainer.save_model()
289
+
290
+ return trainer.state.log_history
291
+
292
+ def _generate_loss_plot(self, history: list):
293
+ if not history:
294
+ return None
295
+
296
+ # Extract Training Loss
297
+ # log_history format: [{'loss': 0.5, 'step': 1}, {'eval_loss': 0.4, 'step': 1}, ...]
298
+ train_steps = [x['step'] for x in history if 'loss' in x]
299
+ train_loss = [x['loss'] for x in history if 'loss' in x]
300
+
301
+ # Extract Validation Loss
302
+ eval_steps = [x['step'] for x in history if 'eval_loss' in x]
303
+ eval_loss = [x['eval_loss'] for x in history if 'eval_loss' in x]
304
+
305
+ fig, ax = plt.subplots(figsize=(10, 5))
306
+
307
+ if train_steps:
308
+ ax.plot(train_steps, train_loss, label='Training Loss', linestyle='-', marker=None)
309
+
310
+ if eval_steps:
311
+ ax.plot(eval_steps, eval_loss, label='Validation Loss', linestyle='--', marker='o')
312
+
313
+ ax.set_xlabel("Steps")
314
+ ax.set_ylabel("Loss")
315
+ ax.set_title("Training & Validation Loss")
316
+ ax.legend()
317
+ ax.grid(True, linestyle=':', alpha=0.6)
318
+
319
+ plt.tight_layout()
320
+ return fig
321
+
322
+ def _evaluate_model(self, test_dataset) -> Generator[str, None, None]:
323
+ results = []
324
+ success_count = 0
325
+
326
+ for idx, item in enumerate(test_dataset):
327
+ messages = item["messages"][:2]
328
+ try:
329
+ # NEW: Pass self.current_tools to the template
330
+ inputs = self.tokenizer.apply_chat_template(
331
+ messages, tools=self.current_tools, add_generation_prompt=True, return_dict=True, return_tensors="pt"
332
+ )
333
+
334
+ device = self.model.device
335
+ inputs = {k: v.to(device) for k, v in inputs.items()}
336
+
337
+ out = self.model.generate(
338
+ **inputs,
339
+ pad_token_id=self.tokenizer.eos_token_id,
340
+ max_new_tokens=128
341
+ )
342
+ output = self.tokenizer.decode(out[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
343
+
344
+ log_entry = f"{idx+1}. Prompt: {messages[1]['content']}\n Output: {output[:100]}..."
345
+
346
+ # Check tool correctness
347
+ expected_tool = item['messages'][2]['tool_calls'][0]['function']['name']
348
+ if expected_tool in output:
349
+ log_entry += "\n -> ✅ Correct Tool"
350
+ success_count += 1
351
+ else:
352
+ log_entry += f"\n -> ❌ Wrong Tool (Expected: {expected_tool})"
353
+
354
+ results.append(log_entry)
355
+ yield "\n".join(results) + f"\n\nRunning Success Rate: {success_count}/{idx+1}"
356
+
357
+ except Exception as e:
358
+ yield f"Error during inference: {e}"
359
+
360
+ def get_zip_path(self) -> Optional[str]:
361
+ if not self.config.OUTPUT_DIR.exists():
362
+ return None
363
+ timestamp = int(time.time())
364
+ base_name = str(self.config.ARTIFACTS_DIR.joinpath(f"functiongemma_finetuned_{timestamp}"))
365
+ return zip_directory(str(self.config.OUTPUT_DIR), base_name)
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  accelerate
2
  datasets
3
  gradio
 
4
  transformers
5
  trl
 
1
  accelerate
2
  datasets
3
  gradio
4
+ matplotlib
5
  transformers
6
  trl
tools.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- Tool Definitions ---
2
+ # (Existing python functions search_knowledge_base/search_google remain here for reference,
3
+ # but the schema below is what matters for the LLM)
4
+
5
+ search_knowledge_base_schema = {
6
+ "type": "function",
7
+ "function": {
8
+ "name": "search_knowledge_base",
9
+ "description": "Search internal company documents, policies and project data.",
10
+ "parameters": {
11
+ "type": "object",
12
+ "properties": {
13
+ "query": {
14
+ "type": "string",
15
+ "description": "query string"
16
+ }
17
+ },
18
+ "required": ["query"]
19
+ },
20
+ "return": {"type": "string"}
21
+ }
22
+ }
23
+
24
+ search_google_schema = {
25
+ "type": "function",
26
+ "function": {
27
+ "name": "search_google",
28
+ "description": "Search public information.",
29
+ "parameters": {
30
+ "type": "object",
31
+ "properties": {
32
+ "query": {
33
+ "type": "string",
34
+ "description": "query string"
35
+ }
36
+ },
37
+ "required": ["query"]
38
+ },
39
+ "return": {"type": "string"}
40
+ }
41
+ }
42
+
43
+ # Renamed to DEFAULT_TOOLS to imply modifiability
44
+ DEFAULT_TOOLS = [search_knowledge_base_schema, search_google_schema]
45
+ DEFAULT_SYSTEM_MSG = "You are a model that can do function calling with the following functions"
ui.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from engine import FunctionGemmaEngine
3
+
4
+ def build_interface(engine: FunctionGemmaEngine) -> gr.Blocks:
5
+ with gr.Blocks(title="FunctionGemma Modkit") as demo:
6
+ gr.Markdown("# 🤖 FunctionGemma Modkit: Fine-Tuning")
7
+ gr.Markdown("Fine-tune FunctionGemma to understand your custom functions.")
8
+
9
+ with gr.Tabs():
10
+
11
+ # --- TAB 1: PREPARING DATASET ---
12
+ with gr.TabItem("1. Preparing Dataset"):
13
+ gr.Markdown("### 🛠️ Tool Schema & Data Import")
14
+
15
+ with gr.Row():
16
+ with gr.Column(scale=1):
17
+ gr.Markdown("**Step 1: Define Functions**\n\nEdit the JSON schema below to define the tools the model should learn.")
18
+ tools_editor = gr.Code(
19
+ value=engine.get_tools_json(),
20
+ language="json",
21
+ label="Tool Definitions (JSON Schema)",
22
+ lines=15
23
+ )
24
+ update_tools_btn = gr.Button("💾 Update Tool Schema")
25
+ tools_status = gr.Markdown("")
26
+
27
+ with gr.Column(scale=1):
28
+ gr.Markdown("**Step 2: Upload Data (Optional)**\n\nUpload a CSV file to replace the default dataset.\nFormat: `[User Prompt, Tool Name, Tool Args JSON]`")
29
+ import_file = gr.File(
30
+ label="Upload Dataset (.csv)",
31
+ file_types=[".csv"],
32
+ height=100
33
+ )
34
+ import_status = gr.Markdown("")
35
+
36
+ # --- TAB 2: TRAINING ---
37
+ with gr.TabItem("2. Training"):
38
+ gr.Markdown("### 🚀 Fine-Tuning Configuration")
39
+
40
+ with gr.Group():
41
+ gr.Markdown("**Hyperparameters**")
42
+ with gr.Row():
43
+ param_epochs = gr.Slider(
44
+ minimum=1, maximum=20, value=5, step=1,
45
+ label="Epochs", info="Total training passes"
46
+ )
47
+ param_lr = gr.Number(
48
+ value=5e-5,
49
+ label="Learning Rate",
50
+ info="e.g. 5e-5"
51
+ )
52
+ param_test_size = gr.Slider(
53
+ minimum=0.1, maximum=0.9, value=0.2, step=0.05,
54
+ label="Test Split", info="Validation data ratio. Typical value is 0.2 (80% for training, 20% for testing)"
55
+ )
56
+ param_shuffle = gr.Checkbox(
57
+ value=True,
58
+ label="Shuffle Data",
59
+ info="Randomize before split"
60
+ )
61
+
62
+ with gr.Row():
63
+ run_training_btn = gr.Button("🚀 Run Fine-Tuning", variant="primary", scale=2)
64
+ stop_training_btn = gr.Button("🛑 Stop", variant="stop", visible=False, scale=1)
65
+ clear_reload_btn = gr.Button("🔄 Reset", variant="secondary", scale=1)
66
+
67
+ with gr.Row():
68
+ # Left column: Text Logs
69
+ output_display = gr.Textbox(
70
+ lines=20,
71
+ label="Logs & Results",
72
+ value="Ready.",
73
+ interactive=False,
74
+ autoscroll=True
75
+ )
76
+ # Right column: Plot (NEW)
77
+ loss_plot = gr.Plot(label="Training Metrics")
78
+
79
+ # --- TAB 3: EXPORT ---
80
+ with gr.TabItem("3. Export"):
81
+ gr.Markdown("### 📦 Export Trained Model")
82
+ gr.Markdown("Download the fine-tuned LoRA adapters or full model weights (depending on configuration) as a ZIP file.")
83
+
84
+ with gr.Row():
85
+ zip_btn = gr.Button("⬇️ Prepare Model ZIP", variant="primary", scale=1)
86
+ download_file = gr.File(label="Download Archive", interactive=False, scale=2)
87
+
88
+ # --- EVENT WIRING ---
89
+
90
+ # Tab 1: Tools
91
+ update_tools_btn.click(
92
+ fn=engine.update_tools,
93
+ inputs=[tools_editor],
94
+ outputs=[tools_status]
95
+ )
96
+
97
+ # Tab 1: File Import
98
+ import_file.upload(
99
+ fn=engine.load_csv,
100
+ inputs=[import_file],
101
+ outputs=[import_status]
102
+ )
103
+
104
+ # Tab 2: Training
105
+ run_training_btn.click(
106
+ fn=lambda: (
107
+ gr.update(visible=False), # Hide Run
108
+ gr.update(interactive=False), # Disable Reset
109
+ gr.update(visible=True) # Show Stop
110
+ ),
111
+ outputs=[run_training_btn, clear_reload_btn, stop_training_btn]
112
+ ).then(
113
+ fn=engine.run_training_pipeline,
114
+ inputs=[param_epochs, param_lr, param_test_size, param_shuffle],
115
+ outputs=[output_display, loss_plot],
116
+ ).then(
117
+ fn=lambda: (
118
+ gr.update(visible=True), # Show Run
119
+ gr.update(interactive=True), # Enable Reset
120
+ gr.update(visible=False) # Hide Stop
121
+ ),
122
+ outputs=[run_training_btn, clear_reload_btn, stop_training_btn]
123
+ )
124
+
125
+ # Tab 2: Stop
126
+ stop_training_btn.click(
127
+ fn=lambda: (engine.trigger_stop(), "Stopping...")[1],
128
+ outputs=None
129
+ )
130
+
131
+ # Tab 2: Reset
132
+ clear_reload_btn.click(
133
+ fn=engine.refresh_data_and_model,
134
+ outputs=[output_display]
135
+ )
136
+
137
+ # Tab 3: Download
138
+ def handle_zip():
139
+ path = engine.get_zip_path()
140
+ if path:
141
+ return gr.update(value=path, visible=True)
142
+ return gr.update(value=None, visible=False)
143
+
144
+ zip_btn.click(
145
+ fn=handle_zip,
146
+ outputs=[download_file]
147
+ )
148
+
149
+ return demo
utils.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import csv
3
+ import json
4
+ import shutil
5
+ from typing import Optional, List, Any
6
+ from huggingface_hub import login
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ from tools import DEFAULT_SYSTEM_MSG
9
+ # Note: We do NOT import TOOLS here anymore to avoid stale data
10
+
11
+ def authenticate_hf(token: Optional[str]) -> None:
12
+ """Logs into the Hugging Face Hub."""
13
+ if token:
14
+ print("Logging into Hugging Face Hub...")
15
+ login(token=token)
16
+ else:
17
+ print("Skipping Hugging Face login: HF_TOKEN not set.")
18
+
19
+ def load_model_and_tokenizer(model_name: str):
20
+ print(f"Loading Transformer model: {model_name}")
21
+ try:
22
+ target_model = model_name
23
+ if model_name.startswith("..") and not os.path.exists(model_name):
24
+ print(f"Warning: Local path {model_name} not found. Falling back to default hub model.")
25
+ target_model = "google/gemma-2b-it"
26
+
27
+ tokenizer = AutoTokenizer.from_pretrained(target_model)
28
+ model = AutoModelForCausalLM.from_pretrained(target_model)
29
+ print("Model loaded successfully.")
30
+ return model, tokenizer
31
+ except Exception as e:
32
+ print(f"Error loading Transformer model {target_model}: {e}")
33
+ raise e
34
+
35
+ # UPDATED: Now accepts tools_list as an argument
36
+ def create_conversation_format(sample, tools_list):
37
+ """Formats a dataset row into the conversational format required for SFT."""
38
+ try:
39
+ tool_args = json.loads(sample["tool_arguments"])
40
+ except (json.JSONDecodeError, TypeError):
41
+ tool_args = {}
42
+
43
+ return {
44
+ "messages": [
45
+ {"role": "developer", "content": DEFAULT_SYSTEM_MSG},
46
+ {"role": "user", "content": sample["user_content"]},
47
+ {"role": "assistant", "tool_calls": [{"type": "function", "function": {"name": sample["tool_name"], "arguments": tool_args}}]},
48
+ ],
49
+ "tools": tools_list # Injects the dynamic tools
50
+ }
51
+
52
+ def parse_csv_dataset(file_path: str) -> List[List[str]]:
53
+ """Parses an uploaded CSV file."""
54
+ dataset = []
55
+ if not file_path:
56
+ return dataset
57
+
58
+ with open(file_path, 'r', newline='', encoding='utf-8') as f:
59
+ reader = csv.reader(f)
60
+ try:
61
+ header = next(reader)
62
+ if not (header and "user_content" in header[0].lower()):
63
+ f.seek(0)
64
+ except StopIteration:
65
+ return dataset
66
+
67
+ for row in reader:
68
+ if len(row) >= 3:
69
+ dataset.append([s.strip() for s in row[:3]])
70
+ return dataset
71
+
72
+ def zip_directory(source_dir: str, output_name_base: str) -> str:
73
+ """Zips a directory."""
74
+ return shutil.make_archive(
75
+ base_name=output_name_base,
76
+ format='zip',
77
+ root_dir=source_dir,
78
+ )