import os import gradio as gr from huggingface_hub import login from datasets import load_dataset, Dataset from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling, pipeline from peft import get_peft_model, LoraConfig, TaskType, PeftModel # ============================================================ # ⚙️ CONFIGURACIÓN GLOBAL # ============================================================ BASE_MODEL = "bigcode/santacoder" # Modelo a refinar LORA_PATH = "./lora_output" # Directorio para guardar los adaptadores DATASET_PATH = "tu_dataset.json" # ¡Asegúrate de que este archivo exista! # Variables globales inicializadas como None tokenizer = None lora_model = None tokenized_dataset = None lora_generator = None # ============================================================ # 🔐 AUTENTICACIÓN Y PRE-CARGA # ============================================================ def setup_resources(): """Carga y configura todos los recursos (modelo, tokenizer, dataset) una sola vez.""" global tokenizer, lora_model, tokenized_dataset # 1. Autenticación hf_token = os.environ.get("HF_TOKEN") if hf_token: login(token=hf_token) else: print("⚠️ Token no encontrado. La app intentará correr sin autenticación de escritura.") # 2. Carga del Tokenizer y Modelo Base print("\n🔄 Cargando modelo y tokenizer...") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) # Usa device_map="auto" para cargar el modelo de forma eficiente en la(s) GPU base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto") if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # 3. Configuración y Aplicación LoRA (PEFT) peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=32, lora_dropout=0.1, # 'c_proj' y 'c_attn' son comunes en modelos GPT/causales target_modules=["c_proj", "c_attn"], ) lora_model = get_peft_model(base_model, peft_config) print(f"✅ Modelo LoRA preparado. Parámetros entrenables: {lora_model.print_trainable_parameters()}") # 4. Carga y Tokenización del Dataset (para evitar errores de longitud) print("📚 Cargando y tokenizando dataset...") try: raw_dataset = load_dataset("json", data_files=DATASET_PATH) tokenized_dataset = raw_dataset.map( lambda e: tokenizer( e["prompt"] + e["completion"], truncation=True, padding="max_length", max_length=256 # Mantener esta longitud consistente para evitar errores ), batched=True, remove_columns=raw_dataset["train"].column_names ) print("✅ Dataset tokenizado correctamente.") except Exception as e: tokenized_dataset = None print(f"❌ Error al cargar o tokenizar el dataset. El auto-entrenamiento fallará. {e}") # ============================================================ # 🧩 FUNCIÓN DE ENTRENAMIENTO # ============================================================ def train_lora(epochs=1, batch_size=2, learning_rate=5e-5): """Ejecuta el entrenamiento del modelo LoRA.""" global lora_model, tokenized_dataset, lora_generator if tokenized_dataset is None or "train" not in tokenized_dataset: return "❌ Error: El dataset no pudo cargarse o está vacío. No se puede entrenar." try: # Re-inicializa el generador a None para que se recargue después del entrenamiento lora_generator = None data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) training_args = TrainingArguments( output_dir=LORA_PATH, per_device_train_batch_size=int(batch_size), num_train_epochs=float(epochs), learning_rate=float(learning_rate), save_total_limit=1, logging_steps=10, push_to_hub=False, # Desactiva la evaluación para simplificar el auto-entrenamiento disable_tqdm=True ) trainer = Trainer( model=lora_model, # Usa el modelo LoRA global args=training_args, train_dataset=tokenized_dataset["train"], data_collator=data_collator, ) trainer.train() # Guardar solo los adaptadores LoRA (PEFT) lora_model.save_pretrained(LORA_PATH) tokenizer.save_pretrained(LORA_PATH) return "✅ Entrenamiento completado y adaptadores LoRA guardados en **./lora_output**" except Exception as e: return f"❌ Error durante el entrenamiento: {e}" # ============================================================ # 🤖 FUNCIÓN DE GENERACIÓN (INFERENCIA) # ============================================================ def generate_text(prompt_text): """Genera texto usando el modelo base + adaptadores LoRA.""" global lora_generator, lora_model try: # Carga el generador SOLO la primera vez o después del entrenamiento if lora_generator is None: # Cargar el modelo base limpio (sin los adaptadores LoRA) base_model_gen = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto") # Aplicar los adaptadores guardados if os.path.exists(LORA_PATH): model_with_lora = PeftModel.from_pretrained(base_model_gen, LORA_PATH) else: # Si no hay adaptadores entrenados, usa el modelo base inicial model_with_lora = lora_model if lora_model else base_model_gen # Fusionar el modelo base y los adaptadores para una inferencia más rápida final_model = model_with_lora.merge_and_unload() lora_generator = pipeline("text-generation", model=final_model, tokenizer=tokenizer) output = lora_generator(prompt_text, max_new_tokens=100, temperature=0.7, top_p=0.9) return output[0]["generated_text"] except Exception as e: return f"❌ Error generando texto (Asegúrate de que el modelo base y/o LoRA estén cargados): {e}" # ============================================================ # 💻 INTERFAZ GRADIO # ============================================================ with gr.Blocks(title="AmorCoderAI - LoRA") as demo: gr.Markdown("# 💙 AmorCoderAI - Entrenamiento y Pruebas LoRA") gr.Markdown(f"Modelo base: `{BASE_MODEL}`. Adaptadores guardados en `{LORA_PATH}`.") with gr.Tab("🧠 Entrenar"): gr.Markdown("--- **¡CUIDADO!** El entrenamiento es lento y consume muchos recursos. ---") epochs = gr.Number(value=1, label="Épocas", precision=0) batch_size = gr.Number(value=2, label="Tamaño de lote (ajusta según tu RAM/VRAM)", precision=0) learning_rate = gr.Number(value=5e-5, label="Tasa de aprendizaje") train_button = gr.Button("🚀 Iniciar entrenamiento Manual") train_output = gr.Textbox(label="Resultado del Entrenamiento Manual") train_button.click( train_lora, inputs=[epochs, batch_size, learning_rate], outputs=train_output ) with gr.Tab("✨ Probar modelo"): prompt = gr.Textbox(label="Escribe código (ej: 'def bubble_sort(arr):')", lines=4) generate_button = gr.Button("💬 Generar código") output_box = gr.Textbox(label="Salida generada", lines=10) generate_button.click(generate_text, inputs=prompt, outputs=output_box) # ============================================================ # 🚀 LANZAR APP Y AUTO-ENTRENAMIENTO # ============================================================ if __name__ == "__main__": # 1. Cargar recursos setup_resources() # 2. AUTO-ENTRENAMIENTO (¡El código se 'autocorre' aquí!) print("\n=============================================") print("🤖 INICIANDO AUTO-ENTRENAMIENTO...") print("=============================================") # Parámetros de auto-entrenamiento: 1 época, batch size 2, LR 5e-5 auto_train_result = train_lora(epochs=1, batch_size=2, learning_rate=5e-5) print(f"\nFIN DEL AUTO-ENTRENAMIENTO: {auto_train_result}") # 3. Lanzar la Interfaz Gradio print("\n=============================================") print("💻 LANZANDO INTERFAZ GRADIO") print("=============================================") demo.launch()