{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "BASE_DIR = \"../../\"\n", "sys.path.append(BASE_DIR)\n", "\n", "import gradio as gr\n", "from src.models.sam_captioner import SAMCaptionerConfig, SAMCaptionerModel, SAMCaptionerProcessor\n", "import torch\n", "from PIL import Image\n", "import requests\n", "import numpy as np\n", "import time\n", "from transformers import CLIPProcessor, CLIPModel\n", "\n", "\n", "import logging\n", "import os\n", "\n", "import hydra\n", "from hydra.utils import instantiate\n", "from datasets import Dataset, load_dataset, IterableDataset, concatenate_datasets, interleave_datasets\n", "from omegaconf import DictConfig, OmegaConf\n", "from src.data.transforms import SamCaptionerDataTransform, SCADataTransform\n", "from src.data.collator import SamCaptionerDataCollator, SCADataCollator\n", "from src.arguments import (\n", " Arguments,\n", " global_setup,\n", " SAMCaptionerModelArguments,\n", " SCAModelBaseArguments,\n", " SCAModelArguments,\n", " SCADirectDecodingModelArguments,\n", ")\n", "from src.models.sam_captioner import SAMCaptionerConfig, SAMCaptionerModel, SAMCaptionerProcessor\n", "from src.sca_seq2seq_trainer import SCASeq2SeqTrainer\n", "from src.models.sca import ScaModel, ScaConfig, ScaProcessor, ScaDirectDecodingModel\n", "from src.integrations import CustomWandbCallBack, EvaluateFirstStepCallback\n", "import src.models.sca\n", "\n", "from transformers.trainer_utils import _re_checkpoint\n", "from transformers import set_seed\n", "import json\n", "from src.train import prepare_datasets, prepare_model, prepare_data_transform, prepare_processor\n", "from hydra import initialize, compose\n", "import subprocess\n", "import dotenv\n", "\n", "logger = logging.getLogger(__name__)\n", "\n", "model = None\n", "processor = None" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "dtype = torch.float16" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# CKPT_PATH=\n", "# python scripts/apps/sca_app.py \\\n", "# +model=base_sca_multitask_v2 \\\n", "# model.model_name_or_path=$CKPT_PATH \\\n", "# model.lm_head_model_name_or_path=$(python scripts/tools/get_sub_model_name_from_ckpt.py $CKPT_PATH \"lm\")\n", "def get_lm_head_name(cmd_script_path, cmd_ckpt_path):\n", "\n", "\n", " command = f'python {cmd_script_path} {cmd_ckpt_path} \"lm\"'\n", " # Use subprocess to run the command and capture the output\n", " process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n", "\n", " # Get the output\n", " stdout, stderr = process.communicate()\n", "\n", " # Decode the output from bytes to string\n", " stdout = stdout.decode('utf-8').strip()\n", " stderr = stderr.decode('utf-8').strip()\n", " if stderr != '':\n", " raise Exception(stderr)\n", "\n", " return stdout\n", "\n", "cmd_script_path = \"scripts/tools/get_sub_model_name_from_ckpt.py\"\n", "cmd_ckpt_path = \"amlt/sca-weights.111823/finetune-gpt2_large-lr_1e_4-1xlr-lsj-bs_1-pretrain_1e_4_no_lsj_bs_32.111223.rr1-4x8-v100-32g-pre/checkpoint-100000\"\n", "cmd_model = \"base_sca_multitask_v2\"\n", "\n", "cmd_script_path = os.path.join(BASE_DIR, cmd_script_path)\n", "cmd_ckpt_path = os.path.join(BASE_DIR, cmd_ckpt_path)\n", "cmd_lm_head_model_name_or_path = get_lm_head_name(cmd_script_path, cmd_ckpt_path)\n", "\n", "with initialize(version_base=\"1.3\", config_path=\"../../src/conf\"):\n", " args = compose(\n", " config_name=\"conf\",\n", " overrides=[\n", " f\"+model={cmd_model}\",\n", " f\"model.model_name_or_path={cmd_ckpt_path}\",\n", " f\"model.lm_head_model_name_or_path={cmd_lm_head_model_name_or_path}\"\n", " ],\n", " )\n", "\n", "\n", "args, training_args, model_args = global_setup(args)\n", "\n", "# Set seed before initializing model.\n", "set_seed(args.training.seed)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# NOTE(xiaoke): load sas_key from .env for huggingface model downloading.\n", "logger.info(f\"Try to load sas_key from .env file: {dotenv.load_dotenv('.env')}.\")\n", "use_auth_token = os.getenv(\"USE_AUTH_TOKEN\", False)\n", "\n", "processor = prepare_processor(model_args, use_auth_token)\n", "\n", "image_mean, image_std = (\n", " processor.sam_processor.image_processor.image_mean,\n", " processor.sam_processor.image_processor.image_std,\n", ")\n", "\n", "model = prepare_model(model_args, use_auth_token)\n", "model = model.to(device, dtype)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "img_url = \"https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg\"\n", "input_image = Image.open(requests.get(img_url, stream=True).raw)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "input_points = [[[[0, 0]], [[0, 200]], [[200, 200]], [[200, 0]]]]\n", "input_boxes = None\n", "\n", "inputs = processor(input_image, input_points=input_points, input_boxes=input_boxes, return_tensors=\"pt\")\n", "for k, v in inputs.items():\n", " if isinstance(v, torch.Tensor):\n", " # NOTE(xiaoke): in original clip, dtype is float16\n", " inputs[k] = v.to(device, dtype if v.dtype == torch.float32 else v.dtype)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "multimask_output = False\n", "tic = time.perf_counter()\n", "with torch.inference_mode():\n", " model_outputs = model.generate(\n", " **inputs,\n", " multimask_output=multimask_output,\n", " pad_token_id=processor.tokenizer.eos_token_id,\n", " num_beams=3,\n", " # max_new_tokens=20,\n", " # return_patches=return_patches,\n", " # return_dict_in_generate=True,\n", " )\n", "toc = time.perf_counter()\n", "print(f\"Time taken: {(toc - tic)*1000:0.4f} ms\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "batch_size, num_masks, num_text_heads, num_tokens = model_outputs.sequences.shape\n", "batch_size_, num_masks, num_mask_heads, *_ = model_outputs.pred_masks.shape\n", "\n", "masks = processor.post_process_masks(\n", " model_outputs.pred_masks, inputs[\"original_sizes\"], inputs[\"reshaped_input_sizes\"]\n", " ) # List[(num_masks, num_heads, H, W)]\n", "iou_scores = model_outputs.iou_scores # (batch_size, num_masks, num_heads)\n", "captions = processor.tokenizer.batch_decode(\n", " model_outputs.sequences.reshape(-1, num_tokens), skip_special_tokens=True\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import amcg\n", "\n", "generator = amcg.ScaAutomaticMaskCaptionGenerator(model, processor)\n", "np_input_image = np.array(input_image)\n", "outputs = generator.generate(np_input_image)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "import matplotlib.pyplot as plt\n", "import cv2\n", "\n", "\n", "def show_anns(anns):\n", " if len(anns) == 0:\n", " return\n", " sorted_anns = sorted(anns, key=(lambda x: x[\"area\"]), reverse=True)\n", " ax = plt.gca()\n", " ax.set_autoscale_on(False)\n", "\n", " img = np.ones((sorted_anns[0][\"segmentation\"].shape[0], sorted_anns[0][\"segmentation\"].shape[1], 4))\n", " img[:, :, 3] = 0\n", " for ann in sorted_anns:\n", " m = ann[\"segmentation\"]\n", " color_mask = np.concatenate([np.random.random(3), [0.35]])\n", " img[m] = color_mask\n", " if \"caption\" in ann:\n", " captions: str = ann[\"caption\"]\n", " # calculate the centroid of the mask\n", " y, x = np.where(m)\n", " random_index = np.random.choice(range(len(x)))\n", " random_position = (x[random_index], y[random_index])\n", " # display the caption at the centroid of the mask\n", " ax.text(*random_position, captions, color=\"white\", fontsize=12, ha=\"center\", va=\"center\")\n", " ax.imshow(img)\n", "\n", "\n", "plt.figure(figsize=(20, 20))\n", "plt.imshow(input_image)\n", "show_anns(outputs)\n", "plt.axis(\"off\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(20,20))\n", "plt.imshow(input_image)\n", "plt.axis('off')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "input_image.size" ] } ], "metadata": { "kernelspec": { "display_name": "sca-v2", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 2 }