File size: 18,988 Bytes
3e462dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
679ca41
cc5bab9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e462dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
679ca41
3e462dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a326636
 
 
 
 
3e462dd
a326636
679ca41
a326636
 
 
 
 
 
 
3e462dd
 
679ca41
a326636
 
 
 
3e462dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc5bab9
3e462dd
 
cc5bab9
3e462dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc5bab9
3e462dd
 
 
 
 
 
 
 
 
cc5bab9
3e462dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
#!/usr/bin/env python3
from __future__ import annotations

import torch
import io
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Union, Optional, Tuple
from datetime import datetime

import aiohttp
import numpy as np
import requests
import torch

from dataset_types import PreferenceSample, ProgressSample, Trajectory


def pad_trajectory_to_max_frames_np(
    frames: np.ndarray, progress: List[float], max_frames: int, pad_from: str = "right"
) -> Tuple[np.ndarray, List[float]]:
    """Pad trajectory frames and progress to max_frames by repeating the first frame/progress if needed.

    Args:
        frames: Trajectory frames (numpy array)
        progress: Progress values (list of floats)
        max_frames: Target number of frames

    Returns:
        Tuple[np.ndarray, List[float]: (padded_frames, padded_progress)
    """
    current_frames = frames.shape[0]

    if current_frames >= max_frames:
        # No padding needed
        return frames, progress

    if pad_from == "left":
        pad_frame = frames[0:1]  # Keep the batch dimension
        pad_progress = progress[0]
    else:
        pad_frame = frames[-1:]
        pad_progress = progress[-1]

    # Calculate how many frames to pad
    frames_to_pad = max_frames - current_frames

    # Pad frames by repeating the first frame
    if pad_from == "left":
        padded_frames = np.concatenate([np.repeat(pad_frame, frames_to_pad, axis=0), frames], axis=0)
        padded_progress = [pad_progress] * frames_to_pad + progress
    else:
        padded_frames = np.concatenate([frames, np.repeat(pad_frame, frames_to_pad, axis=0)], axis=0)
        padded_progress = progress + [pad_progress] * frames_to_pad

    return padded_frames, padded_progress


def linspace_subsample_frames(
    frames: np.ndarray, num_frames: int = 8, end_idx: Optional[int] = None
) -> Tuple[np.ndarray, List[int]]:
    """Uniformly subsample frames from a trajectory and return the indices.

    This method takes the full trajectory (e.g., 64 frames) and uniformly subsamples
    num_frames from it. The first and last frames are always included.

    Args:
        frames: Full trajectory frames (N frames)
        num_frames: Number of frames to subsample (default: 8)
        end_idx: Optional end index to subsample up to (if None, uses total_frames - 1)

    Returns:
        Tuple[np.ndarray, List[int]: (subsampled_frames, subsampled_indices)
    """
    if hasattr(frames, "shape"):
        total_frames = frames.shape[0]
    else:
        total_frames = len(frames)

    if total_frames <= 0:
        return frames, []

    # Use end_idx if provided, otherwise use full trajectory
    if end_idx is not None:
        end_idx = min(end_idx, total_frames - 1)
        frames_to_subsample = frames[: end_idx + 1]
        effective_total = end_idx + 1
    else:
        frames_to_subsample = frames
        effective_total = total_frames

    if effective_total <= num_frames:
        # If we have fewer (or equal) frames than requested, return all frames
        indices = list(range(effective_total))
        return frames_to_subsample, indices

    # Special case: if num_frames == 1, always take the last frame
    if num_frames == 1:
        indices = [effective_total - 1]
        subsampled_frames = frames_to_subsample[indices]
        return subsampled_frames, indices

    # Evenly spaced indices from 0 to effective_total-1, inclusive
    indices_np = np.linspace(0, effective_total - 1, num_frames)
    indices = np.rint(indices_np).astype(int).tolist()

    # Enforce first and last explicitly
    indices[0] = 0
    indices[-1] = effective_total - 1

    # Ensure indices are strictly non-decreasing and within bounds
    for k in range(1, len(indices)):
        if indices[k] < indices[k - 1]:
            indices[k] = indices[k - 1]
        if indices[k] >= effective_total:
            indices[k] = effective_total - 1

    # Subsample frames
    subsampled_frames = frames_to_subsample[indices]

    return subsampled_frames, indices


