File size: 20,912 Bytes
6f09d40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
#!/usr/bin/env python3
"""
FunctionGemma SFT fine-tuning script.

Runs TRL SFTTrainer for FunctionGemma with two modes:
  1) LoRA (recommended): faster, lower memory, less overfit
  2) Full-parameter: higher cost, maximal capacity

Usage:
    # LoRA (default)
    python -m src.train \
        --model_path /path/to/model \
        --dataset_path ./data/training_data.json \
        --bf16
    
    # Full-parameter
    python -m src.train \
        --model_path /path/to/model \
        --dataset_path ./data/training_data.json \
        --no-use-lora \
        --bf16
"""

import os
import json
import argparse
import logging
from datetime import datetime
from pathlib import Path
from typing import Optional

import torch
from datasets import Dataset, load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
from trl import SFTTrainer, SFTConfig

# Paths and logging
PROJECT_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_DATA_PATH = PROJECT_ROOT / "data" / "training_data.json"
DEFAULT_OUTPUT_DIR = PROJECT_ROOT / "runs"

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


def parse_args():
    """Parse CLI arguments."""
    parser = argparse.ArgumentParser(description="FunctionGemma SFT fine-tuning (LoRA / full)")
    
    # Model
    parser.add_argument(
        "--model_path",
        type=str,
        default="google/functiongemma-270m-it",
        help="Model path or HF model id"
    )
    parser.add_argument(
        "--tokenizer_path",
        type=str,
        default=None,
        help="Tokenizer path (defaults to model_path)"
    )
    
    # Dataset
    parser.add_argument(
        "--dataset_path",
        type=str,
        default=str(DEFAULT_DATA_PATH),
        help="Training dataset path"
    )
    parser.add_argument(
        "--val_split",
        type=float,
        default=0.1,
        help="Validation split ratio"
    )
    
    # Output
    parser.add_argument(
        "--output_dir",
        type=str,
        default=str(DEFAULT_OUTPUT_DIR),
        help="Root output directory"
    )
    parser.add_argument(
        "--run_name",
        type=str,
        default=None,
        help="Run name for logging and saving"
    )
    
    # Fine-tuning mode
    parser.add_argument(
        "--use_lora",
        action="store_true",
        default=True,
        help="Enable LoRA (recommended). Add --no-use-lora for full-parameter finetune"
    )
    parser.add_argument("--no-use-lora", dest="use_lora", action="store_false", help="Disable LoRA, run full-parameter finetune")
    
    # LoRA (only when use_lora=True)
    parser.add_argument("--lora_r", type=int, default=16, help="LoRA rank")
    parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha")
    parser.add_argument("--lora_dropout", type=float, default=0.05, help="LoRA dropout")
    parser.add_argument(
        "--target_modules",
        type=str,
        nargs="+",
        default=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        help="Target modules for LoRA"
    )
    
    # Training (aligned with FunctionGemma guidance)
    parser.add_argument("--num_train_epochs", type=int, default=6, help="Training epochs (official rec: 8)")
    parser.add_argument("--max_steps", type=int, default=-1, help="Max training steps (-1 to use epochs)")
    parser.add_argument("--per_device_train_batch_size", type=int, default=4, help="Train batch size per device")
    parser.add_argument("--per_device_eval_batch_size", type=int, default=2, help="Eval batch size")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Grad accumulation steps")
    parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate")
    parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay")
    parser.add_argument("--warmup_ratio", type=float, default=0.0, help="Warmup ratio (constant scheduler usually skips warmup)")
    parser.add_argument("--max_seq_length", type=int, default=2048, help="Max sequence length (model supports up to 32768)")
    parser.add_argument("--lr_scheduler_type", type=str, default="constant", help="LR scheduler type (default constant)")
    
    # Precision & optimization
    parser.add_argument("--bf16", action="store_true", help="Use BF16")
    parser.add_argument("--fp16", action="store_true", help="Use FP16")
    parser.add_argument("--use_4bit", action="store_true", help="Enable 4-bit quant (QLoRA)")
    parser.add_argument("--use_8bit", action="store_true", help="Enable 8-bit quant")
    parser.add_argument("--use_flash_attention", action="store_true", help="Enable Flash Attention 2")
    parser.add_argument("--gradient_checkpointing", action="store_true", help="Enable gradient checkpointing")
    
    # Logging & saving
    parser.add_argument("--logging_steps", type=int, default=10, help="Log every N steps")
    parser.add_argument("--save_steps", type=int, default=100, help="Save checkpoint every N steps")
    parser.add_argument("--eval_steps", type=int, default=100, help="Eval every N steps")
    parser.add_argument("--save_total_limit", type=int, default=3, help="Max checkpoints to keep")
    
    # Misc
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Resume from checkpoint")
    parser.add_argument("--push_to_hub", action="store_true", help="Push to Hugging Face Hub")
    parser.add_argument("--hub_model_id", type=str, default=None, help="Hub model id")
    
    return parser.parse_args()


