Spaces:
Runtime error
Runtime error
File size: 8,619 Bytes
7b02281 0462c81 2636834 8a66252 1527ee9 8a66252 a7bd5fa 0462c81 8a66252 0462c81 8a66252 a7bd5fa 8a66252 7b02281 3410ef1 8a66252 3410ef1 8a66252 0462c81 8a66252 1527ee9 413d0ff 8a66252 1527ee9 8a66252 413d0ff 8a66252 0462c81 8a66252 1527ee9 0462c81 1527ee9 413d0ff 8a66252 413d0ff 1527ee9 8a66252 0462c81 8a66252 0462c81 8a66252 0462c81 8a66252 413d0ff 8a66252 0462c81 1527ee9 0462c81 8a66252 0462c81 1527ee9 8a66252 1527ee9 8a66252 413d0ff 8a66252 413d0ff 8a66252 7b02281 0462c81 8a66252 0462c81 8a66252 0462c81 8a66252 0462c81 8a66252 0462c81 8a66252 0462c81 8a66252 0462c81 2636834 8a66252 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import os
import gradio as gr
from huggingface_hub import login
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling, pipeline
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
# ============================================================
# ⚙️ CONFIGURACIÓN GLOBAL
# ============================================================
BASE_MODEL = "bigcode/santacoder" # Modelo a refinar
LORA_PATH = "./lora_output" # Directorio para guardar los adaptadores
DATASET_PATH = "tu_dataset.json" # ¡Asegúrate de que este archivo exista!
# Variables globales inicializadas como None
tokenizer = None
lora_model = None
tokenized_dataset = None
lora_generator = None
# ============================================================
# 🔐 AUTENTICACIÓN Y PRE-CARGA
# ============================================================
def setup_resources():
"""Carga y configura todos los recursos (modelo, tokenizer, dataset) una sola vez."""
global tokenizer, lora_model, tokenized_dataset
# 1. Autenticación
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
else:
print("⚠️ Token no encontrado. La app intentará correr sin autenticación de escritura.")
# 2. Carga del Tokenizer y Modelo Base
print("\n🔄 Cargando modelo y tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
# Usa device_map="auto" para cargar el modelo de forma eficiente en la(s) GPU
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 3. Configuración y Aplicación LoRA (PEFT)
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=8,
lora_alpha=32,
lora_dropout=0.1,
# 'c_proj' y 'c_attn' son comunes en modelos GPT/causales
target_modules=["c_proj", "c_attn"],
)
lora_model = get_peft_model(base_model, peft_config)
print(f"✅ Modelo LoRA preparado. Parámetros entrenables: {lora_model.print_trainable_parameters()}")
# 4. Carga y Tokenización del Dataset (para evitar errores de longitud)
print("📚 Cargando y tokenizando dataset...")
try:
raw_dataset = load_dataset("json", data_files=DATASET_PATH)
tokenized_dataset = raw_dataset.map(
lambda e: tokenizer(
e["prompt"] + e["completion"],
truncation=True,
padding="max_length",
max_length=256 # Mantener esta longitud consistente para evitar errores
),
batched=True,
remove_columns=raw_dataset["train"].column_names
)
print("✅ Dataset tokenizado correctamente.")
except Exception as e:
tokenized_dataset = None
print(f"❌ Error al cargar o tokenizar el dataset. El auto-entrenamiento fallará. {e}")
# ============================================================
# 🧩 FUNCIÓN DE ENTRENAMIENTO
# ============================================================
def train_lora(epochs=1, batch_size=2, learning_rate=5e-5):
"""Ejecuta el entrenamiento del modelo LoRA."""
global lora_model, tokenized_dataset, lora_generator
if tokenized_dataset is None or "train" not in tokenized_dataset:
return "❌ Error: El dataset no pudo cargarse o está vacío. No se puede entrenar."
try:
# Re-inicializa el generador a None para que se recargue después del entrenamiento
lora_generator = None
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
training_args = TrainingArguments(
output_dir=LORA_PATH,
per_device_train_batch_size=int(batch_size),
num_train_epochs=float(epochs),
learning_rate=float(learning_rate),
save_total_limit=1,
logging_steps=10,
push_to_hub=False,
# Desactiva la evaluación para simplificar el auto-entrenamiento
disable_tqdm=True
)
trainer = Trainer(
model=lora_model, # Usa el modelo LoRA global
args=training_args,
train_dataset=tokenized_dataset["train"],
data_collator=data_collator,
)
trainer.train()
# Guardar solo los adaptadores LoRA (PEFT)
lora_model.save_pretrained(LORA_PATH)
tokenizer.save_pretrained(LORA_PATH)
return "✅ Entrenamiento completado y adaptadores LoRA guardados en **./lora_output**"
except Exception as e:
return f"❌ Error durante el entrenamiento: {e}"
# ============================================================
# 🤖 FUNCIÓN DE GENERACIÓN (INFERENCIA)
# ============================================================
def generate_text(prompt_text):
"""Genera texto usando el modelo base + adaptadores LoRA."""
global lora_generator, lora_model
try:
# Carga el generador SOLO la primera vez o después del entrenamiento
if lora_generator is None:
# Cargar el modelo base limpio (sin los adaptadores LoRA)
base_model_gen = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto")
# Aplicar los adaptadores guardados
if os.path.exists(LORA_PATH):
model_with_lora = PeftModel.from_pretrained(base_model_gen, LORA_PATH)
else:
# Si no hay adaptadores entrenados, usa el modelo base inicial
model_with_lora = lora_model if lora_model else base_model_gen
# Fusionar el modelo base y los adaptadores para una inferencia más rápida
final_model = model_with_lora.merge_and_unload()
lora_generator = pipeline("text-generation", model=final_model, tokenizer=tokenizer)
output = lora_generator(prompt_text, max_new_tokens=100, temperature=0.7, top_p=0.9)
return output[0]["generated_text"]
except Exception as e:
return f"❌ Error generando texto (Asegúrate de que el modelo base y/o LoRA estén cargados): {e}"
# ============================================================
# 💻 INTERFAZ GRADIO
# ============================================================
with gr.Blocks(title="AmorCoderAI - LoRA") as demo:
gr.Markdown("# 💙 AmorCoderAI - Entrenamiento y Pruebas LoRA")
gr.Markdown(f"Modelo base: `{BASE_MODEL}`. Adaptadores guardados en `{LORA_PATH}`.")
with gr.Tab("🧠 Entrenar"):
gr.Markdown("--- **¡CUIDADO!** El entrenamiento es lento y consume muchos recursos. ---")
epochs = gr.Number(value=1, label="Épocas", precision=0)
batch_size = gr.Number(value=2, label="Tamaño de lote (ajusta según tu RAM/VRAM)", precision=0)
learning_rate = gr.Number(value=5e-5, label="Tasa de aprendizaje")
train_button = gr.Button("🚀 Iniciar entrenamiento Manual")
train_output = gr.Textbox(label="Resultado del Entrenamiento Manual")
train_button.click(
train_lora,
inputs=[epochs, batch_size, learning_rate],
outputs=train_output
)
with gr.Tab("✨ Probar modelo"):
prompt = gr.Textbox(label="Escribe código (ej: 'def bubble_sort(arr):')", lines=4)
generate_button = gr.Button("💬 Generar código")
output_box = gr.Textbox(label="Salida generada", lines=10)
generate_button.click(generate_text, inputs=prompt, outputs=output_box)
# ============================================================
# 🚀 LANZAR APP Y AUTO-ENTRENAMIENTO
# ============================================================
if __name__ == "__main__":
# 1. Cargar recursos
setup_resources()
# 2. AUTO-ENTRENAMIENTO (¡El código se 'autocorre' aquí!)
print("\n=============================================")
print("🤖 INICIANDO AUTO-ENTRENAMIENTO...")
print("=============================================")
# Parámetros de auto-entrenamiento: 1 época, batch size 2, LR 5e-5
auto_train_result = train_lora(epochs=1, batch_size=2, learning_rate=5e-5)
print(f"\nFIN DEL AUTO-ENTRENAMIENTO: {auto_train_result}")
# 3. Lanzar la Interfaz Gradio
print("\n=============================================")
print("💻 LANZANDO INTERFAZ GRADIO")
print("=============================================")
demo.launch()
|