bdck's picture
Upload depthpro_wrapper/io.py
ab4b137 verified
Raw
History Blame Contribute Delete
2.82 kB
"""
I/O helpers for images and point-cloud files.
"""
from __future__ import annotations
from pathlib import Path
from typing import Optional, Tuple, Union
import numpy as np
from PIL import Image
def load_image(path: Union[str, Path, Image.Image, np.ndarray]) -> Image.Image:
"""
Normalise any image input to a PIL RGB image.
Parameters
----------
path : str, Path, PIL.Image, or np.ndarray
If a string/Path, loaded from disk. If an ndarray, converted.
Returns
-------
PIL.Image.Image
RGB image ready for DepthPro.
"""
if isinstance(path, (str, Path)):
return Image.open(str(path)).convert("RGB")
if isinstance(path, np.ndarray):
if path.dtype != np.uint8:
path = (path * 255).clip(0, 255).astype(np.uint8)
return Image.fromarray(path).convert("RGB")
if isinstance(path, Image.Image):
return path.convert("RGB")
raise TypeError(f"Unsupported image type: {type(path)}")
def save_point_cloud(
path: Union[str, Path],
points: np.ndarray,
colors: Optional[np.ndarray] = None,
normals: Optional[np.ndarray] = None,
) -> None:
"""
Save a point cloud to a PLY file.
Parameters
----------
path : str or Path
Output file path. Should end with ``.ply``.
points : np.ndarray
(N, 3) float array of 3D positions.
colors : np.ndarray, optional
(N, 3) uint8 RGB colours.
normals : np.ndarray, optional
(N, 3) float array of normals.
"""
path = Path(path)
points = np.asarray(points, dtype=np.float32)
has_colors = colors is not None
has_normals = normals is not None
header_lines = [
"ply",
"format ascii 1.0",
f"element vertex {len(points)}",
"property float x",
"property float y",
"property float z",
]
if has_normals:
header_lines += [
"property float nx",
"property float ny",
"property float nz",
]
if has_colors:
header_lines += [
"property uchar red",
"property uchar green",
"property uchar blue",
]
header_lines += ["end_header"]
with open(path, "w") as f:
f.write("\n".join(header_lines) + "\n")
for i in range(len(points)):
row = [f"{points[i, 0]:.6f}", f"{points[i, 1]:.6f}", f"{points[i, 2]:.6f}"]
if has_normals:
row += [
f"{normals[i, 0]:.6f}",
f"{normals[i, 1]:.6f}",
f"{normals[i, 2]:.6f}",
]
if has_colors:
c = np.clip(colors[i], 0, 255).astype(np.uint8)
row += [str(c[0]), str(c[1]), str(c[2])]
f.write(" ".join(row) + "\n")