Spaces:
Runtime error
Runtime error
File size: 8,901 Bytes
7b02281 0462c81 2636834 8a66252 1527ee9 8a66252 a7bd5fa 0462c81 8a66252 0462c81 932f265 a7bd5fa 932f265 8a66252 7b02281 3410ef1 932f265 3410ef1 8a66252 932f265 8a66252 932f265 8a66252 932f265 8a66252 932f265 8a66252 932f265 0462c81 8a66252 932f265 8a66252 1527ee9 413d0ff 932f265 1527ee9 8a66252 413d0ff 8a66252 932f265 8a66252 0462c81 8a66252 932f265 8a66252 932f265 8a66252 932f265 8a66252 1527ee9 0462c81 1527ee9 413d0ff 8a66252 413d0ff 1527ee9 8a66252 0462c81 8a66252 0462c81 8a66252 0462c81 8a66252 932f265 8a66252 413d0ff 932f265 0462c81 1527ee9 0462c81 8a66252 0462c81 1527ee9 8a66252 932f265 1527ee9 8a66252 932f265 8a66252 932f265 8a66252 932f265 8a66252 932f265 8a66252 932f265 8a66252 932f265 413d0ff 8a66252 413d0ff 932f265 7b02281 0462c81 8a66252 0462c81 8a66252 932f265 0462c81 932f265 8a66252 932f265 0462c81 932f265 8a66252 932f265 8a66252 0462c81 932f265 8a66252 0462c81 932f265 0462c81 2636834 932f265 8a66252 932f265 8a66252 932f265 8a66252 932f265 8a66252 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
import os
import gradio as gr
from huggingface_hub import login
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling, pipeline
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
# ============================================================
# ⚙️ CONFIGURACIÓN GLOBAL
# ============================================================
# Modelo base para generación de código
BASE_MODEL = "bigcode/santacoder"
LORA_PATH = "./lora_output" # Directorio para guardar los adaptadores LoRA
DATASET_PATH = "tu_dataset.json" # ¡Cambia esto por el nombre de tu archivo JSON!
MAX_TOKEN_LENGTH = 256 # Longitud de secuencia uniforme (corrige errores de tamaño)
# Variables globales para acceso a recursos
tokenizer = None
lora_model = None
tokenized_dataset = None
lora_generator = None
# ============================================================
# 🔐 AUTENTICACIÓN Y PRE-CARGA DE RECURSOS (SINGLETON)
# ============================================================
def setup_resources():
"""Carga y configura todos los recursos (modelo, tokenizer, dataset) una sola vez."""
global tokenizer, lora_model, tokenized_dataset
# 1. Autenticación con Hugging Face
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
else:
print("⚠️ Token no encontrado. La app intentará correr sin autenticación de escritura.")
# 2. Carga del Tokenizer y Modelo Base
print("\n🔄 Cargando modelo y tokenizer una sola vez...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 3. Configuración y Aplicación LoRA (PEFT)
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=8,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["c_proj", "c_attn"], # Adaptado para modelos causales tipo GPT/Santacoder
)
# Envuelve el modelo base con la configuración LoRA
lora_model = get_peft_model(base_model, peft_config)
print(f"✅ Modelo LoRA preparado. Parámetros entrenables: {lora_model.print_trainable_parameters()}")
# 4. Carga y Tokenización del Dataset (para eficiencia)
print("📚 Cargando y tokenizando dataset (esto solo se hace una vez)...")
try:
raw_dataset = load_dataset("json", data_files=DATASET_PATH)
# Tokenización rápida con batched=True
tokenized_dataset = raw_dataset.map(
lambda e: tokenizer(
e["prompt"] + e["completion"],
truncation=True,
padding="max_length",
max_length=MAX_TOKEN_LENGTH
),
batched=True,
remove_columns=raw_dataset["train"].column_names
)
print("✅ Dataset tokenizado correctamente.")
except Exception as e:
tokenized_dataset = None
print(f"❌ Error al cargar o tokenizar el dataset. Asegúrate que '{DATASET_PATH}' exista. {e}")
# ============================================================
# 🧩 FUNCIÓN DE ENTRENAMIENTO
# ============================================================
def train_lora(epochs, batch_size, learning_rate):
"""Ejecuta el entrenamiento del modelo LoRA."""
global lora_model, tokenized_dataset, lora_generator
# Verifica si el dataset está disponible
if tokenized_dataset is None or "train" not in tokenized_dataset:
return "❌ Error: El dataset no pudo cargarse o está vacío. No se puede entrenar."
try:
# Reinicia el generador para que cargue el modelo entrenado en el siguiente test
lora_generator = None
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
training_args = TrainingArguments(
output_dir=LORA_PATH,
per_device_train_batch_size=int(batch_size),
num_train_epochs=float(epochs),
learning_rate=float(learning_rate),
save_total_limit=1,
logging_steps=10,
push_to_hub=False,
)
trainer = Trainer(
model=lora_model, # Usa el modelo LoRA global
args=training_args,
train_dataset=tokenized_dataset["train"],
data_collator=data_collator,
)
trainer.train()
# Guarda SOLAMENTE los adaptadores LoRA (archivos pequeños)
lora_model.save_pretrained(LORA_PATH)
tokenizer.save_pretrained(LORA_PATH)
return f"✅ Entrenamiento completado. Adaptadores LoRA guardados en **{LORA_PATH}**"
except Exception as e:
return f"❌ Error durante el entrenamiento: {e}"
# ============================================================
# 🤖 FUNCIÓN DE GENERACIÓN (INFERENCIA)
# ============================================================
def generate_text(prompt_text):
"""Genera texto usando el modelo base + adaptadores LoRA."""
global lora_generator
try:
# Carga el generador SOLAMENTE si no ha sido cargado aún
if lora_generator is None:
# 1. Cargar el modelo base
base_model_gen = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto")
# 2. Aplicar adaptadores LoRA si existen
if os.path.exists(LORA_PATH):
model_with_lora = PeftModel.from_pretrained(base_model_gen, LORA_PATH)
else:
# Si no hay LoRA, usa el modelo base (solo si está cargado globalmente)
model_with_lora = base_model_gen
# 3. Fusionar el modelo base y los adaptadores para una inferencia más rápida
final_model = model_with_lora.merge_and_unload()
final_model.eval() # Poner en modo evaluación
# 4. Inicializar el pipeline global
lora_generator = pipeline("text-generation", model=final_model, tokenizer=tokenizer)
# Generar la respuesta
output = lora_generator(prompt_text, max_new_tokens=150, temperature=0.7, top_p=0.9)
return output[0]["generated_text"]
except Exception as e:
return f"❌ Error generando texto (Asegúrate de que el modelo base/LoRA esté cargado): {e}"
# ============================================================
# 💻 INTERFAZ GRADIO
# ============================================================
with gr.Blocks(title="AmorCoderAI - LoRA") as demo:
gr.Markdown("# 💙 AmorCoderAI - Entrenamiento y Pruebas LoRA")
gr.Markdown(f"Modelo base: `{BASE_MODEL}`. Adaptadores guardados en `{LORA_PATH}`. **El auto-entrenamiento se ejecuta al iniciar.**")
with gr.Tab("🧠 Entrenar (Manual)"):
gr.Markdown("--- **¡CUIDADO!** El entrenamiento es lento y consume muchos recursos (VRAM/RAM). ---")
epochs = gr.Number(value=1, label="Épocas", precision=0)
batch_size = gr.Number(value=2, label="Tamaño de lote (ajusta según tu VRAM)", precision=0)
learning_rate = gr.Number(value=5e-5, label="Tasa de aprendizaje")
train_button = gr.Button("🚀 Iniciar Entrenamiento Manual")
train_output = gr.Textbox(label="Resultado del Entrenamiento Manual")
train_button.click(
train_lora,
inputs=[epochs, batch_size, learning_rate],
outputs=train_output
)
with gr.Tab("✨ Probar modelo"):
prompt = gr.Textbox(label="Escribe código (ej: 'def fibonacci(n):')", lines=4)
generate_button = gr.Button("💬 Generar código")
output_box = gr.Textbox(label="Salida generada", lines=10)
generate_button.click(generate_text, inputs=prompt, outputs=output_box)
# ============================================================
# 🚀 LANZAR APP Y AUTO-ENTRENAMIENTO (¡AQUÍ SUCEDE LA MAGIA!)
# ============================================================
if __name__ == "__main__":
# 1. Cargar todos los recursos globales
setup_resources()
# 2. AUTO-ENTRENAMIENTO (Se ejecuta con valores por defecto)
print("\n=============================================")
print("🤖 INICIANDO AUTO-ENTRENAMIENTO (1 Época, 2 Batch Size)")
print("=============================================")
# Parámetros por defecto para el auto-entrenamiento
auto_train_result = train_lora(epochs=1, batch_size=2, learning_rate=5e-5)
print(f"\nFIN DEL AUTO-ENTRENAMIENTO: {auto_train_result}")
# 3. Lanzar la Interfaz Gradio
print("\n=============================================")
print("💻 LANZANDO INTERFAZ GRADIO")
print("=============================================")
demo.launch()
|