In [None]:
import sys

BASE_DIR = "../../"
sys.path.append(BASE_DIR)

import gradio as gr
from src.models.sam_captioner import SAMCaptionerConfig, SAMCaptionerModel, SAMCaptionerProcessor
import torch
from PIL import Image
import requests
import numpy as np
import time
from transformers import CLIPProcessor, CLIPModel


import logging
import os

import hydra
from hydra.utils import instantiate
from datasets import Dataset, load_dataset, IterableDataset, concatenate_datasets, interleave_datasets
from omegaconf import DictConfig, OmegaConf
from src.data.transforms import SamCaptionerDataTransform, SCADataTransform
from src.data.collator import SamCaptionerDataCollator, SCADataCollator
from src.arguments import (
 Arguments,
 global_setup,
 SAMCaptionerModelArguments,
 SCAModelBaseArguments,
 SCAModelArguments,
 SCADirectDecodingModelArguments,
)
from src.models.sam_captioner import SAMCaptionerConfig, SAMCaptionerModel, SAMCaptionerProcessor
from src.sca_seq2seq_trainer import SCASeq2SeqTrainer
from src.models.sca import ScaModel, ScaConfig, ScaProcessor, ScaDirectDecodingModel
from src.integrations import CustomWandbCallBack, EvaluateFirstStepCallback
import src.models.sca

from transformers.trainer_utils import _re_checkpoint
from transformers import set_seed
import json
from src.train import prepare_datasets, prepare_model, prepare_data_transform, prepare_processor
from hydra import initialize, compose
import subprocess
import dotenv

logger = logging.getLogger(__name__)

model = None
processor = None

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16

In [None]:
# CKPT_PATH=
# python scripts/apps/sca_app.py \
# +model=base_sca_multitask_v2 \
# model.model_name_or_path=$CKPT_PATH \
# model.lm_head_model_name_or_path=$(python scripts/tools/get_sub_model_name_from_ckpt.py $CKPT_PATH "lm")
def get_lm_head_name(cmd_script_path, cmd_ckpt_path):


 command = f'python {cmd_script_path} {cmd_ckpt_path} "lm"'
 # Use subprocess to run the command and capture the output
 process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

 # Get the output
 stdout, stderr = process.communicate()

 # Decode the output from bytes to string
 stdout = stdout.decode('utf-8').strip()
 stderr = stderr.decode('utf-8').strip()
 if stderr != '':
 raise Exception(stderr)

 return stdout

cmd_script_path = "scripts/tools/get_sub_model_name_from_ckpt.py"
cmd_ckpt_path = "amlt/sca-weights.111823/finetune-gpt2_large-lr_1e_4-1xlr-lsj-bs_1-pretrain_1e_4_no_lsj_bs_32.111223.rr1-4x8-v100-32g-pre/checkpoint-100000"
cmd_model = "base_sca_multitask_v2"

cmd_script_path = os.path.join(BASE_DIR, cmd_script_path)
cmd_ckpt_path = os.path.join(BASE_DIR, cmd_ckpt_path)
cmd_lm_head_model_name_or_path = get_lm_head_name(cmd_script_path, cmd_ckpt_path)

with initialize(version_base="1.3", config_path="../../src/conf"):
 args = compose(
 config_name="conf",
 overrides=[
 f"+model={cmd_model}",
 f"model.model_name_or_path={cmd_ckpt_path}",
 f"model.lm_head_model_name_or_path={cmd_lm_head_model_name_or_path}"
 ],
 )


args, training_args, model_args = global_setup(args)

# Set seed before initializing model.
set_seed(args.training.seed)

In [None]:
# NOTE(xiaoke): load sas_key from .env for huggingface model downloading.
logger.info(f"Try to load sas_key from .env file: {dotenv.load_dotenv('.env')}.")
use_auth_token = os.getenv("USE_AUTH_TOKEN", False)

processor = prepare_processor(model_args, use_auth_token)

image_mean, image_std = (
 processor.sam_processor.image_processor.image_mean,
 processor.sam_processor.image_processor.image_std,
)

model = prepare_model(model_args, use_auth_token)
model = model.to(device, dtype)

In [None]:
img_url = "https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg"
input_image = Image.open(requests.get(img_url, stream=True).raw)

In [None]:
input_points = [[[[0, 0]], [[0, 200]], [[200, 200]], [[200, 0]]]]
input_boxes = None

inputs = processor(input_image, input_points=input_points, input_boxes=input_boxes, return_tensors="pt")
for k, v in inputs.items():
 if isinstance(v, torch.Tensor):
 # NOTE(xiaoke): in original clip, dtype is float16
 inputs[k] = v.to(device, dtype if v.dtype == torch.float32 else v.dtype)

In [None]:
multimask_output = False
tic = time.perf_counter()
with torch.inference_mode():
 model_outputs = model.generate(
 **inputs,
 multimask_output=multimask_output,
 pad_token_id=processor.tokenizer.eos_token_id,
 num_beams=3,
 # max_new_tokens=20,
 # return_patches=return_patches,
 # return_dict_in_generate=True,
 )
toc = time.perf_counter()
print(f"Time taken: {(toc - tic)*1000:0.4f} ms")

In [None]:
batch_size, num_masks, num_text_heads, num_tokens = model_outputs.sequences.shape
batch_size_, num_masks, num_mask_heads, *_ = model_outputs.pred_masks.shape

masks = processor.post_process_masks(
 model_outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
 ) # List[(num_masks, num_heads, H, W)]
iou_scores = model_outputs.iou_scores # (batch_size, num_masks, num_heads)
captions = processor.tokenizer.batch_decode(
 model_outputs.sequences.reshape(-1, num_tokens), skip_special_tokens=True
)

In [None]:
import amcg

generator = amcg.ScaAutomaticMaskCaptionGenerator(model, processor)
np_input_image = np.array(input_image)
outputs = generator.generate(np_input_image)

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2


def show_anns(anns):
 if len(anns) == 0:
 return
 sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
 ax = plt.gca()
 ax.set_autoscale_on(False)

 img = np.ones((sorted_anns[0]["segmentation"].shape[0], sorted_anns[0]["segmentation"].shape[1], 4))
 img[:, :, 3] = 0
 for ann in sorted_anns:
 m = ann["segmentation"]
 color_mask = np.concatenate([np.random.random(3), [0.35]])
 img[m] = color_mask
 if "caption" in ann:
 captions: str = ann["caption"]
 # calculate the centroid of the mask
 y, x = np.where(m)
 random_index = np.random.choice(range(len(x)))
 random_position = (x[random_index], y[random_index])
 # display the caption at the centroid of the mask
 ax.text(*random_position, captions, color="white", fontsize=12, ha="center", va="center")
 ax.imshow(img)


plt.figure(figsize=(20, 20))
plt.imshow(input_image)
show_anns(outputs)
plt.axis("off")
plt.show()

In [None]:
plt.figure(figsize=(20,20))
plt.imshow(input_image)
plt.axis('off')
plt.show()

In [None]:
input_image.size