from tqdm.auto import tqdm from constants import * from utils import * import pickle from torch.utils.data import Dataset import torch import torch.nn.functional as F from PIL import Image from torch.utils.data import DataLoader import os def format_point_text(points): # This function should already handle multiple points correctly text = "" for point in points: # Ensure point coordinates are within [0, 100] before processing px = min(max(int(point.get('x', 50) * IMAGE_SIZE / 100), 0), IMAGE_SIZE - 1) # Added .get for safety py = min(max(int(point.get('y', 50) * IMAGE_SIZE / 100), 0), IMAGE_SIZE - 1) x_bin = min(px // (IMAGE_SIZE // NUM_BINS), NUM_BINS - 1) y_bin = min(py // (IMAGE_SIZE // NUM_BINS), NUM_BINS - 1) text += f"" text += "" + tokenizer.eos_token return text def format_data_for_training(sample): """Format data sample for training, handling 0 to MAX_POINTS continuous coordinates.""" try: # Check if 'points' key exists and is a list, otherwise treat as 0 points sample_points = sample.get('points', []) if not isinstance(sample_points, list): print(f"Warning: Invalid 'points' type for {sample.get('image_url', 'N/A')}. Treating as 0 points.") sample_points = [] # Limit the number of points processed points_to_process = sample_points[:MAX_POINTS] num_points = len(points_to_process) # Load image - this is where most memory is used image_path = f"{IMAGE_LOCATION}{sample['image_url']}" # Check if file exists before attempting to open if not os.path.exists(image_path): print(f"Warning: Image not found: {image_path}. Skipping.") return None # Open image with error handling try: image = Image.open(image_path) # Convert grayscale to RGB if needed if image.mode != 'RGB': image = image.convert('RGB') image_tensor = image_to_tensor(image) # Explicitly delete the PIL image to free memory del image except Exception as e: print(f"Error processing image {image_path}: {e}") return None # Process text with memory efficiency in mind prompt_text = f"{sample['label']}" # format_point_text correctly handles an empty points_to_process list target_text = format_point_text(points_to_process) # Tokenize with explicit max lengths prompt_tokens = tokenizer(prompt_text, return_tensors="pt", max_length=PROMPT_LENGTH, truncation=True, padding=False) target_tokens = tokenizer(target_text, return_tensors="pt", max_length=TEXT_LENGTH, truncation=True, padding=False) # Check for empty tokens after tokenization if prompt_tokens.input_ids.numel() == 0 or target_tokens.input_ids.numel() == 0: print(f"Warning: Empty tokens after tokenization for {sample.get('image_url', 'N/A')}. Skipping.") return None # --- Handle Multiple Continuous Coordinates with Padding (Handles num_points=0 correctly) --- continuous_coords_list = [] for point in points_to_process: # This loop won't run if num_points is 0 coord_x = min(max(point.get('x', 50) / 100.0, 0.0), 1.0) coord_y = min(max(point.get('y', 50) / 100.0, 0.0), 1.0) continuous_coords_list.append([coord_x, coord_y]) # Pad coordinates and create mask # If continuous_coords_list is empty, create empty tensor with right shape if num_points == 0: padded_coords = torch.full((MAX_POINTS, 2), -1.0) coords_mask = torch.zeros(MAX_POINTS) else: coords_tensor = torch.tensor(continuous_coords_list, dtype=torch.float32) padding_needed = MAX_POINTS - num_points padded_coords = F.pad(coords_tensor, (0, 0, 0, padding_needed), value=-1.0) coords_mask = torch.cat([torch.ones(num_points, dtype=torch.float32), torch.zeros(padding_needed, dtype=torch.float32)]) # Create and return the formatted sample return { "image": image_tensor, "prompt_ids": prompt_tokens.input_ids[0], "target_ids": target_tokens.input_ids[0], "continuous_coords": padded_coords, "coords_mask": coords_mask, "num_points": num_points, "label": sample['label'], "image_url": sample['image_url'] } except FileNotFoundError: print(f"Warning: Image not found: {sample.get('image_url', 'N/A')}. Skipping.") return None except Exception as e: print(f"Error formatting sample ({sample.get('image_url', 'N/A')}): {e}. Skipping.") import traceback traceback.print_exc() return None class PointDataset(Dataset): def __init__(self, data_path="active_point_dataset.pkl", split="train", test_size=1000): with open(data_path, "rb") as f: raw_data = pickle.load(f) # --- Corrected filter and print statement --- # Keep samples with 0 to MAX_POINTS points. Handle potential non-list 'points' safely. original_count = len(raw_data) raw_data = [sample for sample in raw_data if 0 <= len(sample.get('points', [])) <= MAX_POINTS and isinstance(sample.get('points', []), list)] filtered_count = len(raw_data) print(f"Original raw data size: {original_count}") print(f"Filtered raw data to {filtered_count} samples with 0 to {MAX_POINTS} points.") total_samples = len(raw_data) if total_samples == 0: raise ValueError("No samples left after filtering. Check data or MAX_POINTS.") # Added error for empty dataset if total_samples <= test_size: print(f"Warning: Dataset size {total_samples} <= test_size {test_size}.") test_size = max(1, int(total_samples * 0.2)) if total_samples > 1 else 0 train_end = total_samples - test_size # Update print statement to reflect 0 points are included print(f"Dataset: {total_samples} total (0 to {MAX_POINTS} points), {train_end} train, {test_size} test") # --- Corrected split logic to use actual train/test counts --- if split == "train": # Check if train_end is valid before slicing if train_end <= 0: print("Warning: No samples allocated for training split.") self.raw_data = raw_data[:train_end] elif split == "test": # Check if test_size is valid before slicing if test_size <= 0: print("Warning: No samples allocated for test split.") self.raw_data = raw_data[train_end:] else: raise ValueError("split must be 'train' or 'test'") # DO NOT preprocess data here - just store the raw data # This is the key change - we don't load all images at once print(f"Dataset initialized with {len(self.raw_data)} samples for {split}") # Optional: Cache a small number of recent items to speed up repeated access self.cache_size = 8000 # Adjust based on memory constraints self.cache = {} # Simple LRU cache for processed samples def __len__(self): return len(self.raw_data) def __getitem__(self, idx): # Check if the item is in the cache if idx in self.cache: return self.cache[idx] # Process the sample on-demand sample = self.raw_data[idx] formatted = format_data_for_training(sample) # If processing failed, try the next sample if formatted is None: # Find next valid index (with wrapping) next_idx = (idx + 1) % len(self.raw_data) # Prevent infinite loop if all samples are invalid attempts = 0 while formatted is None and attempts < min(10, len(self.raw_data)): sample = self.raw_data[next_idx] formatted = format_data_for_training(sample) next_idx = (next_idx + 1) % len(self.raw_data) attempts += 1 # If we still don't have a valid sample after attempts, return a dummy sample if formatted is None: print(f"Warning: Failed to find valid sample after {attempts} attempts") # Create minimal valid sample with zeros formatted = self._create_dummy_sample() # Update cache - simple LRU implementation if len(self.cache) >= self.cache_size: # Remove oldest item (first key) if self.cache: oldest_key = next(iter(self.cache)) del self.cache[oldest_key] # Add to cache self.cache[idx] = formatted return formatted def _create_dummy_sample(self): """Creates a minimal valid sample when all else fails.""" # Create empty image tensor image_tensor = torch.zeros(3, IMAGE_SIZE, IMAGE_SIZE) # Create minimal tokens prompt_text = "dummy" target_text = "" + tokenizer.eos_token prompt_tokens = tokenizer(prompt_text, return_tensors="pt").input_ids[0] target_tokens = tokenizer(target_text, return_tensors="pt").input_ids[0] # Create empty coordinates padded_coords = torch.full((MAX_POINTS, 2), -1.0) coords_mask = torch.zeros(MAX_POINTS) return { "image": image_tensor, "prompt_ids": prompt_tokens, "target_ids": target_tokens, "continuous_coords": padded_coords, "coords_mask": coords_mask, "num_points": 0, "label": "dummy", "image_url": "none" } # --- collate_fn remains the same as the previous version --- @staticmethod def collate_fn(batch): # ... (Same as before, correctly handles stacking the padded coords and masks) ... batch = [item for item in batch if item is not None] if not batch: return None images = torch.stack([item['image'] for item in batch]).to(DTYPE) # --- Pad Prompt IDs --- max_prompt_len = max(item['prompt_ids'].size(0) for item in batch) prompt_ids_padded, prompt_attention_mask = [], [] for item in batch: ids, pad_len = item['prompt_ids'], max_prompt_len - item['prompt_ids'].size(0) prompt_ids_padded.append(torch.cat([ids, torch.full((pad_len,), tokenizer.pad_token_id, dtype=torch.long)])) prompt_attention_mask.append(torch.cat([torch.ones_like(ids, dtype=torch.long), torch.zeros(pad_len, dtype=torch.long)])) prompt_ids = torch.stack(prompt_ids_padded) prompt_attention_mask = torch.stack(prompt_attention_mask) # --- Pad Target IDs & Create Generative Targets --- max_target_len = max(item['target_ids'].size(0) for item in batch) target_ids_padded, target_attention_mask, generative_targets = [], [], [] for item in batch: ids, pad_len = item['target_ids'], max_target_len - item['target_ids'].size(0) padded_ids = torch.cat([ids, torch.full((pad_len,), tokenizer.pad_token_id, dtype=torch.long)]) target_ids_padded.append(padded_ids) mask = torch.cat([torch.ones_like(ids, dtype=torch.long), torch.zeros(pad_len, dtype=torch.long)]) target_attention_mask.append(mask) targets = torch.full_like(padded_ids, -100) if ids.size(0) > 1: targets[:ids.size(0)-1] = ids[1:] if ids.numel() > 0 and ids[-1] == tokenizer.eos_token_id: if ids.size(0) > 1: targets[ids.size(0)-1] = tokenizer.eos_token_id else: targets[0] = -100 generative_targets.append(targets) target_ids = torch.stack(target_ids_padded) target_attention_mask = torch.stack(target_attention_mask) generative_targets = torch.stack(generative_targets) # --- Stack Continuous Coords and Masks --- continuous_coords = torch.stack([item['continuous_coords'] for item in batch]) coords_mask = torch.stack([item['coords_mask'] for item in batch]) num_points = [item['num_points'] for item in batch] labels = [item['label'] for item in batch] image_urls = [item.get('image_url', '') for item in batch] return { 'image': images, 'prompt_ids': prompt_ids, 'prompt_attention_mask': prompt_attention_mask, 'target_ids': target_ids, 'target_attention_mask': target_attention_mask, 'generative_targets': generative_targets, 'continuous_coords': continuous_coords, 'coords_mask': coords_mask, 'num_points': num_points, 'label': labels, 'image_url': image_urls } def create_train_dataloader(batch_size=BATCH_SIZE, num_workers=0, prefetch_factor=2): """Create training dataloader with memory-efficient settings. Args: batch_size: Number of samples per batch num_workers: Number of worker processes for data loading prefetch_factor: Number of batches to prefetch per worker Returns: DataLoader instance or None if dataset is empty """ dataset = PointDataset(split="train") if len(dataset) == 0: return None # Configure DataLoader for memory efficiency return DataLoader( dataset, batch_size=batch_size, shuffle=True, collate_fn=PointDataset.collate_fn, pin_memory=True, # Speeds up CPU to GPU transfer num_workers=num_workers, prefetch_factor=prefetch_factor if num_workers > 0 else None, # Only valid with workers persistent_workers=num_workers > 0, # Keep workers alive between epochs drop_last=False # Don't drop the last incomplete batch ) def create_test_dataloader(batch_size=BATCH_SIZE, num_workers=0, prefetch_factor=2): """Create test dataloader with memory-efficient settings. Args: batch_size: Number of samples per batch num_workers: Number of worker processes for data loading prefetch_factor: Number of batches to prefetch per worker Returns: DataLoader instance or None if dataset is empty """ dataset = PointDataset(split="test") if len(dataset) == 0: print("Warning: Test dataset is empty. Returning None.") return None # Test loader with similar memory settings but no shuffling return DataLoader( dataset, batch_size=batch_size, shuffle=False, collate_fn=PointDataset.collate_fn, pin_memory=True, num_workers=num_workers, prefetch_factor=prefetch_factor if num_workers > 0 else None, persistent_workers=num_workers > 0, drop_last=False )