Spaces:
Runtime error
Runtime error
File size: 2,649 Bytes
542c815 3f8e328 542c815 9bf688a 542c815 a888400 d6e753e 8a357d1 542c815 4f91b95 8f942dd 542c815 988f91c 542c815 fdc77c7 542c815 1605763 542c815 70974c3 542c815 70974c3 542c815 949892a 542c815 68112f8 542c815 68112f8 d909bca c530952 a6448b5 4ffbad6 a6448b5 c530952 4ffbad6 6ca28a8 d909bca 68112f8 d909bca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from huggingface_hub import hf_hub_download
import gradio as gr
from gradio_imageslider import ImageSlider
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple
net=BriaRMBG()
# model_path = "./model1.pth"
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_path))
net=net.cuda()
else:
net.load_state_dict(torch.load(model_path,map_location="cpu"))
net.eval()
def resize_image(image):
image = image.convert('RGB')
model_input_size = (1024, 1024)
image = image.resize(model_input_size, Image.BILINEAR)
return image
def process(image):
# prepare input
orig_image = Image.fromarray(image)
w,h = orig_im_size = orig_image.size
image = resize_image(orig_image)
im_np = np.array(image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
im_tensor = torch.unsqueeze(im_tensor,0)
im_tensor = torch.divide(im_tensor,255.0)
im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
if torch.cuda.is_available():
im_tensor=im_tensor.cuda()
#inference
result=net(im_tensor)
# post process
result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
ma = torch.max(result)
mi = torch.min(result)
result = (result-mi)/(ma-mi)
# image to pil
im_array = (result*255).cpu().data.numpy().astype(np.uint8)
pil_im = Image.fromarray(np.squeeze(im_array))
# paste the mask on the original image
new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
new_im.paste(orig_image, mask=pil_im)
# new_orig_image = orig_image.convert('RGBA')
return new_im
gr.Markdown("## BRIA RMBG 1.4")
gr.HTML('''
<p style="margin-bottom: 10px; font-size: 94%">
Remove Background from Image for Free<br>
For just upload your image and wait some seconds.
</p>
''')
title = "Remove Background from Image for Free"
description = r"""Remove Background from Image for Free<br>
For upload your image and wait some seconds.
"""
examples = [['./input.jpg'],]
# output = ImageSlider(position=0.5,label='Image without background', type="pil", show_download_button=True)
# demo = gr.Interface(fn=process,inputs="image", outputs=output, examples=examples, title=title, description=description)
demo = gr.Interface(fn=process,inputs="image", outputs="image", examples=examples, title=title, description=description)
if __name__ == "__main__":
demo.launch(share=False) |