python_ai_coder / train.py
Percy3822's picture
Update train.py
cc4f041 verified
raw
history blame
2.32 kB
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
import os
import sys
print("πŸ”₯ Python AI training script started!", file=sys.stderr)
DATASET_PATH = "python_ai_dataset.jsonl"
MODEL_ID = "bigcode/starcoderbase-7b"
OUTPUT_DIR = "train_output"
# === Step 1: Check dataset ===
if not os.path.exists(DATASET_PATH):
print(f"❌ Dataset file not found: {DATASET_PATH}", file=sys.stderr)
sys.exit(1)
# === Step 2: Load dataset (first 10 samples for fast test) ===
try:
dataset = load_dataset("json", data_files=DATASET_PATH, split="train[:10]") # Load only 10 samples for testing
except Exception as e:
print(f"❌ Failed to load dataset: {e}", file=sys.stderr)
sys.exit(1)
# === Step 3: Load tokenizer and model ===
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True)
except Exception as e:
print(f"❌ Failed to load model/tokenizer: {e}", file=sys.stderr)
sys.exit(1)
# === Step 4: Preprocess data ===
def tokenize(example):
return tokenizer(example["prompt"] + "\n" + example["completion"], truncation=True, max_length=512)
try:
tokenized_dataset = dataset.map(tokenize, remove_columns=["prompt", "completion"])
except Exception as e:
print(f"❌ Tokenization error: {e}", file=sys.stderr)
sys.exit(1)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# === Step 5: Training config ===
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
overwrite_output_dir=True,
per_device_train_batch_size=1,
num_train_epochs=1,
logging_dir="./logs",
logging_steps=1,
save_strategy="epoch",
save_total_limit=1,
fp16=False,
report_to="none"
)
# === Step 6: Train the model ===
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
tokenizer=tokenizer,
data_collator=data_collator
)
print("πŸš€ Starting training on 10 samples...", file=sys.stderr)
trainer.train()
# === Step 7: Save model ===
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print("βœ… Training finished and model saved!", file=sys.stderr)