Spaces:
No application file
No application file
| import rclpy | |
| import gradio as gr | |
| from loguru import logger | |
| from llm_request_handler import LLMRequestHandler | |
| from ros_node_publisher import RosNodePublisher | |
| from json_processor import JsonProcessor | |
| from dag_visualizer import DAGVisualizer | |
| from config import ROBOTS_CONFIG, MODEL_CONFIG, MODE_CONFIG | |
| class GradioLlmInterface: | |
| def __init__(self): | |
| self.node_publisher = None | |
| self.received_tasks = [] | |
| self.json_processor = JsonProcessor() | |
| self.dag_visualizer = DAGVisualizer() | |
| def initialize_interface(self, mode): | |
| if not rclpy.ok(): | |
| rclpy.init() | |
| self.node_publisher = RosNodePublisher(mode) | |
| mode_config = MODE_CONFIG[mode] | |
| # Use provider from mode_config and model version | |
| model_version = mode_config.get("model_version", MODEL_CONFIG["default_model"]) | |
| provider = mode_config.get("provider", "openai") | |
| llm_handler = LLMRequestHandler( | |
| model_name=model_version, | |
| provider=provider, | |
| max_tokens=mode_config.get("max_tokens", MODEL_CONFIG["max_tokens"]), | |
| temperature=mode_config.get("temperature", MODEL_CONFIG["temperature"]), | |
| frequency_penalty=mode_config.get("frequency_penalty", MODEL_CONFIG["frequency_penalty"]), | |
| list_navigation_once=True | |
| ) | |
| file_path = mode_config["prompt_file"] | |
| initial_messages = llm_handler.build_initial_messages(file_path, mode) | |
| # Don't create chatbot here, return state data only | |
| state_data = { | |
| "file_path": file_path, | |
| "initial_messages": initial_messages, | |
| "mode": mode, | |
| # store config dict with provider | |
| "llm_config": llm_handler.get_config_dict() | |
| } | |
| return state_data | |
| async def predict(self, input, state): | |
| if not self.node_publisher.is_initialized(): | |
| mode = state.get('mode') | |
| self.node_publisher.initialize_node(mode) | |
| initial_messages = state['initial_messages'] | |
| full_history = initial_messages + state.get('history', []) | |
| user_input = f"# Query: {input}" | |
| full_history.append({"role": "user", "content": user_input}) | |
| mode_config = MODE_CONFIG[state.get('mode')] | |
| if mode_config["type"] == 'complex' and self.received_tasks: | |
| for task in self.received_tasks: | |
| task_prompt = f"# Task: {task}" | |
| full_history.append({"role": "user", "content": task_prompt}) | |
| self.received_tasks = [] | |
| # Create a new LLMRequestHandler instance for each request | |
| llm_config = state['llm_config'] | |
| llm_handler = LLMRequestHandler.create_from_config_dict(llm_config) | |
| response = await llm_handler.make_completion(full_history) | |
| if response: | |
| full_history.append({"role": "assistant", "content": response}) | |
| else: | |
| response = "Error: Unable to get response." | |
| full_history.append({"role": "assistant", "content": response}) | |
| response_json = self.json_processor.process_response(response) | |
| # Store the task plan for approval workflow | |
| state.update({'pending_task_plan': response_json}) | |
| # Generate DAG visualization if valid task data is available | |
| dag_image_path = None | |
| confirm_button_visible = False | |
| if response_json and "tasks" in response_json: | |
| try: | |
| dag_image_path = self.dag_visualizer.create_dag_visualization( | |
| response_json, | |
| title="Robot Task Dependency Graph - Pending Approval" | |
| ) | |
| logger.info(f"DAG visualization generated: {dag_image_path}") | |
| confirm_button_visible = True | |
| except Exception as e: | |
| logger.error(f"Failed to generate DAG visualization: {e}") | |
| # Modify the messages format to match the "messages" type | |
| messages = [{"role": message["role"], "content": message["content"]} for message in full_history[len(initial_messages):]] | |
| updated_history = state.get('history', []) + [{"role": "user", "content": input}, {"role": "assistant", "content": response}] | |
| state.update({'history': updated_history}) | |
| return messages, state, dag_image_path, gr.update(visible=confirm_button_visible) | |
| def clear_chat(self, state): | |
| state['history'] = [] | |
| def show_task_plan_editor(self, state): | |
| """ | |
| Show the current task plan in JSON format for manual editing. | |
| """ | |
| # Check for pending plan first, then deployed plan as fallback | |
| pending_plan = state.get('pending_task_plan') | |
| deployed_plan = state.get('deployed_task_plan') | |
| # Use pending plan if available, otherwise use deployed plan | |
| current_plan = pending_plan if pending_plan else deployed_plan | |
| if current_plan and "tasks" in current_plan and len(current_plan["tasks"]) > 0: | |
| import json | |
| # Format JSON for better readability | |
| formatted_json = json.dumps(current_plan, indent=2, ensure_ascii=False) | |
| plan_status = "pending" if pending_plan else "deployed" | |
| logger.info(f"π Task plan editor opened with {plan_status} plan") | |
| # Set pending plan for editing (copy from deployed if needed) | |
| if not pending_plan and deployed_plan: | |
| state.update({'pending_task_plan': deployed_plan}) | |
| return ( | |
| gr.update(visible=True, value=formatted_json), # Show editor with current JSON | |
| gr.update(visible=True), # Show Update DAG button | |
| gr.update(visible=False), # Hide Validate & Deploy button | |
| f"π **Task Plan Editor Opened**\n\nYou can now manually edit the task plan JSON below. {plan_status.title()} plan loaded for editing." | |
| ) | |
| else: | |
| # Provide a better template with example structure | |
| template_json = """{ | |
| "tasks": [ | |
| { | |
| "task": "example_task_1", | |
| "instruction_function": { | |
| "name": "example_function_name", | |
| "robot_ids": ["robot_dump_truck_01"], | |
| "dependencies": [], | |
| "object_keywords": ["object1", "object2"] | |
| } | |
| } | |
| ] | |
| }""" | |
| logger.info("π Task plan editor opened with template") | |
| return ( | |
| gr.update(visible=True, value=template_json), # Show template | |
| gr.update(visible=True), # Show Update DAG button | |
| gr.update(visible=False), # Hide Validate & Deploy button | |
| "β οΈ **No Task Plan Available**\n\nStarting with example template. Please edit the JSON structure and update." | |
| ) | |
| def update_dag_from_editor(self, edited_json, state): | |
| """ | |
| Update DAG visualization from manually edited JSON. | |
| """ | |
| try: | |
| import json | |
| # Parse the edited JSON | |
| edited_plan = json.loads(edited_json) | |
| # Validate the JSON structure | |
| if "tasks" not in edited_plan: | |
| raise ValueError("JSON must contain 'tasks' field") | |
| # Store the edited plan | |
| state.update({'pending_task_plan': edited_plan}) | |
| # Generate updated DAG visualization | |
| dag_image_path = self.dag_visualizer.create_dag_visualization( | |
| edited_plan, | |
| title="Robot Task Dependency Graph - EDITED & PENDING APPROVAL" | |
| ) | |
| logger.info("π DAG updated from manual edits") | |
| return ( | |
| dag_image_path, | |
| gr.update(visible=True), # Show Validate & Deploy button | |
| gr.update(visible=False), # Hide editor | |
| gr.update(visible=False), # Hide Update DAG button | |
| "β **DAG Updated Successfully**\n\nTask plan has been updated with your edits. Please review the visualization and click 'Validate & Deploy' to proceed.", | |
| state | |
| ) | |
| except json.JSONDecodeError as e: | |
| error_msg = f"β **JSON Parsing Error**\n\nInvalid JSON format: {str(e)}\n\nPlease fix the JSON syntax and try again." | |
| return ( | |
| None, | |
| gr.update(visible=False), # Hide Validate & Deploy button | |
| gr.update(visible=True), # Keep editor visible | |
| gr.update(visible=True), # Keep Update DAG button visible | |
| error_msg, | |
| state | |
| ) | |
| except Exception as e: | |
| error_msg = f"β **Update Failed**\n\nError: {str(e)}" | |
| logger.error(f"Failed to update DAG from editor: {e}") | |
| return ( | |
| None, | |
| gr.update(visible=False), # Hide Validate & Deploy button | |
| gr.update(visible=True), # Keep editor visible | |
| gr.update(visible=True), # Keep Update DAG button visible | |
| error_msg, | |
| state | |
| ) | |
| def validate_and_deploy_task_plan(self, state): | |
| """ | |
| Validate and deploy the task plan to the construction site. | |
| This function implements the safety confirmation workflow. | |
| """ | |
| pending_plan = state.get('pending_task_plan') | |
| if pending_plan: | |
| try: | |
| # Deploy the approved task plan to ROS | |
| self.node_publisher.publish_response(pending_plan) | |
| # Update DAG visualization to show approved status | |
| approved_image_path = None | |
| if "tasks" in pending_plan: | |
| approved_image_path = self.dag_visualizer.create_dag_visualization( | |
| pending_plan, | |
| title="Robot Task Dependency Graph - APPROVED & DEPLOYED" | |
| ) | |
| # Keep the deployed plan for potential re-editing, but mark as deployed | |
| state.update({'deployed_task_plan': pending_plan, 'pending_task_plan': None}) | |
| logger.info("β Task plan validated and deployed to construction site") | |
| # Return confirmation message and updated visualization | |
| confirmation_msg = "β **Task Plan Successfully Deployed**\n\nThe validated task dependency graph has been sent to the construction site robots. All safety protocols confirmed." | |
| return ( | |
| confirmation_msg, | |
| approved_image_path, | |
| gr.update(visible=False), # Hide confirmation button | |
| state | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to deploy task plan: {e}") | |
| error_msg = f"β **Deployment Failed**\n\nError: {str(e)}" | |
| return ( | |
| error_msg, | |
| None, | |
| gr.update(visible=True), # Keep button visible for retry | |
| state | |
| ) | |
| else: | |
| warning_msg = "β οΈ **No Task Plan to Deploy**\n\nPlease generate a task plan first." | |
| return ( | |
| warning_msg, | |
| None, | |
| gr.update(visible=False), | |
| state | |
| ) | |
| def update_chatbot(self, mode, state): | |
| # Destroy and reinitialize the ROS node | |
| self.node_publisher.destroy_node() | |
| if not rclpy.ok(): | |
| rclpy.init() | |
| self.node_publisher = RosNodePublisher(mode) | |
| self.json_processor = JsonProcessor() | |
| # Update llm_handler with the new model settings | |
| mode_config = MODE_CONFIG[mode] | |
| model_version = mode_config["model_version"] | |
| model_type = mode_config.get("model_type", "openai") # Ensure the correct model_type is used | |
| provider = mode_config.get("provider", MODEL_CONFIG["provider"]) | |
| # Re-instantiate LLMRequestHandler with the new model_version and model_type | |
| llm_handler = LLMRequestHandler( | |
| model_version=model_version, | |
| provider=provider, | |
| max_tokens=mode_config.get("max_tokens", MODEL_CONFIG["max_tokens"]), | |
| temperature=mode_config.get("temperature", MODEL_CONFIG["temperature"]), | |
| frequency_penalty=mode_config.get("frequency_penalty", MODEL_CONFIG["frequency_penalty"]), | |
| list_navigation_once=True, | |
| model_type=model_type | |
| ) | |
| # Update the prompt file and initial messages | |
| file_path = mode_config["prompt_file"] | |
| initial_messages = llm_handler.build_initial_messages(file_path, mode) | |
| # Update state with the new handler and reset history | |
| logger.info(f"Updating chatbot with {file_path}, model {model_version}, provider {provider}") | |
| state['file_path'] = file_path | |
| state['initial_messages'] = initial_messages | |
| state['history'] = [] | |
| state['mode'] = mode | |
| state['llm_config'] = llm_handler.get_config_dict() # Update the state with the new handler | |
| logger.info(f"\033[33mMode updated to {mode}\033[0m") | |
| return gr.update(value=[]), state | |