| import gradio as gr | |
| import torch | |
| from carvekit.api.interface import Interface | |
| from carvekit.ml.wrap.basnet import BASNET | |
| from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 | |
| from carvekit.ml.wrap.fba_matting import FBAMatting | |
| from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 | |
| from carvekit.ml.wrap.u2net import U2NET | |
| from carvekit.pipelines.postprocessing import MattingMethod | |
| from carvekit.pipelines.preprocessing import PreprocessingStub | |
| from carvekit.trimap.generator import TrimapGenerator | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| segment_net = { | |
| "U2NET": U2NET(device=device, batch_size=1), | |
| "BASNET": BASNET(device=device, batch_size=1), | |
| "DeepLabV3": DeepLabV3(device=device, batch_size=1), | |
| "TracerUniversalB7": TracerUniversalB7(device=device, batch_size=1) | |
| } | |
| fba = FBAMatting(device=device, | |
| input_tensor_size=2048, | |
| batch_size=1) | |
| trimap = TrimapGenerator() | |
| preprocessing = PreprocessingStub() | |
| postprocessing = MattingMethod(matting_module=fba, | |
| trimap_generator=trimap, | |
| device=device) | |
| method_choices = [k for k, v in segment_net.items()] | |
| def generate_trimap(method, original): | |
| mask = segment_net[method]([original]) | |
| return trimap(original_image=original, mask=mask[0]) | |
| def predict(method, image): | |
| method = segment_net[method] | |
| return Interface(pre_pipe=preprocessing, | |
| post_pipe=postprocessing, | |
| seg_pipe=method)([image])[0] | |
| footer = r""" | |
| <center> | |
| <img src='https://raw.githubusercontent.com/leonelhs/image-background-remove-tool/master/docs/imgs/logo.png' alt='CarveKit' width="200" height="80"> | |
| </br> | |
| <b> | |
| Demo based on <a href='https://github.com/OPHoperHPO/image-background-remove-tool'>CarveKit</a> | |
| </b> | |
| </center> | |
| """ | |
| with gr.Blocks(title="CarveKit") as app: | |
| gr.Markdown("<center><h1><b>CarveKit</b></h1></center>") | |
| gr.HTML("<center><h3>High-quality image background removal</h3></center>") | |
| with gr.Tabs() as tabs: | |
| with gr.TabItem("Remove background", id=0): | |
| with gr.Row(equal_height=False): | |
| with gr.Column(): | |
| input_img = gr.Image(type="pil", label="Input image") | |
| drp_itf = gr.Dropdown( | |
| value="TracerUniversalB7", | |
| label="Segmentor model", | |
| choices=method_choices) | |
| run_btn = gr.Button(variant="primary") | |
| with gr.Column(): | |
| output_img = gr.Image(type="pil", label="result") | |
| run_btn.click(predict, [drp_itf, input_img], [output_img]) | |
| with gr.TabItem("Trimap generator", id=1): | |
| with gr.Row(equal_height=False): | |
| with gr.Column(): | |
| trimap_input = gr.Image(type="pil", label="Input image") | |
| drp_itf = gr.Dropdown( | |
| value="TracerUniversalB7", | |
| label="Segmentor model", | |
| choices=method_choices) | |
| trimap_btn = gr.Button(variant="primary") | |
| with gr.Column(): | |
| trimap_output = gr.Image(type="pil", label="result") | |
| trimap_btn.click(generate_trimap, [drp_itf, trimap_input], [trimap_output]) | |
| with gr.Row(): | |
| gr.HTML(footer) | |
| app.queue() | |
| app.launch(share=False, debug=True, show_error=True) | |