Spaces:
Runtime error
Runtime error
| from datasets import load_dataset | |
| from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling | |
| # Se asume que peft, tokenizer, base_model, etc., están definidos globalmente. | |
| def train_lora(epochs, batch_size, learning_rate, model_to_train, tokenizer, dataset_path, lora_path): | |
| """ | |
| Ejecuta el entrenamiento del modelo LoRA de forma eficiente. | |
| :param model_to_train: El modelo PEFT (LoRA) ya envuelto y listo para entrenar. | |
| :param tokenizer: El tokenizer cargado. | |
| :param dataset_path: Ruta al archivo JSON del dataset. | |
| :param lora_path: Ruta donde se guardarán los adaptadores LoRA. | |
| """ | |
| try: | |
| # 1. Carga del Dataset (Asegúrate de que 'tu_dataset.json' exista) | |
| print(f"🔄 Cargando dataset desde: {dataset_path}") | |
| dataset = load_dataset("json", data_files=dataset_path) | |
| # 2. Tokenización eficiente | |
| def tokenize_fn(example): | |
| return tokenizer( | |
| example["prompt"] + example["completion"], | |
| truncation=True, | |
| padding="max_length", | |
| max_length=256, | |
| ) | |
| # 🟢 MEJORA: batched=True para tokenización más rápida | |
| tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=dataset["train"].column_names) | |
| # 3. Preparación final de los datos | |
| # No es estrictamente necesario si ya se usa DataCollator, pero es buena práctica. | |
| tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"]) | |
| # El DataCollatorForLanguageModeling se encarga de clonar 'input_ids' a 'labels' | |
| data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) | |
| # 4. Argumentos de Entrenamiento | |
| training_args = TrainingArguments( | |
| output_dir=lora_path, | |
| per_device_train_batch_size=int(batch_size), | |
| num_train_epochs=float(epochs), # 🟢 MEJORA: Usar float para aceptar épocas decimales | |
| learning_rate=float(learning_rate), # 🟢 MEJORA: Usar float | |
| save_total_limit=1, | |
| logging_steps=10, | |
| push_to_hub=False | |
| ) | |
| # 5. Inicialización y Entrenamiento del Trainer | |
| trainer = Trainer( | |
| # 🟢 CORRECCIÓN CRÍTICA: Debe usarse el modelo PEFT (LoRA) ya envuelto | |
| model=model_to_train, | |
| args=training_args, | |
| train_dataset=tokenized["train"], | |
| data_collator=data_collator, | |
| ) | |
| print("🚀 Iniciando entrenamiento...") | |
| trainer.train() | |
| # 6. Guardado Correcto de los Adaptadores | |
| # 🟢 CORRECCIÓN CRÍTICA: Guardar solo los adaptadores LoRA (peft) | |
| model_to_train.save_pretrained(lora_path) | |
| tokenizer.save_pretrained(lora_path) | |
| return f"✅ Entrenamiento completado. Adaptadores LoRA guardados en {lora_path}" | |
| except Exception as e: | |
| return f"❌ Error durante el entrenamiento: {e}" | |