|
|
import sys |
|
|
|
|
|
sys.path.append(".") |
|
|
|
|
|
import logging |
|
|
import os |
|
|
from typing import Optional, Dict |
|
|
|
|
|
import hydra |
|
|
import torch |
|
|
from hydra.utils import instantiate |
|
|
from datasets import DatasetDict, load_dataset, IterableDatasetDict |
|
|
from omegaconf import DictConfig, OmegaConf |
|
|
from src.data.transforms import SamCaptionerDataTransform |
|
|
from src.data.collator import SamCaptionerDataCollator |
|
|
from src.arguments import Arguments, global_setup, SAMCaptionerModelArguments, SCAModelArguments |
|
|
from src.models.sam_captioner import SAMCaptionerConfig, SAMCaptionerModel, SAMCaptionerProcessor |
|
|
|
|
|
from transformers.trainer_utils import get_last_checkpoint |
|
|
from transformers import set_seed, Trainer |
|
|
import gradio as gr |
|
|
from dataclasses import dataclass |
|
|
from hydra import initialize, compose |
|
|
from src.train import prepare_datasets, prepare_model, prepare_data_transform, prepare_processor |
|
|
from datasets import Dataset, IterableDataset |
|
|
import dotenv |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def test_with_initialize() -> None: |
|
|
with initialize(version_base="1.3", config_path="../../../src/conf"): |
|
|
args = compose( |
|
|
config_name="conf", |
|
|
overrides=[ |
|
|
"train_data=[vg-densecap-region_descriptions]", |
|
|
"eval_data=[vg-densecap-region_descriptions]", |
|
|
"+model=base_sam_captioner", |
|
|
"training.do_train=True", |
|
|
"training.do_eval=True", |
|
|
"training.overwrite_output_dir=True", |
|
|
"training.num_masks_per_sample=6", |
|
|
], |
|
|
) |
|
|
main(args) |
|
|
|
|
|
|
|
|
def main(args: DictConfig) -> None: |
|
|
|
|
|
|
|
|
logger.info(OmegaConf.to_yaml(args)) |
|
|
args, training_args, model_args = global_setup(args) |
|
|
|
|
|
|
|
|
last_checkpoint = None |
|
|
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: |
|
|
last_checkpoint = get_last_checkpoint(training_args.output_dir) |
|
|
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: |
|
|
raise ValueError( |
|
|
f"Output directory ({training_args.output_dir}) already exists and is not empty. " |
|
|
"Use --overwrite_output_dir to overcome." |
|
|
) |
|
|
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: |
|
|
logger.info( |
|
|
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " |
|
|
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch." |
|
|
) |
|
|
|
|
|
|
|
|
set_seed(args.training.seed) |
|
|
|
|
|
|
|
|
train_dataset, eval_dataset = prepare_datasets(args) |
|
|
|
|
|
|
|
|
logger.info(f"Try to load sas_key from .env file: {dotenv.load_dotenv('.env')}.") |
|
|
use_auth_token = os.getenv("USE_AUTH_TOKEN", False) |
|
|
|
|
|
processor = prepare_processor(model_args, use_auth_token) |
|
|
|
|
|
train_dataset, eval_dataset = prepare_data_transform( |
|
|
training_args, model_args, train_dataset, eval_dataset, processor |
|
|
) |
|
|
if len(eval_dataset) > 1: |
|
|
raise ValueError( |
|
|
f"len(eval_dataset) should be 1, but got {len(eval_dataset)}. Check args.eval_data: {args.eval_data}" |
|
|
) |
|
|
eval_dataset = next(iter(eval_dataset.values())) |
|
|
|
|
|
collate_fn = SamCaptionerDataCollator(processor.tokenizer) |
|
|
|
|
|
model = prepare_model(model_args, use_auth_token) |
|
|
|
|
|
def cycle(iterable): |
|
|
while True: |
|
|
for x in iterable: |
|
|
yield x |
|
|
|
|
|
if training_args.do_train: |
|
|
train_data_loader = torch.utils.data.DataLoader( |
|
|
train_dataset, batch_size=training_args.per_device_train_batch_size, collate_fn=collate_fn |
|
|
) |
|
|
train_data_loader = cycle(train_data_loader) |
|
|
else: |
|
|
train_data_loader = None |
|
|
|
|
|
if training_args.do_eval: |
|
|
eval_data_loader = torch.utils.data.DataLoader( |
|
|
eval_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=collate_fn |
|
|
) |
|
|
eval_data_loader = cycle(eval_data_loader) |
|
|
else: |
|
|
eval_data_loader = None |
|
|
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
model = model.to(device) |
|
|
dtype = model.dtype |
|
|
|
|
|
@torch.no_grad() |
|
|
def run_one_batch(batch): |
|
|
for k, v in batch.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
if torch.is_floating_point(v): |
|
|
batch[k] = v.to(device, dtype) |
|
|
else: |
|
|
batch[k] = v.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
model_outputs = model.sam(**batch) |
|
|
|
|
|
original_sizes = batch["original_sizes"] |
|
|
reshaped_input_sizes = batch["reshaped_input_sizes"] |
|
|
pred_masks = model_outputs.pred_masks |
|
|
masks = processor.post_process_masks(pred_masks, original_sizes, reshaped_input_sizes) |
|
|
|
|
|
model_outputs["masks"] = masks |
|
|
return model_outputs |
|
|
|
|
|
full_batch = next(train_data_loader) |
|
|
full_batch_model_outputs = run_one_batch(full_batch) |
|
|
|
|
|
def print_batch(batch): |
|
|
for k, v in batch.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
print(f"{k}: {v.shape}") |
|
|
else: |
|
|
print(f"{k}: {type(v)}") |
|
|
|
|
|
def make_sub_batch(batch, batch_dim_slice: slice, num_masks_dim_slice: slice): |
|
|
returned_batch = {} |
|
|
batch_num_masks_keys_to_be_edited = ["input_boxes"] |
|
|
batch_keys_to_be_edited = [ |
|
|
"images", |
|
|
"pixel_values", |
|
|
"original_sizes", |
|
|
"reshaped_input_sizes", |
|
|
] |
|
|
assert len(set(batch_num_masks_keys_to_be_edited).intersection(set(batch_keys_to_be_edited))) == 0 |
|
|
for k, v in batch.items(): |
|
|
if k in batch_num_masks_keys_to_be_edited: |
|
|
returned_batch[k] = v[batch_dim_slice, num_masks_dim_slice] |
|
|
elif k in batch_keys_to_be_edited: |
|
|
returned_batch[k] = v[batch_dim_slice] |
|
|
else: |
|
|
returned_batch[k] = v |
|
|
return returned_batch |
|
|
|
|
|
batch_dim_slice = slice(3, 5) |
|
|
num_masks_dim_slice = slice(0, 5) |
|
|
sub_batch = make_sub_batch(full_batch, batch_dim_slice, num_masks_dim_slice) |
|
|
print("sub_batch") |
|
|
print_batch(sub_batch) |
|
|
print("full_batch") |
|
|
print_batch(full_batch) |
|
|
sub_batch_model_outputs = run_one_batch(sub_batch) |
|
|
|
|
|
def make_sub_model_outputs(model_outputs, batch_dim_slice: slice, num_masks_dim_slice: slice): |
|
|
returned_model_outputs = {} |
|
|
|
|
|
batch_num_masks_model_outputs_keys_to_be_edited = ["iou_scores", "pred_masks"] |
|
|
batch_model_outputs_keys_to_be_edited = ["masks"] |
|
|
assert ( |
|
|
len( |
|
|
set(batch_num_masks_model_outputs_keys_to_be_edited).intersection( |
|
|
set(batch_model_outputs_keys_to_be_edited) |
|
|
) |
|
|
) |
|
|
== 0 |
|
|
) |
|
|
|
|
|
for k, v in model_outputs.items(): |
|
|
if k in batch_num_masks_model_outputs_keys_to_be_edited: |
|
|
returned_model_outputs[k] = v[batch_dim_slice, num_masks_dim_slice] |
|
|
elif k in batch_model_outputs_keys_to_be_edited: |
|
|
returned_model_outputs[k] = v[batch_dim_slice] |
|
|
else: |
|
|
returned_model_outputs[k] = v |
|
|
return returned_model_outputs |
|
|
|
|
|
sub_full_batch_model_outputs = make_sub_model_outputs( |
|
|
full_batch_model_outputs, batch_dim_slice, num_masks_dim_slice |
|
|
) |
|
|
print("sub_full_batch_model_outputs") |
|
|
print_batch(sub_full_batch_model_outputs) |
|
|
print("full_batch_model_outputs") |
|
|
print(sub_full_batch_model_outputs.get("iou_scores")) |
|
|
print(sub_batch_model_outputs.iou_scores) |
|
|
for k, v in sub_full_batch_model_outputs.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
print(f"torch allclose {k}: {v.shape}") |
|
|
print(torch.allclose(v, getattr(sub_batch_model_outputs, k))) |
|
|
sub_full_batch_model_outputs.get("pred_masks") |
|
|
sub_batch_model_outputs.get("pred_masks") |
|
|
torch.max(torch.abs(sub_full_batch_model_outputs.get("pred_masks") - sub_batch_model_outputs.get("pred_masks"))) |
|
|
|
|
|
full_batch_input_boxes = full_batch["input_boxes"] |
|
|
with torch.no_grad(): |
|
|
full_batch_out_1, full_batch_out_2 = model.sam.prompt_encoder( |
|
|
input_boxes=full_batch_input_boxes, input_points=None, input_labels=None, input_masks=None |
|
|
) |
|
|
sub_batch_input_boxes = sub_batch["input_boxes"] |
|
|
with torch.no_grad(): |
|
|
sub_batch_out_1, sub_batch_out_2 = model.sam.prompt_encoder( |
|
|
input_boxes=sub_batch_input_boxes, input_points=None, input_labels=None, input_masks=None |
|
|
) |
|
|
full_batch_out_1[batch_dim_slice, num_masks_dim_slice], sub_batch_out_1 |
|
|
torch.allclose(full_batch_out_1[batch_dim_slice, num_masks_dim_slice], sub_batch_out_1, atol=1e-6) |
|
|
torch.allclose(full_batch_input_boxes[batch_dim_slice, num_masks_dim_slice], sub_batch_input_boxes) |
|
|
|