Andro0s commited on
Commit
bc19ef1
·
verified ·
1 Parent(s): 3410ef1

Update Train.py

Browse files
Files changed (1) hide show
  1. Train.py +44 -81
Train.py CHANGED
@@ -1,81 +1,44 @@
1
- # ===============================
2
- # AmorCoder AI - Entrenamiento LoRA Avanzado
3
- # ===============================
4
- import torch
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
6
- from datasets import load_dataset
7
- from peft import LoraConfig, get_peft_model, TaskType
8
-
9
- # -------------------------------
10
- # 1️⃣ Modelo base
11
- # -------------------------------
12
- MODEL_NAME = "codellama/CodeLlama-7b-hf"
13
- print("Cargando modelo base...")
14
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
15
- model = AutoModelForCausalLM.from_pretrained(
16
- MODEL_NAME,
17
- device_map="auto",
18
- torch_dtype=torch.float16
19
- )
20
-
21
- # -------------------------------
22
- # 2️⃣ Configuración LoRA
23
- # -------------------------------
24
- print("Aplicando LoRA...")
25
- lora_config = LoraConfig(
26
- task_type=TaskType.CAUSAL_LM,
27
- r=16,
28
- lora_alpha=32,
29
- target_modules=["q_proj", "v_proj"], # módulos recomendados para LLMs
30
- lora_dropout=0.05,
31
- bias="none"
32
- )
33
- model = get_peft_model(model, lora_config)
34
-
35
- # -------------------------------
36
- # 3️⃣ Dataset
37
- # -------------------------------
38
- print("Cargando dataset...")
39
- dataset = load_dataset("json", data_files={"train":"tu_dataset.json"}, split="train")
40
-
41
- def preprocess(example):
42
- prompt = f"# Instrucción:\n{example['instruction']}\n\n# Código:\n"
43
- input_ids = tokenizer(prompt, truncation=True, max_length=512)["input_ids"]
44
- labels = tokenizer(example['code'], truncation=True, max_length=512)["input_ids"]
45
- return {"input_ids": input_ids, "labels": labels}
46
-
47
- dataset = dataset.map(preprocess)
48
-
49
- # -------------------------------
50
- # 4️⃣ Argumentos de entrenamiento
51
- # -------------------------------
52
- training_args = TrainingArguments(
53
- output_dir="./lora_codellama",
54
- per_device_train_batch_size=1, # usar gradient accumulation para batches grandes
55
- gradient_accumulation_steps=4,
56
- num_train_epochs=3, # puedes subir a 5 para más precisión
57
- learning_rate=3e-4,
58
- fp16=True,
59
- logging_steps=10,
60
- save_steps=50,
61
- save_total_limit=3,
62
- report_to="none", # para no depender de wandb u otro tracker
63
- )
64
-
65
- # -------------------------------
66
- # 5️⃣ Entrenamiento
67
- # -------------------------------
68
- trainer = Trainer(
69
- model=model,
70
- train_dataset=dataset,
71
- args=training_args
72
- )
73
-
74
- print("Entrenando LoRA...")
75
- trainer.train()
76
-
77
- # -------------------------------
78
- # 6️⃣ Guardar pesos
79
- # -------------------------------
80
- model.save_pretrained("lora_codellama")
81
- print("✅ Entrenamiento completado. Pesos guardados en 'lora_codellama'.")
 
1
+ def train_lora(epochs, batch_size, learning_rate):
2
+ try:
3
+ dataset = load_dataset("json", data_files=DATASET_PATH)
4
+
5
+ # Tokenización correcta
6
+ def tokenize_fn(example):
7
+ return tokenizer(
8
+ example["prompt"] + example["completion"],
9
+ truncation=True,
10
+ padding="max_length",
11
+ max_length=256,
12
+ )
13
+
14
+ tokenized = dataset.map(tokenize_fn, batched=False)
15
+
16
+ # Asegúrate que las columnas correctas estén
17
+ tokenized.set_format(type="torch", columns=["input_ids", "attention_mask"])
18
+
19
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
20
+
21
+ training_args = TrainingArguments(
22
+ output_dir=LORA_PATH,
23
+ per_device_train_batch_size=int(batch_size),
24
+ num_train_epochs=int(epochs),
25
+ learning_rate=learning_rate,
26
+ save_total_limit=1,
27
+ logging_steps=10,
28
+ push_to_hub=False
29
+ )
30
+
31
+ trainer = Trainer(
32
+ model=base_model,
33
+ args=training_args,
34
+ train_dataset=tokenized["train"],
35
+ data_collator=data_collator,
36
+ )
37
+
38
+ trainer.train()
39
+ base_model.save_pretrained(LORA_PATH)
40
+ tokenizer.save_pretrained(LORA_PATH)
41
+
42
+ return " Entrenamiento completado y guardado en ./lora_output"
43
+ except Exception as e:
44
+ return f"❌ Error durante el entrenamiento: {e}"