Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import numpy as np | |
| from tqdm import tqdm | |
| from einops import rearrange, reduce | |
| from utils.io_utils import * | |
| from utils.cv import img_alpha_blending | |
| import sam3 | |
| from PIL import Image | |
| from sam3 import build_sam3_image_model | |
| from sam3.model.box_ops import box_xywh_to_cxcywh | |
| from sam3.model.sam3_image_processor import Sam3Processor | |
| from sam3.visualization_utils import draw_box_on_image, normalize_bbox, plot_results | |
| from live2d.scrap_model import VALID_BODY_PARTS_V2 | |
| import torch | |
| # turn on tfloat32 for Ampere GPUs | |
| # https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| # use bfloat16 for the entire notebook | |
| torch.autocast("cuda", dtype=torch.bfloat16).__enter__() | |
| sam3_root = os.path.join(os.path.dirname(sam3.__file__), "..") | |
| bpe_path = f"{sam3_root}/sam3/assets/bpe_simple_vocab_16e6.txt.gz" | |
| model = build_sam3_image_model(bpe_path=bpe_path, checkpoint_path='/home/jlin/repos/live2d_parsing/local_gitclones/sam3/sam3.pt') | |
| save_dir = 'workspace/datasets/l2deval_sam3_ouput' | |
| src = '' | |
| exec_list = load_exec_list(src) | |
| for image_path in tqdm(exec_list): | |
| # image_path = f"/home/jlin/repos/live2d_parsing/workspace/datasets/leo/final.png" | |
| image = Image.open(image_path).convert('RGB') | |
| width, height = image.size | |
| processor = Sam3Processor(model, confidence_threshold=0.5) | |
| inference_state = processor.set_image(image) | |
| saved = osp.join(save_dir, osp.splitext(osp.basename(image_path))[0]) | |
| os.makedirs(saved, exist_ok=True) | |
| img_list = [] | |
| for tag in VALID_BODY_PARTS_V2: | |
| if tag == 'handwear': | |
| prompt = 'arms,hands' | |
| else: | |
| prompt = tag | |
| processor.reset_all_prompts(inference_state) | |
| inference_state = processor.set_text_prompt(state=inference_state, prompt=prompt) | |
| masks = inference_state['masks'] | |
| mask = reduce(inference_state['masks'], 'b c h w -> h w', 'any').to(device='cpu', dtype=torch.float32).numpy() | |
| alpha = mask.astype(np.uint8) * 255 | |
| tag_img = np.concatenate([np.array(image), alpha[..., None]], axis=2) | |
| savep = osp.join(saved, tag + '.png') | |
| img_list.append(tag_img) | |
| final_size = (image.height, image.width) | |
| recon = img_alpha_blending(img_list, final_size=final_size) | |
| Image.fromarray(recon).save(osp.join(saved, 'reconstruction.png')) | |