def load_and_prepare_dataset(dataset_path: str, val_split: float = 0.1):
    """Load and normalize dataset structure for SFT."""
    logger.info(f"Loading dataset: {dataset_path}")
    
    # Load JSON dataset
    with open(dataset_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    logger.info(f"Dataset size: {len(data)} samples")
    
    # Normalize nested structures:
    # if an item has input.messages/tools, lift them to top-level
    processed_data = []
    for idx, item in enumerate(data):
        if 'input' in item and 'messages' in item['input']:
            # Deep copy messages to avoid mutating original
            messages = json.loads(json.dumps(item['input']['messages']))
            
            # Fix tool_calls formatting if present
            for msg in messages:
                if 'tool_calls' in msg and msg['tool_calls']:
                    for tc in msg['tool_calls']:
                        if 'function' in tc and 'arguments' in tc['function']:
                            args = tc['function']['arguments']
                            # ensure arguments is a string
                            if not isinstance(args, str):
                                tc['function']['arguments'] = json.dumps(args)
            
            # Convert expected field into assistant response if present
            if 'expected' in item and item['expected']:
                expected = item['expected']
                # If last message is not assistant, append one
                if messages[-1]['role'] != 'assistant':
                    # Decide between function call or refusal
                    function_name = expected.get('function_name')
                    arguments = expected.get('arguments')
                    response = expected.get('response', '')
                    
                    if function_name is not None and arguments is not None:
                        # Case 1: function call -> add tool_calls
                        arguments_str = json.dumps(arguments) if isinstance(arguments, dict) else str(arguments)
                        
                        assistant_msg = {
                            "role": "assistant",
                            "content": None,
                            "tool_calls": [{
                                "id": f"call_{hash(function_name + arguments_str) % 1000000}",  # generate unique id
                                "type": "function",
                                "function": {
                                    "name": function_name,
                                    "arguments": arguments_str
                                }
                            }]
                        }
                        messages.append(assistant_msg)
                        logger.debug(f"Added assistant tool_calls: {function_name}")
                    elif function_name is None and arguments is None and response:
                        # Case 2: refusal -> plain text response
                        assistant_msg = {
                            "role": "assistant",
                            "content": response
                        }
                        messages.append(assistant_msg)
                        logger.debug(f"Added assistant refusal response: {response[:50]}")
                    else:
                        logger.warning(f"Unknown expected format: {expected}")
            
            processed_item = {
                'messages': messages
            }
            
            # include tools if present
            if 'tools' in item['input']:
                processed_item['tools'] = item['input']['tools']
            
            # preserve id
            if 'id' in item:
                processed_item['id'] = item['id']
            
            # Final check: tool_calls arguments must be strings
            for msg in processed_item['messages']:
                if 'tool_calls' in msg and msg['tool_calls']:
                    for tc in msg['tool_calls']:
                        if 'function' in tc and 'arguments' in tc['function']:
                            if not isinstance(tc['function']['arguments'], str):
                                logger.error(f"Sample {idx} arguments not string: {type(tc['function']['arguments'])}")
                                tc['function']['arguments'] = json.dumps(tc['function']['arguments'])
            
            processed_data.append(processed_item)
            
        elif 'messages' in item:
            # Already proper format, just normalize tool_calls
            messages = json.loads(json.dumps(item['messages']))
            for msg in messages:
                if 'tool_calls' in msg and msg['tool_calls']:
                    for tc in msg['tool_calls']:
                        if 'function' in tc and 'arguments' in tc['function']:
                            if not isinstance(tc['function']['arguments'], str):
                                tc['function']['arguments'] = json.dumps(tc['function']['arguments'])
            item_copy = dict(item)
            item_copy['messages'] = messages
            processed_data.append(item_copy)
        else:
            logger.warning(f"Skip malformed item: {item.get('id', 'unknown')}")
    
    logger.info(f"Processed dataset size: {len(processed_data)}")
    
    # Validate format
    tool_calls_count = 0
    for item in processed_data:
        for msg in item['messages']:
            if 'tool_calls' in msg and msg['tool_calls']:
                tool_calls_count += 1
                for tc in msg['tool_calls']:
                    if 'function' in tc and 'arguments' in tc['function']:
                        if not isinstance(tc['function']['arguments'], str):
                            logger.error(f"Found non-string arguments: {type(tc['function']['arguments'])}")
    logger.info(f"Messages containing tool_calls: {tool_calls_count}")
    
    # Convert to Hugging Face Dataset
    dataset = Dataset.from_list(processed_data)
    
    # Split train/val
    if val_split > 0:
        dataset = dataset.train_test_split(test_size=val_split, seed=42)
        train_dataset = dataset['train']
        eval_dataset = dataset['test']
        logger.info(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
    else:
        train_dataset = dataset
        eval_dataset = None
        logger.info(f"Train: {len(train_dataset)}, no eval split")
    
    return train_dataset, eval_dataset


def get_quantization_config(use_4bit: bool, use_8bit: bool):
    """Build quantization config if requested."""
    if use_4bit:
        logger.info("Using 4-bit quantization (QLoRA)")
        return BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
        )
    elif use_8bit:
        logger.info("Using 8-bit quantization")
        return BitsAndBytesConfig(
            load_in_8bit=True,
        )
    return None


def load_model_and_tokenizer(args):
    """Load model and tokenizer."""
    logger.info(f"Loading model: {args.model_path}")
    
    tokenizer_path = args.tokenizer_path or args.model_path
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_path,
        trust_remote_code=True,
        padding_side="right",
    )
    
    # Ensure pad token exists
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    
    # Quantization config
    quantization_config = get_quantization_config(args.use_4bit, args.use_8bit)
    
    # Model kwargs
    model_kwargs = {
        "trust_remote_code": True,
        "device_map": "auto",
    }
    
    if quantization_config:
        model_kwargs["quantization_config"] = quantization_config
    
    # Precision
    if args.bf16 and not (args.use_4bit or args.use_8bit):
        model_kwargs["torch_dtype"] = torch.bfloat16
    elif args.fp16 and not (args.use_4bit or args.use_8bit):
        model_kwargs["torch_dtype"] = torch.float16
    
    # Flash Attention
    if args.use_flash_attention:
        model_kwargs["attn_implementation"] = "flash_attention_2"
        logger.info("Using Flash Attention 2")
    
    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        **model_kwargs
    )
    
    # Prepare for k-bit training when quantized
    if args.use_4bit or args.use_8bit:
        model = prepare_model_for_kbit_training(model)
    
    # Gradient checkpointing
    if args.gradient_checkpointing:
        model.gradient_checkpointing_enable()
        logger.info("Enabled gradient checkpointing")
    
    logger.info(f"Model parameters: {model.num_parameters():,}")
    
    return model, tokenizer


