# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import random import numpy as np import pytest import tensordict import torch from packaging.version import parse as parse_version from tensordict import TensorDict from verl import DataProto from verl.protocol import ( deserialize_single_tensor, deserialize_tensordict, serialize_single_tensor, serialize_tensordict, union_numpy_dict, union_tensor_dict, ) from verl.utils import tensordict_utils as tu def test_union_tensor_dict(): obs = torch.randn(100, 10) data1 = TensorDict({"obs": obs, "act": torch.randn(100, 3)}, batch_size=[100]) data2 = TensorDict({"obs": obs, "next_obs": torch.randn(100, 10), "rew": torch.randn(100)}, batch_size=[100]) data_with_copied_obs = TensorDict( {"obs": obs.clone(), "next_obs": torch.randn(100, 10), "rew": torch.randn(100)}, batch_size=[100] ) union_tensor_dict(data1, data2) with pytest.raises(AssertionError): union_tensor_dict(data1, data_with_copied_obs) def test_union_numpy_dict(): """ A comprehensive test suite for union_numpy_dict, covering standard use cases, N-dimensional arrays, object-dtype arrays, and NaN value handling. """ arr_3d = np.arange(8).reshape((2, 2, 2)) union_numpy_dict({"a": arr_3d}, {"a": arr_3d}) arr1 = np.array([1, "hello", np.array([2, 3])], dtype=object) arr2 = np.array([1, "hello", np.array([2, 3])], dtype=object) union_numpy_dict({"a": arr1}, {"a": arr2}) # --- Test Case 1: The original test with mixed object/float types --- # This test case from the original test file is preserved. data = np.random.random(100) # This array intentionally mixes float('nan') and the string 'nan' nan_data = [float("nan") for _ in range(99)] nan_data.append("nan") nan_data_arr = np.array(nan_data, dtype=object) dict1 = {"a": data, "b": nan_data_arr} dict2_same = {"a": data.copy(), "b": nan_data_arr.copy()} dict3_different = {"a": np.random.random(100)} union_numpy_dict(dict1, dict2_same) # Should pass with pytest.raises(AssertionError): union_numpy_dict(dict1, dict3_different) # --- Test Case 2: Standard 3D arrays (fixes the core bug) --- arr_3d = np.arange(24, dtype=np.int32).reshape((2, 3, 4)) dict_3d_1 = {"nd_array": arr_3d} dict_3d_2_same = {"nd_array": arr_3d.copy()} dict_3d_3_different = {"nd_array": arr_3d + 1} union_numpy_dict(dict_3d_1, dict_3d_2_same) # Should pass with pytest.raises(AssertionError, match="`nd_array` in tensor_dict1 and tensor_dict2 are not the same object."): union_numpy_dict(dict_3d_1, dict_3d_3_different) # --- Test Case 3: Nested 2D and 4D object-dtype arrays --- sub_arr1 = np.array([1, 2]) sub_arr2 = np.array([3.0, 4.0]) # 2D object array arr_2d_obj = np.array([[sub_arr1, "text"], [sub_arr2, None]], dtype=object) arr_2d_obj_diff = np.array([[sub_arr1, "text"], [sub_arr2, "other"]], dtype=object) union_numpy_dict({"data": arr_2d_obj}, {"data": arr_2d_obj.copy()}) # Should pass with pytest.raises(AssertionError): union_numpy_dict({"data": arr_2d_obj}, {"data": arr_2d_obj_diff}) # 4D object array to ensure deep recursion is robust arr_4d_obj = np.array([[[[sub_arr1]]], [[[sub_arr2]]]], dtype=object) arr_4d_obj_diff = np.array([[[[sub_arr1]]], [[[np.array([9, 9])]]]], dtype=object) union_numpy_dict({"data": arr_4d_obj}, {"data": arr_4d_obj.copy()}) # Should pass with pytest.raises(AssertionError): union_numpy_dict({"data": arr_4d_obj}, {"data": arr_4d_obj_diff}) # --- Test Case 4: Explicit NaN value comparison --- # This verifies that our new _deep_equal logic correctly handles NaNs. nan_arr = np.array([1.0, np.nan, 3.0]) dict_nan_1 = {"data": nan_arr} dict_nan_2_same = {"data": np.array([1.0, np.nan, 3.0])} # A new array with same values dict_nan_3_different_val = {"data": np.array([1.0, 2.0, 3.0])} dict_nan_4_different_pos = {"data": np.array([np.nan, 1.0, 3.0])} # NaNs in the same position should be considered equal for merging. union_numpy_dict(dict_nan_1, dict_nan_2_same) # Should pass with pytest.raises(AssertionError): union_numpy_dict(dict_nan_1, dict_nan_3_different_val) with pytest.raises(AssertionError): union_numpy_dict(dict_nan_1, dict_nan_4_different_pos) # --- Test Case 5: Circular reference handling --- # Create two separate, but structurally identical, circular references. # This should pass without a RecursionError. circ_arr_1 = np.array([None], dtype=object) circ_arr_1[0] = circ_arr_1 circ_arr_2 = np.array([None], dtype=object) circ_arr_2[0] = circ_arr_2 union_numpy_dict({"data": circ_arr_1}, {"data": circ_arr_2}) # Should pass # Create a circular reference and a non-circular one. # This should fail with an AssertionError because they are different. non_circ_arr = np.array([None], dtype=object) with pytest.raises(AssertionError): union_numpy_dict({"data": circ_arr_1}, {"data": non_circ_arr}) def test_tensor_dict_constructor(): obs = torch.randn(100, 10) act = torch.randn(100, 10, 3) data = DataProto.from_dict(tensors={"obs": obs, "act": act}) assert data.batch.batch_size == torch.Size([100]) with pytest.raises(AssertionError): data = DataProto.from_dict(tensors={"obs": obs, "act": act}, num_batch_dims=2) with pytest.raises(AssertionError): data = DataProto.from_dict(tensors={"obs": obs, "act": act}, num_batch_dims=3) def test_tensor_dict_make_iterator(): obs = torch.randn(100, 10) labels = [random.choice(["abc", "cde"]) for _ in range(100)] dataset = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}) data_iter_1 = dataset.make_iterator(mini_batch_size=10, epochs=2, seed=1) data_list_1 = [] for data in data_iter_1: data_list_1.append(data) data_iter_2 = dataset.make_iterator(mini_batch_size=10, epochs=2, seed=1) data_list_2 = [] for data in data_iter_2: data_list_2.append(data) for data1, data2 in zip(data_list_1, data_list_2, strict=True): assert isinstance(data1, DataProto) assert isinstance(data2, DataProto) result = torch.all(torch.eq(data1.batch["obs"], data2.batch["obs"])) if not result.item(): print(data1.batch["obs"]) print(data2.batch["obs"]) raise AssertionError() non_tensor_result = np.all(np.equal(data1.non_tensor_batch["labels"], data2.non_tensor_batch["labels"])) if not non_tensor_result.item(): print(data1.non_tensor_batch["labels"]) print(data2.non_tensor_batch["labels"]) def test_reorder(): obs = torch.tensor([1, 2, 3, 4, 5, 6]) labels = ["a", "b", "c", "d", "e", "f"] data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"name": "abdce"}) data.reorder(torch.tensor([3, 4, 2, 0, 1, 5])) assert torch.all(torch.eq(data.batch["obs"], torch.tensor([4, 5, 3, 1, 2, 6]))) assert np.all(data.non_tensor_batch["labels"] == np.array(["d", "e", "c", "a", "b", "f"])) assert data.meta_info == {"name": "abdce"} def test_chunk_concat(): obs = torch.tensor([1, 2, 3, 4, 5, 6]) labels = ["a", "b", "c", "d", "e", "f"] data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"name": "abdce"}) with pytest.raises(AssertionError): data.chunk(5) data_split = data.chunk(2) assert len(data_split) == 2 assert torch.all(torch.eq(data_split[0].batch["obs"], torch.tensor([1, 2, 3]))) assert np.all(data_split[0].non_tensor_batch["labels"] == np.array(["a", "b", "c"])) assert data_split[0].meta_info == {"name": "abdce"} assert torch.all(torch.eq(data_split[1].batch["obs"], torch.tensor([4, 5, 6]))) assert np.all(data_split[1].non_tensor_batch["labels"] == np.array(["d", "e", "f"])) assert data_split[1].meta_info == {"name": "abdce"} concat_data = DataProto.concat(data_split) assert torch.all(torch.eq(concat_data.batch["obs"], data.batch["obs"])) assert np.all(concat_data.non_tensor_batch["labels"] == data.non_tensor_batch["labels"]) assert concat_data.meta_info == data.meta_info def test_concat_metrics_from_multiple_workers(): """Test that concat() properly merges metrics from all workers in distributed training.""" # Simulate 3 workers each with their own metrics obs1 = torch.tensor([1, 2]) obs2 = torch.tensor([3, 4]) obs3 = torch.tensor([5, 6]) # Each worker has different metrics (as list of dict format) worker1_metrics = [{"loss": 0.5, "accuracy": 0.9}] worker2_metrics = [{"loss": 0.6, "accuracy": 0.85}] worker3_metrics = [{"loss": 0.55, "accuracy": 0.88}] data1 = DataProto.from_dict(tensors={"obs": obs1}, meta_info={"metrics": worker1_metrics, "config_flag": True}) data2 = DataProto.from_dict(tensors={"obs": obs2}, meta_info={"metrics": worker2_metrics, "config_flag": True}) data3 = DataProto.from_dict(tensors={"obs": obs3}, meta_info={"metrics": worker3_metrics, "config_flag": True}) # Concat all workers' data concat_data = DataProto.concat([data1, data2, data3]) # Verify tensors are concatenated assert torch.all(torch.eq(concat_data.batch["obs"], torch.tensor([1, 2, 3, 4, 5, 6]))) # Verify ALL workers' metrics are flattened to dict of lists expected_metrics = {"loss": [0.5, 0.6, 0.55], "accuracy": [0.9, 0.85, 0.88]} assert concat_data.meta_info["metrics"] == expected_metrics # Verify config flags are preserved from first worker assert concat_data.meta_info["config_flag"] is True def test_concat_with_empty_and_non_list_meta_info(): """Test concat() handles edge cases: empty meta_info, non-list values, and None.""" obs1 = torch.tensor([1, 2]) obs2 = torch.tensor([3, 4]) # Worker 1 has metrics, worker 2 doesn't data1 = DataProto.from_dict(tensors={"obs": obs1}, meta_info={"metrics": [{"loss": 0.5}], "flag": True}) data2 = DataProto.from_dict(tensors={"obs": obs2}, meta_info={"flag": True}) concat_data = DataProto.concat([data1, data2]) # Should flatten worker1's metrics to dict of lists assert concat_data.meta_info["metrics"] == {"loss": [0.5]} assert concat_data.meta_info["flag"] is True # Test with non-list meta_info value data3 = DataProto.from_dict(tensors={"obs": obs1}, meta_info={"single_value": 42}) data4 = DataProto.from_dict(tensors={"obs": obs2}, meta_info={"single_value": 42}) concat_data2 = DataProto.concat([data3, data4]) assert concat_data2.meta_info["single_value"] == 42 def test_concat_first_worker_missing_metrics(): """Test that metrics from other workers are preserved even when first worker has no metrics. This is a critical edge case - the old buggy implementation only checked data[0].meta_info and would lose all metrics if the first worker didn't have any. """ obs1 = torch.tensor([1, 2]) obs2 = torch.tensor([3, 4]) obs3 = torch.tensor([5, 6]) # First worker has NO metrics, but workers 2 and 3 do data1 = DataProto.from_dict(tensors={"obs": obs1}, meta_info={"config_flag": True}) data2 = DataProto.from_dict(tensors={"obs": obs2}, meta_info={"metrics": {"loss": 0.6}, "config_flag": True}) data3 = DataProto.from_dict(tensors={"obs": obs3}, meta_info={"metrics": {"loss": 0.55}, "config_flag": True}) concat_data = DataProto.concat([data1, data2, data3]) # Should flatten metrics from workers 2 and 3 into dict of lists expected_metrics = {"loss": [0.6, 0.55]} assert concat_data.meta_info["metrics"] == expected_metrics assert concat_data.meta_info["config_flag"] is True def test_concat_non_list_metrics(): """Test that concat() handles non-list metrics (single dict) correctly. In some cases, metrics might be a single dict instead of a list. The implementation should flatten them into a dict of lists. """ obs1 = torch.tensor([1, 2]) obs2 = torch.tensor([3, 4]) # Metrics as single dict (not wrapped in list) data1 = DataProto.from_dict(tensors={"obs": obs1}, meta_info={"metrics": {"loss": 0.5, "accuracy": 0.9}}) data2 = DataProto.from_dict(tensors={"obs": obs2}, meta_info={"metrics": {"loss": 0.6, "accuracy": 0.85}}) concat_data = DataProto.concat([data1, data2]) # Should flatten to dict of lists expected_metrics = {"loss": [0.5, 0.6], "accuracy": [0.9, 0.85]} assert concat_data.meta_info["metrics"] == expected_metrics def test_concat_merge_different_non_metric_keys(): """Test that concat() merges non-metric meta_info keys from all workers. When different workers have different non-metric keys, all keys should be preserved. This prevents silent data loss and aligns with the docstring stating meta_info is "merged". """ obs1 = torch.tensor([1, 2]) obs2 = torch.tensor([3, 4]) obs3 = torch.tensor([5, 6]) # Each worker has some unique non-metric keys data1 = DataProto.from_dict(tensors={"obs": obs1}, meta_info={"config": "A", "shared_key": "X"}) data2 = DataProto.from_dict(tensors={"obs": obs2}, meta_info={"extra_key": "B", "shared_key": "X"}) data3 = DataProto.from_dict(tensors={"obs": obs3}, meta_info={"another_key": "C", "shared_key": "X"}) concat_data = DataProto.concat([data1, data2, data3]) # All unique keys should be preserved assert concat_data.meta_info["config"] == "A" assert concat_data.meta_info["extra_key"] == "B" assert concat_data.meta_info["another_key"] == "C" assert concat_data.meta_info["shared_key"] == "X" def test_concat_conflicting_non_metric_keys(): """Test that concat() raises an assertion error when non-metric keys have conflicting values. This ensures data integrity by catching cases where workers have different values for what should be the same configuration parameter. """ obs1 = torch.tensor([1, 2]) obs2 = torch.tensor([3, 4]) # Same key "config" but different values data1 = DataProto.from_dict(tensors={"obs": obs1}, meta_info={"config": "A"}) data2 = DataProto.from_dict(tensors={"obs": obs2}, meta_info={"config": "B"}) # Should raise an assertion error due to conflicting values with pytest.raises(AssertionError, match="Conflicting values for meta_info key 'config'"): DataProto.concat([data1, data2]) def test_pop(): obs = torch.randn(100, 10) act = torch.randn(100, 3) dataset = DataProto.from_dict({"obs": obs, "act": act}, meta_info={"2": 2, "1": 1}) poped_dataset = dataset.pop(batch_keys=["obs"], meta_info_keys=["2"]) assert poped_dataset.batch.keys() == {"obs"} assert poped_dataset.meta_info.keys() == {"2"} assert dataset.batch.keys() == {"act"} assert dataset.meta_info.keys() == {"1"} def test_repeat(): # Create a DataProto object with some batch and non-tensor data obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) labels = ["a", "b", "c"] data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) # Test interleave=True repeated_data_interleave = data.repeat(repeat_times=2, interleave=True) expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [3, 4], [3, 4], [5, 6], [5, 6]]) expected_labels_interleave = ["a", "a", "b", "b", "c", "c"] assert torch.all(torch.eq(repeated_data_interleave.batch["obs"], expected_obs_interleave)) assert (repeated_data_interleave.non_tensor_batch["labels"] == expected_labels_interleave).all() assert repeated_data_interleave.meta_info == {"info": "test_info"} # Test interleave=False repeated_data_no_interleave = data.repeat(repeat_times=2, interleave=False) expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6]]) expected_labels_no_interleave = ["a", "b", "c", "a", "b", "c"] assert torch.all(torch.eq(repeated_data_no_interleave.batch["obs"], expected_obs_no_interleave)) assert (repeated_data_no_interleave.non_tensor_batch["labels"] == expected_labels_no_interleave).all() assert repeated_data_no_interleave.meta_info == {"info": "test_info"} def test_dataproto_pad_unpad(): obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) labels = ["a", "b", "c"] data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=2) assert pad_size == 1 expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2]]) expected_labels = ["a", "b", "c", "a"] assert torch.all(torch.eq(padded_data.batch["obs"], expected_obs)) assert (padded_data.non_tensor_batch["labels"] == expected_labels).all() assert padded_data.meta_info == {"info": "test_info"} unpadd_data = unpad_dataproto(padded_data, pad_size=pad_size) assert torch.all(torch.eq(unpadd_data.batch["obs"], obs)) assert (unpadd_data.non_tensor_batch["labels"] == labels).all() assert unpadd_data.meta_info == {"info": "test_info"} padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=3) assert pad_size == 0 expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) expected_labels = ["a", "b", "c"] assert torch.all(torch.eq(padded_data.batch["obs"], expected_obs)) assert (padded_data.non_tensor_batch["labels"] == expected_labels).all() assert padded_data.meta_info == {"info": "test_info"} unpadd_data = unpad_dataproto(padded_data, pad_size=pad_size) assert torch.all(torch.eq(unpadd_data.batch["obs"], obs)) assert (unpadd_data.non_tensor_batch["labels"] == labels).all() assert unpadd_data.meta_info == {"info": "test_info"} padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=7) assert pad_size == 4 expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6], [1, 2]]) expected_labels = ["a", "b", "c", "a", "b", "c", "a"] assert torch.all(torch.eq(padded_data.batch["obs"], expected_obs)) assert (padded_data.non_tensor_batch["labels"] == expected_labels).all() assert padded_data.meta_info == {"info": "test_info"} unpadd_data = unpad_dataproto(padded_data, pad_size=pad_size) assert torch.all(torch.eq(unpadd_data.batch["obs"], obs)) assert (unpadd_data.non_tensor_batch["labels"] == labels).all() assert unpadd_data.meta_info == {"info": "test_info"} def test_dataproto_fold_unfold(): from verl.protocol import DataProto, fold_batch_dim, unfold_batch_dim obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) labels = ["a", "b", "c"] data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) data1 = data.repeat(repeat_times=2, interleave=True) data2 = fold_batch_dim(data1, new_batch_size=3) torch.testing.assert_close(data2.batch["obs"], torch.tensor([[[1, 2], [1, 2]], [[3, 4], [3, 4]], [[5, 6], [5, 6]]])) assert (data2.non_tensor_batch["labels"] == [["a", "a"], ["b", "b"], ["c", "c"]]).all() data2.reorder(indices=torch.tensor([1, 2, 0])) data3 = unfold_batch_dim(data2, batch_dims=2) torch.testing.assert_close(data3.batch["obs"], torch.tensor([[3, 4], [3, 4], [5, 6], [5, 6], [1, 2], [1, 2]])) assert (data3.non_tensor_batch["labels"] == ["b", "b", "c", "c", "a", "a"]).all() assert data3.meta_info == {"info": "test_info"} def test_torch_save_data_proto(): obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) labels = ["a", "b", "c"] data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) data.save_to_disk("test_data.pt") loaded_data = DataProto.load_from_disk("test_data.pt") assert torch.all(torch.eq(loaded_data.batch["obs"], data.batch["obs"])) assert (loaded_data.non_tensor_batch["labels"] == data.non_tensor_batch["labels"]).all() assert loaded_data.meta_info == data.meta_info import os os.remove("test_data.pt") def test_len(): obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) labels = np.array(["a", "b", "c"], dtype=object) data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) assert len(data) == 3 data = DataProto(batch=None, non_tensor_batch={"labels": labels}, meta_info={"info": "test_info"}) assert len(data) == 3 data = DataProto(batch=None, non_tensor_batch={}, meta_info={"info": "test_info"}) assert len(data) == 0 data = DataProto(batch=None, non_tensor_batch=None, meta_info={"info": "test_info"}) assert len(data) == 0 def test_dataproto_index(): data_len = 100 idx_num = 10 obs = torch.randn(data_len, 10) labels = [random.choice(["abc", "cde"]) for _ in range(data_len)] data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}) labels_np = np.array(labels) idx_np_int = np.random.randint(0, data_len, size=(idx_num,)) result_np_int = data[idx_np_int] assert result_np_int.batch.keys() == data.batch.keys() assert result_np_int.non_tensor_batch.keys() == data.non_tensor_batch.keys() assert result_np_int.batch["obs"].shape[0] == idx_num assert result_np_int.non_tensor_batch["labels"].shape[0] == idx_num assert np.array_equal(result_np_int.batch["obs"].cpu().numpy(), obs[idx_np_int].numpy()) assert np.array_equal(result_np_int.non_tensor_batch["labels"], labels_np[idx_np_int]) idx_torch_int = torch.randint(0, data_len, size=(idx_num,)) result_torch_int = data[idx_torch_int] assert result_torch_int.batch.keys() == data.batch.keys() assert result_torch_int.non_tensor_batch.keys() == data.non_tensor_batch.keys() assert result_torch_int.batch["obs"].shape[0] == idx_num assert result_torch_int.non_tensor_batch["labels"].shape[0] == idx_num assert np.array_equal(result_torch_int.batch["obs"].cpu().numpy(), obs[idx_torch_int].cpu().numpy()) assert np.array_equal(result_torch_int.non_tensor_batch["labels"], labels_np[idx_torch_int.cpu().numpy()]) idx_list_int = [np.random.randint(0, data_len) for _ in range(idx_num)] result_list_int = data[idx_list_int] assert result_list_int.batch.keys() == data.batch.keys() assert result_list_int.non_tensor_batch.keys() == data.non_tensor_batch.keys() assert result_list_int.batch["obs"].shape[0] == idx_num assert result_list_int.non_tensor_batch["labels"].shape[0] == idx_num assert np.array_equal(result_list_int.batch["obs"].cpu().numpy(), obs[idx_list_int].cpu().numpy()) assert np.array_equal(result_list_int.non_tensor_batch["labels"], labels_np[idx_list_int]) idx_np_bool = np.random.randint(0, 2, size=(data_len,), dtype=bool) result_np_bool = data[idx_np_bool] assert result_np_bool.batch.keys() == data.batch.keys() assert result_np_bool.non_tensor_batch.keys() == data.non_tensor_batch.keys() assert result_np_bool.batch["obs"].shape[0] == idx_np_bool.sum() assert result_np_bool.non_tensor_batch["labels"].shape[0] == idx_np_bool.sum() assert np.array_equal(result_np_bool.batch["obs"].cpu().numpy(), obs[idx_np_bool].cpu().numpy()) assert np.array_equal(result_np_bool.non_tensor_batch["labels"], labels_np[idx_np_bool]) idx_torch_bool = torch.randint(0, 2, size=(data_len,), dtype=torch.bool) result_torch_bool = data[idx_torch_bool] assert result_torch_bool.batch.keys() == data.batch.keys() assert result_torch_bool.non_tensor_batch.keys() == data.non_tensor_batch.keys() assert result_torch_bool.batch["obs"].shape[0] == idx_torch_bool.sum().item() assert result_torch_bool.non_tensor_batch["labels"].shape[0] == idx_torch_bool.sum().item() assert np.array_equal(result_torch_bool.batch["obs"].cpu().numpy(), obs[idx_torch_bool].cpu().numpy()) assert np.array_equal(result_torch_bool.non_tensor_batch["labels"], labels_np[idx_torch_bool]) idx_list_bool = [np.random.randint(0, 2, dtype=bool) for _ in range(data_len)] result_list_bool = data[idx_list_bool] assert result_list_bool.batch.keys() == data.batch.keys() assert result_list_bool.non_tensor_batch.keys() == data.non_tensor_batch.keys() assert result_list_bool.batch["obs"].shape[0] == sum(idx_list_bool) assert result_list_bool.non_tensor_batch["labels"].shape[0] == sum(idx_list_bool) assert np.array_equal(result_list_bool.batch["obs"].cpu().numpy(), obs[idx_list_bool].cpu().numpy()) assert np.array_equal(result_list_bool.non_tensor_batch["labels"], labels_np[idx_list_bool]) def test_old_vs_new_from_single_dict(): class CustomProto(DataProto): """Uses the new, fixed from_single_dict.""" pass class OriginProto(DataProto): """Mimics the *old* from_single_dict (always returns a DataProto).""" @classmethod def from_single_dict(cls, data, meta_info=None, auto_padding=False): tensors, non_tensors = {}, {} for k, v in data.items(): if torch.is_tensor(v): tensors[k] = v else: non_tensors[k] = v # always calls DataProto.from_dict, ignoring `cls` return DataProto.from_dict( tensors=tensors, non_tensors=non_tensors, meta_info=meta_info, auto_padding=auto_padding, ) sample = {"x": torch.tensor([0])} orig = OriginProto.from_single_dict(sample) # old behavior: always DataProto, not a CustomOriginProto assert type(orig) is DataProto assert type(orig) is not OriginProto cust = CustomProto.from_single_dict(sample) # new behavior: respects subclass assert type(cust) is CustomProto def test_dataproto_no_batch(): labels = ["a", "b", "c"] data = DataProto.from_dict(non_tensors={"labels": labels}, meta_info={"info": "test_info"}) selected = data.select(non_tensor_batch_keys=["labels"]) assert (selected.non_tensor_batch["labels"] == labels).all() pop_data = data.pop(non_tensor_batch_keys=["labels"]) assert (pop_data.non_tensor_batch["labels"] == labels).all() assert data.non_tensor_batch == {} def test_sample_level_repeat(): # Create a DataProto object with some batch and non-tensor data obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) labels = ["a", "b", "c"] data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) # list repeated_data_interleave = data.sample_level_repeat(repeat_times=[3, 1, 2]) expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [1, 2], [3, 4], [5, 6], [5, 6]]) expected_labels_interleave = ["a", "a", "a", "b", "c", "c"] assert torch.all(torch.eq(repeated_data_interleave.batch["obs"], expected_obs_interleave)) assert (repeated_data_interleave.non_tensor_batch["labels"] == expected_labels_interleave).all() assert repeated_data_interleave.meta_info == {"info": "test_info"} # torch.tensor repeated_data_no_interleave = data.sample_level_repeat(repeat_times=torch.tensor([1, 2, 3])) expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [3, 4], [5, 6], [5, 6], [5, 6]]) expected_labels_no_interleave = ["a", "b", "b", "c", "c", "c"] assert torch.all(torch.eq(repeated_data_no_interleave.batch["obs"], expected_obs_no_interleave)) assert (repeated_data_no_interleave.non_tensor_batch["labels"] == expected_labels_no_interleave).all() assert repeated_data_no_interleave.meta_info == {"info": "test_info"} def test_dataproto_unfold_column_chunks(): obs1 = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]) obs2 = torch.tensor([[1, 2], [5, 6], [9, 10]]) labels = ["a", "b", "c"] data = DataProto.from_dict( tensors={"obs1": obs1, "obs2": obs2}, non_tensors={"labels": labels}, meta_info={"name": "abc"} ) ret = data.unfold_column_chunks(2, split_keys=["obs1"]) expect_obs1 = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]) expect_obs2 = torch.tensor([[1, 2], [1, 2], [5, 6], [5, 6], [9, 10], [9, 10]]) expect_labels = ["a", "a", "b", "b", "c", "c"] assert torch.all(torch.eq(ret.batch["obs1"], expect_obs1)) assert torch.all(torch.eq(ret.batch["obs2"], expect_obs2)) assert (ret.non_tensor_batch["labels"] == expect_labels).all() assert ret.meta_info == {"name": "abc"} obs1 = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]) obs2 = torch.tensor([[1, 2], [5, 6], [9, 10]]) labels = [["a1", "a2"], ["b1", "b2"], ["c1", "c2"]] data = DataProto.from_dict( tensors={"obs1": obs1, "obs2": obs2}, non_tensors={"labels": labels}, meta_info={"name": "abc"} ) ret = data.unfold_column_chunks(2, split_keys=["obs1", "labels"]) expect_obs1 = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]) expect_obs2 = torch.tensor([[1, 2], [1, 2], [5, 6], [5, 6], [9, 10], [9, 10]]) expect_labels = [["a1"], ["a2"], ["b1"], ["b2"], ["c1"], ["c2"]] assert torch.all(torch.eq(ret.batch["obs1"], expect_obs1)) assert torch.all(torch.eq(ret.batch["obs2"], expect_obs2)) assert (ret.non_tensor_batch["labels"] == expect_labels).all() assert ret.meta_info == {"name": "abc"} obs1 = torch.tensor( [[[1, 1], [2, 2], [3, 3], [4, 4]], [[5, 5], [6, 6], [7, 7], [8, 8]], [[9, 9], [10, 10], [11, 11], [12, 12]]] ) obs2 = torch.tensor([[[1, 1], [2, 2]], [[5, 5], [6, 6]], [[9, 9], [10, 10]]]) labels = ["a", "b", "c"] data = DataProto.from_dict( tensors={"obs1": obs1, "obs2": obs2}, non_tensors={"labels": labels}, meta_info={"name": "abc"} ) ret = data.unfold_column_chunks(2, split_keys=["obs1"]) expect_obs1 = torch.tensor( [ [[1, 1], [2, 2]], [[3, 3], [4, 4]], [[5, 5], [6, 6]], [[7, 7], [8, 8]], [[9, 9], [10, 10]], [[11, 11], [12, 12]], ] ) expect_obs2 = torch.tensor( [[[1, 1], [2, 2]], [[1, 1], [2, 2]], [[5, 5], [6, 6]], [[5, 5], [6, 6]], [[9, 9], [10, 10]], [[9, 9], [10, 10]]] ) expect_labels = ["a", "a", "b", "b", "c", "c"] assert torch.all(torch.eq(ret.batch["obs1"], expect_obs1)) assert torch.all(torch.eq(ret.batch["obs2"], expect_obs2)) assert (ret.non_tensor_batch["labels"] == expect_labels).all() assert ret.meta_info == {"name": "abc"} def test_dataproto_chunk_after_index(): data_len = 4 obs = torch.randn(data_len, 4) labels = [f"label_{i}" for i in range(data_len)] data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"name": "abc"}) # Test with boolean numpy array bool_mask = np.array([True, False, True, False]) selected = data[bool_mask] assert isinstance(selected.batch.batch_size, torch.Size) assert all(isinstance(d, int) for d in selected.batch.batch_size) # int or List[int] # Test with integer numpy array int_mask = np.array([0, 2]) selected = data[int_mask] assert isinstance(selected.batch.batch_size, torch.Size) assert all(isinstance(d, int) for d in selected.batch.batch_size) # Test with boolean list list_mask = [True, False, True, False] selected = data[list_mask] assert isinstance(selected.batch.batch_size, torch.Size) assert all(isinstance(d, int) for d in selected.batch.batch_size) # Test with list list_mask = [0, 2] selected = data[list_mask] assert isinstance(selected.batch.batch_size, torch.Size) assert all(isinstance(d, int) for d in selected.batch.batch_size) # Test with torch tensor (bool) torch_bool_mask = torch.tensor([True, False, True, False]) selected = data[torch_bool_mask] assert isinstance(selected.batch.batch_size, torch.Size) assert all(isinstance(d, int) for d in selected.batch.batch_size) # Test with torch tensor (int) torch_int_mask = torch.tensor([0, 2]) selected = data[torch_int_mask] assert isinstance(selected.batch.batch_size, torch.Size) assert all(isinstance(d, int) for d in selected.batch.batch_size) @pytest.mark.skipif( parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10" ) def test_to_tensordict(): obs = torch.tensor([1, 2, 3, 4, 5, 6]) labels = ["a", "b", "c", "d", "e", "f"] data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"name": "abdce"}) output = data.to_tensordict() assert torch.all(torch.eq(output["obs"], obs)).item() assert output["labels"] == labels assert output["name"] == "abdce" @pytest.mark.skipif( parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10" ) def test_from_tensordict(): tensor_dict = { "obs": torch.tensor([1, 2, 3, 4, 5, 6]), "labels": ["a", "b", "c", "d", "e", "f"], } non_tensor_dict = {"name": "abdce"} tensordict = tu.get_tensordict(tensor_dict, non_tensor_dict) data = DataProto.from_tensordict(tensordict) assert data.non_tensor_batch["labels"].tolist() == tensor_dict["labels"] assert torch.all(torch.eq(data.batch["obs"], tensor_dict["obs"])).item() assert data.meta_info["name"] == "abdce" @pytest.mark.skipif( parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10" ) def test_to_tensordict_with_nested_lists(): """Test converting DataProto with nested lists to TensorDict (lists of lists).""" obs = torch.tensor([1, 2, 3]) # Simulate turn_scores or tool_rewards: array of lists with varying lengths turn_scores = [[], [0.5, 0.8], [0.9]] data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"turn_scores": turn_scores}) # This should not raise an error tensordict_output = data.to_tensordict() # Verify the data is preserved assert torch.all(torch.eq(tensordict_output["obs"], obs)).item() # Verify nested structure is accessible (TensorDict wraps NonTensorStack as LinkedList) retrieved_scores = tensordict_output["turn_scores"] assert len(retrieved_scores) == len(turn_scores) # Verify content matches assert list(retrieved_scores[0]) == [] assert list(retrieved_scores[1]) == [0.5, 0.8] assert list(retrieved_scores[2]) == [0.9] @pytest.mark.skipif( parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10" ) def test_to_tensordict_with_nested_dicts(): """Test converting DataProto with lists of dicts to TensorDict.""" obs = torch.tensor([1, 2, 3]) # Simulate reward_extra_info: array of dicts reward_extra_info = [{"acc": 1.0}, {"acc": 0.0}, {"acc": 1.0}] data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"reward_extra_info": reward_extra_info}) # This should not raise an error - this was the original bug tensordict_output = data.to_tensordict() # Verify the data is preserved assert torch.all(torch.eq(tensordict_output["obs"], obs)).item() # Verify nested dicts are accessible retrieved_info = tensordict_output["reward_extra_info"] assert len(retrieved_info) == len(reward_extra_info) # Verify content matches for i, expected_dict in enumerate(reward_extra_info): assert dict(retrieved_info[i]) == expected_dict @pytest.mark.skipif( parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10" ) def test_to_tensordict_with_complex_nested_structures(): """Test converting DataProto with complex nested structures (lists of lists of dicts).""" obs = torch.tensor([1, 2, 3]) # Simulate raw_prompt: array of lists containing dicts raw_prompt = [ [{"content": "Question 1", "role": "user"}], [{"content": "Question 2", "role": "user"}, {"content": "Answer 2", "role": "assistant"}], [{"content": "Question 3", "role": "user"}], ] data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"raw_prompt": raw_prompt}) # This should not raise an error tensordict_output = data.to_tensordict() # Verify the data is preserved assert torch.all(torch.eq(tensordict_output["obs"], obs)).item() # Verify complex nested structure is accessible retrieved_prompt = tensordict_output["raw_prompt"] assert len(retrieved_prompt) == len(raw_prompt) # Spot check: verify first prompt has correct structure assert len(retrieved_prompt[0]) == 1 assert dict(retrieved_prompt[0][0]) == {"content": "Question 1", "role": "user"} @pytest.mark.skipif( parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10" ) def test_to_tensordict_and_back_with_nested_data(): """Test round-trip conversion: DataProto → TensorDict → DataProto with nested structures.""" obs = torch.tensor([1, 2, 3, 4]) labels = ["a", "b", "c", "d"] # Multiple types of nested structures turn_scores = [[], [0.5], [0.8, 0.9], [0.7]] reward_extra_info = [ {"acc": 1.0, "loss": 0.1}, {"acc": 0.5, "loss": 0.3}, {"acc": 1.0, "loss": 0.05}, {"acc": 0.0, "loss": 0.9}, ] raw_prompt = [ [{"content": "Q1", "role": "user"}], [{"content": "Q2", "role": "user"}], [{"content": "Q3", "role": "user"}, {"content": "A3", "role": "assistant"}], [{"content": "Q4", "role": "user"}], ] # Create original DataProto original_data = DataProto.from_dict( tensors={"obs": obs}, non_tensors={ "labels": labels, "turn_scores": turn_scores, "reward_extra_info": reward_extra_info, "raw_prompt": raw_prompt, }, meta_info={"experiment": "test_nested"}, ) # Convert to TensorDict tensordict_output = original_data.to_tensordict() # Convert back to DataProto reconstructed_data = DataProto.from_tensordict(tensordict_output) # Verify tensors are preserved assert torch.all(torch.eq(reconstructed_data.batch["obs"], obs)).item() # Verify non-tensor data is preserved assert reconstructed_data.non_tensor_batch["labels"].tolist() == labels # Verify nested structures are preserved assert len(reconstructed_data.non_tensor_batch["turn_scores"]) == len(turn_scores) for orig, recon in zip(turn_scores, reconstructed_data.non_tensor_batch["turn_scores"], strict=True): assert list(orig) == list(recon) assert len(reconstructed_data.non_tensor_batch["reward_extra_info"]) == len(reward_extra_info) for orig, recon in zip(reward_extra_info, reconstructed_data.non_tensor_batch["reward_extra_info"], strict=True): assert orig == recon assert len(reconstructed_data.non_tensor_batch["raw_prompt"]) == len(raw_prompt) for orig, recon in zip(raw_prompt, reconstructed_data.non_tensor_batch["raw_prompt"], strict=True): assert orig == list(recon) # Verify meta_info is preserved assert reconstructed_data.meta_info["experiment"] == "test_nested" @pytest.mark.skipif( parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10" ) def test_to_tensordict_agent_loop_scenario(): """Test the exact scenario from agent loop: DataProto with tool rewards, acc, etc. This test reproduces the exact error from the agent loop where nested structures (lists of lists, lists of dicts) failed to convert to TensorDict. """ # Simulate real agent loop data structure prompts = torch.tensor([[1, 2, 3], [4, 5, 6]]) responses = torch.tensor([[7, 8], [9, 10]]) # Non-tensor data with nested structures from agent loop data_source = ["lighteval/MATH", "lighteval/MATH"] uid = ["uuid-1", "uuid-2"] num_turns = np.array([2, 4], dtype=np.int32) acc = np.array([1.0, 0.0]) turn_scores = [[], [0.5, 0.8]] # Lists of varying lengths reward_extra_info = [{"acc": 1.0}, {"acc": 0.0}] # List of dicts raw_prompt = [ [{"content": "Compute 4 @ 2", "role": "user"}], [{"content": "Compute 8 @ 7", "role": "user"}], ] tool_rewards = [[0.0], []] # List of lists data = DataProto.from_dict( tensors={"prompts": prompts, "responses": responses}, non_tensors={ "data_source": data_source, "uid": uid, "num_turns": num_turns, "acc": acc, "turn_scores": turn_scores, "reward_extra_info": reward_extra_info, "raw_prompt": raw_prompt, "tool_rewards": tool_rewards, }, meta_info={"global_steps": 42}, ) # THE KEY TEST: This should not raise ValueError about TensorDict conversion tensordict_output = data.to_tensordict() # Verify tensors are accessible assert torch.all(torch.eq(tensordict_output["prompts"], prompts)).item() assert torch.all(torch.eq(tensordict_output["responses"], responses)).item() # Verify all nested structures are accessible (content check, not type check) assert len(tensordict_output["turn_scores"]) == 2 assert list(tensordict_output["turn_scores"][0]) == [] assert list(tensordict_output["turn_scores"][1]) == [0.5, 0.8] assert len(tensordict_output["reward_extra_info"]) == 2 assert dict(tensordict_output["reward_extra_info"][0]) == {"acc": 1.0} assert len(tensordict_output["raw_prompt"]) == 2 assert dict(tensordict_output["raw_prompt"][0][0]) == {"content": "Compute 4 @ 2", "role": "user"} assert len(tensordict_output["tool_rewards"]) == 2 assert list(tensordict_output["tool_rewards"][0]) == [0.0] assert list(tensordict_output["tool_rewards"][1]) == [] # Verify round-trip conversion works perfectly reconstructed = DataProto.from_tensordict(tensordict_output) assert len(reconstructed) == 2 assert reconstructed.meta_info["global_steps"] == 42 assert torch.all(torch.eq(reconstructed.batch["prompts"], prompts)).item() def test_serialize_deserialize_single_tensor(): """Test serialization and deserialization of a single tensor""" # Create test tensor original_tensor = torch.randn(3, 4, 5) # Serialize dtype, shape, data = serialize_single_tensor(original_tensor) # Deserialize reconstructed_tensor = deserialize_single_tensor((dtype, shape, data)) # Verify results assert torch.allclose(original_tensor, reconstructed_tensor) assert original_tensor.shape == reconstructed_tensor.shape assert original_tensor.dtype == reconstructed_tensor.dtype def test_serialize_deserialize_tensordict_regular_tensors(): """Test serialization and deserialization of TensorDict with regular tensors""" # Create test data batch_size = (5, 3) tensor1 = torch.randn(*batch_size, 4) tensor2 = torch.randint(0, 10, (*batch_size, 2)) # Create TensorDict original_tensordict = TensorDict({"tensor1": tensor1, "tensor2": tensor2}, batch_size=batch_size) # Serialize batch_size_serialized, device, encoded_items = serialize_tensordict(original_tensordict) # Deserialize reconstructed_tensordict = deserialize_tensordict((batch_size_serialized, device, encoded_items)) # Verify results assert original_tensordict.batch_size == reconstructed_tensordict.batch_size assert set(original_tensordict.keys()) == set(reconstructed_tensordict.keys()) for key in original_tensordict.keys(): original_tensor = original_tensordict[key] reconstructed_tensor = reconstructed_tensordict[key] assert torch.allclose(original_tensor, reconstructed_tensor) assert original_tensor.shape == reconstructed_tensor.shape assert original_tensor.dtype == reconstructed_tensor.dtype def test_serialize_deserialize_tensordict_nested_tensors(): """Test serialization and deserialization of TensorDict with nested tensors""" # Create nested tensor tensor_list = [torch.randn(2, 3), torch.randn(3, 4), torch.randn(1, 5)] nested_tensor = torch.nested.as_nested_tensor(tensor_list) # Create regular tensor for comparison regular_tensor = torch.randn(3, 4, 5) # Create TensorDict original_tensordict = TensorDict({"nested": nested_tensor, "regular": regular_tensor}, batch_size=(3,)) # Serialize batch_size_serialized, device, encoded_items = serialize_tensordict(original_tensordict) # Deserialize reconstructed_tensordict = deserialize_tensordict((batch_size_serialized, device, encoded_items)) # Verify results assert original_tensordict.batch_size == reconstructed_tensordict.batch_size assert set(original_tensordict.keys()) == set(reconstructed_tensordict.keys()) # Verify regular tensor original_regular = original_tensordict["regular"] reconstructed_regular = reconstructed_tensordict["regular"] assert torch.allclose(original_regular, reconstructed_regular) assert original_regular.shape == reconstructed_regular.shape assert original_regular.dtype == reconstructed_regular.dtype # Verify nested tensor original_nested = original_tensordict["nested"] reconstructed_nested = reconstructed_tensordict["nested"] # Check if it's a nested tensor assert original_nested.is_nested assert reconstructed_nested.is_nested # Check layout assert original_nested.layout == reconstructed_nested.layout # Check each tensor after unbinding original_unbind = original_nested.unbind() reconstructed_unbind = reconstructed_nested.unbind() assert len(original_unbind) == len(reconstructed_unbind) for orig, recon in zip(original_unbind, reconstructed_unbind, strict=False): assert torch.allclose(orig, recon) assert orig.shape == recon.shape assert orig.dtype == recon.dtype def test_serialize_deserialize_tensordict_mixed_types(): """Test serialization and deserialization of TensorDict with mixed tensor types""" # Create tensors with different data types float_tensor = torch.randn(2, 3).float() double_tensor = torch.randn(2, 3).double() int_tensor = torch.randint(0, 10, (2, 3)).int() long_tensor = torch.randint(0, 10, (2, 3)).long() bool_tensor = torch.tensor([[True, False], [False, True]]) bfloat16_tensor = torch.randn(2, 3).bfloat16() # Add fp8 tensor (if available) # Note: FP8 is not natively supported in all PyTorch versions # We'll check if it's available and conditionally include it has_fp8 = hasattr(torch, "float8_e5m2") or hasattr(torch, "float8_e4m3fn") if has_fp8: try: # Try to create an FP8 tensor (implementation may vary) # This is a placeholder - actual FP8 support might require specific hardware fp8_tensor = torch.randn(2, 3) if hasattr(torch, "float8_e5m2"): fp8_tensor = fp8_tensor.to(torch.float8_e5m2) elif hasattr(torch, "float8_e4m3fn"): fp8_tensor = fp8_tensor.to(torch.float8_e4m3fn) except Exception: has_fp8 = False # Create nested tensor tensor_list = [ torch.randn(2, 3), torch.randn(3, 4), ] nested_tensor = torch.nested.as_nested_tensor(tensor_list) # Create TensorDict with all available types tensordict_data = { "float": float_tensor, "double": double_tensor, "int": int_tensor, "long": long_tensor, "bool": bool_tensor, "bfloat16": bfloat16_tensor, "nested": nested_tensor, } # Conditionally add fp8 tensor if available if has_fp8: tensordict_data["fp8"] = fp8_tensor original_tensordict = TensorDict( tensordict_data, batch_size=(2,), ) # Serialize batch_size_serialized, device, encoded_items = serialize_tensordict(original_tensordict) # Deserialize reconstructed_tensordict = deserialize_tensordict((batch_size_serialized, device, encoded_items)) # Verify results assert original_tensordict.batch_size == reconstructed_tensordict.batch_size assert set(original_tensordict.keys()) == set(reconstructed_tensordict.keys()) for key in original_tensordict.keys(): original_tensor = original_tensordict[key] reconstructed_tensor = reconstructed_tensordict[key] if original_tensor.is_nested: # For nested tensors, check each tensor after unbinding original_unbind = original_tensor.unbind() reconstructed_unbind = reconstructed_tensor.unbind() assert len(original_unbind) == len(reconstructed_unbind) for orig, recon in zip(original_unbind, reconstructed_unbind, strict=False): assert torch.allclose(orig, recon, equal_nan=True) assert orig.shape == recon.shape assert orig.dtype == recon.dtype else: # For regular tensors, compare directly assert torch.all(original_tensor == reconstructed_tensor) assert original_tensor.shape == reconstructed_tensor.shape assert original_tensor.dtype == reconstructed_tensor.dtype def test_serialize_deserialize_tensordict_with_device(): """Test serialization and deserialization of TensorDict with device information""" # Create test data batch_size = (2, 3) tensor1 = torch.randn(*batch_size, 4) tensor2 = torch.randint(0, 10, (*batch_size, 2)) # Create TensorDict with device information device = "cpu" original_tensordict = TensorDict({"tensor1": tensor1, "tensor2": tensor2}, batch_size=batch_size, device=device) # Serialize batch_size_serialized, device_serialized, encoded_items = serialize_tensordict(original_tensordict) # Deserialize reconstructed_tensordict = deserialize_tensordict((batch_size_serialized, device_serialized, encoded_items)) # Verify results assert original_tensordict.batch_size == reconstructed_tensordict.batch_size assert str(original_tensordict.device) == str(reconstructed_tensordict.device) assert set(original_tensordict.keys()) == set(reconstructed_tensordict.keys()) for key in original_tensordict.keys(): original_tensor = original_tensordict[key] reconstructed_tensor = reconstructed_tensordict[key] assert torch.allclose(original_tensor.cpu(), reconstructed_tensor.cpu()) assert original_tensor.shape == reconstructed_tensor.shape assert original_tensor.dtype == reconstructed_tensor.dtype def test_serialize_dataproto_with_empty_tensordict(): """Tests that serializing a DataProto with an empty TensorDict does not crash. This test verifies the fix for the torch.cat error that occurs when calling consolidate() on an empty TensorDict during serialization. """ import pickle # This test requires tensordict >= 0.5.0 to trigger the code path if parse_version(tensordict.__version__) < parse_version("0.5.0"): pytest.skip("Test requires tensordict>=0.5.0") # Create a DataProto with an empty TensorDict but with a batch size empty_td = TensorDict({}, batch_size=[10]) data = DataProto(batch=empty_td) # This would crash before the fix with: # RuntimeError: torch.cat(): expected a non-empty list of Tensors try: serialized_data = pickle.dumps(data) except Exception as e: pytest.fail(f"Serializing DataProto with empty TensorDict failed with: {e}") # Verify deserialization works as expected deserialized_data = pickle.loads(serialized_data) assert len(deserialized_data.batch.keys()) == 0 assert deserialized_data.batch.batch_size == torch.Size([10])