| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| r"""Main file for PCT Scenic. |
| |
| """ |
|
|
| from absl import flags |
| from absl import logging |
| from clu import metric_writers |
| import jax |
| import jax.numpy as jnp |
| import ml_collections |
| from scenic import app |
| from scenic.projects.pointcloud import segmentation_model |
| from scenic.projects.pointcloud import segmentation_trainer |
| from scenic.projects.pointcloud import shapenet_dataset |
|
|
| FLAGS = flags.FLAGS |
|
|
|
|
| def main(rng: jnp.ndarray, config: ml_collections.ConfigDict, workdir: str, |
| writer: metric_writers.MetricWriter): |
| """Main function for Scenic.""" |
|
|
| model_cls = segmentation_model.PointCloudTransformerSegmentationModel |
| data_rng, rng = jax.random.split(rng) |
|
|
| |
| device_count = jax.device_count() |
| logging.info('device_count: %d', device_count) |
| logging.info('num_hosts : %d', jax.process_count()) |
| logging.info('host_id : %d', jax.process_index()) |
|
|
| logging.info('rng: %s', rng) |
| batch_size = config.batch_size |
| if batch_size % device_count > 0: |
| raise ValueError(f'Batch size ({batch_size}) must be divisible by the ' |
| f'number of devices ({device_count})') |
|
|
| eval_batch_size = config.get('eval_batch_size', batch_size) |
| if eval_batch_size % device_count > 0: |
| raise ValueError(f'Eval batch size ({eval_batch_size}) must be divisible ' |
| f'by the number of devices ({device_count})') |
|
|
| local_batch_size = batch_size // jax.process_count() |
| eval_local_batch_size = eval_batch_size // jax.process_count() |
| device_batch_size = batch_size // device_count |
| logging.info('local_batch_size : %d', local_batch_size) |
| logging.info('device_batch_size : %d', device_batch_size) |
|
|
| dataset = shapenet_dataset.get_dataset( |
| batch_size=local_batch_size, |
| eval_batch_size=eval_local_batch_size, |
| num_shards=jax.local_device_count(), |
| dtype_str=config.data_dtype_str, |
| shuffle_seed=0, |
| rng=data_rng, |
| prefetch_buffer_size=2, |
| dataset_configs=None, |
| dataset_service_address=FLAGS.dataset_service_address, |
| seq_length=config.max_seq_len) |
|
|
| segmentation_trainer.train( |
| rng=rng, |
| config=config, |
| model_cls=model_cls, |
| dataset=dataset, |
| workdir=workdir, |
| writer=writer) |
|
|
|
|
| if __name__ == '__main__': |
| app.run(main=main) |
|
|