| import os |
| import sys |
| import argparse |
| import time |
| import traceback |
| from pathlib import Path |
| from PIL import Image, ImageOps |
| import torch |
| from torchvision import transforms |
|
|
| |
| from transformers import configuration_utils |
| _original_get_text_config = configuration_utils.PretrainedConfig.get_text_config |
|
|
| def _patched_get_text_config(self, *args, **kwargs): |
| if not hasattr(self, 'is_encoder_decoder'): |
| self.is_encoder_decoder = False |
| return _original_get_text_config(self, *args, **kwargs) |
|
|
| configuration_utils.PretrainedConfig.get_text_config = _patched_get_text_config |
| |
|
|
| |
| _orig_linspace = torch.linspace |
| def _patched_linspace(*args, **kwargs): |
| t = _orig_linspace(*args, **kwargs) |
| if t.is_meta: |
| return _orig_linspace(*args, **{**kwargs, "device": "cpu"}) |
| return t |
| torch.linspace = _patched_linspace |
| |
|
|
| |
| def patch_birefnet_tied_weights(): |
| try: |
| from transformers import PreTrainedModel |
| |
| |
| def _get_all_tied_weights_keys(self): |
| return getattr(self, "_tied_weights_keys", {}) or {} |
| |
| PreTrainedModel.all_tied_weights_keys = property(_get_all_tied_weights_keys) |
| print("Applied robust BiRefNet tied weights patch") |
| |
| except Exception as e: |
| print(f"Failed to apply BiRefNet tied weights patch: {e}") |
|
|
| patch_birefnet_tied_weights() |
| |
|
|
| from transformers import AutoModelForImageSegmentation, AutoConfig |
| import retouch |
|
|
| |
| try: |
| import devicetorch |
| except ImportError: |
| print("Error: 'devicetorch' not found. Please run this script from the project root or install requirements.") |
| sys.exit(1) |
|
|
| |
| ALLOWED_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.webp', '.bmp'} |
|
|
| def setup_model(): |
| """Load and configure the RMBG-2.0 model""" |
| print("Loading BRIA-RMBG-2.0 model...") |
| |
| |
| device = devicetorch.get(torch) |
| print(f"Device: {device}") |
| |
| if device == 'cpu': |
| torch.set_num_threads(max(1, os.cpu_count() or 1)) |
|
|
| |
| try: |
| print("Loading model config...") |
| config = AutoConfig.from_pretrained("cocktailpeanut/rm", trust_remote_code=True) |
| |
| |
| model = AutoModelForImageSegmentation.from_pretrained( |
| "cocktailpeanut/rm", |
| config=config, |
| trust_remote_code=True, |
| low_cpu_mem_usage=False |
| ) |
| model = devicetorch.to(torch, model) |
| model.eval() |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| traceback.print_exc() |
| sys.exit(1) |
|
|
| |
| if device == 'cpu': |
| print("Applying Dynamic Quantization for CPU speedup...") |
| try: |
| model = torch.quantization.quantize_dynamic( |
| model, {torch.nn.Linear}, dtype=torch.qint8 |
| ) |
| except Exception: |
| pass |
|
|
| return model, device |
|
|
| def get_transform(): |
| """Get the specific image transformation required by the model""" |
| return transforms.Compose([ |
| transforms.Resize((1024, 1024)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| ]) |
|
|
| def remove_background(model, image, transform): |
| """Process a single image""" |
| |
| orig_size = image.size |
| |
| |
| input_tensor = transform(image).unsqueeze(0) |
| input_tensor = devicetorch.to(torch, input_tensor) |
| |
| |
| with torch.inference_mode(): |
| outputs = model(input_tensor) |
| if isinstance(outputs, (list, tuple)): |
| preds = outputs[-1].sigmoid().cpu() |
| else: |
| preds = outputs.sigmoid().cpu() |
| |
| |
| pred = preds[0].squeeze() |
| pred_pil = transforms.ToPILImage()(pred) |
| mask = pred_pil.resize(orig_size) |
| |
| |
| result = image.copy() |
| result.putalpha(mask) |
| |
| |
| devicetorch.empty_cache(torch) |
| |
| return result |
|
|
| def retouch_face(image, sensitivity=3.0, tone_smoothing=0.6): |
| """Wrapper for the surgical retouch logic with detailed logging""" |
| start_time = time.time() |
| try: |
| retouched_img, count = retouch.retouch_image_pil(image, sensitivity, tone_smoothing) |
| duration = (time.time() - start_time) * 1000 |
| print(f"RETOUCH: Success | Blemishes: {count} | Time: {duration:.1f}ms") |
| return retouched_img |
| except Exception as e: |
| print(f"RETOUCH: Failed | Error: {e}") |
| return image |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Batch Background Removal Tool") |
| parser.add_argument('--input', '-i', required=True, help="Input folder containing images") |
| parser.add_argument('--output', '-o', required=True, help="Output folder for processed images") |
| args = parser.parse_args() |
|
|
| input_path = Path(args.input) |
| output_path = Path(args.output) |
|
|
| if not input_path.exists(): |
| print(f"Error: Input folder '{input_path}' does not exist.") |
| sys.exit(1) |
|
|
| |
| output_path.mkdir(parents=True, exist_ok=True) |
|
|
| |
| model, device = setup_model() |
| transform = get_transform() |
| |
| |
| files = [f for f in input_path.iterdir() if f.suffix.lower() in ALLOWED_EXTENSIONS] |
| total = len(files) |
| |
| print(f"\nFound {total} images. Starting processing...") |
| print("-" * 50) |
|
|
| start_time = time.time() |
| for idx, file_path in enumerate(files, 1): |
| try: |
| filename = file_path.name |
| print(f"[{idx}/{total}] Processing {filename}...", end='', flush=True) |
| |
| |
| img = Image.open(file_path) |
| img = ImageOps.exif_transpose(img) |
| img = img.convert('RGB') |
| |
| |
| result = remove_background(model, img, transform) |
| |
| |
| out_name = file_path.stem + "_rmbg.png" |
| out_file = output_path / out_name |
| result.save(out_file, "PNG") |
| |
| print(" Done.") |
| |
| except Exception as e: |
| print(f" Failed! Error: {e}") |
|
|
| duration = time.time() - start_time |
| print("-" * 50) |
| print(f"Finished! Processed {total} images in {duration:.2f} seconds.") |
| print(f"Output saved to: {output_path.absolute()}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|