| import os |
| import torch |
| import intel_extension_for_pytorch as ipex |
| from datasets import load_dataset |
| from huggingface_hub import login |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| TrainingArguments, |
| Trainer, |
| DataCollatorForLanguageModeling |
| ) |
|
|
| |
| |
| MODEL_ID = "google/gemma-3-270m" |
| REPO_NAME = "gemma-3-270m-dolly-fft" |
| MAX_SEQ_LENGTH = 512 |
|
|
| hf_token = os.getenv("HF_TOKEN") |
| if hf_token: |
| login(token=hf_token) |
| else: |
| print("Warning: HF_TOKEN not found. Model will only save locally.") |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| torch_dtype=torch.float32, |
| low_cpu_mem_usage=True |
| ) |
|
|
| |
| model.train() |
| model = ipex.optimize(model, dtype=torch.float32) |
|
|
| |
| dataset = load_dataset("databricks/databricks-dolly-15k", split="train") |
|
|
| def format_prompt(sample): |
| """Formats data into Gemma's chat template""" |
| user_text = sample['instruction'] |
| if sample['context']: |
| user_text += f"\nContext: {sample['context']}" |
| |
| return {"text": f"<|user|>\n{user_text}\n<|assistant|>\n{sample['response']}<|end|>"} |
|
|
| |
| dataset = dataset.map(format_prompt) |
| tokenized_dataset = dataset.map( |
| lambda x: tokenizer(x["text"], truncation=True, max_length=MAX_SEQ_LENGTH), |
| batched=True, |
| remove_columns=dataset.column_names |
| ) |
| data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) |
|
|
| |
| training_args = TrainingArguments( |
| output_dir=REPO_NAME, |
| use_cpu=True, |
| optim="adafactor", |
| learning_rate=1e-5, |
| weight_decay=0.01, |
| per_device_train_batch_size=1, |
| gradient_accumulation_steps=8, |
| num_train_epochs=1, |
| save_strategy="steps", |
| save_steps=100, |
| logging_steps=10, |
| push_to_hub=True if hf_token else False, |
| hub_model_id=REPO_NAME, |
| hub_strategy="end", |
| report_to="none" |
| ) |
|
|
| |
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=tokenized_dataset, |
| data_collator=data_collator, |
| ) |
|
|
| |
| print("--- Starting Full Fine-Tuning on CPU ---") |
| trainer.train() |
|
|
| print("--- Training Finished. Saving and Pushing ---") |
| trainer.save_model(f"./{REPO_NAME}") |
| tokenizer.save_pretrained(f"./{REPO_NAME}") |
|
|
| if hf_token: |
| trainer.push_to_hub() |
| print(f"Model pushed to: https://huggingface.co{REPO_NAME}") |
|
|