File size: 9,165 Bytes
002bd9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import sys

sys.path.append(".")

import logging
import os
from typing import Optional, Dict

import hydra
import torch
from hydra.utils import instantiate
from datasets import DatasetDict, load_dataset, IterableDatasetDict
from omegaconf import DictConfig, OmegaConf
from src.data.transforms import SamCaptionerDataTransform
from src.data.collator import SamCaptionerDataCollator
from src.arguments import Arguments, global_setup, SAMCaptionerModelArguments, SCAModelArguments
from src.models.sam_captioner import SAMCaptionerConfig, SAMCaptionerModel, SAMCaptionerProcessor

from transformers.trainer_utils import get_last_checkpoint
from transformers import set_seed, Trainer
import gradio as gr
from dataclasses import dataclass
from hydra import initialize, compose
from src.train import prepare_datasets, prepare_model, prepare_data_transform, prepare_processor
from datasets import Dataset, IterableDataset
import dotenv


logger = logging.getLogger(__name__)


def test_with_initialize() -> None:
    with initialize(version_base="1.3", config_path="../../../src/conf"):
        args = compose(
            config_name="conf",
            overrides=[
                "train_data=[vg-densecap-region_descriptions]",
                "eval_data=[vg-densecap-region_descriptions]",
                "+model=base_sam_captioner",
                "training.do_train=True",
                "training.do_eval=True",
                "training.overwrite_output_dir=True",
                "training.num_masks_per_sample=6",
            ],
        )
        main(args)


