python_ai_coder / train.py
Percy3822's picture
Update train.py
01be04f verified
raw
history blame
1.67 kB
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
import os
# === CONFIG ===
DATASET_PATH = "python_ai_dataset.jsonl" # Your .jsonl file
MODEL_ID = "bigcode/starcoderbase-7b"
OUTPUT_DIR = "train_output"
# === Load Dataset ===
dataset = load_dataset("json", data_files=DATASET_PATH, split="train")
# === Load Tokenizer and Model ===
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True)
# === Preprocessing ===
def tokenize(example):
return tokenizer(example["prompt"] + "\n" + example["completion"],
truncation=True, max_length=512)
tokenized_dataset = dataset.map(tokenize, remove_columns=["prompt", "completion"])
# === Data Collator ===
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# === Training Arguments ===
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
overwrite_output_dir=True,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
num_train_epochs=2,
logging_dir="./logs",
logging_steps=10,
save_strategy="epoch",
save_total_limit=2,
fp16=True,
bf16=False,
report_to="none", # Prevent HF integration logs
)
# === Trainer ===
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
tokenizer=tokenizer,
data_collator=data_collator
)
# === Start Training ===
trainer.train()
# === Save Final Model ===
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)