| | import time |
| | import gc |
| | import torch |
| |
|
| | from PIL import Image |
| | from torchvision import transforms |
| | import gradio as gr |
| |
|
| | from transformers import AutoConfig, AutoModelForImageSegmentation |
| |
|
| | |
| |
|
| | def load_model(): |
| | config = AutoConfig.from_pretrained("zhengpeng7/BiRefNet_lite", trust_remote_code=True) |
| | config.is_encoder_decoder = False |
| |
|
| | |
| | |
| | def dummy_text_config(decoder=True): |
| | class DummyTextConfig: |
| | tie_word_embeddings = False |
| | return DummyTextConfig() |
| |
|
| | |
| | setattr(config, "get_text_config", dummy_text_config) |
| |
|
| | model = AutoModelForImageSegmentation.from_pretrained( |
| | "zhengpeng7/BiRefNet_lite", |
| | config=config, |
| | trust_remote_code=True |
| | ) |
| |
|
| | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| | model.to(device) |
| | model.eval() |
| | return model, device |
| |
|
| | |
| | birefnet, device = load_model() |
| |
|
| | |
| | image_size = (1024, 1024) |
| | transform_image = transforms.Compose([ |
| | transforms.Resize(image_size), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.485, 0.456, 0.406], |
| | [0.229, 0.224, 0.225]) |
| | ]) |
| |
|
| | def run_inference(images, model, device): |
| | inputs = [] |
| | original_sizes = [] |
| | for img in images: |
| | original_sizes.append(img.size) |
| | inputs.append(transform_image(img)) |
| |
|
| | input_tensor = torch.stack(inputs).to(device) |
| | try: |
| | with torch.no_grad(): |
| | |
| | output = model(input_tensor) |
| | |
| | |
| | |
| | |
| | preds = output[-1].sigmoid().cpu() |
| | except torch.OutOfMemoryError: |
| | del input_tensor |
| | torch.cuda.empty_cache() |
| | raise |
| |
|
| | |
| | results = [] |
| | for i, img in enumerate(images): |
| | pred = preds[i].squeeze() |
| | pred_pil = transforms.ToPILImage()(pred) |
| | mask = pred_pil.resize(original_sizes[i]) |
| | result = Image.new("RGBA", original_sizes[i], (0, 0, 0, 0)) |
| | result.paste(img, mask=mask) |
| | results.append(result) |
| |
|
| | |
| | del input_tensor, preds |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| |
|
| | return results |
| |
|
| | def binary_search_max(images): |
| | low, high = 1, len(images) |
| | best, best_count = None, 0 |
| |
|
| | while low <= high: |
| | mid = (low + high) // 2 |
| | batch = images[:mid] |
| | try: |
| | |
| | global birefnet, device |
| | birefnet, device = load_model() |
| | res = run_inference(batch, birefnet, device) |
| | best, best_count = res, mid |
| | low = mid + 1 |
| | except torch.OutOfMemoryError: |
| | high = mid - 1 |
| |
|
| | return best, best_count |
| |
|
| | def extract_objects(filepaths): |
| | images = [Image.open(p).convert("RGB") for p in filepaths] |
| | start_time = time.time() |
| |
|
| | |
| | try: |
| | results = run_inference(images, birefnet, device) |
| | end_time = time.time() |
| | total_time = end_time - start_time |
| | summary = f"Total request time: {total_time:.2f}s\nProcessed {len(images)} images successfully." |
| | return results, summary |
| |
|
| | except torch.OutOfMemoryError: |
| | |
| | oom_time = time.time() |
| | initial_attempt_time = oom_time - start_time |
| | |
| | best, best_count = binary_search_max(images) |
| | end_time = time.time() |
| | total_time = end_time - start_time |
| |
|
| | if best is None: |
| | |
| | summary = ( |
| | f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n" |
| | f"Could not process even a single image.\n" |
| | f"Total time including fallback attempts: {total_time:.2f}s." |
| | ) |
| | return [], summary |
| | else: |
| | summary = ( |
| | f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n" |
| | f"Found that {best_count} images can be processed without OOM.\n" |
| | f"Total time including fallback attempts: {total_time:.2f}s.\n" |
| | f"Next time, try using up to {best_count} images." |
| | ) |
| | return best, summary |
| |
|
| | iface = gr.Interface( |
| | fn=extract_objects, |
| | inputs=gr.Files(label="Upload Multiple Images", type="filepath", file_count="multiple"), |
| | outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")], |
| | title="BiRefNet Bulk Background Removal (with fallback)", |
| | description="Upload multiple images. If OOM occurs, we fallback to smaller batches." |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | iface.launch() |
| |
|