AmorCoderAI-Train / Train.py
Andro0s's picture
Create Train.py
8819d2a verified
raw
history blame
2.57 kB
# ===============================
# AmorCoder AI - Entrenamiento LoRA Avanzado
# ===============================
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, TaskType
# -------------------------------
# 1️⃣ Modelo base
# -------------------------------
MODEL_NAME = "codellama/CodeLlama-7b-hf"
print("Cargando modelo base...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto",
torch_dtype=torch.float16
)
# -------------------------------
# 2️⃣ Configuración LoRA
# -------------------------------
print("Aplicando LoRA...")
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"], # módulos recomendados para LLMs
lora_dropout=0.05,
bias="none"
)
model = get_peft_model(model, lora_config)
# -------------------------------
# 3️⃣ Dataset
# -------------------------------
print("Cargando dataset...")
dataset = load_dataset("json", data_files={"train":"tu_dataset.json"}, split="train")
def preprocess(example):
prompt = f"# Instrucción:\n{example['instruction']}\n\n# Código:\n"
input_ids = tokenizer(prompt, truncation=True, max_length=512)["input_ids"]
labels = tokenizer(example['code'], truncation=True, max_length=512)["input_ids"]
return {"input_ids": input_ids, "labels": labels}
dataset = dataset.map(preprocess)
# -------------------------------
# 4️⃣ Argumentos de entrenamiento
# -------------------------------
training_args = TrainingArguments(
output_dir="./lora_codellama",
per_device_train_batch_size=1, # usar gradient accumulation para batches grandes
gradient_accumulation_steps=4,
num_train_epochs=3, # puedes subir a 5 para más precisión
learning_rate=3e-4,
fp16=True,
logging_steps=10,
save_steps=50,
save_total_limit=3,
report_to="none", # para no depender de wandb u otro tracker
)
# -------------------------------
# 5️⃣ Entrenamiento
# -------------------------------
trainer = Trainer(
model=model,
train_dataset=dataset,
args=training_args
)
print("Entrenando LoRA...")
trainer.train()
# -------------------------------
# 6️⃣ Guardar pesos
# -------------------------------
model.save_pretrained("lora_codellama")
print("✅ Entrenamiento completado. Pesos guardados en 'lora_codellama'.")