id-making / id-maker /core /process_images.py
Esmaill1
Initial commit: Combined ID Maker and CodeFormer for HF Spaces
96f4ff4
import os
import sys
import argparse
import time
import traceback
from pathlib import Path
from PIL import Image, ImageOps
import torch
from torchvision import transforms
# ---- Monkeypatch for transformers 4.50+ compatibility with custom Config classes ----
from transformers import configuration_utils
_original_get_text_config = configuration_utils.PretrainedConfig.get_text_config
def _patched_get_text_config(self, *args, **kwargs):
if not hasattr(self, 'is_encoder_decoder'):
self.is_encoder_decoder = False
return _original_get_text_config(self, *args, **kwargs)
configuration_utils.PretrainedConfig.get_text_config = _patched_get_text_config
# ---- End Monkeypatch ----
# ---- Monkeypatch for BiRefNet/RMBG-2.0 meta-tensor bug during initialization ----
_orig_linspace = torch.linspace
def _patched_linspace(*args, **kwargs):
t = _orig_linspace(*args, **kwargs)
if t.is_meta:
return _orig_linspace(*args, **{**kwargs, "device": "cpu"})
return t
torch.linspace = _patched_linspace
# ---- End Monkeypatch ----
# ---- Monkeypatch for BiRefNet tied weights compatibility with transformers 4.50+ ----
def patch_birefnet_tied_weights():
try:
from transformers import PreTrainedModel
# Force the property to always return a dict, even if _tied_weights_keys is None
def _get_all_tied_weights_keys(self):
return getattr(self, "_tied_weights_keys", {}) or {}
PreTrainedModel.all_tied_weights_keys = property(_get_all_tied_weights_keys)
print("Applied robust BiRefNet tied weights patch")
except Exception as e:
print(f"Failed to apply BiRefNet tied weights patch: {e}")
patch_birefnet_tied_weights()
# ---- End Monkeypatch ----
from transformers import AutoModelForImageSegmentation, AutoConfig
import retouch
# Try to import devicetorch (from your project dependencies)
try:
import devicetorch
except ImportError:
print("Error: 'devicetorch' not found. Please run this script from the project root or install requirements.")
sys.exit(1)
# Configure allowed extensions
ALLOWED_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.webp', '.bmp'}
def setup_model():
"""Load and configure the RMBG-2.0 model"""
print("Loading BRIA-RMBG-2.0 model...")
# 1. Device Selection
device = devicetorch.get(torch)
print(f"Device: {device}")
if device == 'cpu':
torch.set_num_threads(max(1, os.cpu_count() or 1))
# 2. Load Model
try:
print("Loading model config...")
config = AutoConfig.from_pretrained("cocktailpeanut/rm", trust_remote_code=True)
# Explicitly set low_cpu_mem_usage=False to avoid meta-tensor issues
model = AutoModelForImageSegmentation.from_pretrained(
"cocktailpeanut/rm",
config=config,
trust_remote_code=True,
low_cpu_mem_usage=False
)
model = devicetorch.to(torch, model)
model.eval()
except Exception as e:
print(f"Error loading model: {e}")
traceback.print_exc()
sys.exit(1)
# 3. CPU Optimization (Optional)
if device == 'cpu':
print("Applying Dynamic Quantization for CPU speedup...")
try:
model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
except Exception:
pass
return model, device
def get_transform():
"""Get the specific image transformation required by the model"""
return transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def remove_background(model, image, transform):
"""Process a single image"""
# Keep original size for later resizing
orig_size = image.size
# Preprocess
input_tensor = transform(image).unsqueeze(0)
input_tensor = devicetorch.to(torch, input_tensor)
# Inference
with torch.inference_mode():
outputs = model(input_tensor)
if isinstance(outputs, (list, tuple)):
preds = outputs[-1].sigmoid().cpu()
else:
preds = outputs.sigmoid().cpu()
# Post-process mask
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(orig_size)
# Apply mask
result = image.copy()
result.putalpha(mask)
# Cleanup VRAM if needed
devicetorch.empty_cache(torch)
return result
def retouch_face(image, sensitivity=3.0, tone_smoothing=0.6):
"""Wrapper for the surgical retouch logic with detailed logging"""
start_time = time.time()
try:
retouched_img, count = retouch.retouch_image_pil(image, sensitivity, tone_smoothing)
duration = (time.time() - start_time) * 1000
print(f"RETOUCH: Success | Blemishes: {count} | Time: {duration:.1f}ms")
return retouched_img
except Exception as e:
print(f"RETOUCH: Failed | Error: {e}")
return image
def main():
parser = argparse.ArgumentParser(description="Batch Background Removal Tool")
parser.add_argument('--input', '-i', required=True, help="Input folder containing images")
parser.add_argument('--output', '-o', required=True, help="Output folder for processed images")
args = parser.parse_args()
input_path = Path(args.input)
output_path = Path(args.output)
if not input_path.exists():
print(f"Error: Input folder '{input_path}' does not exist.")
sys.exit(1)
# Create output folder if it doesn't exist
output_path.mkdir(parents=True, exist_ok=True)
# Setup
model, device = setup_model()
transform = get_transform()
# Process files
files = [f for f in input_path.iterdir() if f.suffix.lower() in ALLOWED_EXTENSIONS]
total = len(files)
print(f"\nFound {total} images. Starting processing...")
print("-" * 50)
start_time = time.time()
for idx, file_path in enumerate(files, 1):
try:
filename = file_path.name
print(f"[{idx}/{total}] Processing {filename}...", end='', flush=True)
# Load image and handle orientation
img = Image.open(file_path)
img = ImageOps.exif_transpose(img)
img = img.convert('RGB')
# Process
result = remove_background(model, img, transform)
# Save (force PNG for transparency)
out_name = file_path.stem + "_rmbg.png"
out_file = output_path / out_name
result.save(out_file, "PNG")
print(" Done.")
except Exception as e:
print(f" Failed! Error: {e}")
duration = time.time() - start_time
print("-" * 50)
print(f"Finished! Processed {total} images in {duration:.2f} seconds.")
print(f"Output saved to: {output_path.absolute()}")
if __name__ == "__main__":
main()