File size: 2,343 Bytes
3e462dd
 
679ca41
3e462dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ae9357
3e462dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
679ca41
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#!/usr/bin/env python3
"""
Dataclasses for RBM model dataset trajectory structures.
Defines the standard format for HuggingFace dataset trajectories.
"""

from typing import Any, Union, List, Dict, Tuple, Optional

import numpy as np
from pydantic import BaseModel, ConfigDict
import torch


class Trajectory(BaseModel):
    """Trajectory structure containing frames, metadata, and progress information."""

    # Core trajectory fields
    frames: Union[List[str], np.ndarray, None] = None
    frames_shape: Optional[Tuple] = None

    # If embeddings are precomputed
    embeddings_path: Optional[str] = None
    video_embeddings: Union[torch.Tensor, np.ndarray, None] = None
    text_embedding: Union[torch.Tensor, np.ndarray, None] = None

    id: Optional[str] = None
    task: Optional[str] = None
    lang_vector: Union[np.ndarray, List[float], None] = None
    data_source: Optional[str] = None
    quality_label: Optional[str] = None
    is_robot: Optional[bool] = None

    # Progress and metadata
    # Can be List[float] for continuous progress, np.ndarray, or List[np.ndarray] for C51 discrete distributions
    target_progress: Optional[Union[List[float], List[torch.Tensor], torch.Tensor, None]] = None
    partial_success: Optional[Union[float, torch.Tensor]] = None  # float for continuous, Tensor for C51 discrete
    success_label: Optional[List[float]] = None  # Success labels for each frame (1.0 for success, 0.0 for failure)
    predict_last_frame_mask: Optional[List[float]] = None  # 1.0 per frame for inference (no masking); server requires a list
    metadata: Optional[Dict[str, Any]] = None
    data_gen_strategy: Optional[str] = None

    model_config = ConfigDict(arbitrary_types_allowed=True)


class ProgressSample(BaseModel):
    """Sample structure for progress evaluation."""

    trajectory: Trajectory
    sample_type: str = "progress"
    data_gen_strategy: Optional[str] = None
    resample_attempts: int = 1


class PreferenceSample(BaseModel):
    """Sample structure for preference prediction: chosen vs rejected where chosen is preferred."""

    # Trajectories
    chosen_trajectory: Trajectory
    rejected_trajectory: Trajectory

    sample_type: str = "preference"
    data_gen_strategy: Optional[str] = None
    resample_attempts: int = 1


SampleType = Union[PreferenceSample, ProgressSample]