File size: 4,895 Bytes
6c57edb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

import json
import torch
from PIL import Image
from torch.utils.data import Dataset


class PhysicsCoTDataset(Dataset):
    """Dataset for Qwen2.5-VL SFT with physics CoT."""

    def __init__(self, data_path, processor, max_length=4096):
        self.processor = processor
        self.max_length = max_length
        with open(data_path, 'r', encoding='utf-8') as f:
            self.records = [json.loads(line) for line in f]
        print(f"Loaded {len(self.records)} records from {data_path}")

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        record = self.records[idx]
        messages = record['messages']
        user_msg = messages[0]
        image_path = None
        text_content = ""
        for content in user_msg['content']:
            if content['type'] == 'image':
                image_path = content['image'].replace('file://', '')
            elif content['type'] == 'text':
                text_content = content['text']
        assistant_msg = messages[1]
        assistant_text = assistant_msg['content'][0]['text']
        image = Image.open(image_path).convert('RGB')
        MIN_DIM = 56
        w, h = image.size
        if w < MIN_DIM or h < MIN_DIM:
            scale = max(MIN_DIM / w, MIN_DIM / h)
            new_w = int(w * scale)
            new_h = int(h * scale)
            image = image.resize((new_w, new_h), Image.LANCZOS)
            if new_w < MIN_DIM or new_h < MIN_DIM:
                padded = Image.new('RGB', (max(new_w, MIN_DIM), max(new_h, MIN_DIM)), (255, 255, 255))
                padded.paste(image, (0, 0))
                image = padded
        conversation = [
            {"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": text_content}]},
            {"role": "assistant", "content": [{"type": "text", "text": assistant_text}]},
        ]
        text = self.processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False)
        inputs = self.processor(text=[text], images=[image], padding=False, truncation=True, max_length=self.max_length, return_tensors="pt")
        input_ids = inputs['input_ids'].squeeze(0)
        attention_mask = inputs['attention_mask'].squeeze(0)
        labels = input_ids.clone()
        assistant_token_str = "<|im_start|>assistant\n"
        assistant_token_ids = self.processor.tokenizer.encode(assistant_token_str, add_special_tokens=False)
        input_ids_list = input_ids.tolist()
        assistant_start = -1
        for i in range(len(input_ids_list) - len(assistant_token_ids) + 1):
            if input_ids_list[i:i + len(assistant_token_ids)] == assistant_token_ids:
                assistant_start = i + len(assistant_token_ids)
                break
        if assistant_start > 0:
            labels[:assistant_start] = -100
        else:
            raise ValueError(f"FATAL: assistant start token not found in sample {idx}.")
        labels[attention_mask == 0] = -100
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels,
            'pixel_values': inputs.get('pixel_values', torch.tensor([])).squeeze(0) if 'pixel_values' in inputs else None,
            'image_grid_thw': inputs.get('image_grid_thw', torch.tensor([])).squeeze(0) if 'image_grid_thw' in inputs else None,
        }


class VLMDataCollator:
    """Custom data collator for variable-length VLM inputs."""
    def __init__(self, processor):
        self.processor = processor
        self.pad_token_id = processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id

    def __call__(self, features):
        max_len = max(f['input_ids'].size(0) for f in features)
        input_ids, attention_mask, labels, pixel_values, image_grid_thw = [], [], [], [], []
        for f in features:
            seq_len = f['input_ids'].size(0)
            pad_len = max_len - seq_len
            input_ids.append(torch.cat([f['input_ids'], torch.full((pad_len,), self.pad_token_id, dtype=f['input_ids'].dtype)]))
            attention_mask.append(torch.cat([f['attention_mask'], torch.zeros(pad_len, dtype=f['attention_mask'].dtype)]))
            labels.append(torch.cat([f['labels'], torch.full((pad_len,), -100, dtype=f['labels'].dtype)]))
            if f.get('pixel_values') is not None: pixel_values.append(f['pixel_values'])
            if f.get('image_grid_thw') is not None: image_grid_thw.append(f['image_grid_thw'])
        batch = {'input_ids': torch.stack(input_ids), 'attention_mask': torch.stack(attention_mask), 'labels': torch.stack(labels)}
        if pixel_values: batch['pixel_values'] = torch.cat(pixel_values, dim=0)
        if image_grid_thw: batch['image_grid_thw'] = torch.stack(image_grid_thw)
        return batch