24yearsold's picture
update: add ComfyUI Node Extension mention to description
b55a1fc verified
import os
import sys
import numpy as np
from tqdm import tqdm
from einops import rearrange, reduce
from utils.io_utils import *
from utils.cv import img_alpha_blending
import sam3
from PIL import Image
from sam3 import build_sam3_image_model
from sam3.model.box_ops import box_xywh_to_cxcywh
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.visualization_utils import draw_box_on_image, normalize_bbox, plot_results
from live2d.scrap_model import VALID_BODY_PARTS_V2
import torch
# turn on tfloat32 for Ampere GPUs
# https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# use bfloat16 for the entire notebook
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
sam3_root = os.path.join(os.path.dirname(sam3.__file__), "..")
bpe_path = f"{sam3_root}/sam3/assets/bpe_simple_vocab_16e6.txt.gz"
model = build_sam3_image_model(bpe_path=bpe_path, checkpoint_path='/home/jlin/repos/live2d_parsing/local_gitclones/sam3/sam3.pt')
save_dir = 'workspace/datasets/l2deval_sam3_ouput'
src = ''
exec_list = load_exec_list(src)
for image_path in tqdm(exec_list):
# image_path = f"/home/jlin/repos/live2d_parsing/workspace/datasets/leo/final.png"
image = Image.open(image_path).convert('RGB')
width, height = image.size
processor = Sam3Processor(model, confidence_threshold=0.5)
inference_state = processor.set_image(image)
saved = osp.join(save_dir, osp.splitext(osp.basename(image_path))[0])
os.makedirs(saved, exist_ok=True)
img_list = []
for tag in VALID_BODY_PARTS_V2:
if tag == 'handwear':
prompt = 'arms,hands'
else:
prompt = tag
processor.reset_all_prompts(inference_state)
inference_state = processor.set_text_prompt(state=inference_state, prompt=prompt)
masks = inference_state['masks']
mask = reduce(inference_state['masks'], 'b c h w -> h w', 'any').to(device='cpu', dtype=torch.float32).numpy()
alpha = mask.astype(np.uint8) * 255
tag_img = np.concatenate([np.array(image), alpha[..., None]], axis=2)
savep = osp.join(saved, tag + '.png')
img_list.append(tag_img)
final_size = (image.height, image.width)
recon = img_alpha_blending(img_list, final_size=final_size)
Image.fromarray(recon).save(osp.join(saved, 'reconstruction.png'))