def main(args: DictConfig) -> None:
    # NOTE(xiaoke): follow https://github.com/huggingface/transformers/blob/main/examples/pytorch/image-classification/run_image_classification.py

    logger.info(OmegaConf.to_yaml(args))
    args, training_args, model_args = global_setup(args)

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # Set seed before initializing model.
    set_seed(args.training.seed)

    # Initialize our dataset and prepare it
    train_dataset, eval_dataset = prepare_datasets(args)

    # NOTE(xiaoke): load sas_key from .env for huggingface model downloading.
    logger.info(f"Try to load sas_key from .env file: {dotenv.load_dotenv('.env')}.")
    use_auth_token = os.getenv("USE_AUTH_TOKEN", False)

    processor = prepare_processor(model_args, use_auth_token)

    train_dataset, eval_dataset = prepare_data_transform(
        training_args, model_args, train_dataset, eval_dataset, processor
    )
    if len(eval_dataset) > 1:
        raise ValueError(
            f"len(eval_dataset) should be 1, but got {len(eval_dataset)}. Check args.eval_data: {args.eval_data}"
        )
    eval_dataset = next(iter(eval_dataset.values()))

    collate_fn = SamCaptionerDataCollator(processor.tokenizer)

    model = prepare_model(model_args, use_auth_token)

    def cycle(iterable):
        while True:
            for x in iterable:
                yield x

    if training_args.do_train:
        train_data_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=training_args.per_device_train_batch_size, collate_fn=collate_fn
        )
        train_data_loader = cycle(train_data_loader)
    else:
        train_data_loader = None

    if training_args.do_eval:
        eval_data_loader = torch.utils.data.DataLoader(
            eval_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=collate_fn
        )
        eval_data_loader = cycle(eval_data_loader)
    else:
        eval_data_loader = None

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = model.to(device)
    dtype = model.dtype

    @torch.no_grad()
    def run_one_batch(batch):
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                if torch.is_floating_point(v):
                    batch[k] = v.to(device, dtype)
                else:
                    batch[k] = v.to(device)

        with torch.no_grad():
            model_outputs = model.sam(**batch)

        original_sizes = batch["original_sizes"]
        reshaped_input_sizes = batch["reshaped_input_sizes"]
        pred_masks = model_outputs.pred_masks
        masks = processor.post_process_masks(pred_masks, original_sizes, reshaped_input_sizes)

        model_outputs["masks"] = masks
        return model_outputs

    full_batch = next(train_data_loader)
    full_batch_model_outputs = run_one_batch(full_batch)

    def print_batch(batch):
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                print(f"{k}: {v.shape}")
            else:
                print(f"{k}: {type(v)}")

    def make_sub_batch(batch, batch_dim_slice: slice, num_masks_dim_slice: slice):
        returned_batch = {}
        batch_num_masks_keys_to_be_edited = ["input_boxes"]
        batch_keys_to_be_edited = [
            "images",
            "pixel_values",
            "original_sizes",
            "reshaped_input_sizes",
        ]
        assert len(set(batch_num_masks_keys_to_be_edited).intersection(set(batch_keys_to_be_edited))) == 0
        for k, v in batch.items():
            if k in batch_num_masks_keys_to_be_edited:
                returned_batch[k] = v[batch_dim_slice, num_masks_dim_slice]
            elif k in batch_keys_to_be_edited:
                returned_batch[k] = v[batch_dim_slice]
            else:
                returned_batch[k] = v
        return returned_batch

    batch_dim_slice = slice(3, 5)
    num_masks_dim_slice = slice(0, 5)
    sub_batch = make_sub_batch(full_batch, batch_dim_slice, num_masks_dim_slice)
    print("sub_batch")
    print_batch(sub_batch)
    print("full_batch")
    print_batch(full_batch)
    sub_batch_model_outputs = run_one_batch(sub_batch)

    def make_sub_model_outputs(model_outputs, batch_dim_slice: slice, num_masks_dim_slice: slice):
        returned_model_outputs = {}

        batch_num_masks_model_outputs_keys_to_be_edited = ["iou_scores", "pred_masks"]
        batch_model_outputs_keys_to_be_edited = ["masks"]
        assert (
            len(
                set(batch_num_masks_model_outputs_keys_to_be_edited).intersection(
                    set(batch_model_outputs_keys_to_be_edited)
                )
            )
            == 0
        )

        for k, v in model_outputs.items():
            if k in batch_num_masks_model_outputs_keys_to_be_edited:
                returned_model_outputs[k] = v[batch_dim_slice, num_masks_dim_slice]
            elif k in batch_model_outputs_keys_to_be_edited:
                returned_model_outputs[k] = v[batch_dim_slice]
            else:
                returned_model_outputs[k] = v
        return returned_model_outputs

    sub_full_batch_model_outputs = make_sub_model_outputs(
        full_batch_model_outputs, batch_dim_slice, num_masks_dim_slice
    )
    print("sub_full_batch_model_outputs")
    print_batch(sub_full_batch_model_outputs)
    print("full_batch_model_outputs")
    print(sub_full_batch_model_outputs.get("iou_scores"))
    print(sub_batch_model_outputs.iou_scores)
    for k, v in sub_full_batch_model_outputs.items():
        if isinstance(v, torch.Tensor):
            print(f"torch allclose {k}: {v.shape}")
            print(torch.allclose(v, getattr(sub_batch_model_outputs, k)))
    sub_full_batch_model_outputs.get("pred_masks")
    sub_batch_model_outputs.get("pred_masks")
    torch.max(torch.abs(sub_full_batch_model_outputs.get("pred_masks") - sub_batch_model_outputs.get("pred_masks")))

    full_batch_input_boxes = full_batch["input_boxes"]
    with torch.no_grad():
        full_batch_out_1, full_batch_out_2 = model.sam.prompt_encoder(
            input_boxes=full_batch_input_boxes, input_points=None, input_labels=None, input_masks=None
        )
    sub_batch_input_boxes = sub_batch["input_boxes"]
    with torch.no_grad():
        sub_batch_out_1, sub_batch_out_2 = model.sam.prompt_encoder(
            input_boxes=sub_batch_input_boxes, input_points=None, input_labels=None, input_masks=None
        )
    full_batch_out_1[batch_dim_slice, num_masks_dim_slice], sub_batch_out_1
    torch.allclose(full_batch_out_1[batch_dim_slice, num_masks_dim_slice], sub_batch_out_1, atol=1e-6)
    torch.allclose(full_batch_input_boxes[batch_dim_slice, num_masks_dim_slice], sub_batch_input_boxes)