| import uuid |
| from pathlib import Path |
|
|
| import numpy as np |
| from PIL import Image as PILImage |
|
|
| try: |
| from trackio.file_storage import FileStorage |
| from trackio.utils import TRACKIO_DIR |
| except ImportError: |
| from file_storage import FileStorage |
| from utils import TRACKIO_DIR |
|
|
|
|
| class TrackioImage: |
| """ |
| Creates an image that can be logged with trackio. |
| |
| Demo: fake-training-images |
| """ |
|
|
| TYPE = "trackio.image" |
|
|
| def __init__( |
| self, value: str | np.ndarray | PILImage.Image, caption: str | None = None |
| ): |
| """ |
| Parameters: |
| value: A string path to an image, a numpy array, or a PIL Image. |
| caption: A string caption for the image. |
| """ |
| self.caption = caption |
| self._pil = TrackioImage._as_pil(value) |
| self._file_path: Path | None = None |
| self._file_format: str | None = None |
|
|
| @staticmethod |
| def _as_pil(value: str | np.ndarray | PILImage.Image) -> PILImage.Image: |
| try: |
| if isinstance(value, str): |
| return PILImage.open(value).convert("RGBA") |
| elif isinstance(value, np.ndarray): |
| arr = np.asarray(value).astype("uint8") |
| return PILImage.fromarray(arr).convert("RGBA") |
| elif isinstance(value, PILImage.Image): |
| return value.convert("RGBA") |
| except Exception as e: |
| raise ValueError(f"Failed to process image data: {value}") from e |
|
|
| def _save(self, project: str, run: str, step: int = 0, format: str = "PNG") -> str: |
| if not self._file_path: |
| |
| filename = f"{uuid.uuid4()}.{format.lower()}" |
| path = FileStorage.save_image( |
| self._pil, project, run, step, filename, format=format |
| ) |
| self._file_path = path.relative_to(TRACKIO_DIR) |
| self._file_format = format |
| return str(self._file_path) |
|
|
| def _get_relative_file_path(self) -> Path | None: |
| return self._file_path |
|
|
| def _get_absolute_file_path(self) -> Path | None: |
| return TRACKIO_DIR / self._file_path |
|
|
| def _to_dict(self) -> dict: |
| if not self._file_path: |
| raise ValueError("Image must be saved to file before serialization") |
| return { |
| "_type": self.TYPE, |
| "file_path": str(self._get_relative_file_path()), |
| "file_format": self._file_format, |
| "caption": self.caption, |
| } |
|
|
| @classmethod |
| def _from_dict(cls, obj: dict) -> "TrackioImage": |
| if not isinstance(obj, dict): |
| raise TypeError(f"Expected dict, got {type(obj).__name__}") |
| if obj.get("_type") != cls.TYPE: |
| raise ValueError(f"Wrong _type: {obj.get('_type')!r}") |
|
|
| file_path = obj.get("file_path") |
| if not isinstance(file_path, str): |
| raise TypeError( |
| f"'file_path' must be string, got {type(file_path).__name__}" |
| ) |
|
|
| absolute_path = TRACKIO_DIR / file_path |
| try: |
| if not absolute_path.is_file(): |
| raise ValueError(f"Image file not found: {file_path}") |
| pil = PILImage.open(absolute_path).convert("RGBA") |
| instance = cls(pil, caption=obj.get("caption")) |
| instance._file_path = Path(file_path) |
| instance._file_format = obj.get("file_format") |
| return instance |
| except Exception as e: |
| raise ValueError(f"Failed to load image from file: {absolute_path}") from e |
|
|