Compositor / main.py
Khan19970's picture
Update main.py
ae8fd5e verified
import io
import logging
import traceback
import numpy as np
import cv2
import torch
from PIL import Image, ImageEnhance
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from skimage import color
# ──────────────────────────────────────────────────────────────────
# INITIALIZATION & CONFIG
# ──────────────────────────────────────────────────────────────────
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
app = FastAPI(title="Automotive Compositor API - Spyne Pro Edition", version="5.0.0")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
_models: dict = {}
# ──────────────────────────────────────────────────────────────────
# MODEL MANAGEMENT (Optimized for HuggingFace Spaces)
# ──────────────────────────────────────────────────────────────────
def get_model(name: str):
"""Lazy-loads models into GPU/CPU memory to optimize deployment."""
if name not in _models:
device = "cuda" if torch.cuda.is_available() else "cpu"
if name == "birefnet":
logger.info("Loading BiRefNet for Segmentation...")
from transformers import AutoModelForImageSegmentation
from torchvision import transforms
# Model load karne ke baad explicitly float32 par force karen, aur cuda agar available ho to
model = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet_dynamic", trust_remote_code=True)
model.to(device).eval().float() # Force FP32 to avoid runtime mismatch errors
transform = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
_models[name] = {"model": model, "transform": transform, "device": device}
elif name == "yolo_cls":
logger.info("Loading YOLOv8 Classification...")
from ultralytics import YOLO
# YOLO ko load karte hi device par bhein aur float32 par force karen
model = YOLO("yolov8n-cls.pt")
model.to(device).float() # Force FP32
_models[name] = {"model": model}
elif name == "depth":
logger.info("Loading Depth Estimator (MiDaS DPT)...")
# trust_repo=True to avoid security prompt, force to float32
midas = torch.hub.load("intel-isl/MiDaS", "MiDaS", trust_repo=True)
midas.to(device).eval().float() # Force FP32
transforms = torch.hub.load("intel-isl/MiDaS", "transforms", trust_repo=True)
transform = transforms.default_transform
_models[name] = {"model": midas, "transform": transform, "device": device}
return _models[name]
# ──────────────────────────────────────────────────────────────────
# INTELLIGENCE & GEOMETRY
# ──────────────────────────────────────────────────────────────────
def classify_vehicle(pil_img: Image.Image) -> str:
"""Identifies high clearance (SUV/Truck) vs low (Sedan/Sports) for shadow calibration."""
try:
bundle = get_model("yolo_cls")
bundle["model"].model.float() # Safety cast
results = bundle["model"](pil_img, half=False, verbose=False)
top_class = results[0].probs.top1
class_name = results[0].names[top_class].lower()
high_clearance_keywords = ['suv', 'truck', 'pickup', 'bus', 'van', 'jeep']
return "high" if any(x in class_name for x in high_clearance_keywords) else "low"
except Exception as e:
logger.warning(f"Classification failed: {e}. Defaulting to low clearance.")
return "low"
def refine_mask(mask: np.ndarray) -> np.ndarray:
"""Anti-aliasing, edge feathering, and morphological cleanup for production-grade cutouts."""
# Ensure binary format for morphology
binary_mask = (mask > 128).astype(np.uint8) * 255
# Morphological closing for internal hole preservation (wheels/grille)
kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
closed_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel_close)
# Morphological opening for floating artifact removal
kernel_open = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
cleaned_mask = cv2.morphologyEx(closed_mask, cv2.MORPH_OPEN, kernel_open)
float_mask = cleaned_mask.astype(np.float32) / 255.0
# Feather edges gently with Gaussian blur
blurred_mask = cv2.GaussianBlur(float_mask, (3, 3), 0)
# Increase edge crispness
refined_alpha = np.power(blurred_mask, 1.2) * 255
return np.clip(refined_alpha, 0, 255).astype(np.uint8)
def estimate_ground_plane(bg_pil: Image.Image) -> float:
"""Uses Depth Estimation to locate the physical ground plane vanishing point."""
try:
bundle = get_model("depth")
img_cv = cv2.cvtColor(np.array(bg_pil), cv2.COLOR_RGB2BGR)
# explicitly cast input to float
input_batch = bundle["transform"](img_cv).to(bundle["device"]).float()
with torch.no_grad():
bundle["model"].float() # Safety cast
prediction = bundle["model"](input_batch)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=bg_pil.size[::-1],
mode="bicubic",
align_corners=False,
).squeeze()
depth_map = prediction.cpu().numpy()
# Ground plane usually has a smooth depth gradient in the lower half
h, w = depth_map.shape
lower_half = depth_map[int(h*0.5):, :]
gradient_y = cv2.Sobel(lower_half, cv2.CV_64F, 0, 1, ksize=3)
# Highest vertical gradient area indicates ground perspective change
y_profile = np.abs(gradient_y).mean(axis=1)
peak_y = np.argmax(y_profile) + int(h*0.5)
# Place car slightly below the horizon transition
return min((peak_y + int(h*0.1)) / h, 0.90)
except Exception as e:
logger.warning(f"Depth estimation failed: {e}. Falling back to 0.85.")
return 0.85
def apply_multiply_shadow(bg_rgb: np.ndarray, shadow_mask: np.ndarray, base_color: tuple = (20, 25, 30)) -> np.ndarray:
"""Applies multiply blending for physically accurate shadows."""
# Ensure background is also float to avoid overflow
bg_float = bg_rgb.astype(np.float32)
alpha = np.clip(shadow_mask.astype(np.float32) / 255.0, 0, 1)
shadow_rgb = np.full_like(bg_rgb, base_color, dtype=np.float32)
# Multiply Formula: BG * ( (1 - Alpha) + (Shadow_Color/255 * Alpha) )
# This maintains background texture inside the shadow.
normalized_shadow = shadow_rgb / 255.0
multiply_factor = (1.0 - alpha[:, :, None]) + (normalized_shadow * alpha[:, :, None])
result = bg_float * multiply_factor
return np.clip(result, 0, 255).astype(np.uint8)
# ──────────────────────────────────────────────────────────────────
# CORE PIPELINE ENGINE
# ──────────────────────────────────────────────────────────────────
def harmonized_color_lab(car_rgb: np.ndarray, car_mask: np.ndarray, bg_rgb: np.ndarray) -> np.ndarray:
"""
Advanced LAB Color Space Relighting Engine.
Matched ambient color, temperature, and luminance, avoiding 'pasted' cutout.
"""
# Extract background ambient lighting (bottom 40% where car will sit)
bg_h, bg_w = bg_rgb.shape[:2]
bg_ambient_zone = bg_rgb[int(bg_h*0.6):, :]
car_lab = color.rgb2lab(car_rgb)
bg_lab = color.rgb2lab(bg_ambient_zone)
# Calculate statistics
car_pixels = car_lab[car_mask > 0.5]
if len(car_pixels) == 0: return car_rgb
car_l_mean, car_a_mean, car_b_mean = np.mean(car_pixels, axis=0)
bg_l_mean, bg_a_mean, bg_b_mean = np.mean(bg_lab, axis=(0,1))
# FIX: NumPy broadcasting Value Error (removed [:, :, None])
# Gently shift temperature (30% strength)
car_lab[:, :, 1] = np.where(car_mask > 0.5, car_lab[:, :, 1] + (bg_a_mean - car_a_mean) * 0.3, car_lab[:, :, 1])
car_lab[:, :, 2] = np.where(car_mask > 0.5, car_lab[:, :, 2] + (bg_b_mean - car_b_mean) * 0.3, car_lab[:, :, 2])
# Gentle shift luminance (15% strength)
l_shift = (bg_l_mean - car_l_mean) * 0.15
car_lab[:, :, 0] = np.where(car_mask > 0.5, np.clip(car_lab[:, :, 0] + l_shift, 0, 100), car_lab[:, :, 0])
# Convert back to RGB
harmonized_rgb = color.lab2rgb(car_lab) * 255.0
return np.clip(harmonized_rgb, 0, 255).astype(np.uint8)
def generate_dealership_shadows(bg_np: np.ndarray, car_alpha: np.ndarray, pos: tuple, v_type: str) -> np.ndarray:
"""Uses the Photoshop Alpha-Shift method to create flawless, perspective-perfect drop shadows."""
bg_h, bg_w = bg_np.shape[:2]
cw, ch = car_alpha.shape[::-1]
px, py = pos
mask_canvas = np.zeros((bg_h, bg_w), dtype=np.float32)
y1, y2 = max(py, 0), min(py + ch, bg_h)
x1, x2 = max(px, 0), min(px + cw, bg_w)
# Isolate only the bottom 30% of the car mask so the roof doesn't cast a glowing halo
crop_h = int(ch * 0.30)
y_start = max(py + ch - crop_h, 0)
if y2 > y_start and x2 > x1:
# Slice the bottom of the alpha channel and map it to the canvas
alpha_crop = car_alpha[ch - (y2 - y_start) : ch, : (x2 - x1)] / 255.0
mask_canvas[y_start:y2, x1:x2] = alpha_crop
# 1. Contact Shadow (Tight, dark line right under the rubber)
shift_c = max(int(ch * 0.015), 2) # Shift mask down ~1.5%
contact = np.roll(mask_canvas, shift_c, axis=0)
blur_c = int(cw * 0.02) | 1
contact = cv2.GaussianBlur(contact, (blur_c, blur_c), 0)
# 2. Ambient Undercarriage Shadow (Wide, soft pool)
shift_a = max(int(ch * 0.04), 5) # Shift mask down ~4%
ambient = np.roll(mask_canvas, shift_a, axis=0)
# Anisotropic blur: massive horizontal spread, tight vertical
blur_ax = int(cw * 0.12) | 1
blur_ay = int(ch * 0.05) | 1
ambient = cv2.GaussianBlur(ambient, (blur_ax, blur_ay), 0)
# 3. Combine and Multiply Blend
combined = (ambient * 0.5) + (contact * 0.9)
shadow_mask = (np.clip(combined, 0, 1) * 255).astype(np.uint8)
# Use a realistic, cool slate-grey base color for the multiply blend
return apply_multiply_shadow(bg_np, shadow_mask, base_color=(15, 20, 25))
def generate_showroom_reflection(bg_np: np.ndarray, car_rgba: Image.Image, pos: tuple) -> Image.Image:
"""Creates a seamless reflection that precisely touches the actual tires."""
bg_h, bg_w = bg_np.shape[:2]
cw, ch = car_rgba.size
px, py = pos
# 1. Flip the tightly cropped car
car_flipped = car_rgba.transpose(Image.FLIP_TOP_BOTTOM)
# 2. Squash for perspective distance
ref_h = int(ch * 0.35)
car_flipped = car_flipped.resize((cw, ref_h), Image.LANCZOS)
ref_np = np.array(car_flipped)
# 3. Soft gradient fade-out
gradient = np.linspace(1.0, 0.0, ref_h).reshape(-1, 1)
gradient = np.repeat(gradient, cw, axis=1)
ref_np[..., 3] = (ref_np[..., 3] * gradient * 0.40).astype(np.uint8)
# 4. Motion Blur (mimics physical showroom floor texture)
k_size = int(ref_h * 0.1) | 1
ref_bgr = cv2.GaussianBlur(ref_np[..., :3], (7, k_size), 0)
ref_alpha = cv2.GaussianBlur(ref_np[..., 3], (7, k_size), 0)
blurred_ref = np.dstack([ref_bgr, ref_alpha])
canvas = Image.new("RGBA", (bg_w, bg_h), (0, 0, 0, 0))
# 5. Anchor. Because of the strict crop, py + ch is the absolute physical bottom.
# We overlap it by 2 pixels to cleanly fuse the shadow and reflection seams.
target_y = py + ch - 2
if target_y < bg_h:
canvas.paste(Image.fromarray(blurred_ref, "RGBA"), (px, target_y))
return canvas
def auto_position_car(car_rgba: Image.Image, bg: Image.Image, ground_y_ratio: float):
"""Calculates perspective-accurate scaling and positioning."""
bg_w, bg_h = bg.size
cw, ch = car_rgba.size
# 1. FIXED SCALESweet spot for a cropped vehicle in a warehouse
target_w = int(bg_w * 0.72)
scale = target_w / cw
# Give it breathing room up top so it doesn't hit the ceiling
if (ch * scale) > (bg_h * 0.60):
scale = (bg_h * 0.60) / ch
target_w = int(cw * scale)
target_h = int(ch * scale)
car_res = car_rgba.resize((target_w, target_h), Image.LANCZOS)
ground_y = int(bg_h * ground_y_ratio)
px = (bg_w - target_w) // 2
# target_h IS the absolute bottom of the tires due to the ruthles bounding box crop
py = ground_y - target_h
# 2. Safety constraints to prevent boundary pasting errors
py = max(int(bg_h * 0.15), min(py, bg_h - target_h - int(bg_h * 0.05)))
return car_res, (px, py)
def run_pipeline(car_pil: Image.Image, bg_pil: Image.Image) -> Image.Image:
# Standardize Resolutions (High-Res Output)
car_pil = car_pil.resize((1536, int(1536 * car_pil.height/car_pil.width)), Image.LANCZOS)
bg_pil = bg_pil.resize((1920, int(1920 * bg_pil.height/bg_pil.width)), Image.LANCZOS)
logger.info("1. Classifying Vehicle Geometry...")
v_type = classify_vehicle(car_pil)
logger.info("2. Estimating Scene Depth & Ground Plane...")
ground_ratio = estimate_ground_plane(bg_pil)
logger.info("3. Executing BiRefNet Segmentation...")
bundle = get_model("birefnet")
inp = bundle["transform"](car_pil.convert("RGB")).unsqueeze(0).to(bundle["device"]).float()
with torch.no_grad():
bundle["model"].float() # FP32 safety cast
preds = bundle["model"](inp)
raw_mask = torch.sigmoid(preds[-1]).squeeze().cpu().numpy()
raw_mask = (cv2.resize(raw_mask, car_pil.size) * 255).astype(np.uint8)
logger.info("4. Refining Mask & Edges...")
refined_alpha = refine_mask(raw_mask)
# Combine Initial RGBA
car_rgba_temp = Image.fromarray(np.dstack([np.array(car_pil), refined_alpha]), "RGBA")
# THE RUTHLESS CROP: Strip every single pixel of transparent padding left by BiRefNet.
# This guarantees the image boundaries are solid rubber and metal for geometry calcs.
alpha_np = np.array(car_rgba_temp)[..., 3]
ys, xs = np.where(alpha_np > 10) # ruthles threshold for solid body
if len(ys) > 0 and len(xs) > 0:
strict_bbox = (np.min(xs), np.min(ys), np.max(xs) + 1, np.max(ys) + 1)
car_rgba_temp = car_rgba_temp.crop(strict_bbox)
logger.info("5. Calculating Perspective Position...")
# Because of the crop, py + ch is now mathematically guaranteed to be the lowest tire
car_positioned, pos = auto_position_car(car_rgba_temp, bg_pil, ground_ratio)
logger.info("6. Applying LAB Ambient Relighting...")
c_arr = np.array(car_positioned)
car_rgb = c_arr[..., :3]
car_alpha = c_arr[..., 3]
bg_np = np.array(bg_pil.convert("RGB"))
harmonized_rgb = harmonized_color_lab(car_rgb, car_alpha / 255.0, bg_np)
car_final = Image.fromarray(np.dstack([harmonized_rgb, car_alpha]), "RGBA")
logger.info("7. Rendering Physical Shadows (Alpha-Shift)...")
bg_with_shadows = generate_dealership_shadows(bg_np, car_alpha, pos, v_type)
bg_layered = Image.fromarray(bg_with_shadows, "RGB").convert("RGBA")
logger.info("8. Generating Showroom Floor Reflections...")
reflection_layer = generate_showroom_reflection(bg_np, car_final, pos)
bg_layered = Image.alpha_composite(bg_layered, reflection_layer)
logger.info("9. Finalizing Composition...")
# Paste car last (non-destructive layering)
bg_layered.paste(car_final, pos, car_final)
# Final localized contrast pop (HDR style simulation)
enhancer = ImageEnhance.Contrast(bg_layered.convert("RGB"))
final_output = enhancer.enhance(1.05)
return final_output
# ──────────────────────────────────────────────────────────────────
# API ENDPOINTS (FastAPI async structure preserved)
# ──────────────────────────────────────────────────────────────────
@app.post("/composite")
async def composite(car_image: UploadFile = File(...), background_image: UploadFile = File(...)):
try:
# FastAPI Preserved structure, internals upgraded
c_pil = Image.open(io.BytesIO(await car_image.read())).convert("RGB")
b_pil = Image.open(io.BytesIO(await background_image.read())).convert("RGB")
result = run_pipeline(c_pil, b_pil)
buf = io.BytesIO()
# High quality JPEG output with preserved resolution
result.save(buf, format="JPEG", quality=95, subsampling=0)
buf.seek(0)
return StreamingResponse(buf, media_type="image/jpeg")
except Exception as e:
logger.error(f"Pipeline Failure: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=str(e))