| import os |
| import shutil |
| from pathlib import Path |
|
|
| import numpy as np |
| from PIL import Image as PILImage |
|
|
| try: |
| from trackio.media.media import TrackioMedia |
| except ImportError: |
| from media.media import TrackioMedia |
|
|
|
|
| TrackioImageSourceType = str | Path | np.ndarray | PILImage.Image |
|
|
|
|
| class TrackioImage(TrackioMedia): |
| """ |
| Initializes an Image object. |
| |
| Example: |
| ```python |
| import trackio |
| import numpy as np |
| from PIL import Image |
| |
| # Create an image from numpy array |
| image_data = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) |
| image = trackio.Image(image_data, caption="Random image") |
| trackio.log({"my_image": image}) |
| |
| # Create an image from PIL Image |
| pil_image = Image.new('RGB', (100, 100), color='red') |
| image = trackio.Image(pil_image, caption="Red square") |
| trackio.log({"red_image": image}) |
| |
| # Create an image from file path |
| image = trackio.Image("path/to/image.jpg", caption="Photo from file") |
| trackio.log({"file_image": image}) |
| ``` |
| |
| Args: |
| value (`str`, `Path`, `numpy.ndarray`, or `PIL.Image`, *optional*): |
| A path to an image, a PIL Image, or a numpy array of shape (height, width, channels). |
| If numpy array, should be of type `np.uint8` with RGB values in the range `[0, 255]`. |
| caption (`str`, *optional*): |
| A string caption for the image. |
| """ |
|
|
| TYPE = "trackio.image" |
|
|
| def __init__(self, value: TrackioImageSourceType, caption: str | None = None): |
| super().__init__(value, caption) |
| self._format: str | None = None |
|
|
| if not isinstance(self._value, TrackioImageSourceType): |
| raise ValueError( |
| f"Invalid value type, expected {TrackioImageSourceType}, got {type(self._value)}" |
| ) |
| if isinstance(self._value, np.ndarray) and self._value.dtype != np.uint8: |
| raise ValueError( |
| f"Invalid value dtype, expected np.uint8, got {self._value.dtype}" |
| ) |
| if ( |
| isinstance(self._value, np.ndarray | PILImage.Image) |
| and self._format is None |
| ): |
| self._format = "png" |
|
|
| def _as_pil(self) -> PILImage.Image | None: |
| try: |
| if isinstance(self._value, np.ndarray): |
| arr = np.asarray(self._value).astype("uint8") |
| return PILImage.fromarray(arr).convert("RGBA") |
| if isinstance(self._value, PILImage.Image): |
| return self._value.convert("RGBA") |
| except Exception as e: |
| raise ValueError(f"Failed to process image data: {self._value}") from e |
| return None |
|
|
| def _save_media(self, file_path: Path): |
| if pil := self._as_pil(): |
| pil.save(file_path, format=self._format) |
| elif isinstance(self._value, str | Path): |
| if os.path.isfile(self._value): |
| shutil.copy(self._value, file_path) |
| else: |
| raise ValueError(f"File not found: {self._value}") |
|
|