pyamy's picture
Upload 31 files
0a0f923 verified
# xray_generator/utils/processing.py
import os
import random
import torch
import numpy as np
import logging
import cv2
from PIL import Image
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as T
logger = logging.getLogger(__name__)
def set_seed(seed=42):
"""Set seeds for reproducibility across all libraries."""
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
logger.info(f"Random seed set to {seed} for reproducibility")
def get_device():
"""Get the best available device with proper error handling."""
if not torch.cuda.is_available():
logger.warning("CUDA not available, falling back to CPU. This will be very slow.")
return torch.device("cpu")
try:
# Try to initialize CUDA device
device = torch.device("cuda")
_ = torch.zeros(1).to(device) # Test CUDA functionality
# Log device info
device_properties = torch.cuda.get_device_properties(0)
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
logger.info(f"GPU Memory: {device_properties.total_memory / 1e9:.2f} GB")
logger.info(f"CUDA Capability: {device_properties.major}.{device_properties.minor}")
return device
except Exception as e:
logger.error(f"Error initializing CUDA: {e}")
logger.warning("Falling back to CPU")
return torch.device("cpu")
def log_gpu_memory(message=""):
"""Log GPU memory usage."""
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1e9
reserved = torch.cuda.memory_reserved() / 1e9
max_allocated = torch.cuda.max_memory_allocated() / 1e9
logger.info(f"GPU Memory {message}: Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB, Max: {max_allocated:.2f}GB")
# Reset max stats
torch.cuda.reset_peak_memory_stats()
def custom_collate_fn(batch):
"""Custom collate function to handle variable sized items."""
batch = [item for item in batch if item is not None]
if len(batch) == 0:
return None
collated_batch = {}
keys = batch[0].keys()
for key in keys:
if key == 'image':
collated_batch[key] = torch.stack([item[key] for item in batch])
elif key in ['input_ids', 'attention_mask']:
collated_batch[key] = torch.stack([item[key] for item in batch])
elif key in ['uid', 'medical_concepts', 'filename', 'report']:
collated_batch[key] = [item[key] for item in batch]
else:
collated_batch[key] = [item[key] for item in batch]
return collated_batch
def verify_dataset_files(dataset_path, sample_size=100):
"""Verify that dataset files exist and are readable."""
logger.info(f"Verifying dataset files in {dataset_path}")
# Check if path exists
if not os.path.exists(dataset_path):
logger.error(f"Dataset path does not exist: {dataset_path}")
return False
# Get list of files
try:
all_files = [f for f in os.listdir(dataset_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
except Exception as e:
logger.error(f"Error listing files in {dataset_path}: {e}")
return False
if not all_files:
logger.error(f"No image files found in {dataset_path}")
return False
logger.info(f"Found {len(all_files)} image files")
# Sample files
sample_files = random.sample(all_files, min(sample_size, len(all_files)))
# Try to open each file
errors = 0
for file in sample_files:
file_path = os.path.join(dataset_path, file)
try:
with Image.open(file_path) as img:
# Try to access image properties to ensure it's valid
_ = img.size
except Exception as e:
logger.error(f"Error opening {file_path}: {e}")
errors += 1
if errors > 0:
logger.error(f"Found {errors} errors in {len(sample_files)} sample files")
return False
logger.info(f"Successfully verified {len(sample_files)} sample files")
return True
def create_transforms(image_size=256):
"""Create standardized image transforms."""
# Train transform with normalization to [-1, 1] for diffusion models
train_transform = T.Compose([
T.Normalize([0.5], [0.5])
])
# Validation/test transform (same as train for consistent evaluation)
val_transform = T.Compose([
T.Normalize([0.5], [0.5])
])
return train_transform, val_transform
def apply_clahe(image_tensor, clip_limit=2.0, grid_size=(8, 8)):
"""Apply CLAHE to a tensor image for better contrast."""
# Convert tensor to numpy array
if isinstance(image_tensor, torch.Tensor):
img_np = image_tensor.cpu().numpy().squeeze()
else:
img_np = np.array(image_tensor)
# Ensure proper range for CLAHE (0-255, uint8)
if img_np.max() <= 1.0:
img_np = (img_np * 255).astype(np.uint8)
# Apply CLAHE
clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=grid_size)
img_np = clahe.apply(img_np)
# Convert back to tensor [0, 1]
if isinstance(image_tensor, torch.Tensor):
img_tensor = torch.from_numpy(img_np).float() / 255.0
if len(image_tensor.shape) > 2: # If original had channel dim
img_tensor = img_tensor.unsqueeze(0)
return img_tensor
else: # Return PIL or numpy
return img_np
def create_dataloader(dataset, batch_size=4, shuffle=True, num_workers=0,
drop_last=False, seed=42, timeout=0):
"""Create a data loader with standard settings."""
loader_args = {
'batch_size': batch_size,
'shuffle': shuffle,
'num_workers': num_workers,
'pin_memory': True,
'drop_last': drop_last,
'worker_init_fn': lambda worker_id: np.random.seed(seed + worker_id),
'collate_fn': custom_collate_fn
}
if num_workers > 0:
loader_args.update({
'timeout': timeout,
'persistent_workers': True,
'prefetch_factor': 2
})
return DataLoader(dataset, **loader_args)
def create_quick_test_dataset(dataset, percentage=0.01):
"""Create a small subset of a dataset for quick testing."""
from torch.utils.data import Dataset
class SmallDatasetWrapper(Dataset):
def __init__(self, dataset, percentage=0.01):
self.dataset = dataset
import random
self.indices = random.sample(range(len(dataset)), int(len(dataset) * percentage))
logger.info(f"Using {len(self.indices)} samples out of {len(dataset)} ({percentage*100:.1f}%)")
def __getitem__(self, idx):
return self.dataset[self.indices[idx]]
def __len__(self):
return len(self.indices)
return SmallDatasetWrapper(dataset, percentage)