def get_lora_config(args):
    """Build LoRA config."""
    logger.info(f"LoRA config: r={args.lora_r}, alpha={args.lora_alpha}, dropout={args.lora_dropout}")
    logger.info(f"Target modules: {args.target_modules}")
    
    return LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        target_modules=args.target_modules,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
    )


def formatting_func(example):
    """
    Format function: pass data through for SFTTrainer.
    
    Dataset format:
    {
        "messages": [
            {"role": "developer", "content": "..."},
            {"role": "user", "content": "..."},
            {"role": "assistant", "tool_calls": [...]} or {"role": "assistant", "content": "..."}
        ],
        "tools": [...]
    }
    """
    # Return as-is; SFTTrainer applies chat template
    return example


def main():
    args = parse_args()
    
    # Set run name
    if args.run_name is None:
        args.run_name = f"functiongemma-lora-{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    
    # Create output directory
    output_dir = os.path.join(args.output_dir, args.run_name)
    os.makedirs(output_dir, exist_ok=True)
    
    logger.info("=" * 60)
    logger.info("FunctionGemma SFT LoRA training")
    logger.info("=" * 60)
    logger.info(f"Output dir: {output_dir}")
    
    # Save config
    config_path = os.path.join(output_dir, "training_config.json")
    with open(config_path, 'w') as f:
        json.dump(vars(args), f, indent=2)
    logger.info(f"Config saved to: {config_path}")
    
    # Load dataset
    train_dataset, eval_dataset = load_and_prepare_dataset(
        args.dataset_path,
        args.val_split
    )
    
    # Load model + tokenizer
    model, tokenizer = load_model_and_tokenizer(args)
    
    # Build LoRA config if enabled
    if args.use_lora:
        logger.info("=" * 60)
        logger.info("LoRA fine-tuning mode")
        logger.info("=" * 60)
        lora_config = get_lora_config(args)
    else:
        logger.info("=" * 60)
        logger.info("Full-parameter fine-tuning mode")
        logger.info("Warning: full fine-tuning needs more memory and time!")
        logger.info("=" * 60)
        lora_config = None
    
    # SFTTrainer config
    training_args = SFTConfig(
        output_dir=output_dir,
        run_name=args.run_name,
        
        # Sequence length / packing
        max_length=args.max_seq_length,
        packing=False,
        
        # Training
        num_train_epochs=args.num_train_epochs,
        max_steps=args.max_steps,
        per_device_train_batch_size=args.per_device_train_batch_size,
        per_device_eval_batch_size=args.per_device_eval_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        
        # Optimizer
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay,
        warmup_ratio=args.warmup_ratio,
        lr_scheduler_type=args.lr_scheduler_type,
        optim="adamw_torch_fused",
        
        # Precision
        bf16=args.bf16,
        fp16=args.fp16,
        
        # Logging / saving
        logging_steps=args.logging_steps,
        save_steps=args.save_steps,
        eval_steps=args.eval_steps if eval_dataset else None,
        eval_strategy="steps" if eval_dataset else "no",
        save_total_limit=args.save_total_limit,
        load_best_model_at_end=True if eval_dataset else False,
        
        # Misc
        seed=args.seed,
        report_to=["tensorboard"],
        
        # Hub
        push_to_hub=args.push_to_hub,
        hub_model_id=args.hub_model_id,
        
        # Gradient checkpointing
        gradient_checkpointing=args.gradient_checkpointing,
        gradient_checkpointing_kwargs={"use_reentrant": False} if args.gradient_checkpointing else None,
    )
    
    # Create SFTTrainer
    # Dataset should include 'messages' and 'tools'; SFTTrainer applies chat template automatically
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        processing_class=tokenizer,  # newer TRL uses processing_class instead of tokenizer
        peft_config=lora_config,
    )
    
    # Parameter stats
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    trainable_percentage = 100 * trainable_params / total_params if total_params > 0 else 0
    
    logger.info("=" * 60)
    logger.info("Model parameter stats:")
    logger.info(f"  Total params: {total_params:,}")
    logger.info(f"  Trainable params: {trainable_params:,}")
    logger.info(f"  Trainable ratio: {trainable_percentage:.2f}%")
    logger.info(f"  Mode: {'LoRA' if args.use_lora else 'Full fine-tune'}")
    logger.info("=" * 60)
    
    # Train
    logger.info("Start training...")
    
    if args.resume_from_checkpoint:
        trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
    else:
        trainer.train()
    
    # Save final model
    logger.info("Saving final model...")
    final_model_path = os.path.join(output_dir, "final_model")
    trainer.save_model(final_model_path)
    tokenizer.save_pretrained(final_model_path)
    
    logger.info("=" * 60)
    logger.info("Training done.")
    logger.info(f"Model saved at: {final_model_path}")
    
    if args.use_lora:
        # LoRA: also save adapter
        lora_path = os.path.join(output_dir, "lora_adapter")
        model.save_pretrained(lora_path)
        tokenizer.save_pretrained(lora_path)
        logger.info(f"LoRA adapter saved to: {lora_path}")
        logger.info("")
        logger.info("Usage:")
        logger.info(f"  1. LoRA adapter: {lora_path}")
        logger.info(f"  2. Merge adapters with your base model before inference")
    else:
        # Full fine-tune: final_model is ready to use
        logger.info("")
        logger.info("Usage:")
        logger.info(f"  Use model directly from: {final_model_path}")
    
    logger.info("=" * 60)


if __name__ == "__main__":
    main()