| import json | |
| import torch | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| class PhysicsCoTDataset(Dataset): | |
| """Dataset for Qwen2.5-VL SFT with physics CoT.""" | |
| def __init__(self, data_path, processor, max_length=4096): | |
| self.processor = processor | |
| self.max_length = max_length | |
| with open(data_path, 'r', encoding='utf-8') as f: | |
| self.records = [json.loads(line) for line in f] | |
| print(f"Loaded {len(self.records)} records from {data_path}") | |
| def __len__(self): | |
| return len(self.records) | |
| def __getitem__(self, idx): | |
| record = self.records[idx] | |
| messages = record['messages'] | |
| user_msg = messages[0] | |
| image_path = None | |
| text_content = "" | |
| for content in user_msg['content']: | |
| if content['type'] == 'image': | |
| image_path = content['image'].replace('file://', '') | |
| elif content['type'] == 'text': | |
| text_content = content['text'] | |
| assistant_msg = messages[1] | |
| assistant_text = assistant_msg['content'][0]['text'] | |
| image = Image.open(image_path).convert('RGB') | |
| MIN_DIM = 56 | |
| w, h = image.size | |
| if w < MIN_DIM or h < MIN_DIM: | |
| scale = max(MIN_DIM / w, MIN_DIM / h) | |
| new_w = int(w * scale) | |
| new_h = int(h * scale) | |
| image = image.resize((new_w, new_h), Image.LANCZOS) | |
| if new_w < MIN_DIM or new_h < MIN_DIM: | |
| padded = Image.new('RGB', (max(new_w, MIN_DIM), max(new_h, MIN_DIM)), (255, 255, 255)) | |
| padded.paste(image, (0, 0)) | |
| image = padded | |
| conversation = [ | |
| {"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": text_content}]}, | |
| {"role": "assistant", "content": [{"type": "text", "text": assistant_text}]}, | |
| ] | |
| text = self.processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False) | |
| inputs = self.processor(text=[text], images=[image], padding=False, truncation=True, max_length=self.max_length, return_tensors="pt") | |
| input_ids = inputs['input_ids'].squeeze(0) | |
| attention_mask = inputs['attention_mask'].squeeze(0) | |
| labels = input_ids.clone() | |
| assistant_token_str = "<|im_start|>assistant\n" | |
| assistant_token_ids = self.processor.tokenizer.encode(assistant_token_str, add_special_tokens=False) | |
| input_ids_list = input_ids.tolist() | |
| assistant_start = -1 | |
| for i in range(len(input_ids_list) - len(assistant_token_ids) + 1): | |
| if input_ids_list[i:i + len(assistant_token_ids)] == assistant_token_ids: | |
| assistant_start = i + len(assistant_token_ids) | |
| break | |
| if assistant_start > 0: | |
| labels[:assistant_start] = -100 | |
| else: | |
| raise ValueError(f"FATAL: assistant start token not found in sample {idx}.") | |
| labels[attention_mask == 0] = -100 | |
| return { | |
| 'input_ids': input_ids, | |
| 'attention_mask': attention_mask, | |
| 'labels': labels, | |
| 'pixel_values': inputs.get('pixel_values', torch.tensor([])).squeeze(0) if 'pixel_values' in inputs else None, | |
| 'image_grid_thw': inputs.get('image_grid_thw', torch.tensor([])).squeeze(0) if 'image_grid_thw' in inputs else None, | |
| } | |
| class VLMDataCollator: | |
| """Custom data collator for variable-length VLM inputs.""" | |
| def __init__(self, processor): | |
| self.processor = processor | |
| self.pad_token_id = processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id | |
| def __call__(self, features): | |
| max_len = max(f['input_ids'].size(0) for f in features) | |
| input_ids, attention_mask, labels, pixel_values, image_grid_thw = [], [], [], [], [] | |
| for f in features: | |
| seq_len = f['input_ids'].size(0) | |
| pad_len = max_len - seq_len | |
| input_ids.append(torch.cat([f['input_ids'], torch.full((pad_len,), self.pad_token_id, dtype=f['input_ids'].dtype)])) | |
| attention_mask.append(torch.cat([f['attention_mask'], torch.zeros(pad_len, dtype=f['attention_mask'].dtype)])) | |
| labels.append(torch.cat([f['labels'], torch.full((pad_len,), -100, dtype=f['labels'].dtype)])) | |
| if f.get('pixel_values') is not None: pixel_values.append(f['pixel_values']) | |
| if f.get('image_grid_thw') is not None: image_grid_thw.append(f['image_grid_thw']) | |
| batch = {'input_ids': torch.stack(input_ids), 'attention_mask': torch.stack(attention_mask), 'labels': torch.stack(labels)} | |
| if pixel_values: batch['pixel_values'] = torch.cat(pixel_values, dim=0) | |
| if image_grid_thw: batch['image_grid_thw'] = torch.stack(image_grid_thw) | |
| return batch | |