Andro0s's picture
Update app.py
932f265 verified
raw
history blame
8.9 kB
import os
import gradio as gr
from huggingface_hub import login
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling, pipeline
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
# ============================================================
# ⚙️ CONFIGURACIÓN GLOBAL
# ============================================================
# Modelo base para generación de código
BASE_MODEL = "bigcode/santacoder"
LORA_PATH = "./lora_output" # Directorio para guardar los adaptadores LoRA
DATASET_PATH = "tu_dataset.json" # ¡Cambia esto por el nombre de tu archivo JSON!
MAX_TOKEN_LENGTH = 256 # Longitud de secuencia uniforme (corrige errores de tamaño)
# Variables globales para acceso a recursos
tokenizer = None
lora_model = None
tokenized_dataset = None
lora_generator = None
# ============================================================
# 🔐 AUTENTICACIÓN Y PRE-CARGA DE RECURSOS (SINGLETON)
# ============================================================
def setup_resources():
"""Carga y configura todos los recursos (modelo, tokenizer, dataset) una sola vez."""
global tokenizer, lora_model, tokenized_dataset
# 1. Autenticación con Hugging Face
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
else:
print("⚠️ Token no encontrado. La app intentará correr sin autenticación de escritura.")
# 2. Carga del Tokenizer y Modelo Base
print("\n🔄 Cargando modelo y tokenizer una sola vez...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 3. Configuración y Aplicación LoRA (PEFT)
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=8,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["c_proj", "c_attn"], # Adaptado para modelos causales tipo GPT/Santacoder
)
# Envuelve el modelo base con la configuración LoRA
lora_model = get_peft_model(base_model, peft_config)
print(f"✅ Modelo LoRA preparado. Parámetros entrenables: {lora_model.print_trainable_parameters()}")
# 4. Carga y Tokenización del Dataset (para eficiencia)
print("📚 Cargando y tokenizando dataset (esto solo se hace una vez)...")
try:
raw_dataset = load_dataset("json", data_files=DATASET_PATH)
# Tokenización rápida con batched=True
tokenized_dataset = raw_dataset.map(
lambda e: tokenizer(
e["prompt"] + e["completion"],
truncation=True,
padding="max_length",
max_length=MAX_TOKEN_LENGTH
),
batched=True,
remove_columns=raw_dataset["train"].column_names
)
print("✅ Dataset tokenizado correctamente.")
except Exception as e:
tokenized_dataset = None
print(f"❌ Error al cargar o tokenizar el dataset. Asegúrate que '{DATASET_PATH}' exista. {e}")
# ============================================================
# 🧩 FUNCIÓN DE ENTRENAMIENTO
# ============================================================
def train_lora(epochs, batch_size, learning_rate):
"""Ejecuta el entrenamiento del modelo LoRA."""
global lora_model, tokenized_dataset, lora_generator
# Verifica si el dataset está disponible
if tokenized_dataset is None or "train" not in tokenized_dataset:
return "❌ Error: El dataset no pudo cargarse o está vacío. No se puede entrenar."
try:
# Reinicia el generador para que cargue el modelo entrenado en el siguiente test
lora_generator = None
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
training_args = TrainingArguments(
output_dir=LORA_PATH,
per_device_train_batch_size=int(batch_size),
num_train_epochs=float(epochs),
learning_rate=float(learning_rate),
save_total_limit=1,
logging_steps=10,
push_to_hub=False,
)
trainer = Trainer(
model=lora_model, # Usa el modelo LoRA global
args=training_args,
train_dataset=tokenized_dataset["train"],
data_collator=data_collator,
)
trainer.train()
# Guarda SOLAMENTE los adaptadores LoRA (archivos pequeños)
lora_model.save_pretrained(LORA_PATH)
tokenizer.save_pretrained(LORA_PATH)
return f"✅ Entrenamiento completado. Adaptadores LoRA guardados en **{LORA_PATH}**"
except Exception as e:
return f"❌ Error durante el entrenamiento: {e}"
# ============================================================
# 🤖 FUNCIÓN DE GENERACIÓN (INFERENCIA)
# ============================================================
def generate_text(prompt_text):
"""Genera texto usando el modelo base + adaptadores LoRA."""
global lora_generator
try:
# Carga el generador SOLAMENTE si no ha sido cargado aún
if lora_generator is None:
# 1. Cargar el modelo base
base_model_gen = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto")
# 2. Aplicar adaptadores LoRA si existen
if os.path.exists(LORA_PATH):
model_with_lora = PeftModel.from_pretrained(base_model_gen, LORA_PATH)
else:
# Si no hay LoRA, usa el modelo base (solo si está cargado globalmente)
model_with_lora = base_model_gen
# 3. Fusionar el modelo base y los adaptadores para una inferencia más rápida
final_model = model_with_lora.merge_and_unload()
final_model.eval() # Poner en modo evaluación
# 4. Inicializar el pipeline global
lora_generator = pipeline("text-generation", model=final_model, tokenizer=tokenizer)
# Generar la respuesta
output = lora_generator(prompt_text, max_new_tokens=150, temperature=0.7, top_p=0.9)
return output[0]["generated_text"]
except Exception as e:
return f"❌ Error generando texto (Asegúrate de que el modelo base/LoRA esté cargado): {e}"
# ============================================================
# 💻 INTERFAZ GRADIO
# ============================================================
with gr.Blocks(title="AmorCoderAI - LoRA") as demo:
gr.Markdown("# 💙 AmorCoderAI - Entrenamiento y Pruebas LoRA")
gr.Markdown(f"Modelo base: `{BASE_MODEL}`. Adaptadores guardados en `{LORA_PATH}`. **El auto-entrenamiento se ejecuta al iniciar.**")
with gr.Tab("🧠 Entrenar (Manual)"):
gr.Markdown("--- **¡CUIDADO!** El entrenamiento es lento y consume muchos recursos (VRAM/RAM). ---")
epochs = gr.Number(value=1, label="Épocas", precision=0)
batch_size = gr.Number(value=2, label="Tamaño de lote (ajusta según tu VRAM)", precision=0)
learning_rate = gr.Number(value=5e-5, label="Tasa de aprendizaje")
train_button = gr.Button("🚀 Iniciar Entrenamiento Manual")
train_output = gr.Textbox(label="Resultado del Entrenamiento Manual")
train_button.click(
train_lora,
inputs=[epochs, batch_size, learning_rate],
outputs=train_output
)
with gr.Tab("✨ Probar modelo"):
prompt = gr.Textbox(label="Escribe código (ej: 'def fibonacci(n):')", lines=4)
generate_button = gr.Button("💬 Generar código")
output_box = gr.Textbox(label="Salida generada", lines=10)
generate_button.click(generate_text, inputs=prompt, outputs=output_box)
# ============================================================
# 🚀 LANZAR APP Y AUTO-ENTRENAMIENTO (¡AQUÍ SUCEDE LA MAGIA!)
# ============================================================
if __name__ == "__main__":
# 1. Cargar todos los recursos globales
setup_resources()
# 2. AUTO-ENTRENAMIENTO (Se ejecuta con valores por defecto)
print("\n=============================================")
print("🤖 INICIANDO AUTO-ENTRENAMIENTO (1 Época, 2 Batch Size)")
print("=============================================")
# Parámetros por defecto para el auto-entrenamiento
auto_train_result = train_lora(epochs=1, batch_size=2, learning_rate=5e-5)
print(f"\nFIN DEL AUTO-ENTRENAMIENTO: {auto_train_result}")
# 3. Lanzar la Interfaz Gradio
print("\n=============================================")
print("💻 LANZANDO INTERFAZ GRADIO")
print("=============================================")
demo.launch()