def raw_dict_to_sample(
    raw_data: Union[Tuple[Dict[str, Any], Dict[str, Any]], Dict[str, Any]],
    max_frames: int = 16,
    sample_type: str = "progress",
) -> Union[ProgressSample, PreferenceSample]:
    """
    Convert raw data dictionary to a ProgressSample or PreferenceSample.

    Args:
        raw_data: Dict with 'frames', 'task', 'id', 'metadata', 'video_embeddings', 'text_embedding' or Tuple of (Dict[str, Any], Dict[str, Any])
        max_frames: Maximum number of frames to use (default: 16)
        sample_type: Either "progress" or "preference" (default: "progress")

    Returns:
        ProgressSample or PreferenceSample
    """

    def _build_trajectory(raw_data: Dict[str, Any], num_frames: int) -> Trajectory:
        processed_item: Dict[str, Any] = {}

        # Process frames
        frames_array = raw_data["frames"]

        # Ensure we have the correct shape: (T, H, W, C)
        if len(frames_array.shape) != 4:
            raise ValueError(f"Expected 4D array (T, H, W, C), got shape {frames_array.shape}")

        # Convert from CxHxW to HxWxC if needed
        if frames_array.shape[1] == 3:
            frames_array = np.transpose(frames_array, (0, 2, 3, 1))

        frames_array, _ = linspace_subsample_frames(frames_array, num_frames)
        dummy_progress = [0.0] * len(frames_array)
        frames_array, _ = pad_trajectory_to_max_frames_np(frames_array, dummy_progress, num_frames, pad_from="right")

        if frames_array.size == 0:
            raise ValueError("No frames processed for example")

        processed_item["frames"] = frames_array
        processed_item["frames_shape"] = frames_array.shape
        processed_item["task"] = raw_data["task"]
        processed_item["lang_vector"] = None
        processed_item["metadata"] = raw_data.get("metadata", None)

        # Process video embeddings using same helper functions
        video_embeddings = raw_data.get("video_embeddings")
        if video_embeddings is not None:
            video_embeddings, _ = linspace_subsample_frames(video_embeddings, num_frames)
            dummy_progress_emb = [0.0] * len(video_embeddings)
            video_embeddings, _ = pad_trajectory_to_max_frames_np(
                video_embeddings, dummy_progress_emb, num_frames, pad_from="right"
            )

        text_embedding = raw_data.get("text_embedding")

        # Convert to tensors if they are numpy arrays
        if video_embeddings is not None and isinstance(video_embeddings, np.ndarray):
            video_embeddings = torch.tensor(video_embeddings)
        if text_embedding is not None and isinstance(text_embedding, np.ndarray):
            text_embedding = torch.tensor(text_embedding)

        processed_item["video_embeddings"] = video_embeddings
        processed_item["text_embedding"] = text_embedding
        processed_item["video_shape"] = video_embeddings.shape if video_embeddings is not None else None
        processed_item["text_shape"] = text_embedding.shape if text_embedding is not None else None

        trajectory = Trajectory(**processed_item)
        return trajectory

    if sample_type == "progress":
        assert isinstance(raw_data, dict), "raw_data must be a dictionary"
        trajectory = _build_trajectory(raw_data=raw_data, num_frames=max_frames)
        return ProgressSample(trajectory=trajectory)
    elif sample_type == "preference":
        assert isinstance(raw_data, tuple), "raw_data must be a tuple"
        assert len(raw_data) == 2, "raw_data must be a tuple of two dictionaries"
        trajectories: List[Trajectory] = []
        for trajectory_data in raw_data:
            trajectory = _build_trajectory(raw_data=trajectory_data, num_frames=max_frames)
            trajectories.append(trajectory)
        return PreferenceSample(chosen_trajectory=trajectories[0], rejected_trajectory=trajectories[1])
    else:
        raise ValueError(f"Unsupported sample_type: {sample_type}")


