File size: 4,309 Bytes
142a1ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import pandas as pd
from tqdm import tqdm
from pathlib import Path
import decord
import shutil
import subprocess
import json
from typing import Dict, Any
from .video_base import VideoDataset


class DroidVideoDataset(VideoDataset):
    def __init__(self, cfg: Dict[str, Any], split: str = "training"):
        self.override_fps = cfg.download.override_fps
        self.views = cfg.download.views
        super().__init__(cfg, split)

    def download(self):
        self.data_root.mkdir(parents=True, exist_ok=True)

        # print("Downloading DROID dataset...")
        # cmd = f"gsutil -m cp -r gs://gresearch/robotics/droid_raw {self.data_root}"
        # subprocess.run(cmd, shell=True, check=True)
        # print("Download complete!")

        # build metadata
        raw_dir = self.data_root / "droid_raw"
        caption_file = raw_dir / "1.0.1" / "aggregated-annotations-030724.json"
        caption_data = json.load(open(caption_file))
        records = []
        for lab_dir in (raw_dir / "1.0.1").glob("*/"):
            print("processing", lab_dir)
            print("=" * 100)
            # Delete failure directory and its contents if it exists
            failure_dir = lab_dir / "failure"
            success_dir = lab_dir / "success"
            if failure_dir.exists():
                shutil.rmtree(failure_dir)

            for date_dir in list(success_dir.glob("*")):
                for episode_dir in list(date_dir.glob("*")):
                    # Rename episode directory if it contains ":"
                    if ":" in episode_dir.name:
                        new_name = episode_dir.name.replace(":", "_")
                        new_path = episode_dir.parent / new_name
                        if new_path.exists():
                            shutil.rmtree(episode_dir)
                        else:
                            episode_dir.rename(new_path)

            for episode_dir in tqdm(list(success_dir.glob("*/*"))):
                annotation_file = list(episode_dir.glob("*.json"))
                if not annotation_file:
                    continue
                annotation_file = annotation_file[0]
                f = json.load(open(annotation_file))
                caption = f["current_task"]
                uuid = f["uuid"]
                for views in self.views:
                    video_path = lab_dir / f[views + "_mp4_path"].replace(":", "_")
                    state_path = lab_dir / f["hdf5_path"].replace(":", "_")
                    n_frames = f["trajectory_length"]

                    if not video_path.exists():
                        print(f"Video file not found: {video_path}")
                        continue

                    try:
                        vr = decord.VideoReader(str(video_path))
                        fps = self.override_fps
                        width = 1280  # vr[0].shape[1]
                        height = 720  # vr[0].shape[0]

                        del vr
                    except Exception as e:
                        print(f"Error loading video {video_path}: {e}")
                        continue

                    video_path = video_path.relative_to(self.data_root)
                    # state_path = state_path.relative_to(self.data_root)

                    if uuid not in caption_data:
                        caption = ""
                        has_caption = False
                    else:
                        caption = caption_data[uuid]
                        has_caption = True
                    records.append(
                        {
                            "video_path": str(video_path),
                            # "state_path": str(state_path),
                            "original_caption": caption,
                            "fps": fps,
                            "n_frames": n_frames,
                            "width": width,
                            "height": height,
                            "has_caption": has_caption,
                        }
                    )
        metadata_path = self.data_root / self.metadata_path
        metadata_path.parent.mkdir(parents=True, exist_ok=True)
        df = pd.DataFrame(records)
        df.to_csv(metadata_path, index=False)
        print(f"Created metadata CSV with {len(records)} videos")