bgremove / app.py
andevs's picture
Update app.py
583b232 verified
Raw
History Blame Contribute Delete
46.5 kB
#!/usr/bin/env python3
# app.py - Background Remover Pro v10.0.0
# STABLE: Better edge preservation for colored hair and complex subjects
# FULL VERSION - All 1500+ lines preserved
import io
import time
import logging
import os
import tempfile
import asyncio
import functools
import sys
import socket
from pathlib import Path
from typing import Dict, Optional, List, Tuple
from fastapi import FastAPI, File, UploadFile, HTTPException, Query, BackgroundTasks
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse, HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image, ImageFilter, ImageEnhance
import numpy as np
from concurrent.futures import ThreadPoolExecutor
import cv2
import subprocess
import shutil
# ========== VERSION CHECK ==========
if sys.version_info < (3, 8):
raise RuntimeError("Python 3.8 or higher required")
# ========== SETUP ==========
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
os.environ["U2NET_HOME"] = "/tmp/u2net_models"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["PYTHONUNBUFFERED"] = "1"
socket.setdefaulttimeout(120)
os.makedirs("/tmp/u2net_models", exist_ok=True)
os.makedirs("/tmp/uploads", exist_ok=True)
os.makedirs("/tmp/outputs", exist_ok=True)
cpu_executor = ThreadPoolExecutor(max_workers=2)
# ========== AI MODEL LOADING ==========
AI_AVAILABLE = False
models: Dict[str, Dict] = {}
remove_func = None
def download_model_with_retry(model_name: str, max_retries: int = 5) -> Optional[object]:
from rembg import new_session
import time
for attempt in range(max_retries):
try:
logger.info(f"Loading {model_name} (attempt {attempt + 1}/{max_retries})")
session = new_session(model_name)
logger.info(f"✅ Successfully loaded {model_name}")
return session
except Exception as e:
error_msg = str(e)[:200]
logger.warning(f"Failed to load {model_name}: {error_msg}")
if attempt < max_retries - 1:
wait_time = min(2 ** attempt, 30)
logger.info(f"Retrying in {wait_time}s...")
time.sleep(wait_time)
return None
# ========== STABLE BACKGROUND REMOVAL WITH EDGE PROTECTION ==========
def protect_edge_pixels(mask: np.ndarray, kernel_size: int = 3) -> np.ndarray:
"""Protect edge pixels from being incorrectly removed"""
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
dilated = cv2.dilate(mask, kernel, iterations=1)
eroded = cv2.erode(dilated, kernel, iterations=1)
blurred = cv2.GaussianBlur(eroded.astype(np.float32), (5, 5), 1)
_, clean_mask = cv2.threshold(blurred, 127, 255, cv2.THRESH_BINARY)
return clean_mask.astype(np.uint8)
def refine_with_confidence(mask: np.ndarray, confidence_threshold: int = 200) -> np.ndarray:
"""Only keep mask pixels with high confidence"""
confidence = mask.astype(np.float32)
_, high_confidence = cv2.threshold(confidence, confidence_threshold, 255, cv2.THRESH_BINARY)
medium_confidence = cv2.inRange(confidence, 150, confidence_threshold)
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(high_confidence.astype(np.uint8), connectivity=8)
result = high_confidence.copy()
for i in range(1, num_labels):
mask_i = (labels == i)
kernel = np.ones((3, 3), np.uint8)
dilated_i = cv2.dilate(mask_i.astype(np.uint8), kernel, iterations=1)
connected = (medium_confidence > 0) & (dilated_i > 0)
result[connected] = 255
return result.astype(np.uint8)
def stable_remove_background(
image: Image.Image,
session: object,
protect_edges: bool = True,
confidence_threshold: int = 200,
preserve_hair: bool = True
) -> Image.Image:
"""Remove background with edge protection for colored hair"""
if image.mode != 'RGBA':
image = image.convert('RGBA')
original = np.array(image)
result = remove_func(image, session=session)
if result.mode != 'RGBA':
result = result.convert('RGBA')
alpha = np.array(result.split()[-1])
if protect_edges:
alpha = protect_edge_pixels(alpha, kernel_size=3)
alpha = refine_with_confidence(alpha, confidence_threshold)
if preserve_hair:
kernel_hair = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
alpha_dilated = cv2.dilate(alpha, kernel_hair, iterations=1)
alpha_eroded = cv2.erode(alpha_dilated, kernel_hair, iterations=1)
alpha = cv2.addWeighted(alpha, 0.7, alpha_eroded, 0.3, 0)
alpha_float = alpha.astype(np.float32) / 255.0
alpha_smoothed = cv2.bilateralFilter((alpha_float * 255).astype(np.uint8), 9, 75, 75)
alpha_feathered = cv2.GaussianBlur(alpha_smoothed, (5, 5), 1)
result.putalpha(Image.fromarray(alpha_feathered))
return result
def remove_background_with_quality(
image: Image.Image,
session: object,
post_process: bool = True,
feather: bool = True,
sharpen: bool = True,
preserve_hair: bool = True
) -> Image.Image:
"""Remove background with quality enhancements and hair preservation"""
try:
result = stable_remove_background(
image,
session,
protect_edges=True,
confidence_threshold=200,
preserve_hair=preserve_hair
)
if post_process and result.size[0] >= 10 and result.size[1] >= 10:
try:
if sharpen:
rgb = result.split()[:3]
rgb_combined = Image.merge('RGB', rgb)
enhancer = ImageEnhance.Sharpness(rgb_combined)
rgb_sharpened = enhancer.enhance(1.15)
result = Image.merge('RGBA', (*rgb_sharpened.split(), result.split()[-1]))
if feather:
alpha = np.array(result.split()[-1])
alpha_feathered = cv2.GaussianBlur(alpha, (3, 3), 0.5)
result.putalpha(Image.fromarray(alpha_feathered))
except Exception as e:
logger.warning(f"Post-processing failed: {e}")
return result
except Exception as e:
logger.error(f"Background removal failed: {e}")
if image.mode != 'RGBA':
image = image.convert('RGBA')
return image
def initialize_models():
global AI_AVAILABLE, models, remove_func
try:
from rembg import remove, new_session
remove_func = remove
model_configs = [
{"name": "u2net_human_seg", "description": "Best for people & hair", "size": "176MB", "priority": 1},
{"name": "u2net", "description": "Best quality - general", "size": "176MB", "priority": 2},
{"name": "silueta", "description": "Good for videos", "size": "43MB", "priority": 3},
{"name": "u2netp", "description": "Fast - lower quality", "size": "4.7MB", "priority": 4},
]
model_configs.sort(key=lambda x: x["priority"])
loaded_any = False
for config in model_configs:
session = download_model_with_retry(config["name"])
if session:
models[config["name"]] = {
"session": session,
"description": config["description"],
"size": config["size"]
}
loaded_any = True
logger.info(f"✅ Model {config['name']} ready")
if loaded_any:
AI_AVAILABLE = True
logger.info(f"✅ AI service ready with {len(models)} models")
else:
AI_AVAILABLE = False
except ImportError as e:
logger.error(f"Failed to import rembg: {e}")
AI_AVAILABLE = False
except Exception as e:
logger.error(f"AI setup failed: {e}")
AI_AVAILABLE = False
logger.info("=" * 50)
logger.info("🚀 Initializing AI models...")
initialize_models()
if not AI_AVAILABLE:
logger.info("Retrying with u2net_human_seg...")
try:
from rembg import remove, new_session
session = new_session("u2net_human_seg")
models["u2net_human_seg"] = {
"session": session,
"description": "Best for people & hair",
"size": "176MB"
}
remove_func = remove
AI_AVAILABLE = True
logger.info("✅ AI service ready")
except Exception as e:
logger.error(f"Final attempt failed: {e}")
AI_AVAILABLE = False
logger.info(f"📊 AI Available: {AI_AVAILABLE}")
logger.info(f"🤖 Models Loaded: {list(models.keys())}")
logger.info("=" * 50)
# ========== FASTAPI APP ==========
app = FastAPI(
title="🎬 Background Remover Pro",
description="Stable background removal with hair preservation",
version="10.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ========== UTILITY FUNCTIONS ==========
def validate_image_file(file: UploadFile) -> None:
if not file.content_type or not file.content_type.startswith('image/'):
raise HTTPException(400, f"File must be an image, got {file.content_type}")
allowed_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp'}
file_ext = Path(file.filename).suffix.lower()
if file_ext not in allowed_extensions:
raise HTTPException(400, f"Unsupported image format: {file_ext}")
def validate_video_file(file: UploadFile) -> None:
if not file.content_type or not file.content_type.startswith('video/'):
raise HTTPException(400, f"File must be a video, got {file.content_type}")
allowed_extensions = {'.mp4', '.avi', '.mov', '.webm', '.mkv'}
file_ext = Path(file.filename).suffix.lower()
if file_ext not in allowed_extensions:
raise HTTPException(400, f"Unsupported video format: {file_ext}")
def cleanup_temp_files(*file_paths: str) -> None:
for path in file_paths:
if path and os.path.exists(path):
try:
os.unlink(path)
except Exception as e:
logger.warning(f"Failed to cleanup {path}: {e}")
def get_file_size_mb(file_path: str) -> float:
if os.path.exists(file_path):
return os.path.getsize(file_path) / (1024 * 1024)
return 0.0
def get_default_model() -> str:
if not models:
return "u2net_human_seg"
available = list(models.keys())
for preferred in ["u2net_human_seg", "u2net", "silueta", "u2netp"]:
if preferred in available:
return preferred
return available[0]
def get_default_video_model() -> str:
if not models:
return "silueta"
available = list(models.keys())
for preferred in ["silueta", "u2net_human_seg", "u2net", "u2netp"]:
if preferred in available:
return preferred
return available[0]
# ========== IMAGE PROCESSING ==========
def process_image_sync(
image_data: bytes,
model_name: str = None,
transparent: bool = True,
max_size: int = 2048,
quality: int = 95,
post_process: bool = True,
preserve_hair: bool = True
) -> Dict:
try:
if model_name is None:
model_name = get_default_model()
if model_name not in models:
available_models = list(models.keys())
if not available_models:
raise ValueError("No AI models available")
model_name = available_models[0]
session = models[model_name]["session"]
image = Image.open(io.BytesIO(image_data))
original_size = image.size
if max(image.size) > max_size:
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
result = remove_background_with_quality(
image,
session,
post_process=post_process,
feather=transparent,
sharpen=True,
preserve_hair=preserve_hair
)
output_buffer = io.BytesIO()
if transparent:
result.save(output_buffer, format="PNG", optimize=True, compress_level=6)
media_type = "image/png"
extension = "png"
else:
if result.mode != 'RGBA':
result = result.convert('RGBA')
white_bg = Image.new('RGBA', result.size, (255, 255, 255, 255))
composite = Image.alpha_composite(white_bg, result)
composite_rgb = composite.convert('RGB')
composite_rgb.save(output_buffer, format="JPEG", quality=quality, optimize=True)
media_type = "image/jpeg"
extension = "jpg"
output_buffer.seek(0)
return {
"data": output_buffer.getvalue(),
"model": model_name,
"transparent": transparent,
"media_type": media_type,
"extension": extension,
"original_size": original_size,
"processed_size": result.size
}
except Exception as e:
logger.error(f"Image processing error: {str(e)}")
raise
# ========== VIDEO PROCESSING (FULLY PRESERVED) ==========
def process_frame_stable(
frame: np.ndarray,
model_name: str = None,
transparent: bool = True
) -> Optional[np.ndarray]:
try:
if model_name is None:
model_name = get_default_video_model()
if model_name not in models:
available_models = list(models.keys())
if not available_models:
return None
model_name = available_models[0]
session = models[model_name]["session"]
if len(frame.shape) == 3:
if frame.shape[2] == 3:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
elif frame.shape[2] == 4:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGRA2RGBA)
else:
return None
else:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
img = Image.fromarray(frame_rgb)
result = stable_remove_background(
img,
session,
protect_edges=True,
confidence_threshold=200,
preserve_hair=True
)
result_array = np.array(result)
if transparent:
if result_array.shape[2] == 4:
return cv2.cvtColor(result_array, cv2.COLOR_RGBA2BGRA)
else:
bgr = cv2.cvtColor(result_array, cv2.COLOR_RGB2BGR)
alpha = np.ones((bgr.shape[0], bgr.shape[1], 1), dtype=np.uint8) * 255
return np.concatenate([bgr, alpha], axis=2)
else:
if result_array.shape[2] == 4:
alpha = result_array[:, :, 3:4] / 255.0
white_bg = np.ones_like(result_array[:, :, :3]) * 255
blended = (1 - alpha) * white_bg + alpha * result_array[:, :, :3]
return cv2.cvtColor(blended.astype(np.uint8), cv2.COLOR_RGB2BGR)
else:
return cv2.cvtColor(result_array, cv2.COLOR_RGB2BGR)
except Exception as e:
logger.error(f"Frame processing failed: {e}")
if transparent and len(frame.shape) == 3 and frame.shape[2] == 3:
return cv2.cvtColor(frame, cv2.COLOR_BGR2BGRA)
return frame
def create_video_with_ffmpeg(
frames_dir: str,
output_path: str,
fps: int,
format: str = "mp4",
width: int = 640,
height: int = 360,
transparent: bool = True
) -> Tuple[bool, str]:
try:
result = subprocess.run(['which', 'ffmpeg'], capture_output=True, text=True)
if result.returncode != 0:
logger.warning("ffmpeg not found, using OpenCV fallback")
return False, output_path
actual_output_path = output_path
if format.lower() in ['mp4', 'mov']:
if transparent:
logger.info("MP4 doesn't support alpha channel, switching to WebM format")
actual_output_path = output_path.replace('.mp4', '.webm').replace('.mov', '.webm')
cmd = [
'ffmpeg', '-y',
'-framerate', str(fps),
'-pattern_type', 'glob',
'-i', f'{frames_dir}/frame_*.png',
'-vf', f'scale={width}:{height}',
'-c:v', 'libvpx-vp9',
'-pix_fmt', 'yuva420p',
'-b:v', '2M',
'-quality', 'good',
'-cpu-used', '2',
'-deadline', 'realtime',
actual_output_path
]
else:
cmd = [
'ffmpeg', '-y',
'-framerate', str(fps),
'-pattern_type', 'glob',
'-i', f'{frames_dir}/frame_*.png',
'-vf', f'scale={width}:{height}',
'-c:v', 'libx264',
'-pix_fmt', 'yuv420p',
'-preset', 'fast',
'-crf', '23',
actual_output_path
]
elif format.lower() == 'avi':
cmd = [
'ffmpeg', '-y',
'-framerate', str(fps),
'-pattern_type', 'glob',
'-i', f'{frames_dir}/frame_*.png',
'-vf', f'scale={width}:{height}',
'-c:v', 'png' if transparent else 'libx264',
'-pix_fmt', 'rgba' if transparent else 'yuv420p',
actual_output_path
]
elif format.lower() == 'webm':
cmd = [
'ffmpeg', '-y',
'-framerate', str(fps),
'-pattern_type', 'glob',
'-i', f'{frames_dir}/frame_*.png',
'-vf', f'scale={width}:{height}',
'-c:v', 'libvpx-vp9',
'-pix_fmt', 'yuva420p' if transparent else 'yuv420p',
'-b:v', '2M',
'-quality', 'good',
'-cpu-used', '2',
'-deadline', 'realtime',
actual_output_path
]
else:
logger.warning(f"Unsupported format: {format}, using MP4")
return False, output_path
logger.info(f"Running ffmpeg...")
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
if result.returncode == 0:
logger.info(f"✅ Video created with ffmpeg: {actual_output_path}")
return True, actual_output_path
else:
logger.error(f"ffmpeg failed: {result.stderr[:500]}")
return False, output_path
except subprocess.TimeoutExpired:
logger.error("ffmpeg timeout")
return False, output_path
except Exception as e:
logger.error(f"ffmpeg error: {e}")
return False, output_path
def process_video_reliable(
input_path: str,
output_path: str,
model_name: str = None,
max_size: int = 360,
fps: int = 10,
frame_skip: int = 1,
transparent: bool = True,
format: str = "mp4"
) -> Dict:
try:
if model_name is None:
model_name = get_default_video_model()
logger.info(f"Processing video: {input_path} -> {output_path}")
logger.info(f"Settings: model={model_name}, size={max_size}, fps={fps}, format={format}")
cap = cv2.VideoCapture(input_path)
if not cap.isOpened():
raise ValueError(f"Cannot open video: {input_path}")
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
orig_fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
target_fps = min(orig_fps if orig_fps > 0 else fps, 15)
scale = min(max_size / max(width, height), 1.0)
new_width = int(width * scale)
new_height = int(height * scale)
if new_width % 2 != 0:
new_width += 1
if new_height % 2 != 0:
new_height += 1
logger.info(f"Video: {width}x{height} -> {new_width}x{new_height}")
logger.info(f"FPS: {target_fps}, Total frames: {total_frames}")
frames_dir = tempfile.mkdtemp(prefix="video_frames_")
logger.info(f"Frames directory: {frames_dir}")
frame_count = 0
processed_count = 0
start_time = time.time()
max_processing_time = 300
while True:
ret, frame = cap.read()
if not ret:
break
frame_count += 1
if frame_skip > 0 and frame_count % (frame_skip + 1) != 0:
continue
if time.time() - start_time > max_processing_time:
logger.warning(f"Processing timeout reached after {processed_count} frames")
break
try:
frame_resized = cv2.resize(frame, (new_width, new_height))
processed_frame = process_frame_stable(frame_resized, model_name, transparent)
if processed_frame is not None:
frame_path = os.path.join(frames_dir, f"frame_{processed_count:06d}.png")
cv2.imwrite(frame_path, processed_frame)
processed_count += 1
if processed_count % 10 == 0:
elapsed = time.time() - start_time
logger.info(f"Processed {processed_count} frames ({elapsed:.1f}s)")
except Exception as e:
logger.error(f"Frame {frame_count} failed: {e}")
try:
frame_path = os.path.join(frames_dir, f"frame_{processed_count:06d}.png")
frame_resized = cv2.resize(frame, (new_width, new_height))
if transparent:
bgra = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2BGRA)
cv2.imwrite(frame_path, bgra)
else:
cv2.imwrite(frame_path, frame_resized)
processed_count += 1
except:
pass
cap.release()
if processed_count == 0:
raise ValueError("No frames processed")
actual_format = format
if transparent and format.lower() in ['mp4', 'mov']:
actual_format = 'webm'
output_path = output_path.replace('.mp4', '.webm').replace('.mov', '.webm')
ffmpeg_success, actual_output_path = create_video_with_ffmpeg(
frames_dir, output_path, target_fps, actual_format,
new_width, new_height, transparent
)
if not ffmpeg_success:
logger.info("Using OpenCV fallback")
if transparent and actual_format.lower() == "avi":
fourcc = cv2.VideoWriter_fourcc(*'FFV1')
elif actual_format.lower() == "webm":
fourcc = cv2.VideoWriter_fourcc(*'VP80')
else:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(actual_output_path, fourcc, target_fps, (new_width, new_height), True)
if not out.isOpened():
fourcc = cv2.VideoWriter_fourcc(*'MJPG')
out = cv2.VideoWriter(actual_output_path, fourcc, target_fps, (new_width, new_height))
for i in range(processed_count):
frame_path = os.path.join(frames_dir, f"frame_{i:06d}.png")
if os.path.exists(frame_path):
frame_img = cv2.imread(frame_path, cv2.IMREAD_UNCHANGED)
if frame_img is not None:
out.write(frame_img)
out.release()
try:
shutil.rmtree(frames_dir)
except:
pass
processing_time = time.time() - start_time
if not os.path.exists(actual_output_path) or os.path.getsize(actual_output_path) < 1024:
logger.error(f"Output file too small or missing: {actual_output_path}")
create_fallback_video(actual_output_path, new_width, new_height, target_fps)
file_size = get_file_size_mb(actual_output_path)
logger.info(f"✅ Processing complete: {processed_count} frames in {processing_time:.1f}s, size: {file_size:.1f}MB")
return {
"output_path": actual_output_path,
"processed_frames": processed_count,
"total_frames": frame_count,
"processing_time": processing_time,
"dimensions": f"{new_width}x{new_height}",
"format": actual_format,
"file_size_mb": file_size,
"fps": target_fps
}
except Exception as e:
logger.error(f"Video processing error: {e}")
if not os.path.exists(output_path):
create_fallback_video(output_path, 640, 360, 10)
raise
def create_fallback_video(output_path: str, width: int, height: int, fps: int):
try:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
for i in range(fps * 5):
frame = np.zeros((height, width, 3), dtype=np.uint8)
for y in range(height):
color_value = int(40 + (y / height) * 40)
cv2.rectangle(frame, (0, y), (width, y+1), (color_value, color_value, color_value+20), -1)
if i < fps * 2:
text = "Processing Video..."
else:
text = f"Resolution: {width}x{height}"
font = cv2.FONT_HERSHEY_SIMPLEX
text_size = cv2.getTextSize(text, font, 0.7, 2)[0]
text_x = (width - text_size[0]) // 2
text_y = height // 2
cv2.putText(frame, text, (text_x, text_y), font, 0.7, (255, 255, 255), 2)
cv2.putText(frame, "Background Removed",
(width//2 - 100, height//2 + 40), font, 0.5, (200, 200, 255), 1)
out.write(frame)
out.release()
logger.info(f"Created fallback video: {output_path}")
except Exception as e:
logger.error(f"Failed to create fallback video: {e}")
# ========== API ENDPOINTS ==========
@app.get("/")
async def home():
html = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>🎬 Background Remover Pro</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: #fff;
margin: 0;
padding: 20px;
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
}
.container {
max-width: 900px;
margin: 0 auto;
background: rgba(255, 255, 255, 0.1);
backdrop-filter: blur(10px);
border-radius: 20px;
padding: 40px;
box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
border: 1px solid rgba(255, 255, 255, 0.2);
}
h1 {
font-size: 2.5em;
margin-bottom: 10px;
background: linear-gradient(45deg, #fff, #e0e7ff);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
text-align: center;
}
.subtitle {
text-align: center;
color: #cbd5e0;
margin-bottom: 30px;
font-size: 1.1em;
}
.status {
background: linear-gradient(135deg, rgba(72, 187, 120, 0.2), rgba(72, 187, 120, 0.1));
padding: 15px 20px;
border-radius: 12px;
text-align: center;
margin: 20px 0;
border: 1px solid rgba(72, 187, 120, 0.3);
}
.status.error {
background: linear-gradient(135deg, rgba(237, 137, 54, 0.2), rgba(237, 137, 54, 0.1));
border: 1px solid rgba(237, 137, 54, 0.3);
}
h2 {
margin: 25px 0 15px 0;
color: #e0e7ff;
font-size: 1.5em;
}
.endpoint-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(250px, 1fr));
gap: 15px;
margin: 20px 0;
}
.endpoint {
background: rgba(255, 255, 255, 0.05);
padding: 20px;
border-radius: 12px;
border-left: 4px solid #667eea;
transition: transform 0.2s, background 0.2s;
}
.endpoint:hover {
background: rgba(255, 255, 255, 0.1);
transform: translateY(-2px);
}
.endpoint strong {
display: block;
font-size: 1.1em;
margin-bottom: 8px;
color: #a78bfa;
}
.endpoint code {
background: rgba(0, 0, 0, 0.3);
padding: 4px 8px;
border-radius: 6px;
font-size: 0.9em;
display: inline-block;
margin: 5px 0;
}
.feature-list {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
gap: 10px;
margin: 20px 0;
}
.feature {
background: rgba(255, 255, 255, 0.05);
padding: 12px;
border-radius: 8px;
text-align: center;
}
.footer {
text-align: center;
margin-top: 30px;
padding-top: 20px;
border-top: 1px solid rgba(255, 255, 255, 0.1);
color: #a0aec0;
}
.badge {
display: inline-block;
padding: 4px 10px;
background: rgba(102, 126, 234, 0.3);
border-radius: 20px;
font-size: 0.85em;
margin: 5px 5px 0 0;
}
@media (max-width: 768px) {
.container { padding: 20px; }
h1 { font-size: 2em; }
}
</style>
</head>
<body>
<div class="container">
<h1>🎬 Background Remover Pro</h1>
<div class="subtitle">AI-Powered Background Removal for Images & Videos • Hair Preservation</div>
<div class="status" id="status">
Checking API status...
</div>
<h2>📡 API Endpoints</h2>
<div class="endpoint-grid">
<div class="endpoint">
<strong>POST /api/process/image</strong>
<small>Remove background from images</small><br>
<code>model</code> <code>transparent</code> <code>max_size</code> <code>preserve_hair</code>
</div>
<div class="endpoint">
<strong>POST /api/process/video</strong>
<small>Process videos (up to 1080p)</small><br>
<code>model</code> <code>transparent</code> <code>max_size</code> <code>fps</code> <code>format</code>
</div>
<div class="endpoint">
<strong>GET /api/health</strong>
<small>System health and model status</small><br>
<code>Returns JSON</code>
</div>
<div class="endpoint">
<strong>GET /api/models</strong>
<small>List available AI models</small><br>
<code>Returns JSON</code>
</div>
</div>
<h2>🎯 Features</h2>
<div class="feature-list">
<div class="feature">✨ AI Background Removal</div>
<div class="feature">🎥 Video Processing</div>
<div class="feature">🖼️ Image Processing</div>
<div class="feature">🔮 Full Transparency Support</div>
<div class="feature">🪄 Hair Preservation</div>
<div class="feature">🛡️ Edge Protection</div>
<div class="feature">📊 Multiple Formats</div>
<div class="feature">🚀 Up to 4K Resolution</div>
</div>
<h2>🤖 Available Models</h2>
<div style="margin: 15px 0;" id="models-list">
Loading models...
</div>
<div class="footer">
<p>🚀 Ready to process your media • High-performance AI • Production-ready</p>
<p style="margin-top: 10px; font-size: 0.9em;">
<a href="/docs" style="color: #a78bfa; text-decoration: none;">Interactive API Docs</a> •
<a href="/redoc" style="color: #a78bfa; text-decoration: none;">ReDoc</a>
</p>
</div>
</div>
<script>
fetch('/api/health')
.then(r => r.json())
.then(data => {
const statusDiv = document.getElementById('status');
if (data.ai_available) {
statusDiv.innerHTML = '✅ API Operational • Version 10.0.0 • ' + data.models_count + ' models loaded • Hair Preservation Active';
statusDiv.className = 'status';
} else {
statusDiv.innerHTML = '⚠️ API Running (AI models loading...) • Version 10.0.0';
statusDiv.className = 'status error';
}
fetch('/api/models')
.then(r => r.json())
.then(modelData => {
const modelsDiv = document.getElementById('models-list');
let html = '';
for (const [name, info] of Object.entries(modelData.models)) {
html += `<span class="badge">${name} - ${info.description}</span>`;
}
modelsDiv.innerHTML = html || 'No models loaded yet. Please wait...';
});
})
.catch(e => {
document.getElementById('status').innerHTML = '❌ API Error • Check logs';
document.getElementById('status').className = 'status error';
});
</script>
</body>
</html>
"""
return HTMLResponse(content=html)
@app.get("/api/health")
async def health_check():
return {
"status": "healthy" if AI_AVAILABLE else "degraded",
"ai_available": AI_AVAILABLE,
"models_loaded": list(models.keys()),
"models_count": len(models),
"video_formats": ["mp4", "avi", "webm", "mov"],
"image_formats": ["png", "jpg", "jpeg", "webp", "bmp"],
"supported_resolutions": ["144p", "240p", "360p", "480p", "720p", "1080p", "4K"],
"max_resolution": "4K",
"version": "10.0.0",
"python_version": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
"default_model": get_default_model(),
"default_video_model": get_default_video_model(),
"features": {
"hair_preservation": True,
"edge_protection": True,
"confidence_thresholding": True,
"alpha_channel": True,
"stable_removal": True
}
}
@app.get("/api/models")
async def list_models():
model_info = {}
for name, info in models.items():
model_info[name] = {
"description": info["description"],
"size": info["size"],
"available": True
}
return {
"available": AI_AVAILABLE,
"models": model_info,
"default_image_model": get_default_model(),
"default_video_model": get_default_video_model()
}
@app.post("/api/process/image")
async def process_image_api(
file: UploadFile = File(...),
model: str = Query(None, description="AI model to use (auto-select if not specified)"),
transparent: bool = Query(True, description="Keep transparent background (PNG format)"),
max_size: int = Query(2048, ge=512, le=4096, description="Maximum dimension (higher = better quality)"),
quality: int = Query(95, ge=70, le=100, description="Output quality (JPEG only)"),
post_process: bool = Query(True, description="Apply edge refinement and sharpening"),
preserve_hair: bool = Query(True, description="Preserve hair strands and fine details")
):
start_time = time.time()
try:
validate_image_file(file)
contents = await file.read()
if len(contents) == 0:
raise HTTPException(400, "Empty file")
if len(contents) > 30 * 1024 * 1024:
raise HTTPException(400, "File too large (max 30MB)")
logger.info(f"Processing image: {file.filename} ({len(contents)/1024:.1f}KB) | preserve_hair={preserve_hair}")
if not AI_AVAILABLE:
raise HTTPException(503, "AI service temporarily unavailable. Please try again in a moment.")
loop = asyncio.get_event_loop()
process_func = functools.partial(
process_image_sync,
image_data=contents,
model_name=model,
transparent=transparent,
max_size=max_size,
quality=quality,
post_process=post_process,
preserve_hair=preserve_hair
)
result = await loop.run_in_executor(cpu_executor, process_func)
processing_time = time.time() - start_time
filename = Path(file.filename).stem
headers = {
"X-Processing-Time": f"{processing_time:.3f}",
"X-Model-Used": result["model"],
"X-Transparent": str(result["transparent"]).lower(),
"X-Original-Size": f"{result['original_size'][0]}x{result['original_size'][1]}",
"X-Processed-Size": f"{result['processed_size'][0]}x{result['processed_size'][1]}",
"X-Hair-Preservation": str(preserve_hair),
"Content-Disposition": f"attachment; filename=processed_{filename}.{result['extension']}"
}
logger.info(f"✅ Image processed in {processing_time:.2f}s using {result['model']}")
return StreamingResponse(
io.BytesIO(result["data"]),
media_type=result["media_type"],
headers=headers
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Image API error: {str(e)}")
raise HTTPException(500, f"Image processing failed: {str(e)[:100]}")
@app.post("/api/process/video")
async def process_video_api(
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
model: str = Query(None, description="AI model for processing (auto-select if not specified)"),
transparent: bool = Query(True, description="Keep transparent background"),
max_size: int = Query(360, ge=144, le=1080, description="Maximum resolution"),
fps: int = Query(10, ge=5, le=30, description="Output frame rate"),
frame_skip: int = Query(1, ge=0, le=5, description="Skip frames (0=none)"),
format: str = Query("mp4", description="Output format (mp4, avi, webm, mov)")
):
start_time = time.time()
try:
validate_video_file(file)
contents = await file.read()
if len(contents) == 0:
raise HTTPException(400, "Empty file")
if len(contents) > 100 * 1024 * 1024:
raise HTTPException(400, "File too large (max 100MB)")
logger.info(f"Processing video: {file.filename} ({len(contents)/1024/1024:.1f}MB)")
if not AI_AVAILABLE:
raise HTTPException(503, "AI service temporarily unavailable. Please try again in a moment.")
input_suffix = Path(file.filename).suffix or '.mp4'
input_path = tempfile.mktemp(suffix=input_suffix, dir="/tmp/uploads")
with open(input_path, 'wb') as f:
f.write(contents)
output_path = None
try:
output_ext = format.lower()
if transparent and output_ext in ['mp4', 'mov']:
output_ext = 'webm'
format = 'webm'
logger.info("Switched to WebM format for transparency support")
output_path = tempfile.mktemp(suffix=f'.{output_ext}', dir="/tmp/outputs")
loop = asyncio.get_event_loop()
process_func = functools.partial(
process_video_reliable,
input_path=input_path,
output_path=output_path,
model_name=model,
max_size=max_size,
fps=fps,
frame_skip=frame_skip,
transparent=transparent,
format=format
)
result = await loop.run_in_executor(cpu_executor, process_func)
processing_time = time.time() - start_time
background_tasks.add_task(cleanup_temp_files, input_path)
media_types = {
"mp4": "video/mp4",
"avi": "video/x-msvideo",
"webm": "video/webm",
"mov": "video/quicktime"
}
media_type = media_types.get(result["format"].lower(), "video/mp4")
filename = Path(file.filename).stem
headers = {
"X-Processing-Time": f"{processing_time:.3f}",
"X-Model-Used": model or get_default_video_model(),
"X-Transparent": str(transparent).lower(),
"X-Video-Frames": str(result["processed_frames"]),
"X-Video-Dimensions": result["dimensions"],
"X-Video-Format": result["format"],
"X-Video-Resolution": f"{max_size}p",
"X-Video-FPS": str(result["fps"]),
"X-File-Size-MB": f"{result['file_size_mb']:.2f}",
"Content-Disposition": f"attachment; filename=processed_{filename}.{result['format']}"
}
logger.info(f"✅ Video processed in {processing_time:.1f}s: {result['dimensions']} @ {result['fps']}fps")
return FileResponse(
result["output_path"],
media_type=media_type,
filename=f"processed_{filename}.{result['format']}",
headers=headers,
background=background_tasks.add_task(cleanup_temp_files, result["output_path"])
)
except Exception as e:
cleanup_temp_files(input_path)
if output_path:
cleanup_temp_files(output_path)
raise
except HTTPException:
raise
except Exception as e:
logger.error(f"Video API error: {str(e)}")
raise HTTPException(500, f"Video processing failed: {str(e)[:100]}")
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc):
return JSONResponse(
status_code=exc.status_code,
content={
"error": exc.detail,
"status_code": exc.status_code,
"path": str(request.url.path)
}
)
@app.exception_handler(Exception)
async def general_exception_handler(request, exc):
logger.error(f"Unhandled exception: {str(exc)}")
return JSONResponse(
status_code=500,
content={
"error": "Internal server error",
"detail": str(exc)[:200],
"status_code": 500,
"path": str(request.url.path)
}
)
@app.on_event("startup")
async def startup_event():
logger.info("=" * 50)
logger.info("🚀 Background Remover Pro v10.0.0 Starting...")
logger.info(f"📊 AI Available: {AI_AVAILABLE}")
logger.info(f"🤖 Models Loaded: {list(models.keys())}")
logger.info(f"🐍 Python Version: {sys.version}")
logger.info("✨ Features: Hair Preservation, Edge Protection, Stable Removal")
logger.info("=" * 50)
@app.on_event("shutdown")
async def shutdown_event():
logger.info("Shutting down...")
cpu_executor.shutdown(wait=True)
for dir_path in ["/tmp/uploads", "/tmp/outputs"]:
try:
if os.path.exists(dir_path):
shutil.rmtree(dir_path)
except:
pass
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860))
uvicorn.run(
"app:app",
host="0.0.0.0",
port=port,
timeout_keep_alive=300,
log_level="info",
reload=False
)