ceperaltab's picture
Upload folder using huggingface_hub
6c0db32 verified
raw
history blame
3.44 kB
import os
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
)
from peft import LoraConfig
from trl import SFTTrainer
# --- CONFIGURATION ---
# Base model: Using a quantized Llama 3 or Mistral is recommended for consumer GPUs.
# Ensure you have access to the model on Hugging Face (might need login).
MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
DATASET_NAME = "ceperaltab/elixir-golden-dataset"
OUTPUT_DIR = "elixir-model-adapter"
def main():
print(f"Loading dataset from {DATASET_NAME}...")
# 1. Load Dataset
try:
# Load directly from HF Hub
dataset = load_dataset(DATASET_NAME, split="train")
except Exception as e:
print(f"Error loading dataset: {e}")
return
# 2. Quantization Config (4-bit for memory efficiency)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
print(f"Loading base model: {MODEL_NAME}...")
# 3. Load Model
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
# 4. Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Critical for fp16 training
# 5. LoRA Config (Parameter Efficient Fine-Tuning)
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)
# 6. Formatting Function for Chat Dataset
# Converts {"messages": [...]} into the model's expected prompt format
def formatting_prompts_func(examples):
output_texts = []
for messages in examples['messages']:
# Apply chat template (e.g., <|begin_of_text|><|start_header_id|>user...)
# We don't tokenize yet, SFTTrainer handles it
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
output_texts.append(text)
return output_texts
print("Starting SFTTrainer setup...")
# 7. Trainer
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=peft_config,
formatting_func=formatting_prompts_func,
max_seq_length=2048,
tokenizer=tokenizer,
args=TrainingArguments(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=2,
gradient_accumulation_steps=4, # Simulate larger batch size
learning_rate=2e-4,
logging_steps=10,
num_train_epochs=1,
optim="paged_adamw_32bit",
fp16=True,
group_by_length=True,
save_strategy="epoch",
report_to="none", # Change to "wandb" if desired
push_to_hub=True,
hub_model_id=f"ceperaltab/{OUTPUT_DIR}", # Pushes to your namespace
),
)
print("Starting training...")
trainer.train()
print(f"Saving model to {OUTPUT_DIR}...")
trainer.save_model(OUTPUT_DIR)
print("Done!")
if __name__ == "__main__":
main()