| import time |
| import numpy as np |
| import xml.etree.ElementTree as ET |
|
|
| from jax import jit, vmap |
| import jax.numpy as jnp |
|
|
|
|
| def rot_x(o): |
| return jnp.array([ |
| [1, 0, 0], |
| [0, jnp.cos(o), jnp.sin(o)], |
| [0, -jnp.sin(o), jnp.cos(o)]]) |
|
|
|
|
| def rot_y(p): |
| return jnp.array([ |
| [jnp.cos(p), 0, -jnp.sin(p)], |
| [0, 1, 0], |
| [jnp.sin(p), 0, jnp.cos(p)]]) |
|
|
|
|
| def rot_z(k): |
| return jnp.array([ |
| [jnp.cos(k), jnp.sin(k), 0], |
| [-jnp.sin(k), jnp.cos(k), 0], |
| [0, 0, 1]]) |
|
|
|
|
| def rot_zyx(o, p, k): |
| return rot_z(k) @ rot_y(p) @ rot_x(o) |
|
|
|
|
| def rms_dict(_s): |
| R = rot_zyx(_s['omega'], _s['phi'], _s['kappa']) |
| M = jnp.array([_s['X'], _s['Y'], _s['Z']]) |
| S = jnp.array([_s['Xs'], _s['Ys'], _s['Zs']]) |
| RMS = R @ (M - S) |
| return RMS |
|
|
|
|
| def xy_frame(_s): |
| RMS = rms_dict(_s) |
| m = - RMS / RMS[2] |
| x, y, _ = m |
| z = -RMS[2] |
| return x, y, z |
|
|
|
|
| def corr_dist_agi(x, y, _s): |
| rc = x ** 2 + y ** 2 |
| dr = 1 + _s['k1'] * rc + _s['k2'] * rc ** 2 + _s['k3'] * rc ** 3 + _s['k4'] * rc ** 4 + _s['k5'] * rc ** 5 |
| drx = x * dr |
| dry = y * dr |
| |
| dtx = _s['p1'] * (rc + 2 * x ** 2) + 2 * _s['p2'] * x * y * (1 + _s['p3'] * rc + _s['p4'] * rc ** 2) |
| dty = _s['p2'] * (rc + 2 * y ** 2) + 2 * _s['p1'] * x * y * (1 + _s['p3'] * rc + _s['p4'] * rc ** 2) |
| xp = drx + dtx |
| yp = dry + dty |
| |
| fx = _s['width'] * 0.5 + _s['cx'] + xp * _s['f'] + xp * _s['b1'] + yp * _s['b2'] |
| fy = _s['height'] * 0.5 + _s['cy'] + yp * _s['f'] |
| return fx, fy |
|
|
|
|
| def f_frame_agi(_s): |
| x, y, z = xy_frame(_s) |
| w_2 = _s['width'] / _s['f'] / 2 |
| h_2 = _s['height'] / _s['f'] / 2 |
| ins = (x >= -w_2) & (x < w_2) & (y >= -h_2) & (y < h_2) & (z > 0) |
| y = -y |
| fx, fy = corr_dist_agi(x, y, _s) |
| return fx, fy, z, ins |
|
|
|
|
| def read_camera_file(filepath, offset): |
| data = {} |
| with open(filepath, 'r') as file: |
| for line in file: |
| if line.startswith("#"): |
| continue |
| values = line.strip().split() |
| if len(values) < 16: |
| continue |
| photo_id = values[0] |
| data[photo_id] = { |
| "Xs": float(values[1]) - offset[0], |
| "Ys": float(values[2]) - offset[1], |
| "Zs": float(values[3]) - offset[2], |
| "omega": np.radians(float(values[4])), |
| "phi": np.radians(float(values[5])), |
| "kappa": np.radians(float(values[6])), |
| } |
| return data |
|
|
| def parse_calibration_xml(file_path): |
| tree = ET.parse(file_path) |
| root = tree.getroot() |
|
|
| |
| calibration_data = { |
| |
| 'width': 0, |
| 'height': 0, |
| 'f': 0.0, |
| 'cx': 0.0, |
| 'cy': 0.0, |
| 'k1': 0.0, |
| 'k2': 0.0, |
| 'k3': 0.0, |
| 'k4': 0.0, |
| 'k5': 0.0, |
| 'p1': 0.0, |
| 'p2': 0.0, |
| 'p3': 0.0, |
| 'p4': 0.0, |
| 'b1': 0.0, |
| 'b2': 0.0, |
| |
| } |
|
|
| for element in root: |
| if element.tag in calibration_data: |
| if element.tag in ['projection', 'date']: |
| continue |
| elif element.tag in ['width', 'height']: |
| calibration_data[element.tag] = int(element.text) |
| else: |
| calibration_data[element.tag] = float(element.text) |
|
|
| return calibration_data |
|
|
|
|
| def get_pixel_values(image, i, j, in_bounds, nb_classes=10): |
| """ |
| Retrieve pixel values from `image` at specified `coords`, marking out-of-bounds |
| coordinates with `jnp.nan`. Uses only `jax.numpy` operations. |
| |
| Parameters: |
| image (jnp.ndarray): 2D image array. |
| fx, fy (jnp.ndarray): 2 arrays of shape (N, 1) representing i and j coordinates. |
| |
| Returns: |
| jnp.ndarray: Array of pixel values with `jnp.nan` for out-of-bounds coordinates. |
| """ |
| |
| pixel_values = jnp.full((i.shape[0], nb_classes), jnp.nan) |
| values = image[j[in_bounds], i[in_bounds]] |
| pixel_values = pixel_values.at[in_bounds].set(image[j[in_bounds], i[in_bounds]]) |
| pixel_values = pixel_values.at[in_bounds].set(values) |
| return pixel_values |
|
|
|
|
| def compute_depth_map(i, j, z, depth_map, buffer_size, threshold=0.05): |
| height, width = depth_map.shape |
| |
| offsets = jnp.arange(-buffer_size, buffer_size + 1) |
| |
| di, dj = jnp.meshgrid(offsets, offsets, indexing='ij') |
| di = di.ravel() |
| dj = dj.ravel() |
| |
| neighbor_i = (i[:, None] + di).clip(0, width - 1) |
| neighbor_j = (j[:, None] + dj).clip(0, height - 1) |
| neighbor_depths = jnp.repeat(z[:, None], len(di), axis=1) |
| |
| neighbor_i = neighbor_i.ravel() |
| neighbor_j = neighbor_j.ravel() |
| neighbor_depths = neighbor_depths.ravel() |
| |
| depth_map = depth_map.at[neighbor_j, neighbor_i].min(neighbor_depths) |
| |
| visibility = jnp.abs(depth_map[j, i] - z) <= threshold |
| return depth_map, visibility |
|
|
|
|
| def project_classes_into_image(i, j, classes, img_classes, buffer_size): |
|
|
| height, width = img_classes.shape |
| |
| offsets = jnp.arange(-buffer_size, buffer_size + 1) |
| |
| di, dj = jnp.meshgrid(offsets, offsets, indexing='ij') |
| di = di.ravel() |
| dj = dj.ravel() |
| |
| neighbor_i = (i[:, None] + di).clip(0, width - 1) |
| neighbor_j = (j[:, None] + dj).clip(0, height - 1) |
| neighbor_classes = jnp.repeat(classes[:, None], len(di), axis=1) |
| |
| neighbor_i = neighbor_i.ravel() |
| neighbor_j = neighbor_j.ravel() |
| neighbor_classes = neighbor_classes.ravel() |
| |
| img_classes = img_classes.at[neighbor_j, neighbor_i].set(neighbor_classes) |
| return img_classes |
|
|