RobotCleanPupusas / backend /train_controller.py
Trujasx's picture
backend done
5ed0597
# backend/train_robot.py
import subprocess
import os
import gradio as gr
import json
import torch
# --- Helper Functions (reused from record_controller for consistency) ---
def run_command(command: str, description: str):
"""
Ejecuta un comando de shell y captura su salida, manejando errores.
"""
print(f"\n--- {description} ---")
process_output = []
try:
# Use Popen to stream output in real-time
process = subprocess.Popen(
command,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # Redirect stderr to stdout
text=True
)
for line in iter(process.stdout.readline, ''):
print(line, end='') # Print to console
process_output.append(line)
# You might want to yield here for Gradio's gr.Progress,
# but for simplicity, we'll collect all output and return at the end.
process.wait() # Wait for the process to complete
if process.returncode == 0:
print(f"脡xito: {description}")
return True, "".join(process_output)
else:
error_message = f"Error durante '{description}': El comando devolvi贸 el c贸digo de salida {process.returncode}"
print(error_message)
return False, "".join(process_output) # Return all collected output, including errors
except Exception as e:
error_message = f"Ocurri贸 un error inesperado durante '{description}': {e}"
print(error_message)
return False, error_message
def login_to_huggingface(token: str):
"""
Inicia sesi贸n en Hugging Face CLI con el token proporcionado.
"""
if not token or token == "hf_YOUR_ACTUAL_WRITE_TOKEN_HERE":
return False, "Error: Por favor, proporciona un token de Hugging Face v谩lido."
success, output = run_command(
f"huggingface-cli login --token {token} --add-to-git-credential",
"Iniciando sesi贸n en Hugging Face CLI"
)
if success:
return True, "隆Inicio de sesi贸n en Hugging Face exitoso!"
else:
return False, output
def get_huggingface_user():
"""
Obtiene el nombre de usuario de Hugging Face.
"""
success, output = run_command(
"huggingface-cli whoami | head -n 1",
"Obteniendo nombre de usuario de Hugging Face"
)
if success:
# The output might contain warnings before the actual username.
# Try to find the username, which should be the first non-empty, non-warning line.
lines = output.splitlines()
for line in lines:
if line.strip() and not (line.strip().startswith("warnings.") or "deprecated" in line.lower()):
return True, line.strip()
return False, "No se pudo extraer el nombre de usuario de la salida de 'whoami'."
else:
return False, output
# --- Core Training and Upload Logic ---
def train_policy_core(hf_user: str,
dataset_repo_id: str,
policy_type: str,
output_dir: str,
job_name: str,
policy_device: str,
wandb_enable: bool,
resume: bool,
resume_config_path: str):
"""
Entrena una pol铆tica de robot usando el script `lerobot/scripts/train.py`.
"""
if not hf_user:
return False, "Error: Nombre de usuario de Hugging Face no disponible. Por favor, inicia sesi贸n primero."
if not dataset_repo_id.startswith(f"{hf_user}/"):
dataset_repo_id = f"{hf_user}/{dataset_repo_id.split('/')[-1]}" # Ensure correct repo_id format
print(f"\nPreparando para entrenar la pol铆tica '{policy_type}' con el dataset '{dataset_repo_id}'...")
command = [
"python", "-m", "lerobot.scripts.train" # Changed to -m lerobot.scripts.train
]
if resume and resume_config_path:
command.extend([
f"--config_path={resume_config_path}",
"--resume=true"
])
else:
command.extend([
f"--dataset.repo_id={dataset_repo_id}",
f"--policy.type={policy_type}",
f"--output_dir={output_dir}",
f"--job_name={job_name}",
f"--policy.device={policy_device}"
])
if wandb_enable:
command.append("--wandb.enable=true")
full_command = " ".join(command)
success, output = run_command(full_command, "Entrenamiento de la Pol铆tica")
if success:
final_message = f"隆Entrenamiento de la pol铆tica completado exitosamente!\n"
final_message += f"Los checkpoints se guardaron en: {output_dir}/checkpoints\n"
if wandb_enable:
final_message += "Revisa Weights & Biases para los gr谩ficos de entrenamiento.\n"
return True, final_message + "\n" + output # Add full output for visibility
else:
return False, f"Error durante el entrenamiento: {output}"
def upload_policy_core(hf_user: str, policy_repo_name: str, checkpoint_path: str, is_intermediate: bool = False):
"""
Sube un checkpoint de pol铆tica a Hugging Face Hub.
"""
if not hf_user:
return False, "Error: Nombre de usuario de Hugging Face no disponible. Por favor, inicia sesi贸n primero."
if not policy_repo_name:
return False, "Error: El nombre del repositorio de la pol铆tica no puede estar vac铆o."
if not checkpoint_path:
return False, "Error: La ruta al checkpoint no puede estar vac铆a."
if not os.path.exists(checkpoint_path):
return False, f"Error: La ruta del checkpoint '{checkpoint_path}' no existe."
full_repo_id = f"{hf_user}/{policy_repo_name}"
# Hugging Face CLI upload expects the local path to be the second argument
command = [
"huggingface-cli", "upload",
full_repo_id,
checkpoint_path,
"--repo-type=model" # Specify repo type as model for policies
]
if is_intermediate:
# For intermediate, we might want to append CKPT to the repo name or use a specific folder within the repo
# The provided doc uses policy_nameCKPT. Let's adapt to that if the user provides just base name
# However, huggingface-cli upload expects a repo_id, which is HF_USER/REPO_NAME
# The common practice is to upload to the same repo but into a different subfolder.
# For simplicity, we'll stick to uploading the specified path to the given repo_id.
pass # The logic for is_intermediate might depend on how the user names their repos/checkpoints
full_command = " ".join(command)
success, output = run_command(full_command, f"Subiendo pol铆tica a {full_repo_id}")
if success:
return True, f"隆Pol铆tica subida exitosamente a https://huggingface.co/{full_repo_id}!"
else:
return False, f"Error al subir pol铆tica: {output}"
def evaluate_policy_core(hf_user: str,
robot_type: str,
robot_port: str,
robot_cameras: str, # Raw string for cameras
robot_id: str,
display_data: bool,
dataset_repo_id_eval: str,
single_task: str,
policy_path: str,
teleop_enable: bool = False, # Optional teleop for evaluation
teleop_type: str = "",
teleop_port: str = "",
teleop_id: str = ""):
"""
Eval煤a una pol铆tica utilizando el script `lerobot.record` modificado.
"""
if not hf_user:
return False, "Error: Nombre de usuario de Hugging Face no disponible. Por favor, inicia sesi贸n primero."
if not policy_path:
return False, "Error: La ruta a la pol铆tica para evaluar no puede estar vac铆a."
# Ensure eval dataset repo ID starts with user
if not dataset_repo_id_eval.startswith(f"{hf_user}/"):
dataset_repo_id_eval = f"{hf_user}/{dataset_repo_id_eval.split('/')[-1]}"
print(f"\nPreparando para evaluar la pol铆tica '{policy_path}'...")
command = [
"python", "-m", "lerobot.record",
f"--robot.type={robot_type}",
f"--robot.port={robot_port}",
f"--robot.cameras=\"{robot_cameras}\"", # Use the raw string provided by user
f"--robot.id={robot_id}",
f"--display_data={str(display_data).lower()}",
f"--dataset.repo_id={dataset_repo_id_eval}",
f"--dataset.single_task=\"{single_task}\"",
f"--policy.path={policy_path}"
]
if teleop_enable:
command.extend([
f"--teleop.type={teleop_type}",
f"--teleop.port={teleop_port}",
f"--teleop.id={teleop_id}"
])
full_command = " ".join(command)
success, output = run_command(full_command, "Evaluaci贸n de la Pol铆tica")
if success:
final_message = f"隆Evaluaci贸n de la pol铆tica completada exitosamente!\n"
final_message += f"Los datos de evaluaci贸n se guardaron en: ~/.cache/huggingface/lerobot/{dataset_repo_id_eval}\n"
return True, final_message + "\n" + output
else:
return False, f"Error durante la evaluaci贸n: {output}"
# --- Gradio Interface Logic ---
# Variable global para almacenar el usuario de Hugging Face
current_hf_user = None
def gradio_login(hf_token_input: str):
"""Interfaz Gradio para iniciar sesi贸n en Hugging Face."""
global current_hf_user
success, message = login_to_huggingface(hf_token_input)
if success:
gr.Info(message)
success_user, user_name = get_huggingface_user()
if success_user:
current_hf_user = user_name
return gr.update(value=user_name, interactive=False), gr.update(visible=True, value=message)
else:
return gr.update(value="", interactive=True), gr.update(visible=True, value=user_name)
else:
current_hf_user = None
return gr.update(value="", interactive=True), gr.update(visible=True, value=message)
def gradio_train(dataset_repo_id_input: str,
policy_type_input: str,
output_dir_input: str,
job_name_input: str,
policy_device_input: str,
wandb_enable_input: bool,
resume_input: bool,
resume_config_path_input: str):
"""Interfaz Gradio para iniciar el entrenamiento."""
global current_hf_user
if not current_hf_user:
return gr.update(visible=True, value="Error: No se ha iniciado sesi贸n en Hugging Face o no se pudo obtener el usuario. Por favor, inicia sesi贸n primero.")
gr.Info("Iniciando entrenamiento del modelo. Esto puede tardar mucho tiempo...")
success, message = train_policy_core(
hf_user=current_hf_user,
dataset_repo_id=dataset_repo_id_input,
policy_type=policy_type_input,
output_dir=output_dir_input,
job_name=job_name_input,
policy_device=policy_device_input,
wandb_enable=wandb_enable_input,
resume=resume_input,
resume_config_path=resume_config_path_input
)
if success:
gr.Info("Entrenamiento completado. Revisa la salida para los detalles.")
else:
gr.Info("Entrenamiento fallido. Revisa la salida para los errores.")
return gr.update(visible=True, value=message)
def gradio_upload(policy_repo_name_input: str, checkpoint_path_input: str):
"""Interfaz Gradio para subir un checkpoint."""
global current_hf_user
if not current_hf_user:
return gr.update(visible=True, value="Error: No se ha iniciado sesi贸n en Hugging Face. Por favor, inicia sesi贸n primero.")
gr.Info(f"Subiendo checkpoint '{checkpoint_path_input}' a '{policy_repo_name_input}'...")
success, message = upload_policy_core(
hf_user=current_hf_user,
policy_repo_name=policy_repo_name_input,
checkpoint_path=checkpoint_path_input
)
if success:
gr.Info("Subida completada.")
else:
gr.Info("Subida fallida. Revisa la salida.")
return gr.update(visible=True, value=message)
def gradio_evaluate(robot_type_input: str,
robot_port_input: str,
robot_cameras_input: str,
robot_id_input: str,
display_data_input: bool,
dataset_repo_id_eval_input: str,
single_task_eval_input: str,
policy_path_input: str,
teleop_enable_input: bool,
teleop_type_input: str,
teleop_port_input: str,
teleop_id_input: str):
"""Interfaz Gradio para evaluar una pol铆tica."""
global current_hf_user
if not current_hf_user:
return gr.update(visible=True, value="Error: No se ha iniciado sesi贸n en Hugging Face. Por favor, inicia sesi贸n primero.")
gr.Info("Iniciando evaluaci贸n de la pol铆tica...")
success, message = evaluate_policy_core(
hf_user=current_hf_user,
robot_type=robot_type_input,
robot_port=robot_port_input,
robot_cameras=robot_cameras_input,
robot_id=robot_id_input,
display_data=display_data_input,
dataset_repo_id_eval=dataset_repo_id_eval_input,
single_task=single_task_eval_input,
policy_path=policy_path_input,
teleop_enable=teleop_enable_input,
teleop_type=teleop_type_input,
teleop_port=teleop_port_input,
teleop_id=teleop_id_input
)
if success:
gr.Info("Evaluaci贸n completada.")
else:
gr.Info("Evaluaci贸n fallida. Revisa la salida.")
return gr.update(visible=True, value=message)
# --- Gradio Interface Definition ---
with gr.Blocks(title="Controlador de Entrenamiento y Evaluaci贸n LeRobot") as demo:
gr.Markdown("# <center>Controlador de Entrenamiento y Evaluaci贸n de Pol铆ticas LeRobot</center>")
gr.Markdown("Esta interfaz te permite entrenar, subir y evaluar pol铆ticas de robot con LeRobot.")
with gr.Tab("1. Configuraci贸n de Hugging Face"):
gr.Markdown("## Configuraci贸n de Hugging Face")
gr.Markdown(
"Introduce tu **token de Hugging Face con permisos de escritura**. "
"Puedes generarlo en [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)."
)
hf_token_input = gr.Textbox(
label="Token de Hugging Face",
type="password",
placeholder="hf_YOUR_ACTUAL_WRITE_TOKEN_HERE",
info="El token debe tener permisos de escritura (write)."
)
login_btn = gr.Button("Iniciar Sesi贸n / Verificar Token")
hf_user_output = gr.Textbox(label="Usuario de Hugging Face Actual", interactive=False, placeholder="No autenticado", show_copy_button=True)
login_status_output = gr.Textbox(label="Estado de Autenticaci贸n", interactive=False, visible=False, lines=3)
login_btn.click(
fn=gradio_login,
inputs=hf_token_input,
outputs=[hf_user_output, login_status_output]
)
with gr.Tab("2. Entrenamiento de la Pol铆tica"):
gr.Markdown("## Entrenar una Pol铆tica")
gr.Markdown(
"Configura los par谩metros para entrenar tu pol铆tica. Aseg煤rate de tener el dataset listo en Hugging Face Hub."
)
with gr.Row():
dataset_repo_id_input = gr.Textbox(
label="ID del Repositorio del Dataset (ej. YOUR_USER/so101_test)",
value="YOUR_USER/so101_test", # Placeholder, will be updated by HF user
placeholder="Dataset para entrenar",
info="Aseg煤rate de que este dataset ya haya sido subido con el script de grabaci贸n."
)
policy_type_input = gr.Dropdown(
label="Tipo de Pol铆tica",
choices=["act", "diffusion", "rlds"], # Add more types as needed from LeRobot
value="act",
info="Tipo de arquitectura de pol铆tica a entrenar (e.g., ACT)."
)
with gr.Row():
output_dir_input = gr.Textbox(
label="Directorio de Salida para Checkpoints",
value="outputs/train/act_so101_test",
placeholder="Directorio donde se guardar谩n los resultados del entrenamiento."
)
job_name_input = gr.Textbox(
label="Nombre del Trabajo (Job Name)",
value="act_so101_test",
placeholder="Nombre para identificar tu sesi贸n de entrenamiento."
)
with gr.Row():
policy_device_input = gr.Dropdown(
label="Dispositivo de Entrenamiento",
choices=["cuda", "mps", "cpu"],
value="cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"), # Auto-detect simple device
info="Dispositivo a usar para el entrenamiento (GPU Nvidia: cuda, Apple Silicon: mps, CPU: cpu)."
)
wandb_enable_input = gr.Checkbox(
label="Habilitar Weights & Biases",
value=True,
info="Habilita el seguimiento de m茅tricas con Weights & Biases (aseg煤rate de haber ejecutado 'wandb login')."
)
gr.Markdown("### Opciones de Reanudaci贸n")
resume_input = gr.Checkbox(
label="Reanudar Entrenamiento Existente",
value=False,
info="Marca esta casilla para continuar un entrenamiento desde un checkpoint."
)
resume_config_path_input = gr.Textbox(
label="Ruta al train_config.json para Reanudar (ej. outputs/train/act_so101_test/checkpoints/last/pretrained_model/train_config.json)",
placeholder="Ruta al archivo train_config.json del checkpoint a reanudar.",
visible=False # Hidden by default, shown when resume is checked
)
resume_input.change(
lambda x: gr.update(visible=x),
inputs=resume_input,
outputs=resume_config_path_input
)
train_btn = gr.Button("馃殌 Iniciar Entrenamiento 馃殌", variant="primary")
train_status_output = gr.Textbox(
label="Log de Entrenamiento",
interactive=False,
visible=False,
lines=20,
autoscroll=True
)
train_btn.click(
fn=gradio_train,
inputs=[
dataset_repo_id_input,
policy_type_input,
output_dir_input,
job_name_input,
policy_device_input,
wandb_enable_input,
resume_input,
resume_config_path_input
],
outputs=train_status_output
)
with gr.Tab("3. Subir Checkpoint de Pol铆tica"):
gr.Markdown("## Subir Checkpoint de Pol铆tica al Hub")
gr.Markdown(
"Sube tus modelos entrenados a Hugging Face Hub para compartirlos o usarlos en evaluaci贸n."
)
policy_repo_name_input = gr.Textbox(
label="Nombre del Repositorio de la Pol铆tica (ej. act_so101_test)",
value="act_so101_test",
placeholder="Nombre del repositorio en Hugging Face Hub para tu pol铆tica."
)
checkpoint_path_input = gr.Textbox(
label="Ruta Local al Directorio del Checkpoint (ej. outputs/train/act_so101_test/checkpoints/last/pretrained_model)",
placeholder="Ruta completa al directorio 'pretrained_model' del checkpoint."
)
upload_btn = gr.Button("猬嗭笍 Subir Pol铆tica 猬嗭笍", variant="secondary")
upload_status_output = gr.Textbox(
label="Log de Subida",
interactive=False,
visible=False,
lines=5
)
upload_btn.click(
fn=gradio_upload,
inputs=[policy_repo_name_input, checkpoint_path_input],
outputs=upload_status_output
)
with gr.Tab("4. Evaluar Pol铆tica"):
gr.Markdown("## Evaluar una Pol铆tica Entrenada")
gr.Markdown(
"Usa esta secci贸n para probar tu pol铆tica entrenada con el robot real. "
"La teleoperaci贸n es opcional durante la evaluaci贸n."
)
with gr.Row():
robot_type_eval_input = gr.Textbox(label="Tipo de Robot (e.g., so100_follower)", value="so100_follower")
robot_port_eval_input = gr.Textbox(label="Puerto del Robot (e.g., /dev/ttyACM1)", value="/dev/ttyACM1")
robot_cameras_eval_input = gr.Textbox(
label="Configuraci贸n de C谩maras (JSON string)",
value='{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}}',
info="Define tus c谩maras como un string JSON. Aseg煤rate de escapar las comillas internas si es necesario."
)
robot_id_eval_input = gr.Textbox(label="ID del Robot", value="my_awesome_follower_arm")
display_data_eval_input = gr.Checkbox(label="Mostrar Datos (Display Data)", value=False)
dataset_repo_id_eval_input = gr.Textbox(
label="ID del Repositorio del Dataset de Evaluaci贸n (ej. YOUR_USER/eval_so100)",
value="YOUR_USER/eval_so100",
info="El nombre del dataset para guardar los resultados de la evaluaci贸n (suele empezar con 'eval_')."
)
single_task_eval_input = gr.Textbox(
label="Descripci贸n de la Tarea (Single Task)",
value="Put lego brick into the transparent box"
)
policy_path_input = gr.Textbox(
label="Ruta de la Pol铆tica a Evaluar (local o Hugging Face Hub ID)",
placeholder="ej. outputs/train/eval_act_so101_test/checkpoints/last/pretrained_model O YOUR_USER/my_policy",
info="Puede ser una ruta local al checkpoint o el ID de un repositorio de modelo en Hugging Face Hub."
)
gr.Markdown("### Teleoperaci贸n (Opcional durante la Evaluaci贸n)")
teleop_enable_eval_input = gr.Checkbox(label="Habilitar Teleoperaci贸n Durante Evaluaci贸n", value=False)
with gr.Row(visible=False) as teleop_options_row: # Hidden by default
teleop_type_eval_input = gr.Textbox(label="Tipo de Teleop (e.g., so100_leader)", value="so100_leader")
teleop_port_eval_input = gr.Textbox(label="Puerto de Teleop (e.g., /dev/ttyACM0)", value="/dev/ttyACM0")
teleop_id_eval_input = gr.Textbox(label="ID de Teleop", value="my_awesome_leader_arm")
teleop_enable_eval_input.change(
lambda x: gr.update(visible=x),
inputs=teleop_enable_eval_input,
outputs=teleop_options_row
)
evaluate_btn = gr.Button("馃搳 Iniciar Evaluaci贸n 馃搳", variant="primary")
evaluate_status_output = gr.Textbox(
label="Log de Evaluaci贸n",
interactive=False,
visible=False,
lines=15,
autoscroll=True
)
evaluate_btn.click(
fn=gradio_evaluate,
inputs=[
robot_type_eval_input,
robot_port_eval_input,
robot_cameras_eval_input,
robot_id_eval_input,
display_data_eval_input,
dataset_repo_id_eval_input,
single_task_eval_input,
policy_path_input,
teleop_enable_eval_input,
teleop_type_eval_input,
teleop_port_eval_input,
teleop_id_eval_input
],
outputs=evaluate_status_output
)
gr.Markdown("---")
gr.Markdown("Hecho con 鉂わ笍 para RobotCleanPupusas503")
# Auto-detect CUDA/MPS availability for default device selection (requires torch)
try:
import torch
except ImportError:
print("Advertencia: PyTorch no est谩 instalado. No se podr谩 auto-detectar 'cuda' o 'mps'.")
torch = None
if __name__ == "__main__":
demo.launch(share=False)