File size: 8,901 Bytes
7b02281
0462c81
2636834
8a66252
1527ee9
8a66252
a7bd5fa
0462c81
8a66252
0462c81
932f265
 
 
 
 
a7bd5fa
932f265
8a66252
 
 
 
7b02281
3410ef1
932f265
3410ef1
 
8a66252
 
 
 
932f265
8a66252
 
 
 
 
 
 
932f265
8a66252
 
 
 
 
 
 
 
 
 
 
 
932f265
8a66252
932f265
8a66252
 
 
932f265
 
0462c81
8a66252
932f265
8a66252
1527ee9
 
413d0ff
 
932f265
1527ee9
8a66252
 
413d0ff
8a66252
 
 
932f265
8a66252
0462c81
8a66252
 
 
932f265
8a66252
 
 
932f265
8a66252
 
 
 
932f265
8a66252
 
1527ee9
 
0462c81
1527ee9
413d0ff
8a66252
 
413d0ff
1527ee9
8a66252
0462c81
 
 
8a66252
0462c81
8a66252
0462c81
 
 
 
8a66252
932f265
8a66252
 
413d0ff
932f265
0462c81
1527ee9
0462c81
 
8a66252
0462c81
1527ee9
8a66252
932f265
1527ee9
8a66252
932f265
8a66252
932f265
8a66252
 
932f265
8a66252
 
 
932f265
 
8a66252
932f265
 
 
 
 
8a66252
 
932f265
 
413d0ff
8a66252
413d0ff
932f265
7b02281
0462c81
8a66252
0462c81
8a66252
 
932f265
0462c81
932f265
 
8a66252
932f265
0462c81
932f265
8a66252
932f265
8a66252
 
 
 
 
0462c81
 
932f265
8a66252
 
0462c81
 
 
932f265
0462c81
2636834
932f265
8a66252
 
932f265
8a66252
932f265
8a66252
 
932f265
8a66252
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import os
import gradio as gr
from huggingface_hub import login
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling, pipeline
from peft import get_peft_model, LoraConfig, TaskType, PeftModel

# ============================================================
# ⚙️ CONFIGURACIÓN GLOBAL
# ============================================================
# Modelo base para generación de código
BASE_MODEL = "bigcode/santacoder"  
LORA_PATH = "./lora_output"        # Directorio para guardar los adaptadores LoRA
DATASET_PATH = "tu_dataset.json"   # ¡Cambia esto por el nombre de tu archivo JSON!
MAX_TOKEN_LENGTH = 256             # Longitud de secuencia uniforme (corrige errores de tamaño)

# Variables globales para acceso a recursos
tokenizer = None
lora_model = None
tokenized_dataset = None
lora_generator = None

# ============================================================
# 🔐 AUTENTICACIÓN Y PRE-CARGA DE RECURSOS (SINGLETON)
# ============================================================

def setup_resources():
    """Carga y configura todos los recursos (modelo, tokenizer, dataset) una sola vez."""
    global tokenizer, lora_model, tokenized_dataset

    # 1. Autenticación con Hugging Face
    hf_token = os.environ.get("HF_TOKEN")
    if hf_token:
        login(token=hf_token)
    else:
        print("⚠️ Token no encontrado. La app intentará correr sin autenticación de escritura.")

    # 2. Carga del Tokenizer y Modelo Base
    print("\n🔄 Cargando modelo y tokenizer una sola vez...")
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
    base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto") 

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # 3. Configuración y Aplicación LoRA (PEFT)
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=8,
        lora_alpha=32,
        lora_dropout=0.1,
        target_modules=["c_proj", "c_attn"], # Adaptado para modelos causales tipo GPT/Santacoder
    )
    # Envuelve el modelo base con la configuración LoRA
    lora_model = get_peft_model(base_model, peft_config)
    print(f"✅ Modelo LoRA preparado. Parámetros entrenables: {lora_model.print_trainable_parameters()}")

    # 4. Carga y Tokenización del Dataset (para eficiencia)
    print("📚 Cargando y tokenizando dataset (esto solo se hace una vez)...")
    try:
        raw_dataset = load_dataset("json", data_files=DATASET_PATH)
        # Tokenización rápida con batched=True
        tokenized_dataset = raw_dataset.map(
            lambda e: tokenizer(
                e["prompt"] + e["completion"],
                truncation=True,
                padding="max_length",
                max_length=MAX_TOKEN_LENGTH
            ),
            batched=True,
            remove_columns=raw_dataset["train"].column_names
        )
        print("✅ Dataset tokenizado correctamente.")
    except Exception as e:
        tokenized_dataset = None
        print(f"❌ Error al cargar o tokenizar el dataset. Asegúrate que '{DATASET_PATH}' exista. {e}")


