Spaces:
Running
Running
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| import io | |
| import tempfile | |
| import unittest | |
| from contextlib import ExitStack | |
| import torch | |
| import torch.distributed as dist | |
| import torch.multiprocessing as mp | |
| from detectron2.utils import comm | |
| from densepose.evaluation.tensor_storage import ( | |
| SingleProcessFileTensorStorage, | |
| SingleProcessRamTensorStorage, | |
| SizeData, | |
| storage_gather, | |
| ) | |
| class TestSingleProcessRamTensorStorage(unittest.TestCase): | |
| def test_read_write_1(self): | |
| schema = { | |
| "tf": SizeData(dtype="float32", shape=(112, 112)), | |
| "ti": SizeData(dtype="int32", shape=(4, 64, 64)), | |
| } | |
| # generate data which corresponds to the schema | |
| data_elts = [] | |
| torch.manual_seed(23) | |
| for _i in range(3): | |
| data_elt = { | |
| "tf": torch.rand((112, 112), dtype=torch.float32), | |
| "ti": (torch.rand(4, 64, 64) * 1000).to(dtype=torch.int32), | |
| } | |
| data_elts.append(data_elt) | |
| storage = SingleProcessRamTensorStorage(schema, io.BytesIO()) | |
| # write data to the storage | |
| for i in range(3): | |
| record_id = storage.put(data_elts[i]) | |
| self.assertEqual(record_id, i) | |
| # read data from the storage | |
| for i in range(3): | |
| record = storage.get(i) | |
| self.assertEqual(len(record), len(schema)) | |
| for field_name in schema: | |
| self.assertTrue(field_name in record) | |
| self.assertEqual(data_elts[i][field_name].shape, record[field_name].shape) | |
| self.assertEqual(data_elts[i][field_name].dtype, record[field_name].dtype) | |
| self.assertTrue(torch.allclose(data_elts[i][field_name], record[field_name])) | |
| class TestSingleProcessFileTensorStorage(unittest.TestCase): | |
| def test_read_write_1(self): | |
| schema = { | |
| "tf": SizeData(dtype="float32", shape=(112, 112)), | |
| "ti": SizeData(dtype="int32", shape=(4, 64, 64)), | |
| } | |
| # generate data which corresponds to the schema | |
| data_elts = [] | |
| torch.manual_seed(23) | |
| for _i in range(3): | |
| data_elt = { | |
| "tf": torch.rand((112, 112), dtype=torch.float32), | |
| "ti": (torch.rand(4, 64, 64) * 1000).to(dtype=torch.int32), | |
| } | |
| data_elts.append(data_elt) | |
| # WARNING: opens the file several times! may not work on all platforms | |
| with tempfile.NamedTemporaryFile() as hFile: | |
| storage = SingleProcessFileTensorStorage(schema, hFile.name, "wb") | |
| # write data to the storage | |
| for i in range(3): | |
| record_id = storage.put(data_elts[i]) | |
| self.assertEqual(record_id, i) | |
| hFile.seek(0) | |
| storage = SingleProcessFileTensorStorage(schema, hFile.name, "rb") | |
| # read data from the storage | |
| for i in range(3): | |
| record = storage.get(i) | |
| self.assertEqual(len(record), len(schema)) | |
| for field_name in schema: | |
| self.assertTrue(field_name in record) | |
| self.assertEqual(data_elts[i][field_name].shape, record[field_name].shape) | |
| self.assertEqual(data_elts[i][field_name].dtype, record[field_name].dtype) | |
| self.assertTrue(torch.allclose(data_elts[i][field_name], record[field_name])) | |
| def _find_free_port(): | |
| """ | |
| Copied from detectron2/engine/launch.py | |
| """ | |
| import socket | |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
| # Binding to port 0 will cause the OS to find an available port for us | |
| sock.bind(("", 0)) | |
| port = sock.getsockname()[1] | |
| sock.close() | |
| # NOTE: there is still a chance the port could be taken by other processes. | |
| return port | |
| def launch(main_func, nprocs, args=()): | |
| port = _find_free_port() | |
| dist_url = f"tcp://127.0.0.1:{port}" | |
| # dist_url = "env://" | |
| mp.spawn( | |
| distributed_worker, nprocs=nprocs, args=(main_func, nprocs, dist_url, args), daemon=False | |
| ) | |
| def distributed_worker(local_rank, main_func, nprocs, dist_url, args): | |
| dist.init_process_group( | |
| backend="gloo", init_method=dist_url, world_size=nprocs, rank=local_rank | |
| ) | |
| comm.synchronize() | |
| assert comm._LOCAL_PROCESS_GROUP is None | |
| pg = dist.new_group(list(range(nprocs))) | |
| comm._LOCAL_PROCESS_GROUP = pg | |
| main_func(*args) | |
| def ram_read_write_worker(): | |
| schema = { | |
| "tf": SizeData(dtype="float32", shape=(112, 112)), | |
| "ti": SizeData(dtype="int32", shape=(4, 64, 64)), | |
| } | |
| storage = SingleProcessRamTensorStorage(schema, io.BytesIO()) | |
| world_size = comm.get_world_size() | |
| rank = comm.get_rank() | |
| data_elts = [] | |
| # prepare different number of tensors in different processes | |
| for i in range(rank + 1): | |
| data_elt = { | |
| "tf": torch.ones((112, 112), dtype=torch.float32) * (rank + i * world_size), | |
| "ti": torch.ones((4, 64, 64), dtype=torch.int32) * (rank + i * world_size), | |
| } | |
| data_elts.append(data_elt) | |
| # write data to the single process storage | |
| for i in range(rank + 1): | |
| record_id = storage.put(data_elts[i]) | |
| assert record_id == i, f"Process {rank}: record ID {record_id}, expected {i}" | |
| comm.synchronize() | |
| # gather all data in process rank 0 | |
| multi_storage = storage_gather(storage) | |
| if rank != 0: | |
| return | |
| # read and check data from the multiprocess storage | |
| for j in range(world_size): | |
| for i in range(j): | |
| record = multi_storage.get(j, i) | |
| record_gt = { | |
| "tf": torch.ones((112, 112), dtype=torch.float32) * (j + i * world_size), | |
| "ti": torch.ones((4, 64, 64), dtype=torch.int32) * (j + i * world_size), | |
| } | |
| assert len(record) == len(schema), ( | |
| f"Process {rank}: multi storage record, rank {j}, id {i}: " | |
| f"expected {len(schema)} fields in the record, got {len(record)}" | |
| ) | |
| for field_name in schema: | |
| assert field_name in record, ( | |
| f"Process {rank}: multi storage record, rank {j}, id {i}: " | |
| f"field {field_name} not in the record" | |
| ) | |
| assert record_gt[field_name].shape == record[field_name].shape, ( | |
| f"Process {rank}: multi storage record, rank {j}, id {i}: " | |
| f"field {field_name}, expected shape {record_gt[field_name].shape} " | |
| f"got {record[field_name].shape}" | |
| ) | |
| assert record_gt[field_name].dtype == record[field_name].dtype, ( | |
| f"Process {rank}: multi storage record, rank {j}, id {i}: " | |
| f"field {field_name}, expected dtype {record_gt[field_name].dtype} " | |
| f"got {record[field_name].dtype}" | |
| ) | |
| assert torch.allclose(record_gt[field_name], record[field_name]), ( | |
| f"Process {rank}: multi storage record, rank {j}, id {i}: " | |
| f"field {field_name}, tensors are not close enough:" | |
| f"L-inf {(record_gt[field_name]-record[field_name]).abs_().max()} " | |
| f"L1 {(record_gt[field_name]-record[field_name]).abs_().sum()} " | |
| ) | |
| def file_read_write_worker(rank_to_fpath): | |
| schema = { | |
| "tf": SizeData(dtype="float32", shape=(112, 112)), | |
| "ti": SizeData(dtype="int32", shape=(4, 64, 64)), | |
| } | |
| world_size = comm.get_world_size() | |
| rank = comm.get_rank() | |
| storage = SingleProcessFileTensorStorage(schema, rank_to_fpath[rank], "wb") | |
| data_elts = [] | |
| # prepare different number of tensors in different processes | |
| for i in range(rank + 1): | |
| data_elt = { | |
| "tf": torch.ones((112, 112), dtype=torch.float32) * (rank + i * world_size), | |
| "ti": torch.ones((4, 64, 64), dtype=torch.int32) * (rank + i * world_size), | |
| } | |
| data_elts.append(data_elt) | |
| # write data to the single process storage | |
| for i in range(rank + 1): | |
| record_id = storage.put(data_elts[i]) | |
| assert record_id == i, f"Process {rank}: record ID {record_id}, expected {i}" | |
| comm.synchronize() | |
| # gather all data in process rank 0 | |
| multi_storage = storage_gather(storage) | |
| if rank != 0: | |
| return | |
| # read and check data from the multiprocess storage | |
| for j in range(world_size): | |
| for i in range(j): | |
| record = multi_storage.get(j, i) | |
| record_gt = { | |
| "tf": torch.ones((112, 112), dtype=torch.float32) * (j + i * world_size), | |
| "ti": torch.ones((4, 64, 64), dtype=torch.int32) * (j + i * world_size), | |
| } | |
| assert len(record) == len(schema), ( | |
| f"Process {rank}: multi storage record, rank {j}, id {i}: " | |
| f"expected {len(schema)} fields in the record, got {len(record)}" | |
| ) | |
| for field_name in schema: | |
| assert field_name in record, ( | |
| f"Process {rank}: multi storage record, rank {j}, id {i}: " | |
| f"field {field_name} not in the record" | |
| ) | |
| assert record_gt[field_name].shape == record[field_name].shape, ( | |
| f"Process {rank}: multi storage record, rank {j}, id {i}: " | |
| f"field {field_name}, expected shape {record_gt[field_name].shape} " | |
| f"got {record[field_name].shape}" | |
| ) | |
| assert record_gt[field_name].dtype == record[field_name].dtype, ( | |
| f"Process {rank}: multi storage record, rank {j}, id {i}: " | |
| f"field {field_name}, expected dtype {record_gt[field_name].dtype} " | |
| f"got {record[field_name].dtype}" | |
| ) | |
| assert torch.allclose(record_gt[field_name], record[field_name]), ( | |
| f"Process {rank}: multi storage record, rank {j}, id {i}: " | |
| f"field {field_name}, tensors are not close enough:" | |
| f"L-inf {(record_gt[field_name]-record[field_name]).abs_().max()} " | |
| f"L1 {(record_gt[field_name]-record[field_name]).abs_().sum()} " | |
| ) | |
| class TestMultiProcessRamTensorStorage(unittest.TestCase): | |
| def test_read_write_1(self): | |
| launch(ram_read_write_worker, 8) | |
| class TestMultiProcessFileTensorStorage(unittest.TestCase): | |
| def test_read_write_1(self): | |
| with ExitStack() as stack: | |
| # WARNING: opens the files several times! may not work on all platforms | |
| rank_to_fpath = { | |
| i: stack.enter_context(tempfile.NamedTemporaryFile()).name for i in range(8) | |
| } | |
| launch(file_read_write_worker, 8, (rank_to_fpath,)) | |