AmorCoderAI-Train / Train.py
Andro0s's picture
Update Train.py
bc19ef1 verified
raw
history blame
1.48 kB
def train_lora(epochs, batch_size, learning_rate):
try:
dataset = load_dataset("json", data_files=DATASET_PATH)
# Tokenización correcta
def tokenize_fn(example):
return tokenizer(
example["prompt"] + example["completion"],
truncation=True,
padding="max_length",
max_length=256,
)
tokenized = dataset.map(tokenize_fn, batched=False)
# Asegúrate que las columnas correctas estén
tokenized.set_format(type="torch", columns=["input_ids", "attention_mask"])
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
training_args = TrainingArguments(
output_dir=LORA_PATH,
per_device_train_batch_size=int(batch_size),
num_train_epochs=int(epochs),
learning_rate=learning_rate,
save_total_limit=1,
logging_steps=10,
push_to_hub=False
)
trainer = Trainer(
model=base_model,
args=training_args,
train_dataset=tokenized["train"],
data_collator=data_collator,
)
trainer.train()
base_model.save_pretrained(LORA_PATH)
tokenizer.save_pretrained(LORA_PATH)
return "✅ Entrenamiento completado y guardado en ./lora_output"
except Exception as e:
return f"❌ Error durante el entrenamiento: {e}"