| import gradio as gr |
| from gradio_image_prompter import ImagePrompter |
| from detectron2.config import LazyConfig, instantiate |
| from detectron2.checkpoint import DetectionCheckpointer |
| import cv2 |
| import numpy as np |
| import torch |
| from huggingface_hub import hf_hub_download |
| import spaces |
|
|
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
| print(f"Is CUDA available: {torch.cuda.is_available()}") |
| if torch.cuda.is_available(): |
| print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") |
| |
| model_choice = { |
| 'SAM': None, |
| 'HQ-SAM': None, |
| 'SAM2': None |
| } |
|
|
| for model_type in model_choice.keys(): |
| model_choice[model_type] = hf_hub_download(repo_id="XiaRho/SEMat", filename=f"SEMat_{model_type}.pth", repo_type="model") |
|
|
| def load_model(model_type='HQ-SAM'): |
| assert model_type in model_choice.keys() |
| config_path = './configs/SEMat_{}.py'.format(model_type) |
| cfg = LazyConfig.load(config_path) |
|
|
| if hasattr(cfg.model.sam_model, 'ckpt_path'): |
| cfg.model.sam_model.ckpt_path = None |
| else: |
| cfg.model.sam_model.checkpoint = None |
| model = instantiate(cfg.model) |
| if model.lora_rank is not None: |
| model.init_lora() |
| model.to(DEVICE) |
| DetectionCheckpointer(model).load(model_choice[model_type]) |
| model.eval() |
| return model, model_type |
|
|
| def transform_image_bbox(prompts): |
| if len(prompts["points"]) != 1: |
| raise gr.Error("Please input only one BBox.", duration=5) |
| [[x1, y1, idx_3, x2, y2, idx_6]] = prompts["points"] |
| if idx_3 != 2 or idx_6 != 3: |
| raise gr.Error("Please input BBox instead of point.", duration=5) |
| x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) |
|
|
| img = prompts["image"] |
| ori_H, ori_W, _ = img.shape |
|
|
| scale = 1024 * 1.0 / max(ori_H, ori_W) |
| new_H, new_W = ori_H * scale, ori_W * scale |
| new_W = int(new_W + 0.5) |
| new_H = int(new_H + 0.5) |
|
|
| img = cv2.resize(img, (new_W, new_H), interpolation=cv2.INTER_LINEAR) |
| padding = np.zeros([1024, 1024, 3], dtype=img.dtype) |
| padding[: new_H, : new_W, :] = img |
| img = padding |
| |
| img = img.transpose((2, 0, 1)).astype(np.float32) / 255.0 |
|
|
| [[x1, y1, _, x2, y2, _]] = prompts["points"] |
| x1, y1, x2, y2 = int(x1 * scale + 0.5), int(y1 * scale + 0.5), int(x2 * scale + 0.5), int(y2 * scale + 0.5) |
| bbox = np.clip(np.array([[x1, y1, x2, y2]]) * 1.0, 0, 1023.0) |
|
|
| return img, bbox, (ori_H, ori_W), (new_H, new_W) |
|
|
|
|
|
|
|
|
| if __name__ == '__main__': |
|
|
| model, model_type = load_model() |
|
|
| @spaces.GPU |
| def inference_image(prompts, input_model_type): |
|
|
| global model_type |
| global model |
|
|
| if input_model_type != model_type: |
| gr.Info('Loading SEMat of {} version.'.format(input_model_type), duration=5) |
| _model, _ = load_model(input_model_type) |
| model_type = input_model_type |
| model = _model |
|
|
| image, bbox, ori_H_W, pad_H_W = transform_image_bbox(prompts) |
| input_data = { |
| 'image': torch.from_numpy(image)[None].to(model.device), |
| 'bbox': torch.from_numpy(bbox)[None].to(model.device), |
| } |
|
|
| with torch.no_grad(): |
| inputs = model.preprocess_inputs(input_data) |
| images, bbox, gt_alpha, trimap, condition = inputs['images'], inputs['bbox'], inputs['alpha'], inputs['trimap'], inputs['condition'] |
|
|
| if model.backbone_condition: |
| condition_proj = model.condition_embedding(condition) |
| elif model.backbone_bbox_prompt is not None or model.bbox_prompt_all_block is not None: |
| condition_proj = bbox |
| else: |
| condition_proj = None |
|
|
| low_res_masks, pred_alphas, pred_trimap, sam_hq_matting_token = model.forward_samhq_and_matting_decoder(images, bbox, condition_proj) |
|
|
|
|
| output_alpha = np.uint8(pred_alphas[0, 0][:pad_H_W[0], :pad_H_W[1], None].repeat(1, 1, 3).cpu().numpy() * 255) |
|
|
| return output_alpha |
|
|
| with gr.Blocks() as demo: |
|
|
| with gr.Row(): |
| with gr.Column(scale=45): |
| img_in = ImagePrompter(type='numpy', show_label=False, label="query image") |
| |
| with gr.Column(scale=45): |
| img_out = gr.Image(type='pil', label="output") |
|
|
| with gr.Row(): |
| with gr.Column(scale=45): |
| input_model_type = gr.Dropdown(list(model_choice.keys()), value='HQ-SAM', label="Trained SEMat Version") |
|
|
| with gr.Column(scale=45): |
| bt = gr.Button() |
|
|
| bt.click(inference_image, inputs=[img_in, input_model_type], outputs=[img_out]) |
|
|
| demo.launch() |
|
|
|
|