Spaces:
Running
Running
File size: 2,343 Bytes
3e462dd 679ca41 3e462dd 5ae9357 3e462dd 679ca41 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | #!/usr/bin/env python3
"""
Dataclasses for RBM model dataset trajectory structures.
Defines the standard format for HuggingFace dataset trajectories.
"""
from typing import Any, Union, List, Dict, Tuple, Optional
import numpy as np
from pydantic import BaseModel, ConfigDict
import torch
class Trajectory(BaseModel):
"""Trajectory structure containing frames, metadata, and progress information."""
# Core trajectory fields
frames: Union[List[str], np.ndarray, None] = None
frames_shape: Optional[Tuple] = None
# If embeddings are precomputed
embeddings_path: Optional[str] = None
video_embeddings: Union[torch.Tensor, np.ndarray, None] = None
text_embedding: Union[torch.Tensor, np.ndarray, None] = None
id: Optional[str] = None
task: Optional[str] = None
lang_vector: Union[np.ndarray, List[float], None] = None
data_source: Optional[str] = None
quality_label: Optional[str] = None
is_robot: Optional[bool] = None
# Progress and metadata
# Can be List[float] for continuous progress, np.ndarray, or List[np.ndarray] for C51 discrete distributions
target_progress: Optional[Union[List[float], List[torch.Tensor], torch.Tensor, None]] = None
partial_success: Optional[Union[float, torch.Tensor]] = None # float for continuous, Tensor for C51 discrete
success_label: Optional[List[float]] = None # Success labels for each frame (1.0 for success, 0.0 for failure)
predict_last_frame_mask: Optional[List[float]] = None # 1.0 per frame for inference (no masking); server requires a list
metadata: Optional[Dict[str, Any]] = None
data_gen_strategy: Optional[str] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
class ProgressSample(BaseModel):
"""Sample structure for progress evaluation."""
trajectory: Trajectory
sample_type: str = "progress"
data_gen_strategy: Optional[str] = None
resample_attempts: int = 1
class PreferenceSample(BaseModel):
"""Sample structure for preference prediction: chosen vs rejected where chosen is preferred."""
# Trajectories
chosen_trajectory: Trajectory
rejected_trajectory: Trajectory
sample_type: str = "preference"
data_gen_strategy: Optional[str] = None
resample_attempts: int = 1
SampleType = Union[PreferenceSample, ProgressSample]
|