Spaces:
Paused
Paused
| ''' | |
| author: caishaofei <caishaofei@stu.pku.edu.cn> | |
| date: 2024-09-20 20:10:44 | |
| Copyright © Team CraftJarvis All rights reserved | |
| ''' | |
| import re | |
| import os | |
| import cv2 | |
| import time | |
| from pathlib import Path | |
| import argparse | |
| import requests | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from io import BytesIO | |
| from PIL import Image, ImageDraw | |
| from rocket.arm.sessions import Session, Pointer | |
| COLORS = [ | |
| (255, 0, 0), (0, 255, 0), (0, 0, 255), | |
| (255, 255, 0), (255, 0, 255), (0, 255, 255), | |
| (255, 255, 255), (0, 0, 0), (128, 128, 128), | |
| (128, 0, 0), (128, 128, 0), (0, 128, 0), | |
| (128, 0, 128), (0, 128, 128), (0, 0, 128), | |
| ] | |
| SEGMENT_MAPPING = { | |
| "Hunt": 0, "Use": 3, "Mine": 2, "Interact": 3, "Craft": 4, "Switch": 5, "Approach": 6 | |
| } | |
| NOOP_ACTION = { | |
| "back": 0, | |
| "drop": 0, | |
| "forward": 0, | |
| "hotbar.1": 0, | |
| "hotbar.2": 0, | |
| "hotbar.3": 0, | |
| "hotbar.4": 0, | |
| "hotbar.5": 0, | |
| "hotbar.6": 0, | |
| "hotbar.7": 0, | |
| "hotbar.8": 0, | |
| "hotbar.9": 0, | |
| "inventory": 0, | |
| "jump": 0, | |
| "left": 0, | |
| "right": 0, | |
| "sneak": 0, | |
| "sprint": 0, | |
| "camera": np.array([0, 0]), | |
| "attack": 0, | |
| "use": 0, | |
| } | |
| def reset_fn(env_name, session): | |
| image = session.reset(env_name) | |
| return image, session | |
| def step_fn(act_key, session): | |
| action = NOOP_ACTION.copy() | |
| if act_key != "null": | |
| action[act_key] = 1 | |
| image = session.step(action) | |
| return image, session | |
| def loop_step_fn(steps, session): | |
| for i in range(steps): | |
| image = session.step() | |
| status = f"Running Agent `Rocket` steps: {i+1}/{steps}. " | |
| yield image, session.num_steps, status, session | |
| def clear_memory_fn(session): | |
| image = session.current_image | |
| session.clear_agent_memory() | |
| return image, "0", session | |
| def get_points_with_draw(image, label, session, evt: gr.SelectData): | |
| points = session.points | |
| point_label = session.points_label | |
| x, y = evt.index[0], evt.index[1] | |
| point_radius, point_color = 5, (0, 255, 0) if label == 'Add Points' else (255, 0, 0) | |
| points.append([x, y]) | |
| point_label.append(1 if label == 'Add Points' else 0) | |
| cv2.circle(image, (x, y), point_radius, point_color, -1) | |
| return image, session | |
| def clear_points_fn(session): | |
| session.clear_points() | |
| return session.current_image, session | |
| def segment_fn(session): | |
| if len(session.points) == 0: | |
| return session.current_image, session | |
| session.segment() | |
| image = session.apply_mask() | |
| return image, session | |
| def clear_segment_fn(session): | |
| session.clear_obj_mask() | |
| session.tracking_flag = False | |
| return session.current_image, False, session | |
| def set_tracking_mode(tracking_flag, session): | |
| session.tracking_flag = tracking_flag | |
| return session | |
| def set_segment_type(segment_type, session): | |
| session.segment_type = segment_type | |
| return session | |
| def play_fn(session): | |
| image = session.step() | |
| return image, session | |
| memory_length = gr.Textbox(value="0", interactive=False, show_label=False) | |
| def make_video_fn(session, make_video, save_video, progress=gr.Progress()): | |
| images = session.image_history | |
| if len(images) == 0: | |
| return session, make_video, save_video | |
| filepath = "rocket.mp4" | |
| h, w = images[0].shape[:2] | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| video = cv2.VideoWriter(filepath, fourcc, 20.0, (w, h)) | |
| for image in progress.tqdm(images): | |
| image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
| video.write(image) | |
| video.release() | |
| session.image_history = [] | |
| return session, gr.Button("Make Video", visible=False), gr.DownloadButton("Download!", value=filepath, visible=True) | |
| def save_video_fn(session, make_video, save_video): | |
| return session, gr.Button("Make Video", visible=True), gr.DownloadButton("Download!", visible=False) | |
| def choose_sam_fn(sam_choice, session): | |
| session.sam_choice = sam_choice | |
| session.load_sam() | |
| return session | |
| def molmo_fn(molmo_text, molmo_session, rocket_session, display_image): | |
| image = rocket_session.current_image.copy() | |
| points = molmo_session.gen_point(image=image, prompt=molmo_text) | |
| molmo_result = molmo_session.molmo_result | |
| for x, y in points: | |
| x, y = int(x), int(y) | |
| point_radius, point_color = 5, (0, 255, 0) | |
| rocket_session.points.append([x, y]) | |
| rocket_session.points_label.append(1) | |
| cv2.circle(display_image, (x, y), point_radius, point_color, -1) | |
| return molmo_result, display_image | |
| def extract_points(data): | |
| # 匹配 x 和 y 坐标的值,支持 <points> 和 <point> 标签 | |
| pattern = r'x\d?="([-+]?\d*\.\d+|\d+)" y\d?="([-+]?\d*\.\d+|\d+)"' | |
| points = re.findall(pattern, data) | |
| # 将提取到的坐标转换为浮点数 | |
| points = [(float(x)/100*640, float(y)/100*360) for x, y in points] | |
| return points | |
| def draw_gradio_components(args): | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # Welcome to Explore ROCKET-1 in Minecraft!! | |
| ## Please follow next steps to interact with the agent: | |
| 1. Reset the environment by selecting an environment name. | |
| 2. Select a SAM2 checkpoint to load. | |
| 3. Use your mouse to add or remove points on the image. | |
| 4. Select the segment type you want to perform. | |
| 5. Enable `tracking` mode if you want to track objects while stepping actions. | |
| 6. Click `New Segment` to segment the image based on the points you added. | |
| 7. Call the agent by clicking `Call Rocket` to run the agent for a certain number of steps. | |
| ## Hints: | |
| 1. You can use the `Make Video` button to generate a video of the agent's actions. | |
| 2. You can use the `Clear Memory` button to clear the ROCKET-1's memory. | |
| 3. You can use the `Clear Segment` button to clear SAM's memory. | |
| 4. You can use the `Manually Step` button to manually step the agent. | |
| """ | |
| ) | |
| rocket_session = gr.State(Session( | |
| sam_path=args.sam_path, | |
| )) | |
| molmo_session = gr.State(Pointer( | |
| model_id="molmo-72b-0924", | |
| model_url="http://172.17.30.127:8000/v1", | |
| )) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # start_image = Image.open("start.png").resize((640, 360)) | |
| start_image = np.zeros((360, 640, 3), dtype=np.uint8) | |
| with gr.Group(): | |
| display_image = gr.Image( | |
| value=np.array(start_image), | |
| interactive=False, | |
| show_label=False, | |
| label="Real-time Environment Observation", | |
| streaming=True | |
| ) | |
| display_status = gr.Textbox("Status Bar", interactive=False, show_label=False) | |
| with gr.Column(scale=1): | |
| sam_choice = gr.Radio( | |
| choices=["large", "base", "small", "tiny"], | |
| value="base", | |
| label="Select SAM2 checkpoint", | |
| ) | |
| sam_choice.select(fn=choose_sam_fn, inputs=[sam_choice, rocket_session], outputs=[rocket_session], show_progress=False) | |
| with gr.Group(): | |
| add_or_remove = gr.Radio( | |
| choices=["Add Points", "Remove Areas"], | |
| value="Add Points", | |
| label="Use you mouse to add or remove points", | |
| ) | |
| clear_points_btn = gr.Button("Clear Points") | |
| clear_points_btn.click(clear_points_fn, inputs=[rocket_session], outputs=[display_image, rocket_session], show_progress=True) | |
| with gr.Group(): | |
| segment_type = gr.Radio( | |
| choices=["Approach", "Interact", "Hunt", "Mine", "Craft", "Switch"], | |
| value="Approach", | |
| label="What do you want with this segment?", | |
| ) | |
| track_flag = gr.Checkbox(True, label="Enable tracking objects while steping actions") | |
| track_flag.select(fn=set_tracking_mode, inputs=[track_flag, rocket_session], outputs=[rocket_session], show_progress=False) | |
| with gr.Group(), gr.Row(): | |
| new_segment_btn = gr.Button("New Segment") | |
| clear_segment_btn = gr.Button("Clear Segment") | |
| new_segment_btn.click(segment_fn, inputs=[rocket_session], outputs=[display_image, rocket_session], show_progress=True) | |
| clear_segment_btn.click(clear_segment_fn, inputs=[rocket_session], outputs=[display_image, track_flag, rocket_session], show_progress=True) | |
| display_image.select(get_points_with_draw, inputs=[display_image, add_or_remove, rocket_session], outputs=[display_image, rocket_session]) | |
| segment_type.select(set_segment_type, inputs=[segment_type, rocket_session], outputs=[rocket_session], show_progress=False) | |
| with gr.Row(): | |
| with gr.Group(): | |
| env_list = [f"rocket/{x.stem}" for x in Path("../env_configs/rocket").glob("*.yaml") if 'base' not in x.name != 'base'] | |
| env_name = gr.Dropdown(env_list, multiselect=False, min_width=200, show_label=False, label="Env Name") | |
| reset_btn = gr.Button("Reset Environment") | |
| reset_btn.click(fn=reset_fn, inputs=[env_name, rocket_session], outputs=[display_image, rocket_session], show_progress=True) | |
| with gr.Group(): | |
| action_list = [x for x in NOOP_ACTION.keys()] | |
| act_key = gr.Dropdown(action_list, multiselect=False, min_width=200, show_label=False, label="Action") | |
| step_btn = gr.Button("Manually Step") | |
| step_btn.click(fn=step_fn, inputs=[act_key, rocket_session], outputs=[display_image, rocket_session], show_progress=False) | |
| with gr.Group(): | |
| steps = gr.Slider(1, 600, 30, 1, label="Steps", show_label=False) | |
| play_btn = gr.Button("Call Rocket") | |
| play_btn.click(fn=loop_step_fn, inputs=[steps, rocket_session], outputs=[display_image, memory_length, display_status, rocket_session], show_progress=False) | |
| with gr.Group(): | |
| memory_length.render() | |
| clear_states_btn = gr.Button("Clear Memory") | |
| clear_states_btn.click(fn=clear_memory_fn, inputs=rocket_session, outputs=[display_image, memory_length, rocket_session], show_progress=False) | |
| make_video_btn = gr.Button("Make Video") | |
| save_video_btn = gr.DownloadButton("Download!!", visible=False) | |
| make_video_btn.click(make_video_fn, inputs=[rocket_session, make_video_btn, save_video_btn], outputs=[rocket_session, make_video_btn, save_video_btn], show_progress=False) | |
| save_video_btn.click(save_video_fn, inputs=[rocket_session, make_video_btn, save_video_btn], outputs=[rocket_session, make_video_btn, save_video_btn], show_progress=False) | |
| with gr.Row(): | |
| with gr.Group(): | |
| molmo_text = gr.Textbox("pinpoint the", label="Molmo Text", show_label=True, min_width=200) | |
| molmo_btn = gr.Button("Generate") | |
| output_text = gr.Textbox("", label="Molmo Output", show_label=False, min_width=200) | |
| molmo_btn.click(molmo_fn, inputs=[molmo_text, molmo_session, rocket_session, display_image],outputs=[output_text, display_image],show_progress=False) | |
| demo.queue() | |
| demo.launch(share=False,server_port=args.port) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--port", type=int, default=7860) | |
| parser.add_argument("--sam-path", type=str, default="/app/ROCKET-1/rocket/realtime_sam/checkpoints") | |
| parser.add_argument("--molmo-id", type=str, default="molmo-72b-0924") | |
| parser.add_argument("--molmo-url", type=str, default="http://127.0.0.1:8000/v1") | |
| args = parser.parse_args() | |
| draw_gradio_components(args) | |