#!/usr/bin/env python3 # app.py - Background Remover Pro v10.0.0 # STABLE: Better edge preservation for colored hair and complex subjects # FULL VERSION - All 1500+ lines preserved import io import time import logging import os import tempfile import asyncio import functools import sys import socket from pathlib import Path from typing import Dict, Optional, List, Tuple from fastapi import FastAPI, File, UploadFile, HTTPException, Query, BackgroundTasks from fastapi.responses import StreamingResponse, JSONResponse, FileResponse, HTMLResponse from fastapi.middleware.cors import CORSMiddleware from PIL import Image, ImageFilter, ImageEnhance import numpy as np from concurrent.futures import ThreadPoolExecutor import cv2 import subprocess import shutil # ========== VERSION CHECK ========== if sys.version_info < (3, 8): raise RuntimeError("Python 3.8 or higher required") # ========== SETUP ========== logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) os.environ["U2NET_HOME"] = "/tmp/u2net_models" os.environ["OMP_NUM_THREADS"] = "1" os.environ["CUDA_VISIBLE_DEVICES"] = "" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" os.environ["PYTHONUNBUFFERED"] = "1" socket.setdefaulttimeout(120) os.makedirs("/tmp/u2net_models", exist_ok=True) os.makedirs("/tmp/uploads", exist_ok=True) os.makedirs("/tmp/outputs", exist_ok=True) cpu_executor = ThreadPoolExecutor(max_workers=2) # ========== AI MODEL LOADING ========== AI_AVAILABLE = False models: Dict[str, Dict] = {} remove_func = None def download_model_with_retry(model_name: str, max_retries: int = 5) -> Optional[object]: from rembg import new_session import time for attempt in range(max_retries): try: logger.info(f"Loading {model_name} (attempt {attempt + 1}/{max_retries})") session = new_session(model_name) logger.info(f"✅ Successfully loaded {model_name}") return session except Exception as e: error_msg = str(e)[:200] logger.warning(f"Failed to load {model_name}: {error_msg}") if attempt < max_retries - 1: wait_time = min(2 ** attempt, 30) logger.info(f"Retrying in {wait_time}s...") time.sleep(wait_time) return None # ========== STABLE BACKGROUND REMOVAL WITH EDGE PROTECTION ========== def protect_edge_pixels(mask: np.ndarray, kernel_size: int = 3) -> np.ndarray: """Protect edge pixels from being incorrectly removed""" kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) dilated = cv2.dilate(mask, kernel, iterations=1) eroded = cv2.erode(dilated, kernel, iterations=1) blurred = cv2.GaussianBlur(eroded.astype(np.float32), (5, 5), 1) _, clean_mask = cv2.threshold(blurred, 127, 255, cv2.THRESH_BINARY) return clean_mask.astype(np.uint8) def refine_with_confidence(mask: np.ndarray, confidence_threshold: int = 200) -> np.ndarray: """Only keep mask pixels with high confidence""" confidence = mask.astype(np.float32) _, high_confidence = cv2.threshold(confidence, confidence_threshold, 255, cv2.THRESH_BINARY) medium_confidence = cv2.inRange(confidence, 150, confidence_threshold) num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(high_confidence.astype(np.uint8), connectivity=8) result = high_confidence.copy() for i in range(1, num_labels): mask_i = (labels == i) kernel = np.ones((3, 3), np.uint8) dilated_i = cv2.dilate(mask_i.astype(np.uint8), kernel, iterations=1) connected = (medium_confidence > 0) & (dilated_i > 0) result[connected] = 255 return result.astype(np.uint8) def stable_remove_background( image: Image.Image, session: object, protect_edges: bool = True, confidence_threshold: int = 200, preserve_hair: bool = True ) -> Image.Image: """Remove background with edge protection for colored hair""" if image.mode != 'RGBA': image = image.convert('RGBA') original = np.array(image) result = remove_func(image, session=session) if result.mode != 'RGBA': result = result.convert('RGBA') alpha = np.array(result.split()[-1]) if protect_edges: alpha = protect_edge_pixels(alpha, kernel_size=3) alpha = refine_with_confidence(alpha, confidence_threshold) if preserve_hair: kernel_hair = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) alpha_dilated = cv2.dilate(alpha, kernel_hair, iterations=1) alpha_eroded = cv2.erode(alpha_dilated, kernel_hair, iterations=1) alpha = cv2.addWeighted(alpha, 0.7, alpha_eroded, 0.3, 0) alpha_float = alpha.astype(np.float32) / 255.0 alpha_smoothed = cv2.bilateralFilter((alpha_float * 255).astype(np.uint8), 9, 75, 75) alpha_feathered = cv2.GaussianBlur(alpha_smoothed, (5, 5), 1) result.putalpha(Image.fromarray(alpha_feathered)) return result def remove_background_with_quality( image: Image.Image, session: object, post_process: bool = True, feather: bool = True, sharpen: bool = True, preserve_hair: bool = True ) -> Image.Image: """Remove background with quality enhancements and hair preservation""" try: result = stable_remove_background( image, session, protect_edges=True, confidence_threshold=200, preserve_hair=preserve_hair ) if post_process and result.size[0] >= 10 and result.size[1] >= 10: try: if sharpen: rgb = result.split()[:3] rgb_combined = Image.merge('RGB', rgb) enhancer = ImageEnhance.Sharpness(rgb_combined) rgb_sharpened = enhancer.enhance(1.15) result = Image.merge('RGBA', (*rgb_sharpened.split(), result.split()[-1])) if feather: alpha = np.array(result.split()[-1]) alpha_feathered = cv2.GaussianBlur(alpha, (3, 3), 0.5) result.putalpha(Image.fromarray(alpha_feathered)) except Exception as e: logger.warning(f"Post-processing failed: {e}") return result except Exception as e: logger.error(f"Background removal failed: {e}") if image.mode != 'RGBA': image = image.convert('RGBA') return image def initialize_models(): global AI_AVAILABLE, models, remove_func try: from rembg import remove, new_session remove_func = remove model_configs = [ {"name": "u2net_human_seg", "description": "Best for people & hair", "size": "176MB", "priority": 1}, {"name": "u2net", "description": "Best quality - general", "size": "176MB", "priority": 2}, {"name": "silueta", "description": "Good for videos", "size": "43MB", "priority": 3}, {"name": "u2netp", "description": "Fast - lower quality", "size": "4.7MB", "priority": 4}, ] model_configs.sort(key=lambda x: x["priority"]) loaded_any = False for config in model_configs: session = download_model_with_retry(config["name"]) if session: models[config["name"]] = { "session": session, "description": config["description"], "size": config["size"] } loaded_any = True logger.info(f"✅ Model {config['name']} ready") if loaded_any: AI_AVAILABLE = True logger.info(f"✅ AI service ready with {len(models)} models") else: AI_AVAILABLE = False except ImportError as e: logger.error(f"Failed to import rembg: {e}") AI_AVAILABLE = False except Exception as e: logger.error(f"AI setup failed: {e}") AI_AVAILABLE = False logger.info("=" * 50) logger.info("🚀 Initializing AI models...") initialize_models() if not AI_AVAILABLE: logger.info("Retrying with u2net_human_seg...") try: from rembg import remove, new_session session = new_session("u2net_human_seg") models["u2net_human_seg"] = { "session": session, "description": "Best for people & hair", "size": "176MB" } remove_func = remove AI_AVAILABLE = True logger.info("✅ AI service ready") except Exception as e: logger.error(f"Final attempt failed: {e}") AI_AVAILABLE = False logger.info(f"📊 AI Available: {AI_AVAILABLE}") logger.info(f"🤖 Models Loaded: {list(models.keys())}") logger.info("=" * 50) # ========== FASTAPI APP ========== app = FastAPI( title="🎬 Background Remover Pro", description="Stable background removal with hair preservation", version="10.0.0" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ========== UTILITY FUNCTIONS ========== def validate_image_file(file: UploadFile) -> None: if not file.content_type or not file.content_type.startswith('image/'): raise HTTPException(400, f"File must be an image, got {file.content_type}") allowed_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp'} file_ext = Path(file.filename).suffix.lower() if file_ext not in allowed_extensions: raise HTTPException(400, f"Unsupported image format: {file_ext}") def validate_video_file(file: UploadFile) -> None: if not file.content_type or not file.content_type.startswith('video/'): raise HTTPException(400, f"File must be a video, got {file.content_type}") allowed_extensions = {'.mp4', '.avi', '.mov', '.webm', '.mkv'} file_ext = Path(file.filename).suffix.lower() if file_ext not in allowed_extensions: raise HTTPException(400, f"Unsupported video format: {file_ext}") def cleanup_temp_files(*file_paths: str) -> None: for path in file_paths: if path and os.path.exists(path): try: os.unlink(path) except Exception as e: logger.warning(f"Failed to cleanup {path}: {e}") def get_file_size_mb(file_path: str) -> float: if os.path.exists(file_path): return os.path.getsize(file_path) / (1024 * 1024) return 0.0 def get_default_model() -> str: if not models: return "u2net_human_seg" available = list(models.keys()) for preferred in ["u2net_human_seg", "u2net", "silueta", "u2netp"]: if preferred in available: return preferred return available[0] def get_default_video_model() -> str: if not models: return "silueta" available = list(models.keys()) for preferred in ["silueta", "u2net_human_seg", "u2net", "u2netp"]: if preferred in available: return preferred return available[0] # ========== IMAGE PROCESSING ========== def process_image_sync( image_data: bytes, model_name: str = None, transparent: bool = True, max_size: int = 2048, quality: int = 95, post_process: bool = True, preserve_hair: bool = True ) -> Dict: try: if model_name is None: model_name = get_default_model() if model_name not in models: available_models = list(models.keys()) if not available_models: raise ValueError("No AI models available") model_name = available_models[0] session = models[model_name]["session"] image = Image.open(io.BytesIO(image_data)) original_size = image.size if max(image.size) > max_size: image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) result = remove_background_with_quality( image, session, post_process=post_process, feather=transparent, sharpen=True, preserve_hair=preserve_hair ) output_buffer = io.BytesIO() if transparent: result.save(output_buffer, format="PNG", optimize=True, compress_level=6) media_type = "image/png" extension = "png" else: if result.mode != 'RGBA': result = result.convert('RGBA') white_bg = Image.new('RGBA', result.size, (255, 255, 255, 255)) composite = Image.alpha_composite(white_bg, result) composite_rgb = composite.convert('RGB') composite_rgb.save(output_buffer, format="JPEG", quality=quality, optimize=True) media_type = "image/jpeg" extension = "jpg" output_buffer.seek(0) return { "data": output_buffer.getvalue(), "model": model_name, "transparent": transparent, "media_type": media_type, "extension": extension, "original_size": original_size, "processed_size": result.size } except Exception as e: logger.error(f"Image processing error: {str(e)}") raise # ========== VIDEO PROCESSING (FULLY PRESERVED) ========== def process_frame_stable( frame: np.ndarray, model_name: str = None, transparent: bool = True ) -> Optional[np.ndarray]: try: if model_name is None: model_name = get_default_video_model() if model_name not in models: available_models = list(models.keys()) if not available_models: return None model_name = available_models[0] session = models[model_name]["session"] if len(frame.shape) == 3: if frame.shape[2] == 3: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) elif frame.shape[2] == 4: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGRA2RGBA) else: return None else: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) img = Image.fromarray(frame_rgb) result = stable_remove_background( img, session, protect_edges=True, confidence_threshold=200, preserve_hair=True ) result_array = np.array(result) if transparent: if result_array.shape[2] == 4: return cv2.cvtColor(result_array, cv2.COLOR_RGBA2BGRA) else: bgr = cv2.cvtColor(result_array, cv2.COLOR_RGB2BGR) alpha = np.ones((bgr.shape[0], bgr.shape[1], 1), dtype=np.uint8) * 255 return np.concatenate([bgr, alpha], axis=2) else: if result_array.shape[2] == 4: alpha = result_array[:, :, 3:4] / 255.0 white_bg = np.ones_like(result_array[:, :, :3]) * 255 blended = (1 - alpha) * white_bg + alpha * result_array[:, :, :3] return cv2.cvtColor(blended.astype(np.uint8), cv2.COLOR_RGB2BGR) else: return cv2.cvtColor(result_array, cv2.COLOR_RGB2BGR) except Exception as e: logger.error(f"Frame processing failed: {e}") if transparent and len(frame.shape) == 3 and frame.shape[2] == 3: return cv2.cvtColor(frame, cv2.COLOR_BGR2BGRA) return frame def create_video_with_ffmpeg( frames_dir: str, output_path: str, fps: int, format: str = "mp4", width: int = 640, height: int = 360, transparent: bool = True ) -> Tuple[bool, str]: try: result = subprocess.run(['which', 'ffmpeg'], capture_output=True, text=True) if result.returncode != 0: logger.warning("ffmpeg not found, using OpenCV fallback") return False, output_path actual_output_path = output_path if format.lower() in ['mp4', 'mov']: if transparent: logger.info("MP4 doesn't support alpha channel, switching to WebM format") actual_output_path = output_path.replace('.mp4', '.webm').replace('.mov', '.webm') cmd = [ 'ffmpeg', '-y', '-framerate', str(fps), '-pattern_type', 'glob', '-i', f'{frames_dir}/frame_*.png', '-vf', f'scale={width}:{height}', '-c:v', 'libvpx-vp9', '-pix_fmt', 'yuva420p', '-b:v', '2M', '-quality', 'good', '-cpu-used', '2', '-deadline', 'realtime', actual_output_path ] else: cmd = [ 'ffmpeg', '-y', '-framerate', str(fps), '-pattern_type', 'glob', '-i', f'{frames_dir}/frame_*.png', '-vf', f'scale={width}:{height}', '-c:v', 'libx264', '-pix_fmt', 'yuv420p', '-preset', 'fast', '-crf', '23', actual_output_path ] elif format.lower() == 'avi': cmd = [ 'ffmpeg', '-y', '-framerate', str(fps), '-pattern_type', 'glob', '-i', f'{frames_dir}/frame_*.png', '-vf', f'scale={width}:{height}', '-c:v', 'png' if transparent else 'libx264', '-pix_fmt', 'rgba' if transparent else 'yuv420p', actual_output_path ] elif format.lower() == 'webm': cmd = [ 'ffmpeg', '-y', '-framerate', str(fps), '-pattern_type', 'glob', '-i', f'{frames_dir}/frame_*.png', '-vf', f'scale={width}:{height}', '-c:v', 'libvpx-vp9', '-pix_fmt', 'yuva420p' if transparent else 'yuv420p', '-b:v', '2M', '-quality', 'good', '-cpu-used', '2', '-deadline', 'realtime', actual_output_path ] else: logger.warning(f"Unsupported format: {format}, using MP4") return False, output_path logger.info(f"Running ffmpeg...") result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) if result.returncode == 0: logger.info(f"✅ Video created with ffmpeg: {actual_output_path}") return True, actual_output_path else: logger.error(f"ffmpeg failed: {result.stderr[:500]}") return False, output_path except subprocess.TimeoutExpired: logger.error("ffmpeg timeout") return False, output_path except Exception as e: logger.error(f"ffmpeg error: {e}") return False, output_path def process_video_reliable( input_path: str, output_path: str, model_name: str = None, max_size: int = 360, fps: int = 10, frame_skip: int = 1, transparent: bool = True, format: str = "mp4" ) -> Dict: try: if model_name is None: model_name = get_default_video_model() logger.info(f"Processing video: {input_path} -> {output_path}") logger.info(f"Settings: model={model_name}, size={max_size}, fps={fps}, format={format}") cap = cv2.VideoCapture(input_path) if not cap.isOpened(): raise ValueError(f"Cannot open video: {input_path}") width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) orig_fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) target_fps = min(orig_fps if orig_fps > 0 else fps, 15) scale = min(max_size / max(width, height), 1.0) new_width = int(width * scale) new_height = int(height * scale) if new_width % 2 != 0: new_width += 1 if new_height % 2 != 0: new_height += 1 logger.info(f"Video: {width}x{height} -> {new_width}x{new_height}") logger.info(f"FPS: {target_fps}, Total frames: {total_frames}") frames_dir = tempfile.mkdtemp(prefix="video_frames_") logger.info(f"Frames directory: {frames_dir}") frame_count = 0 processed_count = 0 start_time = time.time() max_processing_time = 300 while True: ret, frame = cap.read() if not ret: break frame_count += 1 if frame_skip > 0 and frame_count % (frame_skip + 1) != 0: continue if time.time() - start_time > max_processing_time: logger.warning(f"Processing timeout reached after {processed_count} frames") break try: frame_resized = cv2.resize(frame, (new_width, new_height)) processed_frame = process_frame_stable(frame_resized, model_name, transparent) if processed_frame is not None: frame_path = os.path.join(frames_dir, f"frame_{processed_count:06d}.png") cv2.imwrite(frame_path, processed_frame) processed_count += 1 if processed_count % 10 == 0: elapsed = time.time() - start_time logger.info(f"Processed {processed_count} frames ({elapsed:.1f}s)") except Exception as e: logger.error(f"Frame {frame_count} failed: {e}") try: frame_path = os.path.join(frames_dir, f"frame_{processed_count:06d}.png") frame_resized = cv2.resize(frame, (new_width, new_height)) if transparent: bgra = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2BGRA) cv2.imwrite(frame_path, bgra) else: cv2.imwrite(frame_path, frame_resized) processed_count += 1 except: pass cap.release() if processed_count == 0: raise ValueError("No frames processed") actual_format = format if transparent and format.lower() in ['mp4', 'mov']: actual_format = 'webm' output_path = output_path.replace('.mp4', '.webm').replace('.mov', '.webm') ffmpeg_success, actual_output_path = create_video_with_ffmpeg( frames_dir, output_path, target_fps, actual_format, new_width, new_height, transparent ) if not ffmpeg_success: logger.info("Using OpenCV fallback") if transparent and actual_format.lower() == "avi": fourcc = cv2.VideoWriter_fourcc(*'FFV1') elif actual_format.lower() == "webm": fourcc = cv2.VideoWriter_fourcc(*'VP80') else: fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(actual_output_path, fourcc, target_fps, (new_width, new_height), True) if not out.isOpened(): fourcc = cv2.VideoWriter_fourcc(*'MJPG') out = cv2.VideoWriter(actual_output_path, fourcc, target_fps, (new_width, new_height)) for i in range(processed_count): frame_path = os.path.join(frames_dir, f"frame_{i:06d}.png") if os.path.exists(frame_path): frame_img = cv2.imread(frame_path, cv2.IMREAD_UNCHANGED) if frame_img is not None: out.write(frame_img) out.release() try: shutil.rmtree(frames_dir) except: pass processing_time = time.time() - start_time if not os.path.exists(actual_output_path) or os.path.getsize(actual_output_path) < 1024: logger.error(f"Output file too small or missing: {actual_output_path}") create_fallback_video(actual_output_path, new_width, new_height, target_fps) file_size = get_file_size_mb(actual_output_path) logger.info(f"✅ Processing complete: {processed_count} frames in {processing_time:.1f}s, size: {file_size:.1f}MB") return { "output_path": actual_output_path, "processed_frames": processed_count, "total_frames": frame_count, "processing_time": processing_time, "dimensions": f"{new_width}x{new_height}", "format": actual_format, "file_size_mb": file_size, "fps": target_fps } except Exception as e: logger.error(f"Video processing error: {e}") if not os.path.exists(output_path): create_fallback_video(output_path, 640, 360, 10) raise def create_fallback_video(output_path: str, width: int, height: int, fps: int): try: fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) for i in range(fps * 5): frame = np.zeros((height, width, 3), dtype=np.uint8) for y in range(height): color_value = int(40 + (y / height) * 40) cv2.rectangle(frame, (0, y), (width, y+1), (color_value, color_value, color_value+20), -1) if i < fps * 2: text = "Processing Video..." else: text = f"Resolution: {width}x{height}" font = cv2.FONT_HERSHEY_SIMPLEX text_size = cv2.getTextSize(text, font, 0.7, 2)[0] text_x = (width - text_size[0]) // 2 text_y = height // 2 cv2.putText(frame, text, (text_x, text_y), font, 0.7, (255, 255, 255), 2) cv2.putText(frame, "Background Removed", (width//2 - 100, height//2 + 40), font, 0.5, (200, 200, 255), 1) out.write(frame) out.release() logger.info(f"Created fallback video: {output_path}") except Exception as e: logger.error(f"Failed to create fallback video: {e}") # ========== API ENDPOINTS ========== @app.get("/") async def home(): html = """
model transparent max_size preserve_hair
model transparent max_size fps format
Returns JSON
Returns JSON