File size: 8,619 Bytes
7b02281
0462c81
2636834
8a66252
1527ee9
8a66252
a7bd5fa
0462c81
8a66252
0462c81
8a66252
 
 
a7bd5fa
8a66252
 
 
 
 
7b02281
3410ef1
8a66252
3410ef1
 
8a66252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0462c81
8a66252
 
1527ee9
 
413d0ff
 
8a66252
1527ee9
8a66252
 
413d0ff
8a66252
 
 
 
 
0462c81
8a66252
 
 
 
 
 
 
 
 
 
 
 
 
 
1527ee9
 
0462c81
1527ee9
413d0ff
8a66252
 
413d0ff
1527ee9
8a66252
 
 
0462c81
 
 
8a66252
0462c81
8a66252
0462c81
 
 
 
8a66252
 
 
 
413d0ff
8a66252
0462c81
1527ee9
0462c81
 
8a66252
0462c81
1527ee9
8a66252
 
1527ee9
8a66252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413d0ff
8a66252
413d0ff
8a66252
7b02281
0462c81
8a66252
0462c81
8a66252
 
 
0462c81
 
8a66252
 
 
0462c81
8a66252
 
 
 
 
 
 
0462c81
 
8a66252
 
 
0462c81
 
 
8a66252
0462c81
2636834
8a66252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import os
import gradio as gr
from huggingface_hub import login
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling, pipeline
from peft import get_peft_model, LoraConfig, TaskType, PeftModel

# ============================================================
# ⚙️ CONFIGURACIÓN GLOBAL
# ============================================================
BASE_MODEL = "bigcode/santacoder"  # Modelo a refinar
LORA_PATH = "./lora_output"        # Directorio para guardar los adaptadores
DATASET_PATH = "tu_dataset.json"   # ¡Asegúrate de que este archivo exista!

# Variables globales inicializadas como None
tokenizer = None
lora_model = None
tokenized_dataset = None
lora_generator = None

# ============================================================
# 🔐 AUTENTICACIÓN Y PRE-CARGA
# ============================================================

def setup_resources():
    """Carga y configura todos los recursos (modelo, tokenizer, dataset) una sola vez."""
    global tokenizer, lora_model, tokenized_dataset

    # 1. Autenticación
    hf_token = os.environ.get("HF_TOKEN")
    if hf_token:
        login(token=hf_token)
    else:
        print("⚠️ Token no encontrado. La app intentará correr sin autenticación de escritura.")

    # 2. Carga del Tokenizer y Modelo Base
    print("\n🔄 Cargando modelo y tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
    # Usa device_map="auto" para cargar el modelo de forma eficiente en la(s) GPU
    base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto") 

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # 3. Configuración y Aplicación LoRA (PEFT)
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=8,
        lora_alpha=32,
        lora_dropout=0.1,
        # 'c_proj' y 'c_attn' son comunes en modelos GPT/causales
        target_modules=["c_proj", "c_attn"], 
    )
    lora_model = get_peft_model(base_model, peft_config)
    print(f"✅ Modelo LoRA preparado. Parámetros entrenables: {lora_model.print_trainable_parameters()}")

    # 4. Carga y Tokenización del Dataset (para evitar errores de longitud)
    print("📚 Cargando y tokenizando dataset...")
    try:
        raw_dataset = load_dataset("json", data_files=DATASET_PATH)
        tokenized_dataset = raw_dataset.map(
            lambda e: tokenizer(
                e["prompt"] + e["completion"],
                truncation=True,
                padding="max_length",
                max_length=256 # Mantener esta longitud consistente para evitar errores
            ),
            batched=True,
            remove_columns=raw_dataset["train"].column_names
        )
        print("✅ Dataset tokenizado correctamente.")
    except Exception as e:
        tokenized_dataset = None
        print(f"❌ Error al cargar o tokenizar el dataset. El auto-entrenamiento fallará. {e}")


