Andro0s's picture
Create app.py
03b3c9a verified
raw
history blame
10.4 kB
import os
import gradio as gr
from huggingface_hub import login
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling, pipeline
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
import json
# ============================================================
# ⚙️ CONFIGURACIÓN GLOBAL
# ============================================================
# Modelo base para generación de código
BASE_MODEL = "bigcode/santacoder" 
LORA_PATH = "./lora_output"        # Directorio para guardar los adaptadores LoRA
# Nombre del archivo donde se guardará el dataset procesado
DATASET_FILE = "codesearchnet_lora_dataset.json"  
MAX_TOKEN_LENGTH = 256             # Longitud de secuencia uniforme
NUM_SAMPLES_TO_PROCESS = 5000     
DEFAULT_EPOCHS = 5 # <--- ¡ENTRENAMIENTO PROFUNDO!
# Variables globales
tokenizer = None
lora_model = None
tokenized_dataset = None
lora_generator = None
# ============================================================
# 🚨 LÓGICA DE PRE-PROCESAMIENTO DE DATOS (INTEGRADA) 🚨
# ============================================================
def prepare_codesearchnet():
    """Descarga, procesa y guarda el dataset CodeSearchNet si no existe."""
    if os.path.exists(DATASET_FILE):
        print(f"✅ Dataset '{DATASET_FILE}' ya existe.")
        return
    print(f"🔄 Descargando y procesando CodeSearchNet ({NUM_SAMPLES_TO_PROCESS} muestras)...")
   
    try:
        raw_csn = load_dataset('Nan-Do/code-search-net-python', split=f'train[:{NUM_SAMPLES_TO_PROCESS}]')
        def format_for_lora(example):
            prompt_text = (
                f"# Descripción: {example['docstring_summary']}\n"
                f"# Completa la siguiente función:\n"
                f"def {example['func_name']}("
            )
            completion_text = example['code']
           
            return {
                "prompt": prompt_text,
                "completion": completion_text
            }
        lora_dataset = raw_csn.map(
            format_for_lora,
            batched=False,
            remove_columns=raw_csn["train"].column_names,
        )
        lora_dataset.to_json(DATASET_FILE)
        print(f"✅ Pre-procesamiento completado. {NUM_SAMPLES_TO_PROCESS} ejemplos guardados en '{DATASET_FILE}'.")
    except Exception as e:
        print(f"❌ Error CRÍTICO al descargar/procesar CodeSearchNet. Error: {e}")
        minimal_dataset = [{"prompt": "# Error de carga. Intenta de nuevo.", "completion": "pass\n"}] * 10
        with open(DATASET_FILE, 'w') as f:
            json.dump(minimal_dataset, f)
# ============================================================
# 🔐 AUTENTICACIÓN Y PRE-CARGA DE RECURSOS (SINGLETON)
# ============================================================
def setup_resources():
    """Carga y configura todos los recursos (modelo, tokenizer, dataset) una sola vez."""
    global tokenizer, lora_model, tokenized_dataset
    prepare_codesearchnet()
   
    hf_token = os.environ.get("HF_TOKEN")
    if hf_token:
        login(token=hf_token)
    # 1. Carga del Tokenizer y Modelo Base
    print("\n🔄 Cargando modelo base y tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
    base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto")
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    # 2. Configuración y Aplicación LoRA (PEFT)
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=8,
        lora_alpha=32,
        lora_dropout=0.1,
        target_modules=["c_proj", "c_attn"],
    )
    lora_model = get_peft_model(base_model, peft_config)
   
    print(f"✅ Modelo LoRA preparado. Parámetros entrenables listos.")
    # 3. Carga y Tokenización del Dataset
    print(f"📚 Cargando y tokenizando dataset: {DATASET_FILE}...")
    try:
        raw_dataset = load_dataset("json", data_files=DATASET_FILE)
       
        def tokenize_function(examples):
            return tokenizer(
                examples["prompt"] + examples["completion"],
                truncation=True,
                padding="max_length",
                max_length=MAX_TOKEN_LENGTH
            )
        tokenized_dataset = raw_dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=raw_dataset["train"].column_names if "train" in raw_dataset else [],
        )
        print("✅ Dataset tokenizado correctamente.")
    except Exception as e:
        tokenized_dataset = None
        print(f"❌ Error al cargar o tokenizar el dataset. {e}")
