Andro0s's picture
Update app.py
ceb558e verified
raw
history blame
8.39 kB
import os
import gradio as gr
from huggingface_hub import login
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling, pipeline
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
import json
BASE_MODEL = "bigcode/santacoder"
LORA_PATH = "./lora_output"
DATASET_FILE = "codesearchnet_lora_dataset.json"
MAX_TOKEN_LENGTH = 256
NUM_SAMPLES_TO_PROCESS = 1000
DEFAULT_EPOCHS = 10
tokenizer = None
lora_model = None
tokenized_dataset = None
lora_generator = None
def prepare_codesearchnet():
if os.path.exists(DATASET_FILE):
print(f"Dataset '{DATASET_FILE}' ya existe.")
return
print(f"Descargando y procesando CodeSearchNet ({NUM_SAMPLES_TO_PROCESS} muestras)...")
try:
raw_csn = load_dataset('Nan-Do/code-search-net-python', split=f'train[:{NUM_SAMPLES_TO_PROCESS}]')
def format_for_lora(example):
prompt_text = (
f"# Descripción: {example['docstring_summary']}\n"
f"# Completa la siguiente función:\n"
f"def {example['func_name']}("
)
completion_text = example['code']
return {"prompt": prompt_text, "completion": completion_text}
lora_dataset = raw_csn.map(
format_for_lora,
batched=False,
remove_columns=raw_csn["train"].column_names,
)
lora_dataset.to_json(DATASET_FILE)
print(f"Pre-procesamiento completado. {NUM_SAMPLES_TO_PROCESS} ejemplos guardados en '{DATASET_FILE}'.")
except Exception as e:
print(f"Error CRÍTICO al descargar/procesar CodeSearchNet. Error: {e}")
minimal_dataset = [{"prompt": "# Error de carga. Intenta de nuevo.", "completion": "pass\n"}] * 10
with open(DATASET_FILE, 'w') as f:
json.dump(minimal_dataset, f)
def setup_resources():
global tokenizer, lora_model, tokenized_dataset
prepare_codesearchnet()
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
print("\nCargando modelo base y tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=8,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["c_proj", "c_attn"],
)
lora_model = get_peft_model(base_model, peft_config)
print(f"Modelo LoRA preparado. Parámetros entrenables listos.")
print(f"Cargando y tokenizando dataset: {DATASET_FILE}...")
try:
raw_dataset = load_dataset("json", data_files=DATASET_FILE)
def tokenize_function(examples):
return tokenizer(
examples["prompt"] + examples["completion"],
truncation=True,
padding="max_length",
max_length=MAX_TOKEN_LENGTH
)
tokenized_dataset = raw_dataset.map(
tokenize_function,
batched=True,
remove_columns=raw_dataset["train"].column_names if "train" in raw_dataset else [],
)
print("Dataset tokenizado correctamente.")
except Exception as e:
tokenized_dataset = None
print(f"Error al cargar o tokenizar el dataset. {e}")
def train_lora(epochs, batch_size, learning_rate):
global lora_model, tokenized_dataset, lora_generator
if tokenized_dataset is None or "train" not in tokenized_dataset:
return f"Error: El dataset no pudo cargarse o está vacío. No se puede entrenar."
try:
lora_generator = None
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
training_args = TrainingArguments(
output_dir=LORA_PATH,
per_device_train_batch_size=int(batch_size),
num_train_epochs=float(epochs),
learning_rate=float(learning_rate),
save_total_limit=1,
logging_steps=10,
push_to_hub=False,
)
trainer = Trainer(
model=lora_model,
args=training_args,
train_dataset=tokenized_dataset["train"],
data_collator=data_collator,
)
trainer.train()
lora_model.save_pretrained(LORA_PATH)
tokenizer.save_pretrained(LORA_PATH)
return f"Entrenamiento completado. Adaptadores LoRA guardados en **{LORA_PATH}**"
except Exception as e:
return f"Error durante el entrenamiento: {e}"
def generate_text(prompt_text):
global lora_generator
try:
if lora_generator is None:
base_model_gen = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto")
if os.path.exists(LORA_PATH):
print("Cargando adaptadores LoRA...")
model_with_lora = PeftModel.from_pretrained(base_model_gen, LORA_PATH)
final_model = model_with_lora.merge_and_unload()
else:
print("No se encontraron adaptadores LoRA. Usando modelo base.")
final_model = base_model_gen
final_model.eval()
lora_generator = pipeline("text-generation", model=final_model, tokenizer=tokenizer)
print("Modelo de inferencia listo.")
prompt_with_indent = prompt_text.strip() + "\n "
output = lora_generator(
prompt_with_indent,
max_new_tokens=150,
temperature=0.7,
top_p=0.9,
clean_up_tokenization_spaces=True
)
full_output = output[0]["generated_text"]
start_index = full_output.find(prompt_with_indent)
if start_index != -1:
completion = full_output[start_index + len(prompt_with_indent):]
else:
completion = full_output
return completion
except Exception as e:
return f"Error generando texto (Asegúrate de que el modelo base/LoRA esté cargado): {e}"
with gr.Blocks(title="AmorCoderAI - LoRA") as demo:
gr.Markdown("# 💙 AmorCoderAI - Entrenamiento y Pruebas LoRA")
gr.Markdown(f"Modelo base: `{BASE_MODEL}`. Usando **{NUM_SAMPLES_TO_PROCESS}** ejemplos de CodeSearchNet (10 Épocas).")
with gr.Tab("🧠 Entrenar (Manual)"):
gr.Markdown(f"--- ¡CUIDADO! El auto-entrenamiento usará {DEFAULT_EPOCHS} épocas para aprender la sintaxis. ---")
epochs = gr.Number(value=DEFAULT_EPOCHS, label="Épocas", precision=0)
batch_size = gr.Number(value=2, label="Tamaño de lote (ajusta según tu VRAM)", precision=0)
learning_rate = gr.Number(value=5e-5, label="Tasa de aprendizaje")
train_button = gr.Button("🚀 Iniciar Entrenamiento Manual")
train_output = gr.Textbox(label="Resultado del Entrenamiento Manual")
train_button.click(
train_lora,
inputs=[epochs, batch_size, learning_rate],
outputs=train_output
)
with gr.Tab("✨ Probar modelo"):
prompt = gr.Textbox(
label="Escribe código (ej: # Descripción: Calcula el factorial de N. \n# Completa la siguiente función:\ndef factorial(n):)",
lines=4
)
generate_button = gr.Button("💬 Generar código")
output_box = gr.Textbox(label="Salida generada (SOLO CÓDIGO)", lines=10)
generate_button.click(generate_text, inputs=prompt, outputs=output_box)
if __name__ == "__main__":
setup_resources()
print("\n=============================================")
print(f"🤖 INICIANDO AUTO-ENTRENAMIENTO ({DEFAULT_EPOCHS} Épocas, 2 Batch Size) usando {NUM_SAMPLES_TO_PROCESS} ejemplos")
print("=============================================")
auto_train_result = train_lora(epochs=DEFAULT_EPOCHS, batch_size=2, learning_rate=5e-5)
print(f"\nFIN DEL AUTO-ENTRENAMIENTO: {auto_train_result}")
print("\n=============================================")
print("💻 LANZANDO INTERFAZ GRADIO")
print("=============================================")
demo.launch()