# /// script # dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "datasets", "transformers", "accelerate", "jinja2"] # /// from datasets import load_dataset from peft import LoraConfig from trl import SFTTrainer, SFTConfig from transformers import AutoTokenizer, TrainingArguments import trl import transformers import trackio import os import inspect print(f"🚀 Starting FunctionGemma 270M Fine-tuning (V4 - Diagnostic)") print(f"📦 TRL Version: {trl.__version__}") print(f"📦 Transformers Version: {transformers.__version__}") model_id = "google/functiongemma-270m-it" tokenizer = AutoTokenizer.from_pretrained(model_id) # Load dataset dataset = load_dataset("epinfomax/vn-function-calling-dataset", split="train") def format_conversation(example): text = tokenizer.apply_chat_template( example["messages"], tools=example["tools"], tokenize=False, add_generation_prompt=False ) return {"text": text} print("🔄 Pre-processing dataset with chat template...") dataset = dataset.map(format_conversation, remove_columns=dataset.column_names) # Training configuration # Trying max_seq_length again but checking if it exists in SFTConfig first sft_config_args = { "dataset_text_field": "text", "output_dir": "vn-function-gemma-270m-finetuned", "push_to_hub": True, "hub_model_id": "epinfomax/vn-function-gemma-270m-finetuned", "hub_strategy": "every_save", "num_train_epochs": 5, "per_device_train_batch_size": 4, "gradient_accumulation_steps": 4, "learning_rate": 5e-5, "logging_steps": 5, "save_strategy": "steps", "save_steps": 50, "report_to": "trackio", "project": "vn-function-calling", "run_name": "function-gemma-270m-v4-diag" } # Check which parameter to use sft_fields = SFTConfig.__dataclass_fields__ if "max_seq_length" in sft_fields: print("✅ Using max_seq_length in SFTConfig") sft_config_args["max_seq_length"] = 1024 elif "max_length" in sft_fields: print("✅ Using max_length in SFTConfig") sft_config_args["max_length"] = 1024 else: print("⚠️ Neither max_seq_length nor max_length found in SFTConfig fields!") print("Fields:", list(sft_fields.keys())) config = SFTConfig(**sft_config_args) # Initialize and train print("🎯 Initializing SFTTrainer...") trainer_kwargs = { "model": model_id, "train_dataset": dataset, "peft_config": peft_config, "args": config, } # Check SFTTrainer init signature trainer_params = inspect.signature(SFTTrainer.__init__).parameters if "max_seq_length" in trainer_params and "max_seq_length" not in sft_config_args: print("✅ Adding max_seq_length to SFTTrainer") trainer_kwargs["max_seq_length"] = 1024 peft_config = LoraConfig( r=16, lora_alpha=32, target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], task_type="CAUSAL_LM", ) trainer_kwargs["peft_config"] = peft_config trainer = SFTTrainer(**trainer_kwargs) trainer.train() trainer.push_to_hub() print("✅ Training complete and pushed to Hub!")