|
|
import sys |
|
|
|
|
|
sys.path.append(".") |
|
|
|
|
|
import logging |
|
|
import os |
|
|
from typing import Optional |
|
|
|
|
|
import hydra |
|
|
import torch |
|
|
from omegaconf import DictConfig, OmegaConf |
|
|
from src.arguments import ( |
|
|
global_setup, |
|
|
SAMCaptionerModelArguments, |
|
|
) |
|
|
|
|
|
from transformers.trainer_utils import get_last_checkpoint |
|
|
from transformers import set_seed |
|
|
import gradio as gr |
|
|
from dataclasses import dataclass |
|
|
import numpy as np |
|
|
from src.train import prepare_datasets, prepare_model, prepare_data_transform, prepare_processor, prepare_collate_fn |
|
|
import dotenv |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@hydra.main(version_base="1.3", config_path="../../src/conf", config_name="conf") |
|
|
def main(args: DictConfig) -> None: |
|
|
|
|
|
|
|
|
logger.info(OmegaConf.to_yaml(args)) |
|
|
args, training_args, model_args = global_setup(args) |
|
|
|
|
|
|
|
|
last_checkpoint = None |
|
|
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: |
|
|
last_checkpoint = get_last_checkpoint(training_args.output_dir) |
|
|
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: |
|
|
logger.warning( |
|
|
f"Output directory ({training_args.output_dir}) already exists and is not empty. " |
|
|
"There is no checkpoint in the directory. Or we can resume from `resume_from_checkpoint`." |
|
|
) |
|
|
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: |
|
|
logger.info( |
|
|
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " |
|
|
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch." |
|
|
) |
|
|
|
|
|
|
|
|
set_seed(args.training.seed) |
|
|
|
|
|
|
|
|
train_dataset, eval_dataset = prepare_datasets(args) |
|
|
|
|
|
|
|
|
logger.info(f"Try to load sas_key from .env file: {dotenv.load_dotenv('.env')}.") |
|
|
use_auth_token = os.getenv("USE_AUTH_TOKEN", False) |
|
|
|
|
|
processor = prepare_processor(model_args, use_auth_token) |
|
|
|
|
|
train_dataset, eval_dataset = prepare_data_transform( |
|
|
training_args, model_args, train_dataset, eval_dataset, processor |
|
|
) |
|
|
if len(eval_dataset) > 1: |
|
|
raise ValueError(f"Only support one eval dataset, but got {len(eval_dataset)}. args: {args.eval_data}") |
|
|
eval_dataset = next(iter(eval_dataset.values())) |
|
|
|
|
|
collate_fn = prepare_collate_fn(training_args, model_args, processor) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
compute_metrics = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = prepare_model(model_args, use_auth_token) |
|
|
|
|
|
def cycle(iterable): |
|
|
while True: |
|
|
for x in iterable: |
|
|
yield x |
|
|
|
|
|
if training_args.do_train: |
|
|
train_data_loader = torch.utils.data.DataLoader( |
|
|
train_dataset, batch_size=training_args.per_device_train_batch_size, collate_fn=collate_fn |
|
|
) |
|
|
train_data_loader = cycle(train_data_loader) |
|
|
else: |
|
|
train_data_loader = None |
|
|
|
|
|
if training_args.do_eval or training_args.do_inference: |
|
|
eval_data_loader = torch.utils.data.DataLoader( |
|
|
eval_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=collate_fn |
|
|
) |
|
|
eval_data_loader = cycle(eval_data_loader) |
|
|
else: |
|
|
eval_data_loader = None |
|
|
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
model = model.to(device) |
|
|
dtype = model.dtype |
|
|
|
|
|
@dataclass |
|
|
class BatchVariable: |
|
|
batch_input: Optional[dict] = None |
|
|
batch_output: Optional[dict] = None |
|
|
batch_id: int = 0 |
|
|
region_id: int = 0 |
|
|
|
|
|
@torch.no_grad() |
|
|
def run_one_batch(data_loader, batch_variable: BatchVariable): |
|
|
batch = next(data_loader) |
|
|
for k, v in batch.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
if torch.is_floating_point(v): |
|
|
batch[k] = v.to(device, dtype) |
|
|
else: |
|
|
batch[k] = v.to(device) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
if isinstance(model_args, SAMCaptionerModelArguments): |
|
|
model_outputs = model.generate(**batch, return_patches=True, return_dict_in_generate=True) |
|
|
else: |
|
|
model_outputs = model.generate(**batch) |
|
|
|
|
|
original_sizes = batch["original_sizes"] |
|
|
reshaped_input_sizes = batch["reshaped_input_sizes"] |
|
|
pred_masks = model_outputs.pred_masks |
|
|
masks = processor.post_process_masks(pred_masks, original_sizes, reshaped_input_sizes) |
|
|
model_outputs.masks = masks |
|
|
|
|
|
|
|
|
batch_size, region_size, num_heads, num_tokens = model_outputs.sequences.shape |
|
|
generated_captions = processor.tokenizer.batch_decode( |
|
|
model_outputs.sequences.view(-1, num_tokens), skip_special_tokens=True |
|
|
) |
|
|
generated_captions = ( |
|
|
np.array(generated_captions, dtype=object).reshape(batch_size, region_size, num_heads).tolist() |
|
|
) |
|
|
model_outputs.generated_captions = generated_captions |
|
|
|
|
|
batch_variable.batch_input = batch |
|
|
batch_variable.batch_output = model_outputs |
|
|
|
|
|
return f"finished running one batch, batch_size={len(batch['images'])}, region_size={len(masks[0])}" |
|
|
|
|
|
def run_one_batch_train(batch_variable: BatchVariable): |
|
|
if train_data_loader is None: |
|
|
raise ValueError("train_data_loader is None, use `training.do_train=True`.") |
|
|
return run_one_batch(train_data_loader, batch_variable) |
|
|
|
|
|
def run_one_batch_eval(batch_variable: BatchVariable): |
|
|
if eval_data_loader is None: |
|
|
raise ValueError("eval_data_loader is None, use `training.do_eval=True` or `training.do_inference=True`.") |
|
|
return run_one_batch(eval_data_loader, batch_variable) |
|
|
|
|
|
def display_one_batch(batch_variable): |
|
|
masks = batch_variable.batch_output.masks |
|
|
generated_captions = batch_variable.batch_output.generated_captions |
|
|
batch = batch_variable.batch_input |
|
|
batch_id = batch_variable.batch_id |
|
|
region_id = batch_variable.region_id |
|
|
|
|
|
batch_size = len(batch["images"]) |
|
|
region_size = len(masks[0]) |
|
|
num_mask_heads = len(masks[0][0]) |
|
|
num_caption_heads = len(generated_captions[0][0]) |
|
|
batch_variable.region_id = (region_id + 1) % region_size |
|
|
if batch_variable.region_id == 0: |
|
|
batch_variable.batch_id = (batch_id + 1) % batch_size |
|
|
if batch_variable.batch_id == 0: |
|
|
print("reached the end of the batch") |
|
|
|
|
|
if isinstance(model_args, SAMCaptionerModelArguments): |
|
|
patches = batch_variable.batch_output.patches[batch_id][region_id] |
|
|
else: |
|
|
|
|
|
patches = [None] * 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return ( |
|
|
( |
|
|
batch["images"][batch_id], |
|
|
[ |
|
|
( |
|
|
i.cpu().numpy(), |
|
|
f"mask-{head_id}:{generated_captions[batch_id][region_id][min(head_id, num_caption_heads - 1)]}", |
|
|
) |
|
|
for head_id, i in enumerate(masks[batch_id][region_id]) |
|
|
] |
|
|
+ [(batch["metadata_input_boxes"][batch_id][region_id].int().tolist(), "box")], |
|
|
), |
|
|
f"batch_id={batch_id}({batch_size}), region_id={region_id}({region_size})", |
|
|
*patches, |
|
|
) |
|
|
|
|
|
with gr.Blocks() as app_main: |
|
|
train_annotated_image = gr.AnnotatedImage(height=500) |
|
|
with gr.Row(): |
|
|
train_patch_images = [gr.Image(height=100) for _ in range(3)] |
|
|
train_batch_output = gr.Variable(BatchVariable()) |
|
|
|
|
|
train_run_button = gr.Button(value="Run one batch") |
|
|
train_run_button_text = gr.Textbox(lines=1, label="train_run_button_text") |
|
|
train_display_button = gr.Button(value="Display one region") |
|
|
train_display_button_text = gr.Textbox(lines=1, label="train_display_button") |
|
|
|
|
|
train_run_button_handle = train_run_button.click( |
|
|
run_one_batch_train, inputs=[train_batch_output], outputs=[train_run_button_text] |
|
|
) |
|
|
train_run_button_handle.then( |
|
|
display_one_batch, |
|
|
inputs=[train_batch_output], |
|
|
outputs=[train_annotated_image, train_display_button_text, *train_patch_images], |
|
|
) |
|
|
train_display_button.click( |
|
|
display_one_batch, |
|
|
inputs=[train_batch_output], |
|
|
outputs=[train_annotated_image, train_display_button_text, *train_patch_images], |
|
|
) |
|
|
|
|
|
eval_annotated_image = gr.AnnotatedImage(height=500) |
|
|
with gr.Row(): |
|
|
eval_patch_images = [gr.Image(height=100) for _ in range(3)] |
|
|
eval_batch_output = gr.Variable(BatchVariable()) |
|
|
|
|
|
eval_run_button = gr.Button(value="Run one batch") |
|
|
eval_run_button_text = gr.Textbox(lines=1, label="eval_run_button") |
|
|
eval_display_button = gr.Button(value="Display one region") |
|
|
eval_display_button_text = gr.Textbox(lines=1, label="eval_display_button") |
|
|
|
|
|
eval_run_button_handle = eval_run_button.click( |
|
|
run_one_batch_eval, inputs=[eval_batch_output], outputs=[eval_run_button_text] |
|
|
) |
|
|
eval_run_button_handle.then( |
|
|
display_one_batch, |
|
|
inputs=[eval_batch_output], |
|
|
outputs=[eval_annotated_image, eval_display_button_text, *eval_patch_images], |
|
|
) |
|
|
eval_display_button.click( |
|
|
display_one_batch, |
|
|
inputs=[eval_batch_output], |
|
|
outputs=[eval_annotated_image, eval_display_button_text, *eval_patch_images], |
|
|
) |
|
|
|
|
|
app_main.launch() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|