# ============================================================
# 🧩 FUNCIÓN DE ENTRENAMIENTO
# ============================================================
def train_lora(epochs, batch_size, learning_rate):
    """Ejecuta el entrenamiento del modelo LoRA."""
    global lora_model, tokenized_dataset, lora_generator

    # Verifica si el dataset está disponible
    if tokenized_dataset is None or "train" not in tokenized_dataset:
        return "❌ Error: El dataset no pudo cargarse o está vacío. No se puede entrenar."

    try:
        # Reinicia el generador para que cargue el modelo entrenado en el siguiente test
        lora_generator = None 
        
        data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

        training_args = TrainingArguments(
            output_dir=LORA_PATH,
            per_device_train_batch_size=int(batch_size),
            num_train_epochs=float(epochs),
            learning_rate=float(learning_rate),
            save_total_limit=1,
            logging_steps=10,
            push_to_hub=False,
        )

        trainer = Trainer(
            model=lora_model, # Usa el modelo LoRA global
            args=training_args,
            train_dataset=tokenized_dataset["train"],
            data_collator=data_collator,
        )

        trainer.train()
        
        # Guarda SOLAMENTE los adaptadores LoRA (archivos pequeños)
        lora_model.save_pretrained(LORA_PATH) 
        tokenizer.save_pretrained(LORA_PATH) 

        return f"✅ Entrenamiento completado. Adaptadores LoRA guardados en **{LORA_PATH}**"
    except Exception as e:
        return f"❌ Error durante el entrenamiento: {e}"

# ============================================================
# 🤖 FUNCIÓN DE GENERACIÓN (INFERENCIA)
# ============================================================
def generate_text(prompt_text):
    """Genera texto usando el modelo base + adaptadores LoRA."""
    global lora_generator

    try:
        # Carga el generador SOLAMENTE si no ha sido cargado aún
        if lora_generator is None:
            # 1. Cargar el modelo base
            base_model_gen = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto")
            
            # 2. Aplicar adaptadores LoRA si existen
            if os.path.exists(LORA_PATH):
                model_with_lora = PeftModel.from_pretrained(base_model_gen, LORA_PATH)
            else:
                # Si no hay LoRA, usa el modelo base (solo si está cargado globalmente)
                model_with_lora = base_model_gen

            # 3. Fusionar el modelo base y los adaptadores para una inferencia más rápida
            final_model = model_with_lora.merge_and_unload()
            final_model.eval() # Poner en modo evaluación
            
            # 4. Inicializar el pipeline global
            lora_generator = pipeline("text-generation", model=final_model, tokenizer=tokenizer)

        # Generar la respuesta
        output = lora_generator(prompt_text, max_new_tokens=150, temperature=0.7, top_p=0.9)
        return output[0]["generated_text"]
        
    except Exception as e:
        return f"❌ Error generando texto (Asegúrate de que el modelo base/LoRA esté cargado): {e}"

# ============================================================
# 💻 INTERFAZ GRADIO
# ============================================================
with gr.Blocks(title="AmorCoderAI - LoRA") as demo:
    gr.Markdown("# 💙 AmorCoderAI - Entrenamiento y Pruebas LoRA")
    gr.Markdown(f"Modelo base: `{BASE_MODEL}`. Adaptadores guardados en `{LORA_PATH}`. **El auto-entrenamiento se ejecuta al iniciar.**")

    with gr.Tab("🧠 Entrenar (Manual)"):
        gr.Markdown("--- **¡CUIDADO!** El entrenamiento es lento y consume muchos recursos (VRAM/RAM). ---")
        epochs = gr.Number(value=1, label="Épocas", precision=0)
        batch_size = gr.Number(value=2, label="Tamaño de lote (ajusta según tu VRAM)", precision=0)
        learning_rate = gr.Number(value=5e-5, label="Tasa de aprendizaje")
        train_button = gr.Button("🚀 Iniciar Entrenamiento Manual")
        train_output = gr.Textbox(label="Resultado del Entrenamiento Manual")
        
        train_button.click(
            train_lora, 
            inputs=[epochs, batch_size, learning_rate], 
            outputs=train_output
        )

    with gr.Tab("✨ Probar modelo"):
        prompt = gr.Textbox(label="Escribe código (ej: 'def fibonacci(n):')", lines=4)
        generate_button = gr.Button("💬 Generar código")
        output_box = gr.Textbox(label="Salida generada", lines=10)
        generate_button.click(generate_text, inputs=prompt, outputs=output_box)

# ============================================================
# 🚀 LANZAR APP Y AUTO-ENTRENAMIENTO (¡AQUÍ SUCEDE LA MAGIA!)
# ============================================================
if __name__ == "__main__":
    # 1. Cargar todos los recursos globales
    setup_resources()
    
    # 2. AUTO-ENTRENAMIENTO (Se ejecuta con valores por defecto)
    print("\n=============================================")
    print("🤖 INICIANDO AUTO-ENTRENAMIENTO (1 Época, 2 Batch Size)")
    print("=============================================")
    
    # Parámetros por defecto para el auto-entrenamiento
    auto_train_result = train_lora(epochs=1, batch_size=2, learning_rate=5e-5) 
    
    print(f"\nFIN DEL AUTO-ENTRENAMIENTO: {auto_train_result}")
    
    # 3. Lanzar la Interfaz Gradio
    print("\n=============================================")
    print("💻 LANZANDO INTERFAZ GRADIO")
    print("=============================================")
    demo.launch()