linoy
inital commit
ebfc6b3
raw
history blame
2.65 kB
import io
import subprocess
from pathlib import Path
import torch
from PIL import ExifTags, Image, ImageCms, ImageOps
from PIL.Image import Image as PilImage
from ltx_trainer import logger
def get_gpu_memory_gb(device: torch.device) -> float:
"""Get current GPU memory usage in GB using nvidia-smi"""
try:
device_id = device.index if device.index is not None else 0
result = subprocess.check_output(
[
"nvidia-smi",
"--query-gpu=memory.used",
"--format=csv,nounits,noheader",
"-i",
str(device_id),
],
encoding="utf-8",
)
return float(result.strip()) / 1024 # Convert MB to GB
except (subprocess.CalledProcessError, FileNotFoundError, ValueError) as e:
logger.error(f"Failed to get GPU memory from nvidia-smi: {e}")
# Fallback to torch
return torch.cuda.memory_allocated(device) / 1024**3
def open_image_as_srgb(image_path: str | Path | io.BytesIO) -> PilImage:
"""
Opens an image file, applies rotation (if it's set in metadata) and converts it
to the sRGB color space respecting the original image color space .
"""
exif_colorspace_srgb = 1
with Image.open(image_path) as img_raw:
img = ImageOps.exif_transpose(img_raw)
input_icc_profile = img.info.get("icc_profile")
# Try to convert to sRGB if the image has ICC profile metadata
srgb_profile = ImageCms.createProfile(colorSpace="sRGB")
if input_icc_profile is not None:
input_profile = ImageCms.ImageCmsProfile(io.BytesIO(input_icc_profile))
srgb_img = ImageCms.profileToProfile(img, input_profile, srgb_profile, outputMode="RGB")
else:
# Try fall back to checking EXIF
exif_data = img.getexif()
if exif_data is not None:
# Assume sRGB if no ICC profile and EXIF has no ColorSpace tag
color_space_value = exif_data.get(ExifTags.Base.ColorSpace.value)
if color_space_value is not None and color_space_value != exif_colorspace_srgb:
raise ValueError(
"Image has colorspace tag in EXIF but it isn't set to sRGB,"
" conversion is not supported."
f" EXIF ColorSpace tag value is {color_space_value}",
)
srgb_img = img.convert("RGB")
# Set sRGB profile in metadata since now the image is assumed to be in sRGB.
srgb_profile_data = ImageCms.ImageCmsProfile(srgb_profile).tobytes()
srgb_img.info["icc_profile"] = srgb_profile_data
return srgb_img