Spaces:
No application file
No application file
| import gradio as gr | |
| from loguru import logger | |
| from gradio_llm_interface import GradioLlmInterface | |
| from config import GRADIO_MESSAGE_MODES, MODE_CONFIG | |
| import openai | |
| import os | |
| from dotenv import load_dotenv | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Speech-to-text function using OpenAI Whisper | |
| def audio_to_text(audio): | |
| if audio is None: | |
| return "No audio file provided." | |
| try: | |
| # Get OpenAI API key from environment variable | |
| openai_api_key = os.getenv("OPENAI_API_KEY") | |
| if not openai_api_key: | |
| return "Error: OpenAI API key not found. Please set OPENAI_API_KEY environment variable." | |
| # Initialize OpenAI client | |
| client = openai.OpenAI(api_key=openai_api_key) | |
| # Open and transcribe the audio file | |
| with open(audio, "rb") as audio_file: | |
| transcript = client.audio.transcriptions.create( | |
| model="whisper-1", | |
| file=audio_file | |
| ) | |
| return transcript.text | |
| except FileNotFoundError: | |
| return "Error: Audio file not found." | |
| except openai.AuthenticationError: | |
| return "Error: Invalid OpenAI API key." | |
| except openai.RateLimitError: | |
| return "Error: OpenAI API rate limit exceeded." | |
| except Exception as e: | |
| logger.error(f"Speech-to-text error: {str(e)}") | |
| return f"Error during speech recognition: {str(e)}" | |
| def main(): | |
| gradio_ros_interface = GradioLlmInterface() | |
| title_markdown = (""" | |
| # π DART-LLM: Dependency-Aware Multi-Robot Task Decomposition and Execution using Large Language Models | |
| [[Project Page](https://wyd0817.github.io/project-dart-llm/)] [[Code](https://github.com/wyd0817/gradio_gpt_interface)] [[Model](https://artificialanalysis.ai/)] | π [[RoboQA](https://www.overleaf.com/project/6614a987ae2994cae02efcb2)] | |
| """) | |
| with gr.Blocks(css=""" | |
| #text-input, #audio-input { | |
| height: 100px; /* Unified height */ | |
| max-height: 100px; | |
| width: 100%; /* Full container width */ | |
| margin: 0; | |
| } | |
| .input-container { | |
| display: flex; /* Flex layout */ | |
| gap: 10px; /* Spacing */ | |
| align-items: center; /* Vertical alignment */ | |
| } | |
| #voice-input-container { | |
| display: flex; | |
| align-items: center; | |
| gap: 15px; | |
| margin: 15px 0; | |
| padding: 15px; | |
| background: linear-gradient(135deg, #ffeef8 0%, #fff5f5 100%); | |
| border-radius: 20px; | |
| border: 1px solid #ffe4e6; | |
| } | |
| #voice-btn { | |
| width: 50px !important; | |
| height: 50px !important; | |
| border-radius: 50% !important; | |
| font-size: 20px !important; | |
| background: linear-gradient(135deg, #ff6b9d 0%, #c44569 100%) !important; | |
| color: white !important; | |
| border: none !important; | |
| box-shadow: 0 4px 15px rgba(255, 107, 157, 0.3) !important; | |
| transition: all 0.3s ease !important; | |
| } | |
| #voice-btn:hover { | |
| transform: scale(1.05) !important; | |
| box-shadow: 0 6px 20px rgba(255, 107, 157, 0.4) !important; | |
| } | |
| #voice-btn:active { | |
| transform: scale(0.95) !important; | |
| } | |
| .voice-recording { | |
| background: linear-gradient(135deg, #ff4757 0%, #ff3742 100%) !important; | |
| animation: pulse 1.5s infinite !important; | |
| } | |
| @keyframes pulse { | |
| 0% { box-shadow: 0 4px 15px rgba(255, 71, 87, 0.3); } | |
| 50% { box-shadow: 0 4px 25px rgba(255, 71, 87, 0.6); } | |
| 100% { box-shadow: 0 4px 15px rgba(255, 71, 87, 0.3); } | |
| } | |
| #voice-status { | |
| color: #ff6b9d; | |
| font-size: 14px; | |
| font-weight: 500; | |
| text-align: center; | |
| margin-top: 10px; | |
| } | |
| /* Enhanced layout for left-right split */ | |
| .gradio-container .gradio-row { | |
| gap: 20px; /* Add spacing between columns */ | |
| } | |
| .gradio-column { | |
| padding: 10px; | |
| border-radius: 8px; | |
| background-color: var(--panel-background-fill); | |
| } | |
| /* Chat interface styling */ | |
| .chat-column { | |
| border: 1px solid var(--border-color-primary); | |
| } | |
| /* DAG visualization column styling */ | |
| .dag-column { | |
| border: 1px solid var(--border-color-primary); | |
| } | |
| """) as demo: | |
| gr.Markdown(title_markdown) | |
| mode_choices = [MODE_CONFIG[mode]["display_name"] for mode in GRADIO_MESSAGE_MODES] | |
| mode_selector = gr.Radio(choices=mode_choices, label="Backend model", value=mode_choices[0]) | |
| clear_button = gr.Button("Clear Chat") | |
| logger.info("Starting Gradio GPT Interface...") | |
| initial_mode = GRADIO_MESSAGE_MODES[0] | |
| def update_mode(selected_mode, state): | |
| mode_key = [key for key, value in MODE_CONFIG.items() if value["display_name"] == selected_mode][0] | |
| return gradio_ros_interface.update_chatbot(mode_key, state) | |
| # Main content area with left-right layout | |
| with gr.Row(): | |
| # Left column: Chat interface | |
| with gr.Column(scale=1, elem_classes=["chat-column"]): | |
| gr.Markdown("### π€ DART-LLM Chat Interface") | |
| # Create chatbot component in the left column | |
| chatbot_container = gr.Chatbot(label="DART-LLM", type="messages") | |
| # Initialize the interface and get state data | |
| state_data = gradio_ros_interface.initialize_interface(initial_mode) | |
| state = gr.State(state_data) | |
| # Add input area in the left column | |
| with gr.Row(elem_id="input-container"): | |
| txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", elem_id="text-input", container=False) | |
| with gr.Row(elem_id="voice-input-container"): | |
| with gr.Column(scale=4): | |
| # Hidden audio component | |
| audio_input = gr.Audio( | |
| sources=["microphone"], | |
| type="filepath", | |
| elem_id="audio-input", | |
| show_label=False, | |
| interactive=True, | |
| streaming=False, | |
| visible=False | |
| ) | |
| # Voice input status display | |
| voice_status = gr.Markdown("", elem_id="voice-status", visible=False) | |
| with gr.Column(scale=1, min_width=80): | |
| # Main voice button | |
| voice_btn = gr.Button( | |
| "ποΈ", | |
| elem_id="voice-btn", | |
| variant="secondary", | |
| size="sm", | |
| scale=1 | |
| ) | |
| # Example prompts in the left column | |
| gr.Examples( | |
| examples=[ | |
| "Dump truck 1 goes to the puddle for inspection, after which all robots avoid the puddle", | |
| "Send Excavator 1 and Dump Truck 1 to the soil area; Excavator 1 will excavate and unload, followed by Dump Truck 1 proceeding to the puddle for unloading." | |
| ], | |
| inputs=txt | |
| ) | |
| # Right column: DAG visualization and controls | |
| with gr.Column(scale=1, elem_classes=["dag-column"]): | |
| gr.Markdown("### π Task Dependency Visualization") | |
| # DAG visualization display | |
| dag_image = gr.Image(label="Task Dependency Graph", visible=True, height=600) | |
| # Task plan editing section | |
| task_editor = gr.Code( | |
| label="Task Plan JSON Editor", | |
| language="json", | |
| visible=False, | |
| lines=15, | |
| interactive=True | |
| ) | |
| # Control buttons section | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| deployment_status = gr.Markdown("", visible=True) | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| edit_task_btn = gr.Button( | |
| "π Edit Task Plan", | |
| variant="secondary", | |
| visible=False, | |
| size="sm" | |
| ) | |
| update_dag_btn = gr.Button( | |
| "π Update DAG Visualization", | |
| variant="secondary", | |
| visible=False, | |
| size="sm" | |
| ) | |
| validate_deploy_btn = gr.Button( | |
| "π Validate & Deploy Task Plan", | |
| variant="primary", | |
| visible=False, | |
| size="sm" | |
| ) | |
| mode_selector.change(update_mode, inputs=[mode_selector, state], outputs=[chatbot_container, state]) | |
| clear_button.click(gradio_ros_interface.clear_chat, inputs=[state], outputs=[chatbot_container]) | |
| # Handle text input submission | |
| async def handle_text_submit(text, state): | |
| messages, state, dag_image_path, validate_btn_update = await gradio_ros_interface.predict(text, state) | |
| # Show edit button when task plan is generated | |
| edit_btn_visible = validate_btn_update.get('visible', False) | |
| return ( | |
| "", # Clear the text input after submission | |
| messages, | |
| state, | |
| dag_image_path, | |
| validate_btn_update, | |
| gr.update(visible=edit_btn_visible) # Show edit button | |
| ) | |
| txt.submit(handle_text_submit, [txt, state], [txt, chatbot_container, state, dag_image, validate_deploy_btn, edit_task_btn]) | |
| # Voice input state management | |
| voice_recording = gr.State(False) | |
| # Voice button click handler | |
| def handle_voice_input(audio, is_recording): | |
| logger.info(f"Voice button clicked, current recording state: {is_recording}") | |
| if not is_recording: | |
| # Start recording state | |
| logger.info("Starting recording...") | |
| return ( | |
| gr.update(value="π΄", elem_classes=["voice-recording"]), # Change button style | |
| "π¬ Recording in progress...", # Status message | |
| gr.update(visible=True), # Show status | |
| gr.update(visible=True), # Show audio component | |
| True, # Update recording state | |
| "" # Clear text box | |
| ) | |
| else: | |
| # Stop recording and transcribe | |
| logger.info("Stopping recording, starting transcription...") | |
| if audio is not None and audio != "": | |
| try: | |
| text = audio_to_text(audio) | |
| logger.info(f"Transcription completed: {text}") | |
| return ( | |
| gr.update(value="ποΈ", elem_classes=[]), # Restore button style | |
| "β¨ Transcription completed!", # Success message | |
| gr.update(visible=True), # Show status | |
| gr.update(visible=False), # Hide audio component | |
| False, # Reset recording state | |
| text # Fill in transcribed text | |
| ) | |
| except Exception as e: | |
| logger.error(f"Transcription error: {e}") | |
| return ( | |
| gr.update(value="ποΈ", elem_classes=[]), # Restore button style | |
| f"β Transcription failed: {str(e)}", | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| False, | |
| "" | |
| ) | |
| else: | |
| logger.warning("No audio detected") | |
| return ( | |
| gr.update(value="ποΈ", elem_classes=[]), # Restore button style | |
| "β οΈ No audio detected, please record again", | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| False, | |
| "" | |
| ) | |
| # Voice button event handling | |
| voice_btn.click( | |
| handle_voice_input, | |
| inputs=[audio_input, voice_recording], | |
| outputs=[voice_btn, voice_status, voice_status, audio_input, voice_recording, txt] | |
| ) | |
| # Audio state change listener - automatic prompt | |
| def on_audio_change(audio): | |
| if audio is not None: | |
| logger.info("Audio file detected") | |
| return "π΅ Audio detected, you can click the button to complete transcription" | |
| return "" | |
| audio_input.change( | |
| on_audio_change, | |
| inputs=[audio_input], | |
| outputs=[voice_status] | |
| ) | |
| # Handle task plan editing | |
| edit_task_btn.click( | |
| gradio_ros_interface.show_task_plan_editor, | |
| inputs=[state], | |
| outputs=[task_editor, update_dag_btn, validate_deploy_btn, deployment_status] | |
| ) | |
| # Handle DAG update from editor | |
| update_dag_btn.click( | |
| gradio_ros_interface.update_dag_from_editor, | |
| inputs=[task_editor, state], | |
| outputs=[dag_image, validate_deploy_btn, task_editor, update_dag_btn, deployment_status, state] | |
| ) | |
| # Handle validation and deployment | |
| validate_deploy_btn.click( | |
| gradio_ros_interface.validate_and_deploy_task_plan, | |
| inputs=[state], | |
| outputs=[deployment_status, dag_image, validate_deploy_btn, state] | |
| ) | |
| demo.launch(server_port=8080, share=True) | |
| if __name__ == "__main__": | |
| main() |