File size: 15,504 Bytes
b781107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
from tqdm.auto import tqdm
from constants import *
from utils import *
import pickle
from torch.utils.data import Dataset
import torch
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import DataLoader
import os

def format_point_text(points):
    # This function should already handle multiple points correctly
    text = "<result_start>"
    for point in points:
        # Ensure point coordinates are within [0, 100] before processing
        px = min(max(int(point.get('x', 50) * IMAGE_SIZE / 100), 0), IMAGE_SIZE - 1) # Added .get for safety
        py = min(max(int(point.get('y', 50) * IMAGE_SIZE / 100), 0), IMAGE_SIZE - 1)
        x_bin = min(px // (IMAGE_SIZE // NUM_BINS), NUM_BINS - 1)
        y_bin = min(py // (IMAGE_SIZE // NUM_BINS), NUM_BINS - 1)
        text += f"<pointx_start><coord_bin_{x_bin}><pointx_end><pointy_start><coord_bin_{y_bin}><pointy_end>"
    text += "<result_end>" + tokenizer.eos_token
    return text

def format_data_for_training(sample):
    """Format data sample for training, handling 0 to MAX_POINTS continuous coordinates."""
    try:
        # Check if 'points' key exists and is a list, otherwise treat as 0 points
        sample_points = sample.get('points', [])
        if not isinstance(sample_points, list):
            print(f"Warning: Invalid 'points' type for {sample.get('image_url', 'N/A')}. Treating as 0 points.")
            sample_points = []

        # Limit the number of points processed
        points_to_process = sample_points[:MAX_POINTS]
        num_points = len(points_to_process)

        # Load image - this is where most memory is used
        image_path = f"{IMAGE_LOCATION}{sample['image_url']}"
        
        # Check if file exists before attempting to open
        if not os.path.exists(image_path):
            print(f"Warning: Image not found: {image_path}. Skipping.")
            return None
            
        # Open image with error handling
        try:
            image = Image.open(image_path)
            # Convert grayscale to RGB if needed
            if image.mode != 'RGB':
                image = image.convert('RGB')
            image_tensor = image_to_tensor(image)
            # Explicitly delete the PIL image to free memory
            del image
        except Exception as e:
            print(f"Error processing image {image_path}: {e}")
            return None

        # Process text with memory efficiency in mind
        prompt_text = f"<point_start>{sample['label']}<point_end>"
        # format_point_text correctly handles an empty points_to_process list
        target_text = format_point_text(points_to_process)

        # Tokenize with explicit max lengths
        prompt_tokens = tokenizer(prompt_text, return_tensors="pt", max_length=PROMPT_LENGTH, 
                                 truncation=True, padding=False)
        target_tokens = tokenizer(target_text, return_tensors="pt", max_length=TEXT_LENGTH, 
                                 truncation=True, padding=False)

        # Check for empty tokens after tokenization
        if prompt_tokens.input_ids.numel() == 0 or target_tokens.input_ids.numel() == 0:
            print(f"Warning: Empty tokens after tokenization for {sample.get('image_url', 'N/A')}. Skipping.")
            return None

        # --- Handle Multiple Continuous Coordinates with Padding (Handles num_points=0 correctly) ---
        continuous_coords_list = []
        for point in points_to_process: # This loop won't run if num_points is 0
            coord_x = min(max(point.get('x', 50) / 100.0, 0.0), 1.0)
            coord_y = min(max(point.get('y', 50) / 100.0, 0.0), 1.0)
            continuous_coords_list.append([coord_x, coord_y])

        # Pad coordinates and create mask
        # If continuous_coords_list is empty, create empty tensor with right shape
        if num_points == 0:
            padded_coords = torch.full((MAX_POINTS, 2), -1.0)
            coords_mask = torch.zeros(MAX_POINTS)
        else:
            coords_tensor = torch.tensor(continuous_coords_list, dtype=torch.float32)
            padding_needed = MAX_POINTS - num_points
            padded_coords = F.pad(coords_tensor, (0, 0, 0, padding_needed), value=-1.0)
            coords_mask = torch.cat([torch.ones(num_points, dtype=torch.float32),
                                   torch.zeros(padding_needed, dtype=torch.float32)])

        # Create and return the formatted sample
        return {
            "image": image_tensor,
            "prompt_ids": prompt_tokens.input_ids[0],
            "target_ids": target_tokens.input_ids[0],
            "continuous_coords": padded_coords,
            "coords_mask": coords_mask,
            "num_points": num_points,
            "label": sample['label'],
            "image_url": sample['image_url']
        }
    except FileNotFoundError:
         print(f"Warning: Image not found: {sample.get('image_url', 'N/A')}. Skipping.")
         return None
    except Exception as e:
        print(f"Error formatting sample ({sample.get('image_url', 'N/A')}): {e}. Skipping.")
        import traceback
        traceback.print_exc()
        return None


class PointDataset(Dataset):
    def __init__(self, data_path="active_point_dataset.pkl", split="train", test_size=1000):
        with open(data_path, "rb") as f:
            raw_data = pickle.load(f)

        # --- Corrected filter and print statement ---
        # Keep samples with 0 to MAX_POINTS points. Handle potential non-list 'points' safely.
        original_count = len(raw_data)
        raw_data = [sample for sample in raw_data
                    if 0 <= len(sample.get('points', [])) <= MAX_POINTS and isinstance(sample.get('points', []), list)]
        filtered_count = len(raw_data)
        print(f"Original raw data size: {original_count}")
        print(f"Filtered raw data to {filtered_count} samples with 0 to {MAX_POINTS} points.")

        total_samples = len(raw_data)
        if total_samples == 0:
             raise ValueError("No samples left after filtering. Check data or MAX_POINTS.") # Added error for empty dataset

        if total_samples <= test_size:
            print(f"Warning: Dataset size {total_samples} <= test_size {test_size}.")
            test_size = max(1, int(total_samples * 0.2)) if total_samples > 1 else 0
        train_end = total_samples - test_size
        # Update print statement to reflect 0 points are included
        print(f"Dataset: {total_samples} total (0 to {MAX_POINTS} points), {train_end} train, {test_size} test")

        # --- Corrected split logic to use actual train/test counts ---
        if split == "train":
             # Check if train_end is valid before slicing
             if train_end <= 0: print("Warning: No samples allocated for training split.")
             self.raw_data = raw_data[:train_end]
        elif split == "test":
             # Check if test_size is valid before slicing
             if test_size <= 0: print("Warning: No samples allocated for test split.")
             self.raw_data = raw_data[train_end:]
        else:
             raise ValueError("split must be 'train' or 'test'")

        # DO NOT preprocess data here - just store the raw data
        # This is the key change - we don't load all images at once
        print(f"Dataset initialized with {len(self.raw_data)} samples for {split}")
        
        # Optional: Cache a small number of recent items to speed up repeated access
        self.cache_size = 8000  # Adjust based on memory constraints
        self.cache = {}  # Simple LRU cache for processed samples

    def __len__(self):
        return len(self.raw_data)

    def __getitem__(self, idx):
        # Check if the item is in the cache
        if idx in self.cache:
            return self.cache[idx]
            
        # Process the sample on-demand
        sample = self.raw_data[idx]
        formatted = format_data_for_training(sample)
        
        # If processing failed, try the next sample
        if formatted is None:
            # Find next valid index (with wrapping)
            next_idx = (idx + 1) % len(self.raw_data)
            
            # Prevent infinite loop if all samples are invalid
            attempts = 0
            while formatted is None and attempts < min(10, len(self.raw_data)):
                sample = self.raw_data[next_idx]
                formatted = format_data_for_training(sample)
                next_idx = (next_idx + 1) % len(self.raw_data)
                attempts += 1
                
            # If we still don't have a valid sample after attempts, return a dummy sample
            if formatted is None:
                print(f"Warning: Failed to find valid sample after {attempts} attempts")
                # Create minimal valid sample with zeros
                formatted = self._create_dummy_sample()
        
        # Update cache - simple LRU implementation
        if len(self.cache) >= self.cache_size:
            # Remove oldest item (first key)
            if self.cache:
                oldest_key = next(iter(self.cache))
                del self.cache[oldest_key]
        
        # Add to cache
        self.cache[idx] = formatted
        
        return formatted
        
    def _create_dummy_sample(self):
        """Creates a minimal valid sample when all else fails."""
        # Create empty image tensor
        image_tensor = torch.zeros(3, IMAGE_SIZE, IMAGE_SIZE)
        
        # Create minimal tokens
        prompt_text = "<point_start>dummy<point_end>"
        target_text = "<result_start><result_end>" + tokenizer.eos_token
        
        prompt_tokens = tokenizer(prompt_text, return_tensors="pt").input_ids[0]
        target_tokens = tokenizer(target_text, return_tensors="pt").input_ids[0]
        
        # Create empty coordinates
        padded_coords = torch.full((MAX_POINTS, 2), -1.0)
        coords_mask = torch.zeros(MAX_POINTS)
        
        return {
            "image": image_tensor,
            "prompt_ids": prompt_tokens,
            "target_ids": target_tokens,
            "continuous_coords": padded_coords,
            "coords_mask": coords_mask,
            "num_points": 0,
            "label": "dummy",
            "image_url": "none"
        }

    # --- collate_fn remains the same as the previous version ---
    @staticmethod
    def collate_fn(batch):
        # ... (Same as before, correctly handles stacking the padded coords and masks) ...
        batch = [item for item in batch if item is not None]
        if not batch: return None

        images = torch.stack([item['image'] for item in batch]).to(DTYPE)

        # --- Pad Prompt IDs ---
        max_prompt_len = max(item['prompt_ids'].size(0) for item in batch)
        prompt_ids_padded, prompt_attention_mask = [], []
        for item in batch:
            ids, pad_len = item['prompt_ids'], max_prompt_len - item['prompt_ids'].size(0)
            prompt_ids_padded.append(torch.cat([ids, torch.full((pad_len,), tokenizer.pad_token_id, dtype=torch.long)]))
            prompt_attention_mask.append(torch.cat([torch.ones_like(ids, dtype=torch.long), torch.zeros(pad_len, dtype=torch.long)]))
        prompt_ids = torch.stack(prompt_ids_padded)
        prompt_attention_mask = torch.stack(prompt_attention_mask)

        # --- Pad Target IDs & Create Generative Targets ---
        max_target_len = max(item['target_ids'].size(0) for item in batch)
        target_ids_padded, target_attention_mask, generative_targets = [], [], []
        for item in batch:
            ids, pad_len = item['target_ids'], max_target_len - item['target_ids'].size(0)
            padded_ids = torch.cat([ids, torch.full((pad_len,), tokenizer.pad_token_id, dtype=torch.long)])
            target_ids_padded.append(padded_ids)
            mask = torch.cat([torch.ones_like(ids, dtype=torch.long), torch.zeros(pad_len, dtype=torch.long)])
            target_attention_mask.append(mask)
            targets = torch.full_like(padded_ids, -100)
            if ids.size(0) > 1:
                 targets[:ids.size(0)-1] = ids[1:]
            if ids.numel() > 0 and ids[-1] == tokenizer.eos_token_id:
                 if ids.size(0) > 1:
                     targets[ids.size(0)-1] = tokenizer.eos_token_id
                 else:
                     targets[0] = -100
            generative_targets.append(targets)
        target_ids = torch.stack(target_ids_padded)
        target_attention_mask = torch.stack(target_attention_mask)
        generative_targets = torch.stack(generative_targets)

        # --- Stack Continuous Coords and Masks ---
        continuous_coords = torch.stack([item['continuous_coords'] for item in batch])
        coords_mask = torch.stack([item['coords_mask'] for item in batch])
        num_points = [item['num_points'] for item in batch]

        labels = [item['label'] for item in batch]
        image_urls = [item.get('image_url', '') for item in batch]

        return {
            'image': images,
            'prompt_ids': prompt_ids,
            'prompt_attention_mask': prompt_attention_mask,
            'target_ids': target_ids,
            'target_attention_mask': target_attention_mask,
            'generative_targets': generative_targets,
            'continuous_coords': continuous_coords,
            'coords_mask': coords_mask,
            'num_points': num_points,
            'label': labels,
            'image_url': image_urls
        }
    
def create_train_dataloader(batch_size=BATCH_SIZE, num_workers=0, prefetch_factor=2):
    """Create training dataloader with memory-efficient settings.
    
    Args:
        batch_size: Number of samples per batch
        num_workers: Number of worker processes for data loading
        prefetch_factor: Number of batches to prefetch per worker
        
    Returns:
        DataLoader instance or None if dataset is empty
    """
    dataset = PointDataset(split="train")
    if len(dataset) == 0: 
        return None
        
    # Configure DataLoader for memory efficiency
    return DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        collate_fn=PointDataset.collate_fn, 
        pin_memory=True,  # Speeds up CPU to GPU transfer
        num_workers=num_workers,
        prefetch_factor=prefetch_factor if num_workers > 0 else None,  # Only valid with workers
        persistent_workers=num_workers > 0,  # Keep workers alive between epochs
        drop_last=False  # Don't drop the last incomplete batch
    )

def create_test_dataloader(batch_size=BATCH_SIZE, num_workers=0, prefetch_factor=2):
    """Create test dataloader with memory-efficient settings.
    
    Args:
        batch_size: Number of samples per batch
        num_workers: Number of worker processes for data loading
        prefetch_factor: Number of batches to prefetch per worker
        
    Returns:
        DataLoader instance or None if dataset is empty
    """
    dataset = PointDataset(split="test")
    if len(dataset) == 0:
        print("Warning: Test dataset is empty. Returning None.")
        return None
        
    # Test loader with similar memory settings but no shuffling
    return DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        collate_fn=PointDataset.collate_fn, 
        pin_memory=True,
        num_workers=num_workers,
        prefetch_factor=prefetch_factor if num_workers > 0 else None,
        persistent_workers=num_workers > 0,
        drop_last=False
    )