Chat / app.py
Andro0s's picture
Create app.py
dbf0a6c verified
import os
import gradio as gr
import threading
import time
from huggingface_hub import login
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling, pipeline, AutoModelForSeq2SeqLM
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
import json
# --- CONFIGURACIÓN DEL MODELO Y ENTRENAMIENTO ---
BASE_MODEL = "bigcode/santacoder" # Modelo base de programación
LORA_PATH = "./lora_output" # Ruta donde se guarda el modelo adaptado
DATASET_FILE = "codesearchnet_lora_dataset.json"
MAX_TOKEN_LENGTH = 256
NUM_SAMPLES_TO_PROCESS = 1000
DEFAULT_EPOCHS = 10
# Configuración del ciclo AUTÓNOMO (Inicia reentrenamiento cada 5 interacciones)
GENERATION_LIMIT_TO_TRAIN = 5
AUTONOMOUS_EPOCHS = 3
# Modelo de chat pre-entrenado en español
CHAT_MODEL_NAME = "bigscience/bloom"
chat_tokenizer = None
chat_model = None
# --- ESTADO GLOBAL Y THREADING ---
tokenizer = None
lora_model = None
tokenized_dataset = None
lora_generator = None
# Variables de estado
version_number = 1.0
is_trained = os.path.exists(LORA_PATH)
generations_since_last_train = 0
training_status_message = "Esperando la inicialización V1.0..."
# Lock para proteger las variables compartidas entre hilos (CRÍTICO para estabilidad)
global_lock = threading.Lock()
# --- LÓGICA DE PREPARACIÓN Y SETUP ---
def prepare_codesearchnet():
"""Descarga y prepara el dataset inicial si no existe."""
if os.path.exists(DATASET_FILE):
return
try:
raw_csn = load_dataset('Nan-Do/code-search-net-python', split=f'train[:{NUM_SAMPLES_TO_PROCESS}]')
def format_for_lora(example):
# Formato que entrena a la IA a enlazar descripción (español) con código (inglés)
prompt_text = (
f"# Descripción: {example['docstring_summary']}\n"
f"# Completa la siguiente función:\n"
f"def {example['func_name']}("
)
completion_text = example['code']
return {"prompt": prompt_text, "completion": completion_text}
lora_dataset = raw_csn.map(format_for_lora, batched=False, remove_columns=raw_csn["train"].column_names)
lora_dataset.to_json(DATASET_FILE)
except Exception as e:
print(f"Error al cargar dataset. Usando datos mínimos. Error: {e}")
minimal_dataset = [{"prompt": "# Error de carga. Intenta de nuevo.", "completion": "pass\n"}] * 10
with open(DATASET_FILE, 'w') as f:
json.dump(minimal_dataset, f)
def setup_resources():
"""Configura el tokenizer, el modelo base y el adaptador LoRA."""
global tokenizer, lora_model, tokenized_dataset, chat_tokenizer, chat_model
prepare_codesearchnet()
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=32, lora_dropout=0.1, target_modules=["c_proj", "c_attn"],
)
lora_model = get_peft_model(base_model, peft_config)
try:
raw_dataset = load_dataset("json", data_files=DATASET_FILE)
def tokenize_function(examples):
return tokenizer(
examples["prompt"] + examples["completion"],
truncation=True,
padding="max_length",
max_length=MAX_TOKEN_LENGTH
)
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=raw_dataset["train"].column_names if "train" in raw_dataset else [],)
except Exception:
tokenized_dataset = None
# Configuración del modelo de chat
chat_tokenizer = AutoTokenizer.from_pretrained(CHAT_MODEL_NAME)
chat_model = AutoModelForSeq2SeqLM.from_pretrained(CHAT_MODEL_NAME)
# --- FUNCIÓN DE ENTRENAMIENTO (EJECUTADA EN HILO SEPARADO) ---
def autonomous_train_lora(epochs, batch_size, learning_rate):
"""Ejecuta el entrenamiento en un hilo separado para la autonomía."""
global lora_model, tokenized_dataset, lora_generator, version_number, is_trained, training_status_message
try:
with global_lock:
if tokenized_dataset is None or "train" not in tokenized_dataset:
training_status_message = "ERROR: No se puede entrenar. Dataset no disponible."
return
# 1. ACTUALIZAR VERSIÓN (Pre-incremento)
if is_trained:
version_number += 0.1
else:
version_number = 1.0
# 2. CONFIGURACIÓN E INICIO DEL ENTRENAMIENTO
training_status_message = f"🧠 ENTRENANDO V{version_number:.1f} (Epochs: {epochs})...."
print(f"\n[AUTÓNOMO] {training_status_message}")
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
training_args = TrainingArguments(
output_dir=LORA_PATH,
per_device_train_batch_size=int(batch_size),
num_train_epochs=float(epochs),
learning_rate=float(learning_rate),
save_total_limit=1,
logging_steps=10,
push_to_hub=False,
disable_tqdm=True,
report_to="none"
)
trainer = Trainer(model=lora_model, args=training_args, train_dataset=tokenized_dataset["train"], data_collator=data_collator)
trainer.train()
lora_model.save_pretrained(LORA_PATH)
tokenizer.save_pretrained(LORA_PATH)
# 3. Marcar como entrenado
is_trained = True
training_status_message = f"✅ ENTRENAMIENTO V{version_number:.1f} COMPLETADO. Modelo listo para Hot Swap."
print(f"[AUTÓNOMO] {training_status_message}")
except Exception as e:
training_status_message = f"ERROR CRÍTICO durante el entrenamiento autónomo: {e}"
print(f"[AUTÓNOMO] {training_status_message}")
# --- FUNCIÓN DE GENERACIÓN (CORREGIDA PARA RETORNAR 2 VALORES) ---
def generate_text(prompt_text):
"""Genera código y dispara el ciclo de reentrenamiento autónomo si es necesario."""
global lora_generator, generations_since_last_train, is_trained, version_number, training_status_message
if not is_trained:
# Si el entrenamiento V1.0 no ha terminado, retorna el mensaje de error y el estado actual
return "ERROR: El modelo LoRA no ha sido entrenado. Por favor, espere mientras la IA se inicializa con el entrenamiento V1.0.", update_status()
# 1. HOT SWAP (Verifica si el modelo necesita recargarse con la nueva versión)
if lora_generator is None:
with global_lock:
try:
# Recarga el modelo solo si está vacío
base_model_gen = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto")
model_with_lora = PeftModel.from_pretrained(base_model_gen, LORA_PATH)
final_model = model_with_lora.merge_and_unload()
final_model.eval()
lora_generator = pipeline("text-generation", model=final_model, tokenizer=tokenizer)
print(f"[HOT SWAP] 🔄 Modelo de inferencia V{version_number:.1f} recargado y listo.")
except Exception as e:
# Si la recarga falla, retorna un error
return f"Error al cargar el modelo V{version_number:.1f} para inferencia: {e}", update_status()
# 2. Generación de texto (Lógica de inferencia)
try:
# Prepara el prompt para guiar la generación del código
prompt_with_indent = prompt_text.strip() + "\n    "
output = lora_generator(prompt_with_indent, max_new_tokens=150, temperature=0.7, top_p=0.9, clean_up_tokenization_spaces=True)
full_output = output[0]["generated_text"]
# Extrae solo la parte de la compleción (el código generado)
start_index = full_output.find(prompt_with_indent)
completion = full_output[start_index + len(prompt_with_indent):] if start_index != -1 else full_output
# 3. Aumentar contador de autonomía
with global_lock:
generations_since_last_train += 1
current_count = generations_since_last_train
current_version = version_number
# 4. Verificar si se requiere reentrenamiento (y dispararlo en un nuevo hilo)
notification = ""
if current_count >= GENERATION_LIMIT_TO_TRAIN:
# Verifica que no haya otro hilo de entrenamiento ya corriendo
if not any(isinstance(t, threading.Thread) and t.name == 'AutonomousTrainer' for t in threading.enumerate()):
print(f"[AUTONOMÍA] Generación #{current_count} alcanzada. Disparando reentrenamiento autónomo en segundo plano...")
# Reiniciar el contador de generaciones y forzar Hot Swap en la próxima interacción
with global_lock:
generations_since_last_train = 0
lora_generator = None
trainer_thread = threading.Thread(
target=autonomous_train_lora,
args=(AUTONOMOUS_EPOCHS, 2, 5e-5),
name='AutonomousTrainer'
)
trainer_thread.daemon = True
trainer_thread.start()
notification = f"\n\n--- [AUTONOMÍA] La IA ha iniciado el reentrenamiento V{current_version+0.1:.1f} para mejorar la traducción de tu diálogo. La próxima generación cargará la nueva versión. ---"
# CORRECCIÓN CLAVE: Retorna el código Y el estado actualizado
return completion + notification, update_status()
except Exception as e:
# Si falla la generación, retorna el mensaje de error y el estado actual
return f"Error generando texto: {e}", update_status()
# --- FUNCIÓN PARA INICIALIZACIÓN Y ENTRENAMIENTO V1.0 (Obligatorio) ---
def initialize_and_train_v1():
"""Ejecuta el entrenamiento inicial V1.0 de forma autónoma al iniciar."""
if not is_trained:
autonomous_train_lora(epochs=DEFAULT_EPOCHS, batch_size=2, learning_rate=5e-5)
else:
global training_status_message
training_status_message = f"✅ Modelo V{version_number:.1f} ya entrenado. Listo."
print(f"[INICIALIZACIÓN] {training_status_message}")
# --- FUNCIÓN PARA ACTUALIZAR EL ESTADO EN LA UI ---
def update_status():
"""Actualiza la versión y el estado del entrenamiento en la interfaz de Gradio."""
global training_status_message, version_number
# Retorna un texto en Markdown que se actualiza constantemente
return f"**Versión de Comprensión:** V{version_number:.1f} | **Estado del Entrenador:** {training_status_message}"
# --- FUNCIÓN DE CHAT ---
def chat_response(user_input):
"""Genera una respuesta de chat basado en el modelo pre-entrenado."""
inputs = chat_tokenizer(user_input, return_tensors="pt")
outputs = chat_model.generate(**inputs)
response = chat_tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
# --- INTERFAZ GRADIO ---
with gr.Blocks(title="AmorCoderAI - Aprendizaje Continuo") as demo:
gr.Markdown("# 💙 AmorCoderAI - Asistente de Código con Aprendizaje Continuo")
# Muestra la versión y el estado.
version_and_status = gr.Markdown(
f"**Versión de Comprensión:** V{version_number:.1f} | **Estado del Entrenador:** {training_status_message}",
elem_id="status_display"
)
gr.Markdown(f"**Modo Autónomo:** La IA se reentrena automáticamente cada **{GENERATION_LIMIT_TO_TRAIN}** códigos generados. Esto mejora su capacidad para traducir tu español conversacional a código.")
with gr.Tab("✨ Generación de Código"):
gr.Markdown("## Escribe tu idea en palabras (¡Usa español fluido!)")
gr.Markdown("Recomendación inicial: Usa el siguiente formato para obtener el mejor código mientras la IA aprende tu idioma:")
prompt = gr.Textbox(
label="Instrucción de Programación:",
lines=4,
placeholder="# Descripción: Quiero que me hagas un código similar a Google Gemini.\n# Completa la siguiente función:\ndef generar_contenido(prompt, modelo):"
)
generate_button = gr.Button("💬 Generar código y disparar Aprendizaje")
output_box = gr.Textbox(label="Código generado", lines=10)
# Conexión del botón con la función principal
# IMPORTANTE: Ahora generate_text retorna DOS valores para coincidir con [output_box, version_and_status]
generate_button.click(
generate_text,
inputs=prompt,
outputs=[output_box, version_and_status],
)
with gr.Tab("🗣️ Chat en Español"):
gr.Markdown("## Habla con la IA en español")
user_input = gr.Textbox(label="Tu mensaje:", lines=2)
chat_button = gr.Button("Enviar")
chat_output = gr.Textbox(label="Respuesta de la IA", lines=5)
chat_button.click(
chat_response,
inputs=user_input,
outputs=chat_output
)
# El estado se actualiza solo al cargar la página.
demo.load(update_status, None, version_and_status)
# --- INICIO DE LA APLICACIÓN ---
if __name__ == "__main__":
setup_resources()
# Lanza el entrenamiento V1.0 inicial en un hilo para que no congele la UI
initialization_thread = threading.Thread(target=initialize_and_train_v1, name='InitializationTrainer')
initialization_thread.daemon = True
initialization_thread.start()
print(f"\n💻 LANZANDO INTERFAZ GRADIO (El entrenamiento V1.0 se ejecuta en segundo plano)")
demo.launch()