owlv2 / scenic /projects /pointcloud /main_seg.py
fcxfcx's picture
Upload 2446 files
1327f34 verified
# Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Main file for PCT Scenic.
"""
from absl import flags
from absl import logging
from clu import metric_writers
import jax
import jax.numpy as jnp
import ml_collections
from scenic import app
from scenic.projects.pointcloud import segmentation_model
from scenic.projects.pointcloud import segmentation_trainer
from scenic.projects.pointcloud import shapenet_dataset
FLAGS = flags.FLAGS
def main(rng: jnp.ndarray, config: ml_collections.ConfigDict, workdir: str,
writer: metric_writers.MetricWriter):
"""Main function for Scenic."""
model_cls = segmentation_model.PointCloudTransformerSegmentationModel
data_rng, rng = jax.random.split(rng)
# Dataset loading
device_count = jax.device_count()
logging.info('device_count: %d', device_count)
logging.info('num_hosts : %d', jax.process_count())
logging.info('host_id : %d', jax.process_index())
logging.info('rng: %s', rng)
batch_size = config.batch_size
if batch_size % device_count > 0:
raise ValueError(f'Batch size ({batch_size}) must be divisible by the '
f'number of devices ({device_count})')
eval_batch_size = config.get('eval_batch_size', batch_size)
if eval_batch_size % device_count > 0:
raise ValueError(f'Eval batch size ({eval_batch_size}) must be divisible '
f'by the number of devices ({device_count})')
local_batch_size = batch_size // jax.process_count()
eval_local_batch_size = eval_batch_size // jax.process_count()
device_batch_size = batch_size // device_count
logging.info('local_batch_size : %d', local_batch_size)
logging.info('device_batch_size : %d', device_batch_size)
dataset = shapenet_dataset.get_dataset(
batch_size=local_batch_size,
eval_batch_size=eval_local_batch_size,
num_shards=jax.local_device_count(),
dtype_str=config.data_dtype_str,
shuffle_seed=0,
rng=data_rng,
prefetch_buffer_size=2,
dataset_configs=None,
dataset_service_address=FLAGS.dataset_service_address,
seq_length=config.max_seq_len)
segmentation_trainer.train(
rng=rng,
config=config,
model_cls=model_cls,
dataset=dataset,
workdir=workdir,
writer=writer)
if __name__ == '__main__':
app.run(main=main)