| import spaces |
| import gradio as gr |
| import torch |
| import os |
| import sys |
| from loadimg import load_img |
| from ben_base import BEN_Base |
| import random |
| import huggingface_hub |
| import numpy as np |
|
|
| def set_random_seed(seed): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
| set_random_seed(9) |
| torch.set_float32_matmul_precision("high") |
|
|
| model = BEN_Base() |
| |
| model_path = huggingface_hub.hf_hub_download( |
| repo_id="PramaLLC/BEN2", |
| filename="BEN2_Base.pth" |
| ) |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
|
|
| |
| model.loadcheckpoints(model_path) |
| model.to(device) |
| model.eval() |
|
|
| output_folder = 'output_images' |
| if not os.path.exists(output_folder): |
| os.makedirs(output_folder) |
|
|
| def fn(image): |
| im = load_img(image, output_type="pil") |
| im = im.convert("RGB") |
| result_image = process(im) |
| image_path = os.path.join(output_folder, "foreground.png") |
| result_image.save(image_path) |
| return result_image, image_path |
|
|
|
|
| @spaces.GPU |
| def process_video(video_path): |
| output_path = "./foreground.mp4" |
| |
| |
| |
| |
| model.segment_video(video_path) |
| return output_path |
|
|
| @spaces.GPU |
| def process(image): |
| foreground = model.inference(image) |
| print(type(foreground)) |
| return foreground |
|
|
| def process_file(f): |
| name_path = f.rsplit(".",1)[0]+".png" |
| im = load_img(f, output_type="pil") |
| im = im.convert("RGB") |
| transparent = process(im) |
| transparent.save(name_path) |
| return name_path |
|
|
| |
| image = gr.Image(label="Upload an image") |
| video = gr.Video(label="Upload a video") |
|
|
| current_dir = os.path.dirname(os.path.abspath(__file__)) |
| image_path = os.path.join(current_dir, "image.jpg") |
| examples = load_img(image_path, output_type="pil") |
|
|
| |
| tab1 = gr.Interface( |
| fn, |
| inputs=image, |
| outputs=[ |
| gr.Image(label="Result Foreground"), |
| gr.File(label="Download PNG") |
| ], |
| examples=[examples], |
| api_name="image" |
| ) |
|
|
| |
| tab2 = gr.Interface( |
| process_video, |
| inputs=video, |
| outputs=gr.Video(label="Result Video"), |
| api_name="video", |
| title="Video Processing (experimental)", |
| description="Note: For ZeroGPU timeout, videos are limited to processing the first 100 frames only." |
| ) |
|
|
| |
| demo = gr.TabbedInterface( |
| [tab1, tab2], |
| ["Image Processing", "Video Processing"], |
| title="BEN2 for background removal. Download the image/video for higher quality foreground.", |
| |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(show_error=True) |