deepspeed / scripts /apps /sca_outputs_visualization_app.py
xingzhikb's picture
init
002bd9b
import sys
sys.path.append(".")
import logging
import os
from typing import Optional
import hydra
import torch
from omegaconf import DictConfig, OmegaConf
from src.arguments import (
global_setup,
SAMCaptionerModelArguments,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers import set_seed
import gradio as gr
from dataclasses import dataclass
import numpy as np
from src.train import prepare_datasets, prepare_model, prepare_data_transform, prepare_processor, prepare_collate_fn
import dotenv
logger = logging.getLogger(__name__)
@hydra.main(version_base="1.3", config_path="../../src/conf", config_name="conf")
def main(args: DictConfig) -> None:
# NOTE(xiaoke): follow https://github.com/huggingface/transformers/blob/main/examples/pytorch/image-classification/run_image_classification.py
logger.info(OmegaConf.to_yaml(args))
args, training_args, model_args = global_setup(args)
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
logger.warning(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"There is no checkpoint in the directory. Or we can resume from `resume_from_checkpoint`."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Set seed before initializing model.
set_seed(args.training.seed)
# Initialize our dataset and prepare it
train_dataset, eval_dataset = prepare_datasets(args)
# NOTE(xiaoke): load sas_key from .env for huggingface model downloading.
logger.info(f"Try to load sas_key from .env file: {dotenv.load_dotenv('.env')}.")
use_auth_token = os.getenv("USE_AUTH_TOKEN", False)
processor = prepare_processor(model_args, use_auth_token)
train_dataset, eval_dataset = prepare_data_transform(
training_args, model_args, train_dataset, eval_dataset, processor
)
if len(eval_dataset) > 1:
raise ValueError(f"Only support one eval dataset, but got {len(eval_dataset)}. args: {args.eval_data}")
eval_dataset = next(iter(eval_dataset.values()))
collate_fn = prepare_collate_fn(training_args, model_args, processor)
# Load the accuracy metric from the datasets package
# metric = evaluate.load("accuracy")
# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float.
# def compute_metrics(p):
# """Computes accuracy on a batch of predictions"""
# return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)
compute_metrics = None
# config = AutoConfig.from_pretrained(
# model_args.config_name or model_args.model_name_or_path,
# num_labels=len(labels),
# label2id=label2id,
# id2label=id2label,
# finetuning_task="image-classification",
# cache_dir=model_args.cache_dir,
# revision=model_args.model_revision,
# use_auth_token=True if model_args.use_auth_token else None,
# )
# model = AutoModelForImageClassification.from_pretrained(
# model_args.model_name_or_path,
# from_tf=bool(".ckpt" in model_args.model_name_or_path),
# config=config,
# cache_dir=model_args.cache_dir,
# revision=model_args.model_revision,
# use_auth_token=True if model_args.use_auth_token else None,
# ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
# )
# image_processor = AutoImageProcessor.from_pretrained(
# model_args.image_processor_name or model_args.model_name_or_path,
# cache_dir=model_args.cache_dir,
# revision=model_args.model_revision,
# use_auth_token=True if model_args.use_auth_token else None,
# )
model = prepare_model(model_args, use_auth_token)
def cycle(iterable):
while True:
for x in iterable:
yield x
if training_args.do_train:
train_data_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=training_args.per_device_train_batch_size, collate_fn=collate_fn
)
train_data_loader = cycle(train_data_loader)
else:
train_data_loader = None
if training_args.do_eval or training_args.do_inference:
eval_data_loader = torch.utils.data.DataLoader(
eval_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=collate_fn
)
eval_data_loader = cycle(eval_data_loader)
else:
eval_data_loader = None
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
dtype = model.dtype
@dataclass
class BatchVariable:
batch_input: Optional[dict] = None
batch_output: Optional[dict] = None
batch_id: int = 0
region_id: int = 0
@torch.no_grad()
def run_one_batch(data_loader, batch_variable: BatchVariable):
batch = next(data_loader)
for k, v in batch.items():
if isinstance(v, torch.Tensor):
if torch.is_floating_point(v):
batch[k] = v.to(device, dtype)
else:
batch[k] = v.to(device)
with torch.inference_mode():
if isinstance(model_args, SAMCaptionerModelArguments):
model_outputs = model.generate(**batch, return_patches=True, return_dict_in_generate=True)
else:
model_outputs = model.generate(**batch)
# add masks to model_outputs
original_sizes = batch["original_sizes"]
reshaped_input_sizes = batch["reshaped_input_sizes"]
pred_masks = model_outputs.pred_masks
masks = processor.post_process_masks(pred_masks, original_sizes, reshaped_input_sizes)
model_outputs.masks = masks
# add generated_captions to model_outputs
batch_size, region_size, num_heads, num_tokens = model_outputs.sequences.shape
generated_captions = processor.tokenizer.batch_decode(
model_outputs.sequences.view(-1, num_tokens), skip_special_tokens=True
)
generated_captions = (
np.array(generated_captions, dtype=object).reshape(batch_size, region_size, num_heads).tolist()
)
model_outputs.generated_captions = generated_captions
batch_variable.batch_input = batch
batch_variable.batch_output = model_outputs
return f"finished running one batch, batch_size={len(batch['images'])}, region_size={len(masks[0])}"
def run_one_batch_train(batch_variable: BatchVariable):
if train_data_loader is None:
raise ValueError("train_data_loader is None, use `training.do_train=True`.")
return run_one_batch(train_data_loader, batch_variable)
def run_one_batch_eval(batch_variable: BatchVariable):
if eval_data_loader is None:
raise ValueError("eval_data_loader is None, use `training.do_eval=True` or `training.do_inference=True`.")
return run_one_batch(eval_data_loader, batch_variable)
def display_one_batch(batch_variable):
masks = batch_variable.batch_output.masks
generated_captions = batch_variable.batch_output.generated_captions
batch = batch_variable.batch_input
batch_id = batch_variable.batch_id
region_id = batch_variable.region_id
batch_size = len(batch["images"])
region_size = len(masks[0])
num_mask_heads = len(masks[0][0])
num_caption_heads = len(generated_captions[0][0])
batch_variable.region_id = (region_id + 1) % region_size
if batch_variable.region_id == 0:
batch_variable.batch_id = (batch_id + 1) % batch_size
if batch_variable.batch_id == 0:
print("reached the end of the batch")
if isinstance(model_args, SAMCaptionerModelArguments):
patches = batch_variable.batch_output.patches[batch_id][region_id]
else:
# NOTE: This will lead to no images displayed.
patches = [None] * 3
# Tuple[numpy.ndarray | PIL.Image | str, List[Tuple[numpy.ndarray | Tuple[int, int, int, int], str]]]
# NOTE: repeat the captions if there are less than 3 heads
# NOTE: shape is list of list of obj, (batch, region, head)
return (
(
batch["images"][batch_id],
[
(
i.cpu().numpy(),
f"mask-{head_id}:{generated_captions[batch_id][region_id][min(head_id, num_caption_heads - 1)]}",
)
for head_id, i in enumerate(masks[batch_id][region_id])
]
+ [(batch["metadata_input_boxes"][batch_id][region_id].int().tolist(), "box")],
),
f"batch_id={batch_id}({batch_size}), region_id={region_id}({region_size})",
*patches,
)
with gr.Blocks() as app_main:
train_annotated_image = gr.AnnotatedImage(height=500)
with gr.Row():
train_patch_images = [gr.Image(height=100) for _ in range(3)]
train_batch_output = gr.Variable(BatchVariable())
train_run_button = gr.Button(value="Run one batch")
train_run_button_text = gr.Textbox(lines=1, label="train_run_button_text")
train_display_button = gr.Button(value="Display one region")
train_display_button_text = gr.Textbox(lines=1, label="train_display_button")
train_run_button_handle = train_run_button.click(
run_one_batch_train, inputs=[train_batch_output], outputs=[train_run_button_text]
)
train_run_button_handle.then(
display_one_batch,
inputs=[train_batch_output],
outputs=[train_annotated_image, train_display_button_text, *train_patch_images],
)
train_display_button.click(
display_one_batch,
inputs=[train_batch_output],
outputs=[train_annotated_image, train_display_button_text, *train_patch_images],
)
eval_annotated_image = gr.AnnotatedImage(height=500)
with gr.Row():
eval_patch_images = [gr.Image(height=100) for _ in range(3)]
eval_batch_output = gr.Variable(BatchVariable())
eval_run_button = gr.Button(value="Run one batch")
eval_run_button_text = gr.Textbox(lines=1, label="eval_run_button")
eval_display_button = gr.Button(value="Display one region")
eval_display_button_text = gr.Textbox(lines=1, label="eval_display_button")
eval_run_button_handle = eval_run_button.click(
run_one_batch_eval, inputs=[eval_batch_output], outputs=[eval_run_button_text]
)
eval_run_button_handle.then(
display_one_batch,
inputs=[eval_batch_output],
outputs=[eval_annotated_image, eval_display_button_text, *eval_patch_images],
)
eval_display_button.click(
display_one_batch,
inputs=[eval_batch_output],
outputs=[eval_annotated_image, eval_display_button_text, *eval_patch_images],
)
app_main.launch()
if __name__ == "__main__":
main()