| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Implementation of data preprocessing ops for VTAB. |
| |
| All preprocessing ops should return a data processing functors. A data |
| is represented as a dictionary of tensors, where field "image" is reserved |
| for 3D images (height x width x channels). The functors output dictionary with |
| field "image" being modified. Potentially, other fields can also be modified |
| or added. |
| """ |
| import numpy as np |
| from scenic.dataset_lib.big_transfer.registry import Registry |
| import tensorflow.compat.v1 as tf |
|
|
|
|
| @Registry.register("preprocess_ops.dsprites_pp", "function") |
| def get_dsprites_pp(predicted_attribute, num_classes=None): |
| """Data preprocess function for dsprites dataset.""" |
|
|
| attribute_to_classes = { |
| "label_shape": 3, |
| "label_scale": 6, |
| "label_orientation": 40, |
| "label_x_position": 32, |
| "label_y_position": 32, |
| } |
|
|
| def _dsprites_pp(data): |
| |
| |
| |
| data["image"] = data["image"] * 255 |
|
|
| |
| |
| |
| num_original_classes = attribute_to_classes[predicted_attribute] |
| n_cls = num_original_classes if num_classes is None else num_classes |
| if not isinstance(n_cls, int) or n_cls <= 1 or n_cls > num_original_classes: |
| raise ValueError( |
| "The number of classes should be None or in [2, ..., num_classes].") |
| class_division_factor = float(num_original_classes) / n_cls |
|
|
| data["label"] = tf.cast( |
| tf.math.floordiv( |
| tf.cast(data[predicted_attribute], tf.float32), |
| class_division_factor), data[predicted_attribute].dtype) |
| return data |
|
|
| return _dsprites_pp |
|
|
|
|
| @Registry.register("preprocess_ops.clevr_pp", "function") |
| def get_clevr_pp(task, outkey="label"): |
| """Data preprocess function for clevr dataset.""" |
|
|
| def _count_preprocess_fn(data): |
| data[outkey] = tf.size(data["objects"]["size"]) - 3 |
| return data |
|
|
| def _closest_object_preprocess_fn(data): |
| dist = tf.reduce_min(data["objects"]["pixel_coords"][:, 2]) |
| |
| |
| thrs = np.array([0.0, 8.0, 8.5, 9.0, 9.5, 10.0, 100.0]) |
| data[outkey] = tf.reduce_max(tf.where((thrs - dist) < 0)) |
| return data |
|
|
| task_to_preprocess = { |
| "count_all": _count_preprocess_fn, |
| "closest_object_distance": _closest_object_preprocess_fn, |
| } |
|
|
| return task_to_preprocess[task] |
|
|
|
|
| @Registry.register("preprocess_ops.kitti_pp", "function") |
| def get_kitti_pp(task): |
| """Data preprocess function for kitti dataset.""" |
|
|
| def _closest_vehicle_distance_pp(data): |
| """Predict the distance to the closest vehicle.""" |
| |
| vehicles = tf.where(data["objects"]["type"] < 3) |
| vehicle_z = tf.gather( |
| params=data["objects"]["location"][:, 2], indices=vehicles) |
| vehicle_z = tf.concat([vehicle_z, tf.constant([[1000.0]])], axis=0) |
| dist = tf.reduce_min(vehicle_z) |
| |
| |
| thrs = np.array([-100.0, 8.0, 20.0, 999.0]) |
| label = tf.reduce_max(tf.where((thrs - dist) < 0)) |
| return {"image": data["image"], "label": label} |
|
|
| task_to_preprocess = { |
| "closest_vehicle_distance": _closest_vehicle_distance_pp, |
| } |
|
|
| return task_to_preprocess[task] |
|
|