# ============================================================
# 🧩 FUNCIÓN DE ENTRENAMIENTO
# ============================================================
def train_lora(epochs=1, batch_size=2, learning_rate=5e-5):
    """Ejecuta el entrenamiento del modelo LoRA."""
    global lora_model, tokenized_dataset, lora_generator

    if tokenized_dataset is None or "train" not in tokenized_dataset:
        return "❌ Error: El dataset no pudo cargarse o está vacío. No se puede entrenar."

    try:
        # Re-inicializa el generador a None para que se recargue después del entrenamiento
        lora_generator = None 
        
        data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

        training_args = TrainingArguments(
            output_dir=LORA_PATH,
            per_device_train_batch_size=int(batch_size),
            num_train_epochs=float(epochs),
            learning_rate=float(learning_rate),
            save_total_limit=1,
            logging_steps=10,
            push_to_hub=False,
            # Desactiva la evaluación para simplificar el auto-entrenamiento
            disable_tqdm=True 
        )

        trainer = Trainer(
            model=lora_model, # Usa el modelo LoRA global
            args=training_args,
            train_dataset=tokenized_dataset["train"],
            data_collator=data_collator,
        )

        trainer.train()
        
        # Guardar solo los adaptadores LoRA (PEFT)
        lora_model.save_pretrained(LORA_PATH) 
        tokenizer.save_pretrained(LORA_PATH) 

        return "✅ Entrenamiento completado y adaptadores LoRA guardados en **./lora_output**"
    except Exception as e:
        return f"❌ Error durante el entrenamiento: {e}"

# ============================================================
# 🤖 FUNCIÓN DE GENERACIÓN (INFERENCIA)
# ============================================================
def generate_text(prompt_text):
    """Genera texto usando el modelo base + adaptadores LoRA."""
    global lora_generator, lora_model

    try:
        # Carga el generador SOLO la primera vez o después del entrenamiento
        if lora_generator is None:
            
            # Cargar el modelo base limpio (sin los adaptadores LoRA)
            base_model_gen = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto")
            
            # Aplicar los adaptadores guardados
            if os.path.exists(LORA_PATH):
                model_with_lora = PeftModel.from_pretrained(base_model_gen, LORA_PATH)
            else:
                # Si no hay adaptadores entrenados, usa el modelo base inicial
                model_with_lora = lora_model if lora_model else base_model_gen
                
            # Fusionar el modelo base y los adaptadores para una inferencia más rápida
            final_model = model_with_lora.merge_and_unload()

            lora_generator = pipeline("text-generation", model=final_model, tokenizer=tokenizer)

        output = lora_generator(prompt_text, max_new_tokens=100, temperature=0.7, top_p=0.9)
        return output[0]["generated_text"]
        
    except Exception as e:
        return f"❌ Error generando texto (Asegúrate de que el modelo base y/o LoRA estén cargados): {e}"

# ============================================================
# 💻 INTERFAZ GRADIO
# ============================================================
with gr.Blocks(title="AmorCoderAI - LoRA") as demo:
    gr.Markdown("# 💙 AmorCoderAI - Entrenamiento y Pruebas LoRA")
    gr.Markdown(f"Modelo base: `{BASE_MODEL}`. Adaptadores guardados en `{LORA_PATH}`.")

    with gr.Tab("🧠 Entrenar"):
        gr.Markdown("--- **¡CUIDADO!** El entrenamiento es lento y consume muchos recursos. ---")
        epochs = gr.Number(value=1, label="Épocas", precision=0)
        batch_size = gr.Number(value=2, label="Tamaño de lote (ajusta según tu RAM/VRAM)", precision=0)
        learning_rate = gr.Number(value=5e-5, label="Tasa de aprendizaje")
        train_button = gr.Button("🚀 Iniciar entrenamiento Manual")
        train_output = gr.Textbox(label="Resultado del Entrenamiento Manual")
        train_button.click(
            train_lora, 
            inputs=[epochs, batch_size, learning_rate], 
            outputs=train_output
        )

    with gr.Tab("✨ Probar modelo"):
        prompt = gr.Textbox(label="Escribe código (ej: 'def bubble_sort(arr):')", lines=4)
        generate_button = gr.Button("💬 Generar código")
        output_box = gr.Textbox(label="Salida generada", lines=10)
        generate_button.click(generate_text, inputs=prompt, outputs=output_box)

# ============================================================
# 🚀 LANZAR APP Y AUTO-ENTRENAMIENTO
# ============================================================
if __name__ == "__main__":
    # 1. Cargar recursos
    setup_resources()
    
    # 2. AUTO-ENTRENAMIENTO (¡El código se 'autocorre' aquí!)
    print("\n=============================================")
    print("🤖 INICIANDO AUTO-ENTRENAMIENTO...")
    print("=============================================")
    
    # Parámetros de auto-entrenamiento: 1 época, batch size 2, LR 5e-5
    auto_train_result = train_lora(epochs=1, batch_size=2, learning_rate=5e-5) 
    
    print(f"\nFIN DEL AUTO-ENTRENAMIENTO: {auto_train_result}")
    
    # 3. Lanzar la Interfaz Gradio
    print("\n=============================================")
    print("💻 LANZANDO INTERFAZ GRADIO")
    print("=============================================")
    demo.launch()