def build_payload(
    samples: list[PreferenceSample | ProgressSample],
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
    """Build a payload with numpy array handling.

    Args:
        samples: List of samples to convert

    Returns:
        Tuple of (files, sample_data) where:
        - files: Dict of numpy arrays converted to .npy format
        - sample_data: List of sample dictionaries with numpy arrays replaced by file references
    """
    files = {}
    sample_data = []

    for sample_idx, sample in enumerate(samples):
        # Copy the original sample and handle numpy arrays
        processed_sample = sample.model_dump().copy()

        # Handle trajectory objects with numpy arrays
        for key in [
            "chosen_trajectory",
            "rejected_trajectory",
            "trajectory",
        ]:
            if key in processed_sample and isinstance(processed_sample[key], dict):
                trajectory = processed_sample[key]

                # Convert numpy arrays to .npy files
                numpy_fields = ["frames", "lang_vector", "video_embeddings", "text_embedding"]
                for field_name in numpy_fields:
                    # if it is a tensor, first convert it to a numpy array
                    if field_name in trajectory and isinstance(trajectory[field_name], torch.Tensor):
                        trajectory[field_name] = trajectory[field_name].numpy()

                    if field_name in trajectory and isinstance(trajectory[field_name], np.ndarray):
                        # Convert numpy array to .npy file
                        buf = io.BytesIO()
                        np.save(buf, trajectory[field_name])
                        buf.seek(0)
                        file_key = f"sample_{sample_idx}_{key}_{field_name}"
                        files[file_key] = (
                            f"sample_{sample_idx}_{key}_{field_name}.npy",
                            buf,
                            "application/octet-stream",
                        )
                        trajectory[field_name] = {"__numpy_file__": file_key}

        sample_data.append(processed_sample)

    return files, sample_data


def post_batch(url: str, payload: dict[str, Any], timeout_s: float = 120.0) -> dict[str, Any]:
    """POST a batch payload to the evaluation server and return parsed JSON."""
    resp = requests.post(url.rstrip("/") + "/evaluate_batch", json=payload, timeout=timeout_s)
    resp.raise_for_status()
    return resp.json()


def post_batch_npy(
    url: str,
    files: dict[str, Any],
    sample_data: list[dict[str, Any]],
    timeout_s: float = 120.0,
    extra_form_data: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
    """POST batch using .npy format for numpy arrays.

    Args:
        url: Server URL
        files: Dict of numpy arrays converted to .npy format
        sample_data: List of sample dictionaries
        timeout_s: Request timeout in seconds
        extra_form_data: Optional extra form data to include (e.g., use_frame_steps)
    """
    # Convert sample_data to form data
    data = {f"sample_{i}": json.dumps(sample) for i, sample in enumerate(sample_data)}

    # Add extra form data if provided
    if extra_form_data:
        for key, value in extra_form_data.items():
            data[key] = json.dumps(value) if not isinstance(value, str) else value

    # Send as multipart form data
    resp = requests.post(url.rstrip("/") + "/evaluate_batch_npy", files=files, data=data, timeout=timeout_s)
    resp.raise_for_status()
    return resp.json()


async def post_batch_npy_async(
    session: aiohttp.ClientSession,
    url: str,
    files: dict[str, Any],
    sample_data: list[dict[str, Any]],
    timeout_s: float = 120.0,
) -> dict[str, Any]:
    """Async version of post_batch_npy using aiohttp."""
    # Create FormData for aiohttp
    form_data = aiohttp.FormData()

    # Add files
    for key, (filename, file_obj, content_type) in files.items():
        form_data.add_field(key, file_obj, filename=filename, content_type=content_type)

    # Add sample data
    for i, sample in enumerate(sample_data):
        form_data.add_field(f"sample_{i}", json.dumps(sample))

    headers = {"Connection": "close"}
    # Send as multipart form data using aiohttp
    timeout = aiohttp.ClientTimeout(total=timeout_s)
    async with session.post(
        url.rstrip("/") + "/evaluate_batch_npy", data=form_data, timeout=timeout, headers=headers
    ) as resp:
        resp.raise_for_status()
        return await resp.json()


async def parse_npy_form_data(form_data: Any) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
    """Parse multipart form data to extract numpy arrays and other data.

    Args:
        form_data: FastAPI form data from request.form()

    Returns:
        Tuple of (numpy_arrays dict, other_data dict)
    """
    numpy_arrays = {}
    other_data = {}

    for key, value in form_data.items():
        # Check if this is a file upload (UploadFile object)
        if hasattr(value, "filename") and value.filename:
            # This is a file upload
            if value.filename.endswith(".npy"):
                # Load .npy file (await async read)
                content = await value.read()
                buf = io.BytesIO(content)
                array = np.load(buf)
                numpy_arrays[key] = array
            else:
                # Non-.npy file, skip for now
                continue
        else:
            # This is a string value (form field)
            try:
                # Try to parse as JSON
                other_data[key] = json.loads(value)
            except (json.JSONDecodeError, TypeError):
                # Keep as string if not JSON
                other_data[key] = value

    return numpy_arrays, other_data


def reconstruct_payload_from_npy(
    numpy_arrays: Dict[str, np.ndarray],
    other_data: Dict[str, Any],
    trajectory_keys: Optional[List[str]] = None,
    convert_embeddings_to_torch: bool = False,
) -> List[Dict[str, Any]]:
    """Reconstruct the original payload structure from .npy files and form data.

    The client sends data in this format:
    - Files: sample_0_chosen_trajectory_frames.npy, sample_0_trajectory_frames.npy, etc.
    - Data: sample_0, sample_1, etc. (each containing the full sample JSON with numpy file references)

    Args:
        numpy_arrays: Dictionary of numpy arrays loaded from .npy files
        other_data: Dictionary of other form data
        trajectory_keys: List of trajectory keys to process (default: common keys)
        convert_embeddings_to_torch: Whether to convert embeddings to torch tensors

    Returns:
        List of reconstructed sample dictionaries
    """
    if trajectory_keys is None:
        trajectory_keys = [
            "chosen_trajectory",
            "rejected_trajectory",
            "trajectory",
        ]

    samples = []

    # Process each sample
    for i in range(len(other_data)):
        sample_key = f"sample_{i}"
        if sample_key in other_data:
            # Get the sample data - might already be parsed or might be a string
            sample_data = other_data[sample_key]
            if isinstance(sample_data, str):
                # Parse the sample JSON if it's a string
                sample_data = json.loads(sample_data)

            # Replace numpy file references with actual arrays
            for key, value in sample_data.items():
                if key in trajectory_keys:
                    if isinstance(value, dict):
                        for traj_key, traj_value in value.items():
                            if isinstance(traj_value, dict) and traj_value.get("__numpy_file__"):
                                # Replace with actual numpy array
                                file_key = traj_value["__numpy_file__"]
                                if file_key in numpy_arrays:
                                    value[traj_key] = numpy_arrays[file_key]

                            # Convert embeddings to torch if requested
                            if convert_embeddings_to_torch and traj_key in ["video_embeddings", "text_embedding"]:
                                if traj_key in value and value[traj_key] is not None:
                                    if isinstance(value[traj_key], np.ndarray):
                                        value[traj_key] = torch.tensor(value[traj_key])
                                    elif isinstance(value[traj_key], list):
                                        value[traj_key] = torch.tensor(value[traj_key])

            samples.append(sample_data)

    return samples


def find_video_files(directory: str) -> list[str]:
    """Find all video files in a directory.

    Args:
        directory: Path to directory containing video files

    Returns:
        List of paths to video files
    """
    video_extensions = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv", ".wmv", ".m4v"}
    video_files = []

    directory_path = Path(directory)
    if not directory_path.is_dir():
        return []

    for file_path in directory_path.iterdir():
        if file_path.is_file() and file_path.suffix.lower() in video_extensions:
            video_files.append(str(file_path))

    video_files.sort()
    return video_files


def infer_task_from_video_name(video_path: str) -> str:
    """Infer task name from video filename.

    Task is everything before the comma (if comma exists), or everything before success/fail/failure.

    Args:
        video_path: Path to video file

    Returns:
        Inferred task name
    """
    video_name = Path(video_path).stem  # Get filename without extension

    # If there's a comma, task is everything before the comma
    if "," in video_name:
        task_part = video_name.split(",")[0]
    else:
        # Otherwise, split by underscore and remove success/fail/failure suffixes
        parts = video_name.split("_")
        filtered_parts = []
        for part in parts:
            part_lower = part.lower()
            if part_lower not in ["success", "fail", "failure"]:
                filtered_parts.append(part)

        if not filtered_parts:
            return "Complete the task"

        task_part = "_".join(filtered_parts)

    # Split by underscore and join with spaces
    task_words = task_part.split("_")
    task = " ".join(task_words)

    if task:
        # Capitalize first letter of first word, keep rest as is
        task = task[0].upper() + task[1:] if len(task) > 1 else task.upper()
    else:
        task = "Complete the task"

    return task


def setup_output_directory(output_dir: Optional[str], video_path: Optional[str] = None) -> str:
    """Create output directory and return path."""
    if output_dir:
        save_dir = output_dir
    else:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        save_dir = os.path.join(".", f"eval_outputs/{timestamp}")

    os.makedirs(save_dir, exist_ok=True)
    return save_dir