linoy
inital commit
ebfc6b3
import io
import subprocess
from pathlib import Path
import torch
from PIL import ExifTags, Image, ImageCms, ImageOps
from PIL.Image import Image as PilImage
from ltx_trainer import logger
def get_gpu_memory_gb(device: torch.device) -> float:
"""Get current GPU memory usage in GB using nvidia-smi"""
try:
device_id = device.index if device.index is not None else 0
result = subprocess.check_output(
[
"nvidia-smi",
"--query-gpu=memory.used",
"--format=csv,nounits,noheader",
"-i",
str(device_id),
],
encoding="utf-8",
)
return float(result.strip()) / 1024 # Convert MB to GB
except (subprocess.CalledProcessError, FileNotFoundError, ValueError) as e:
logger.error(f"Failed to get GPU memory from nvidia-smi: {e}")
# Fallback to torch
return torch.cuda.memory_allocated(device) / 1024**3
def open_image_as_srgb(image_path: str | Path | io.BytesIO) -> PilImage:
"""
Opens an image file, applies rotation (if it's set in metadata) and converts it
to the sRGB color space respecting the original image color space .
"""
exif_colorspace_srgb = 1
with Image.open(image_path) as img_raw:
img = ImageOps.exif_transpose(img_raw)
input_icc_profile = img.info.get("icc_profile")
# Try to convert to sRGB if the image has ICC profile metadata
srgb_profile = ImageCms.createProfile(colorSpace="sRGB")
if input_icc_profile is not None:
input_profile = ImageCms.ImageCmsProfile(io.BytesIO(input_icc_profile))
srgb_img = ImageCms.profileToProfile(img, input_profile, srgb_profile, outputMode="RGB")
else:
# Try fall back to checking EXIF
exif_data = img.getexif()
if exif_data is not None:
# Assume sRGB if no ICC profile and EXIF has no ColorSpace tag
color_space_value = exif_data.get(ExifTags.Base.ColorSpace.value)
if color_space_value is not None and color_space_value != exif_colorspace_srgb:
raise ValueError(
"Image has colorspace tag in EXIF but it isn't set to sRGB,"
" conversion is not supported."
f" EXIF ColorSpace tag value is {color_space_value}",
)
srgb_img = img.convert("RGB")
# Set sRGB profile in metadata since now the image is assumed to be in sRGB.
srgb_profile_data = ImageCms.ImageCmsProfile(srgb_profile).tobytes()
srgb_img.info["icc_profile"] = srgb_profile_data
return srgb_img