AmorCoderAI-Train / Train.py
Andro0s's picture
Update Train.py
8eae154 verified
raw
history blame
3 kB
from datasets import load_dataset
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
# Se asume que peft, tokenizer, base_model, etc., están definidos globalmente.
def train_lora(epochs, batch_size, learning_rate, model_to_train, tokenizer, dataset_path, lora_path):
"""
Ejecuta el entrenamiento del modelo LoRA de forma eficiente.
:param model_to_train: El modelo PEFT (LoRA) ya envuelto y listo para entrenar.
:param tokenizer: El tokenizer cargado.
:param dataset_path: Ruta al archivo JSON del dataset.
:param lora_path: Ruta donde se guardarán los adaptadores LoRA.
"""
try:
# 1. Carga del Dataset (Asegúrate de que 'tu_dataset.json' exista)
print(f"🔄 Cargando dataset desde: {dataset_path}")
dataset = load_dataset("json", data_files=dataset_path)
# 2. Tokenización eficiente
def tokenize_fn(example):
return tokenizer(
example["prompt"] + example["completion"],
truncation=True,
padding="max_length",
max_length=256,
)
# 🟢 MEJORA: batched=True para tokenización más rápida
tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=dataset["train"].column_names)
# 3. Preparación final de los datos
# No es estrictamente necesario si ya se usa DataCollator, pero es buena práctica.
tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
# El DataCollatorForLanguageModeling se encarga de clonar 'input_ids' a 'labels'
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# 4. Argumentos de Entrenamiento
training_args = TrainingArguments(
output_dir=lora_path,
per_device_train_batch_size=int(batch_size),
num_train_epochs=float(epochs), # 🟢 MEJORA: Usar float para aceptar épocas decimales
learning_rate=float(learning_rate), # 🟢 MEJORA: Usar float
save_total_limit=1,
logging_steps=10,
push_to_hub=False
)
# 5. Inicialización y Entrenamiento del Trainer
trainer = Trainer(
# 🟢 CORRECCIÓN CRÍTICA: Debe usarse el modelo PEFT (LoRA) ya envuelto
model=model_to_train,
args=training_args,
train_dataset=tokenized["train"],
data_collator=data_collator,
)
print("🚀 Iniciando entrenamiento...")
trainer.train()
# 6. Guardado Correcto de los Adaptadores
# 🟢 CORRECCIÓN CRÍTICA: Guardar solo los adaptadores LoRA (peft)
model_to_train.save_pretrained(lora_path)
tokenizer.save_pretrained(lora_path)
return f"✅ Entrenamiento completado. Adaptadores LoRA guardados en {lora_path}"
except Exception as e:
return f"❌ Error durante el entrenamiento: {e}"