id-making / id-maker /core /pipeline.py
Esmaill1
Initial commit: Combined ID Maker and CodeFormer for HF Spaces
96f4ff4
import os
import time
import argparse
import sys
import torch
from pathlib import Path
from PIL import Image
# Import functions from existing scripts
# We might need to handle the monkeypatch for transformers in process_images
import crop
import process_images
import color_steal
import white_bg
import restoration
def run_pipeline(raw_dir, crop_dir, trans_dir, colored_dir, white_dir, curves_file, restore=False, fidelity=0.5):
start_total = time.time()
# Step 0: Face Restoration
current_raw_dir = raw_dir
if restore:
print("\n" + "="*50)
print("STEP 0: Face Restoration (CodeFormer)")
print("="*50)
restored_dir = os.path.join(os.path.dirname(crop_dir), "restored")
restoration.batch_restore(raw_dir, restored_dir, fidelity=fidelity)
current_raw_dir = restored_dir
# Step 1: Crop
print("\n" + "="*50)
print("STEP 1: Cropping and Face Detection")
print("="*50)
crop.batch_process(current_raw_dir, crop_dir)
# Step 2: Background Removal
print("\n" + "="*50)
print("STEP 2: Background Removal (AI)")
print("="*50)
# Setup model (this is the heavy part)
model, device = process_images.setup_model()
transform = process_images.get_transform()
input_path = Path(crop_dir)
output_path = Path(trans_dir)
output_path.mkdir(parents=True, exist_ok=True)
files = [f for f in input_path.iterdir() if f.suffix.lower() in process_images.ALLOWED_EXTENSIONS]
if not files:
print(f"No images found in {crop_dir} for background removal.")
else:
for idx, file_path in enumerate(files, 1):
try:
print(f"[{idx}/{len(files)}] Removing background: {file_path.name}...", end='', flush=True)
img = Image.open(file_path)
from PIL import ImageOps
img = ImageOps.exif_transpose(img)
img = img.convert('RGB')
result = process_images.remove_background(model, img, transform)
out_name = file_path.stem + "_rmbg.png"
result.save(output_path / out_name, "PNG")
print(" Done.")
except Exception as e:
print(f" Failed! {e}")
# Step 3: Color Grading
print("\n" + "="*50)
print("STEP 3: Color Grading")
print("="*50)
luts = color_steal.load_trained_curves(curves_file)
if not luts:
print(f"Warning: No trained curves found at {curves_file}. Skipping color grading.")
# If no grading, we might want to copy trans to colored or just skip to step 4 using trans_dir
# For simplicity, let's assume we need curves or we skip this step and use trans_dir for step 4
current_input_for_white = trans_dir
else:
color_steal.apply_to_folder(luts, trans_dir, colored_dir)
current_input_for_white = colored_dir
# Step 4: White Background
print("\n" + "="*50)
print("STEP 4: Adding White Background & Finalizing")
print("="*50)
white_bg.add_white_background(current_input_for_white, white_dir)
end_total = time.time()
print("\n" + "="*50)
print(f"PIPELINE COMPLETE in {end_total - start_total:.2f} seconds")
print(f"Final results are in: {os.path.abspath(white_dir)}")
print("="*50)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Full Image Processing Pipeline")
parser.add_argument("--raw", default="raw", help="Folder with raw images")
parser.add_argument("--crop", default="crop", help="Folder for cropped images")
parser.add_argument("--trans", default="trans", help="Folder for transparent images")
parser.add_argument("--colored", default="colored", help="Folder for color-graded images")
parser.add_argument("--white", default="white", help="Folder for final results")
parser.add_argument("--curves", default="trained_curves.npz", help="Pre-trained curves file")
parser.add_argument("--restore", action="store_true", help="Enable face restoration using CodeFormer")
parser.add_argument("--fidelity", type=float, default=0.5, help="CodeFormer fidelity (0-1, lower is more restoration)")
args = parser.parse_args()
# Ensure all directories exist
for d in [args.raw, args.crop, args.trans, args.colored, args.white]:
if not os.path.exists(d):
os.makedirs(d)
print(f"Created directory: {d}")
run_pipeline(args.raw, args.crop, args.trans, args.colored, args.white, args.curves, restore=args.restore, fidelity=args.fidelity)