Spaces:
Paused
Paused
| import gradio as gr | |
| import requests | |
| # import json | |
| import os | |
| from typing import Optional | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| # 后端API配置(可配置化) | |
| BACKEND_URL = os.getenv("BACKEND_URL", "http://your-backend-server:5000") | |
| API_ENDPOINTS = { | |
| "submit_task": f"{BACKEND_URL}/api/v1/submit", | |
| "query_status": f"{BACKEND_URL}/api/v1/status", | |
| "get_result": f"{BACKEND_URL}/api/v1/result" | |
| } | |
| # 全局缓存原始图像 | |
| #ORIGINAL_IMAGE = cv2.imread("scene.png") | |
| ORIGINAL_IMAGE = np.array(Image.open("scene.png").convert("RGB")) | |
| if ORIGINAL_IMAGE is None: | |
| raise RuntimeError("❌ 无法加载 scene.png,请确保图片文件与 app.py 同目录,并命名正确。") | |
| # 模拟场景配置 | |
| SCENE_CONFIGS = { | |
| "default_desk": { | |
| "description": "标准实验桌", | |
| "objects": ["番茄酱", "盐瓶", "餐刀", "杯子"] | |
| }, | |
| "cluttered_desk": { | |
| "description": "杂乱桌面场景", | |
| "objects": ["书本", "笔", "手机", "水杯", "零食袋"] | |
| }, | |
| "industrial_table": { | |
| "description": "工业工作台", | |
| "objects": ["扳手", "螺丝", "电路板", "润滑剂"] | |
| } | |
| } | |
| # 可用模型列表 | |
| MODEL_CHOICES = [ | |
| "GRManipulation-v1.0", | |
| "GR00T-N1", | |
| "GR00T-1.5", | |
| "Pi0", | |
| "DP+CLIP", | |
| "AcT+CLIP" | |
| ] | |
| def image_to_position(image: np.ndarray, evt: gr.SelectData) -> tuple[np.ndarray, str]: | |
| h, w = image.shape[:2] | |
| px, py = evt.index # 点击位置 (x, y) | |
| # 坐标转换 | |
| x = (px / w) * 2 - 1 | |
| y = -((py / h) * 2 - 1) | |
| z = 0.1 | |
| coord_str = f"{x:.2f}, {y:.2f}, {z:.2f}" | |
| # 使用原始图像绘制新图(保证每次只有一个点) | |
| marked = ORIGINAL_IMAGE.copy() | |
| cv2.circle(marked, center=(px, py), radius=8, color=(255, 0, 0), thickness=-1) | |
| return marked, coord_str | |
| def submit_to_backend( | |
| scene: str, | |
| prompt: str, | |
| start_position: str, | |
| max_steps: int = 100, | |
| visualize: bool = True | |
| ) -> dict: | |
| """ | |
| 提交任务到后端API | |
| """ | |
| payload = { | |
| "scene_config": scene, | |
| "prompt": prompt, | |
| "start_position": start_position, | |
| "params": { | |
| "max_steps": max_steps, | |
| "visualize": visualize | |
| }, | |
| "metadata": { | |
| "submit_from": "gradio_ui" | |
| } | |
| } | |
| try: | |
| response = requests.post( | |
| API_ENDPOINTS["submit_task"], | |
| json=payload, | |
| timeout=10 | |
| ) | |
| return response.json() | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| def get_task_status(task_id: str) -> dict: | |
| """ | |
| 查询任务状态 | |
| """ | |
| try: | |
| response = requests.get( | |
| f"{API_ENDPOINTS['query_status']}/{task_id}", | |
| timeout=5 | |
| ) | |
| return response.json() | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| def get_task_result(task_id: str) -> Optional[dict]: | |
| """ | |
| 获取任务结果 | |
| """ | |
| try: | |
| response = requests.get( | |
| f"{API_ENDPOINTS['get_result']}/{task_id}", | |
| timeout=5 | |
| ) | |
| return response.json() | |
| except Exception as e: | |
| print(f"Error fetching result: {e}") | |
| return None | |
| def run_simulation( | |
| scene: str, | |
| prompt: str, | |
| model: str, | |
| progress=gr.Progress() | |
| ) -> dict: | |
| """ | |
| 运行仿真的主函数 | |
| """ | |
| # 提交任务到后端 | |
| progress(0.1, desc="提交任务到后端...") | |
| submission = submit_to_backend(scene, prompt, model) | |
| if submission.get("status") != "success": | |
| raise gr.Error(f"提交失败: {submission.get('message', '未知错误')}") | |
| task_id = submission["task_id"] | |
| progress(0.3, desc="任务已提交,等待执行...") | |
| # 轮询任务状态 | |
| max_checks = 20 | |
| for i in range(max_checks): | |
| status = get_task_status(task_id) | |
| if status.get("status") == "completed": | |
| progress(0.9, desc="获取结果...") | |
| result = get_task_result(task_id) | |
| if result: | |
| return { | |
| "video": result.get("video_path"), | |
| "metrics": result.get("metrics"), | |
| "log": result.get("log") | |
| } | |
| else: | |
| raise gr.Error("获取结果失败") | |
| elif status.get("status") == "failed": | |
| raise gr.Error(f"任务执行失败: {status.get('message')}") | |
| progress(0.3 + 0.6 * (i/max_checks), desc=f"任务执行中...({status.get('progress', 0)}%)") | |
| raise gr.Error("任务执行超时") | |
| # 自定义CSS样式 | |
| custom_css = """ | |
| #simulation-panel { | |
| border-radius: 8px; | |
| padding: 20px; | |
| background: #f9f9f9; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
| } | |
| #result-panel { | |
| border-radius: 8px; | |
| padding: 20px; | |
| background: #f0f8ff; | |
| } | |
| .dark #simulation-panel { background: #2a2a2a; } | |
| .dark #result-panel { background: #1a2a3a; } | |
| /* 强力隐藏图像组件底部工具栏 */ | |
| .gr-image .absolute.bottom-0, | |
| .gr-image .flex.justify-between.items-center.px-2.pb-2 { | |
| display: none !important; | |
| } | |
| """ | |
| with gr.Blocks(title="机器人导航仿真系统", css=custom_css) as demo: | |
| # 标题和描述 | |
| gr.Markdown(""" | |
| # 🧭 GRNavigation 机器人导航仿真平台 | |
| ### 基于 GRNavigation 框架的多场景路径规划与自主导航训练 | |
| """) | |
| with gr.Row(): | |
| # 左侧控制面板 | |
| with gr.Column(elem_id="simulation-panel"): | |
| gr.Markdown("### 仿真任务配置") | |
| # 场景选择 | |
| scene_dropdown = gr.Dropdown( | |
| label="选择导航环境", | |
| choices=list(SCENE_CONFIGS.keys()), | |
| value="default_desk", | |
| interactive=True | |
| ) | |
| def update_scene_desc(scene): | |
| config = SCENE_CONFIGS.get(scene, {}) | |
| desc = config.get("description", "无描述") | |
| objects = "、".join(config.get("objects", [])) | |
| return f"**{desc}** \n包含物体: {objects}" | |
| # 场景描述预览 | |
| scene_description = gr.Markdown("") | |
| # 动态更新场景描述(函数不变) | |
| # 操作指令输入 | |
| prompt_input = gr.Textbox( | |
| label="导航指令(自然语言)", | |
| placeholder="例如:'从桌角出发,穿过障碍物,前往水杯位置'", | |
| lines=2, | |
| max_lines=4 | |
| ) | |
| # 起始坐标输入 | |
| start_pos_input = gr.Textbox( | |
| label="起始位置坐标 (x, y, z)", | |
| placeholder="例如:0.0, 0.0, 0.2", | |
| lines=1 | |
| ) | |
| # 高级参数 | |
| with gr.Accordion("高级设置", open=False): | |
| max_steps = gr.Slider( | |
| minimum=50, | |
| maximum=500, | |
| value=100, | |
| step=10, | |
| label="最大导航步数" | |
| ) | |
| visualize = gr.Checkbox( | |
| value=True, | |
| label="显示可视化界面(Isaac Sim)" | |
| ) | |
| # 提交按钮 | |
| submit_btn = gr.Button("开始导航仿真", variant="primary") | |
| # 右侧结果面板 | |
| with gr.Column(elem_id="result-panel"): | |
| gr.Markdown("### 仿真结果预览") | |
| # 视频输出 | |
| video_output = gr.Video( | |
| label="导航过程回放", | |
| interactive=False, | |
| format="mp4" | |
| ) | |
| # 场景俯视图图像(点击获取起点) | |
| scene_image = gr.Image( | |
| value="/scene.png", # 占位图路径 | |
| label="点击选择起点位置(场景俯视图)", | |
| type="numpy", # 获取坐标 | |
| interactive=True, | |
| height=300, | |
| show_share_button=False # ✅ 关闭底部按钮(上传、拍照、复制) | |
| ) | |
| # ✅ 添加“刷新场景图像”按钮 | |
| def reload_scene_image(): | |
| new_image = np.array(Image.open("scene.png").convert("RGB")) | |
| global ORIGINAL_IMAGE | |
| ORIGINAL_IMAGE = new_image | |
| return new_image | |
| refresh_btn = gr.Button("🔁 刷新场景图像") | |
| refresh_btn.click(fn=reload_scene_image, outputs=scene_image) | |
| # 指标展示 | |
| metrics_output = gr.JSON( | |
| label="导航性能指标", | |
| visible=False | |
| ) | |
| # 日志输出 | |
| log_output = gr.Textbox( | |
| label="任务执行日志", | |
| visible=False, | |
| lines=10, | |
| max_lines=20 | |
| ) | |
| # 示例任务 | |
| gr.Examples( | |
| examples=[ | |
| ["default_desk", "从桌角出发,前往番茄酱附近", "0.0, 0.0, 0.1"], | |
| ["cluttered_desk", "从水杯出发,移动到手机旁", "1.0, -0.5, 0.0"], | |
| ["industrial_table", "避开扳手,从台边移动到润滑剂", "0.5, 0.2, 0.0"] | |
| ], | |
| inputs=[scene_dropdown, prompt_input, start_pos_input], | |
| label="导航任务示例" | |
| ) | |
| # 提交处理逻辑 | |
| submit_btn.click( | |
| fn=run_simulation, | |
| inputs=[scene_dropdown, prompt_input, start_pos_input], | |
| outputs=[video_output, metrics_output, log_output], | |
| api_name="run_simulation" | |
| ) | |
| # 初始场景文字描述 | |
| demo.load( | |
| fn=lambda: (update_scene_desc("default_desk"), reload_scene_image()), | |
| outputs=[scene_description, scene_image] | |
| ) | |
| # ✅ 添加点击图片 → 自动设置起始位置 | |
| scene_image.select( | |
| fn=image_to_position, | |
| inputs=[scene_image], | |
| outputs=[scene_image, start_pos_input] | |
| ) | |
| # 启动应用 | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True, debug=True) |