Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| from huggingface_hub import login | |
| from datasets import load_dataset | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling, pipeline | |
| from peft import get_peft_model, LoraConfig, TaskType, PeftModel | |
| import json | |
| # ============================================================ | |
| # ⚙️ CONFIGURACIÓN GLOBAL | |
| # ============================================================ | |
| # Modelo base para generación de código | |
| BASE_MODEL = "bigcode/santacoder" | |
| LORA_PATH = "./lora_output" # Directorio para guardar los adaptadores LoRA | |
| # Nombre del archivo donde se guardará el dataset procesado | |
| DATASET_FILE = "codesearchnet_lora_dataset.json" | |
| MAX_TOKEN_LENGTH = 256 # Longitud de secuencia uniforme | |
| NUM_SAMPLES_TO_PROCESS = 5000 | |
| DEFAULT_EPOCHS = 5 # <--- ¡ENTRENAMIENTO PROFUNDO! | |
| # Variables globales | |
| tokenizer = None | |
| lora_model = None | |
| tokenized_dataset = None | |
| lora_generator = None | |
| # ============================================================ | |
| # 🚨 LÓGICA DE PRE-PROCESAMIENTO DE DATOS (INTEGRADA) 🚨 | |
| # ============================================================ | |
| def prepare_codesearchnet(): | |
| """Descarga, procesa y guarda el dataset CodeSearchNet si no existe.""" | |
| if os.path.exists(DATASET_FILE): | |
| print(f"✅ Dataset '{DATASET_FILE}' ya existe.") | |
| return | |
| print(f"🔄 Descargando y procesando CodeSearchNet ({NUM_SAMPLES_TO_PROCESS} muestras)...") | |
| try: | |
| raw_csn = load_dataset('Nan-Do/code-search-net-python', split=f'train[:{NUM_SAMPLES_TO_PROCESS}]') | |
| def format_for_lora(example): | |
| prompt_text = ( | |
| f"# Descripción: {example['docstring_summary']}\n" | |
| f"# Completa la siguiente función:\n" | |
| f"def {example['func_name']}(" | |
| ) | |
| completion_text = example['code'] | |
| return { | |
| "prompt": prompt_text, | |
| "completion": completion_text | |
| } | |
| lora_dataset = raw_csn.map( | |
| format_for_lora, | |
| batched=False, | |
| remove_columns=raw_csn["train"].column_names, | |
| ) | |
| lora_dataset.to_json(DATASET_FILE) | |
| print(f"✅ Pre-procesamiento completado. {NUM_SAMPLES_TO_PROCESS} ejemplos guardados en '{DATASET_FILE}'.") | |
| except Exception as e: | |
| print(f"❌ Error CRÍTICO al descargar/procesar CodeSearchNet. Error: {e}") | |
| minimal_dataset = [{"prompt": "# Error de carga. Intenta de nuevo.", "completion": "pass\n"}] * 10 | |
| with open(DATASET_FILE, 'w') as f: | |
| json.dump(minimal_dataset, f) | |
| # ============================================================ | |
| # 🔐 AUTENTICACIÓN Y PRE-CARGA DE RECURSOS (SINGLETON) | |
| # ============================================================ | |
| def setup_resources(): | |
| """Carga y configura todos los recursos (modelo, tokenizer, dataset) una sola vez.""" | |
| global tokenizer, lora_model, tokenized_dataset | |
| prepare_codesearchnet() | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if hf_token: | |
| login(token=hf_token) | |
| # 1. Carga del Tokenizer y Modelo Base | |
| print("\n🔄 Cargando modelo base y tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto") | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # 2. Configuración y Aplicación LoRA (PEFT) | |
| peft_config = LoraConfig( | |
| task_type=TaskType.CAUSAL_LM, | |
| r=8, | |
| lora_alpha=32, | |
| lora_dropout=0.1, | |
| target_modules=["c_proj", "c_attn"], | |
| ) | |
| lora_model = get_peft_model(base_model, peft_config) | |
| print(f"✅ Modelo LoRA preparado. Parámetros entrenables listos.") | |
| # 3. Carga y Tokenización del Dataset | |
| print(f"📚 Cargando y tokenizando dataset: {DATASET_FILE}...") | |
| try: | |
| raw_dataset = load_dataset("json", data_files=DATASET_FILE) | |
| def tokenize_function(examples): | |
| return tokenizer( | |
| examples["prompt"] + examples["completion"], | |
| truncation=True, | |
| padding="max_length", | |
| max_length=MAX_TOKEN_LENGTH | |
| ) | |
| tokenized_dataset = raw_dataset.map( | |
| tokenize_function, | |
| batched=True, | |
| remove_columns=raw_dataset["train"].column_names if "train" in raw_dataset else [], | |
| ) | |
| print("✅ Dataset tokenizado correctamente.") | |
| except Exception as e: | |
| tokenized_dataset = None | |
| print(f"❌ Error al cargar o tokenizar el dataset. {e}") | |
| # ============================================================ | |
| # 🧩 FUNCIÓN DE ENTRENAMIENTO | |
| # ============================================================ | |
| def train_lora(epochs, batch_size, learning_rate): | |
| """Ejecuta el entrenamiento del modelo LoRA.""" | |
| global lora_model, tokenized_dataset, lora_generator | |
| if tokenized_dataset is None or "train" not in tokenized_dataset: | |
| return f"❌ Error: El dataset no pudo cargarse o está vacío. No se puede entrenar." | |
| try: | |
| lora_generator = None | |
| data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) | |
| training_args = TrainingArguments( | |
| output_dir=LORA_PATH, | |
| per_device_train_batch_size=int(batch_size), | |
| num_train_epochs=float(epochs), | |
| learning_rate=float(learning_rate), | |
| save_total_limit=1, | |
| logging_steps=10, | |
| push_to_hub=False, | |
| ) | |
| trainer = Trainer( | |
| model=lora_model, | |
| args=training_args, | |
| train_dataset=tokenized_dataset["train"], | |
| data_collator=data_collator, | |
| ) | |
| trainer.train() | |
| lora_model.save_pretrained(LORA_PATH) | |
| tokenizer.save_pretrained(LORA_PATH) | |
| return f"✅ Entrenamiento completado. Adaptadores LoRA guardados en **{LORA_PATH}**" | |
| except Exception as e: | |
| return f"❌ Error durante el entrenamiento: {e}" | |
| # ============================================================ | |
| # 🤖 FUNCIÓN DE GENERACIÓN (INFERENCIA) | |
| # ============================================================ | |
| def generate_text(prompt_text): | |
| """Genera texto usando el modelo base + adaptadores LoRA.""" | |
| global lora_generator | |
| try: | |
| if lora_generator is None: | |
| base_model_gen = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto") | |
| if os.path.exists(LORA_PATH): | |
| print("Cargando adaptadores LoRA...") | |
| model_with_lora = PeftModel.from_pretrained(base_model_gen, LORA_PATH) | |
| else: | |
| print("No se encontraron adaptadores LoRA. Usando modelo base.") | |
| model_with_lora = base_model_gen | |
| final_model = model_with_lora.merge_and_unload() | |
| final_model.eval() | |
| lora_generator = pipeline("text-generation", model=final_model, tokenizer=tokenizer) | |
| print("Modelo de inferencia listo.") | |
| output = lora_generator(prompt_text, max_new_tokens=150, temperature=0.7, top_p=0.9) | |
| return output[0]["generated_text"] | |
| except Exception as e: | |
| return f"❌ Error generando texto (Asegúrate de que el modelo base/LoRA esté cargado): {e}" | |
| # ============================================================ | |
| # 💻 INTERFAZ GRADIO | |
| # ============================================================ | |
| with gr.Blocks(title="AmorCoderAI - LoRA") as demo: | |
| gr.Markdown("# 💙 AmorCoderAI - Entrenamiento y Pruebas LoRA") | |
| gr.Markdown(f"Modelo base: `{BASE_MODEL}`. Usando **{NUM_SAMPLES_TO_PROCESS}** ejemplos de CodeSearchNet.") | |
| with gr.Tab("🧠 Entrenar (Manual)"): | |
| gr.Markdown(f"--- **¡CUIDADO!** El auto-entrenamiento usará {DEFAULT_EPOCHS} épocas para aprender la sintaxis. ---") | |
| epochs = gr.Number(value=DEFAULT_EPOCHS, label="Épocas", precision=0) | |
| batch_size = gr.Number(value=2, label="Tamaño de lote (ajusta según tu VRAM)", precision=0) | |
| learning_rate = gr.Number(value=5e-5, label="Tasa de aprendizaje") | |
| train_button = gr.Button("🚀 Iniciar Entrenamiento Manual") | |
| train_output = gr.Textbox(label="Resultado del Entrenamiento Manual") | |
| train_button.click( | |
| train_lora, | |
| inputs=[epochs, batch_size, learning_rate], | |
| outputs=train_output | |
| ) | |
| with gr.Tab("✨ Probar modelo"): | |
| prompt = gr.Textbox(label="Escribe código (ej: 'def fibonacci(n):')", lines=4) | |
| generate_button = gr.Button("💬 Generar código") | |
| output_box = gr.Textbox(label="Salida generada", lines=10) | |
| generate_button.click(generate_text, inputs=prompt, outputs=output_box) | |
| # ============================================================ | |
| # 🚀 LANZAR APP Y AUTO-ENTRENAMIENTO | |
| # ============================================================ | |
| if __name__ == "__main__": | |
| setup_resources() | |
| print("\n=============================================") | |
| print(f"🤖 INICIANDO AUTO-ENTRENAMIENTO ({DEFAULT_EPOCHS} Épocas, 2 Batch Size) usando {NUM_SAMPLES_TO_PROCESS} ejemplos") | |
| print("=============================================") | |
| auto_train_result = train_lora(epochs=DEFAULT_EPOCHS, batch_size=2, learning_rate=5e-5) | |
| print(f"\nFIN DEL AUTO-ENTRENAMIENTO: {auto_train_result}") | |
| print("\n=============================================") | |
| print("💻 LANZANDO INTERFAZ GRADIO") | |
| print("=============================================") | |
| demo.launch() |