File size: 2,513 Bytes
142a1ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from torch.utils.data import Dataset
from omegaconf import DictConfig
from pathlib import Path


class DummyVideoDataset(Dataset):
    def __init__(self, cfg: DictConfig, split: str = "training") -> None:
        super().__init__()
        self.cfg = cfg
        self.split = split
        self.height = cfg.height
        self.width = cfg.width
        self.n_frames = cfg.n_frames
        self.load_video_latent = cfg.load_video_latent
        self.load_prompt_embed = cfg.load_prompt_embed
        self.image_to_video = cfg.image_to_video
        self.max_text_tokens = cfg.max_text_tokens

    @property
    def metadata_path(self):
        raise ValueError("Dummy dataset does not have a metadata path")

    @property
    def data_root(self):
        raise ValueError("Dummy dataset does not have a data root path")

    def __len__(self) -> int:
        return 10000000  # Return fixed size of 10000000

    def __getitem__(self, idx: int) -> dict:
        # Generate dummy video tensor [T, C, H, W]
        videos = torch.randn(self.n_frames, 3, self.height, self.width)

        # Generate dummy image if needed
        images = videos[:1].clone() if self.image_to_video else None

        output = {
            "prompts": f"A dummy video caption for debugging purpose",
            "videos": videos,
            "video_metadata": {
                "num_frames": self.n_frames,
                "height": self.height,
                "width": self.width,
                "has_caption": True,
            },
            "has_bbox": torch.tensor([False, False]),
            "bbox_render": torch.zeros(2, self.height, self.width),
        }

        if images is not None:
            output["images"] = images

        if self.load_prompt_embed:
            # Generate dummy prompt embeddings [self.max_text_tokens, 4096]
            output["prompt_embeds"] = torch.randn(self.max_text_tokens, 4096)
            output["prompt_embed_len"] = self.max_text_tokens

        if self.load_video_latent:
            # Generate dummy latents
            if self.image_to_video:
                output["image_latents"] = torch.randn(
                    4,
                    self.n_frames // 4,
                    self.height // 8,
                    self.width // 8,
                )
            output["video_latents"] = torch.randn(
                4,
                self.n_frames // 4,
                self.height // 8,
                self.width // 8,
            )

        return output