# ============================================================
# 🧩 FUNCIÓN DE ENTRENAMIENTO
# ============================================================
def train_lora(epochs, batch_size, learning_rate):
    """Ejecuta el entrenamiento del modelo LoRA."""
    global lora_model, tokenized_dataset, lora_generator
    if tokenized_dataset is None or "train" not in tokenized_dataset:
        return f"❌ Error: El dataset no pudo cargarse o está vacío. No se puede entrenar."
    try:
        lora_generator = None
        data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
        training_args = TrainingArguments(
            output_dir=LORA_PATH,
            per_device_train_batch_size=int(batch_size),
            num_train_epochs=float(epochs),
            learning_rate=float(learning_rate),
            save_total_limit=1,
            logging_steps=10,
            push_to_hub=False,
        )
        trainer = Trainer(
            model=lora_model,
            args=training_args,
            train_dataset=tokenized_dataset["train"],
            data_collator=data_collator,
        )
        trainer.train()
       
        lora_model.save_pretrained(LORA_PATH)
        tokenizer.save_pretrained(LORA_PATH)
        return f"✅ Entrenamiento completado. Adaptadores LoRA guardados en **{LORA_PATH}**"
    except Exception as e:
        return f"❌ Error durante el entrenamiento: {e}"
# ============================================================
# 🤖 FUNCIÓN DE GENERACIÓN (INFERENCIA)
# ============================================================
def generate_text(prompt_text):
    """Genera texto usando el modelo base + adaptadores LoRA."""
    global lora_generator
    try:
        if lora_generator is None:
            base_model_gen = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto")
           
            if os.path.exists(LORA_PATH):
                print("Cargando adaptadores LoRA...")
                model_with_lora = PeftModel.from_pretrained(base_model_gen, LORA_PATH)
            else:
                print("No se encontraron adaptadores LoRA. Usando modelo base.")
                model_with_lora = base_model_gen
            final_model = model_with_lora.merge_and_unload()
            final_model.eval()
           
            lora_generator = pipeline("text-generation", model=final_model, tokenizer=tokenizer)
            print("Modelo de inferencia listo.")
        output = lora_generator(prompt_text, max_new_tokens=150, temperature=0.7, top_p=0.9)
        return output[0]["generated_text"]
       
    except Exception as e:
        return f"❌ Error generando texto (Asegúrate de que el modelo base/LoRA esté cargado): {e}"
# ============================================================
# 💻 INTERFAZ GRADIO
# ============================================================
with gr.Blocks(title="AmorCoderAI - LoRA") as demo:
    gr.Markdown("# 💙 AmorCoderAI - Entrenamiento y Pruebas LoRA")
    gr.Markdown(f"Modelo base: `{BASE_MODEL}`. Usando **{NUM_SAMPLES_TO_PROCESS}** ejemplos de CodeSearchNet.")
    with gr.Tab("🧠 Entrenar (Manual)"):
        gr.Markdown(f"--- **¡CUIDADO!** El auto-entrenamiento usará {DEFAULT_EPOCHS} épocas para aprender la sintaxis. ---")
        epochs = gr.Number(value=DEFAULT_EPOCHS, label="Épocas", precision=0)
        batch_size = gr.Number(value=2, label="Tamaño de lote (ajusta según tu VRAM)", precision=0)
        learning_rate = gr.Number(value=5e-5, label="Tasa de aprendizaje")
        train_button = gr.Button("🚀 Iniciar Entrenamiento Manual")
        train_output = gr.Textbox(label="Resultado del Entrenamiento Manual")
       
        train_button.click(
            train_lora,
            inputs=[epochs, batch_size, learning_rate],
            outputs=train_output
        )
    with gr.Tab("✨ Probar modelo"):
        prompt = gr.Textbox(label="Escribe código (ej: 'def fibonacci(n):')", lines=4)
        generate_button = gr.Button("💬 Generar código")
        output_box = gr.Textbox(label="Salida generada", lines=10)
        generate_button.click(generate_text, inputs=prompt, outputs=output_box)
# ============================================================
# 🚀 LANZAR APP Y AUTO-ENTRENAMIENTO
# ============================================================
if __name__ == "__main__":
    setup_resources()
   
    print("\n=============================================")
    print(f"🤖 INICIANDO AUTO-ENTRENAMIENTO ({DEFAULT_EPOCHS} Épocas, 2 Batch Size) usando {NUM_SAMPLES_TO_PROCESS} ejemplos")
    print("=============================================")
   
    auto_train_result = train_lora(epochs=DEFAULT_EPOCHS, batch_size=2, learning_rate=5e-5)
   
    print(f"\nFIN DEL AUTO-ENTRENAMIENTO: {auto_train_result}")
   
    print("\n=============================================")
    print("💻 LANZANDO INTERFAZ GRADIO")
    print("=============================================")
    demo.launch()