from datasets import load_dataset from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling # Se asume que peft, tokenizer, base_model, etc., están definidos globalmente. def train_lora(epochs, batch_size, learning_rate, model_to_train, tokenizer, dataset_path, lora_path): """ Ejecuta el entrenamiento del modelo LoRA de forma eficiente. :param model_to_train: El modelo PEFT (LoRA) ya envuelto y listo para entrenar. :param tokenizer: El tokenizer cargado. :param dataset_path: Ruta al archivo JSON del dataset. :param lora_path: Ruta donde se guardarán los adaptadores LoRA. """ try: # 1. Carga del Dataset (Asegúrate de que 'tu_dataset.json' exista) print(f"🔄 Cargando dataset desde: {dataset_path}") dataset = load_dataset("json", data_files=dataset_path) # 2. Tokenización eficiente def tokenize_fn(example): return tokenizer( example["prompt"] + example["completion"], truncation=True, padding="max_length", max_length=256, ) # 🟢 MEJORA: batched=True para tokenización más rápida tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=dataset["train"].column_names) # 3. Preparación final de los datos # No es estrictamente necesario si ya se usa DataCollator, pero es buena práctica. tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"]) # El DataCollatorForLanguageModeling se encarga de clonar 'input_ids' a 'labels' data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) # 4. Argumentos de Entrenamiento training_args = TrainingArguments( output_dir=lora_path, per_device_train_batch_size=int(batch_size), num_train_epochs=float(epochs), # 🟢 MEJORA: Usar float para aceptar épocas decimales learning_rate=float(learning_rate), # 🟢 MEJORA: Usar float save_total_limit=1, logging_steps=10, push_to_hub=False ) # 5. Inicialización y Entrenamiento del Trainer trainer = Trainer( # 🟢 CORRECCIÓN CRÍTICA: Debe usarse el modelo PEFT (LoRA) ya envuelto model=model_to_train, args=training_args, train_dataset=tokenized["train"], data_collator=data_collator, ) print("🚀 Iniciando entrenamiento...") trainer.train() # 6. Guardado Correcto de los Adaptadores # 🟢 CORRECCIÓN CRÍTICA: Guardar solo los adaptadores LoRA (peft) model_to_train.save_pretrained(lora_path) tokenizer.save_pretrained(lora_path) return f"✅ Entrenamiento completado. Adaptadores LoRA guardados en {lora_path}" except Exception as e: return f"❌ Error durante el entrenamiento: {e}"