SAKS-backend / app.py
sahal42's picture
Update app.py
71f1160 verified
import os
import numpy as np
import torch
import torchvision
import gradio as gr
import spaces
from PIL import Image
import traceback
import warnings
import cv2
import random
from contextlib import nullcontext
from utils import base64_to_image, logger, convert_to_3_4_aspect_ratio, image_to_base64, load_grounding_dino, DEVICE, OUTPUT_DIR, CONFIG_FILE, DINO_CKPT, transform_image, get_grounding_output, create_ghost_image, create_overlay_image, process_mask, load_swinir_x3, upscale_tiled_bgr, segment_image_on_white_background
from segment_anything import SamPredictor, build_sam
from huggingface_hub import hf_hub_download
# FLUX imports for background generation
from diffusers import FluxFillPipeline, FluxTransformer2DModel
# SAM checkpoint with fallback
try:
SAM_CKPT = hf_hub_download("segments-arnaud/sam_vit_h", "sam_vit_h_4b8939.pth")
except (FileNotFoundError, Exception):
import requests
SAM_CKPT = OUTPUT_DIR / "sam_vit_h_4b8939.pth"
if not SAM_CKPT.exists():
logger.info("Downloading SAM checkpoint from Meta's CDN (~2.6 GB)")
url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
with requests.get(url, stream=True) as r, open(SAM_CKPT, "wb") as f:
for chunk in r.iter_content(chunk_size=1048576):
if chunk:
f.write(chunk)
# SwinIR checkpoint path
SWINIR_CKPT = "upscaler_model/net_g_82500.pth"
# Category configurations for jewelry detection
CATEGORY_CONFIG = {
"Bracelets": {
"classes": ["bracelet", "wrist band", "bangle"],
"box_threshold": 0.4,
"text_threshold": 0.4,
"nms_threshold": 0.6
},
"Earrings": {
"classes": ["earring", "earrings", "stud earring"],
"box_threshold": 0.3,
"text_threshold": 0.3,
"nms_threshold": 0.65
},
"Watches": {
"classes": ["watch", "wristwatch", "smartwatch"],
"box_threshold": 0.3,
"text_threshold": 0.3,
"nms_threshold": 0.6
},
"Rings": {
"classes": ["wedding ring", "finger ring"],
"box_threshold": 0.25,
"text_threshold": 0.25,
"nms_threshold": 0.5
},
"Mixed Jewelry": {
"classes": ["ring", "wedding ring", "bracelet", "wristwatch", "wrist band", "necklace", "earring", "stud earring", "jewelry"],
"box_threshold": 0.25,
"text_threshold": 0.25,
"nms_threshold": 0.5
}
}
# Negative words to filter out
NEGATIVE_WORDS = ["hand", "face", "arm", "mouth", "lips", "finger", "teeth", "eye", "nails", "fingernail", "mole"]
# Global model caches
DINO_MODEL = None
SAM_PREDICTOR = None
FLUX_PIPELINE = None
SWINIR_UPSAMPLER = None
# ───────── Model Initialization Functions ─────────
def get_swinir_upsampler():
"""Initialize SwinIR upsampler (without GPU decorator)"""
global SWINIR_UPSAMPLER
if SWINIR_UPSAMPLER is None:
try:
if not os.path.exists(SWINIR_CKPT):
logger.error(f"SwinIR checkpoint not found at: {SWINIR_CKPT}")
return None
SWINIR_UPSAMPLER = load_swinir_x3(SWINIR_CKPT, DEVICE)
logger.info("SwinIR upsampler loaded successfully")
except Exception as e:
logger.error(f"Failed to load SwinIR upsampler: {e}")
SWINIR_UPSAMPLER = None
return SWINIR_UPSAMPLER
def get_flux_pipeline():
"""Initialize FLUX pipeline for background generation (without GPU decorator)"""
global FLUX_PIPELINE
if FLUX_PIPELINE is None:
try:
transformer = FluxTransformer2DModel.from_pretrained(
"SnapwearAI/SAKS_Background_Model",
subfolder="transformer",
torch_dtype=torch.bfloat16
)
FLUX_PIPELINE = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev",
transformer=transformer,
torch_dtype=torch.bfloat16
).to("cuda")
logger.info("FLUX pipeline loaded successfully")
except Exception as e:
logger.error(f"Failed to load FLUX pipeline: {e}")
FLUX_PIPELINE = None
return FLUX_PIPELINE
@spaces.GPU(duration=120) # Increased duration for the entire process
def intelligent_upscale_and_resize(bg_image, target_size):
"""Use SwinIR to upscale background, then resize to exact target size"""
global SWINIR_UPSAMPLER
try:
# Initialize upsampler if needed
if SWINIR_UPSAMPLER is None:
SWINIR_UPSAMPLER = get_swinir_upsampler()
if SWINIR_UPSAMPLER is None:
logger.warning("SwinIR upsampler not available, using standard resize")
return bg_image.resize(target_size, Image.Resampling.LANCZOS)
target_width, target_height = target_size
current_width, current_height = bg_image.size
# Calculate required scale factor
width_scale = target_width / current_width
height_scale = target_height / current_height
required_scale = max(width_scale, height_scale)
logger.info(f"Target size: {target_size}, Current size: {bg_image.size}")
logger.info(f"Required scale: {required_scale:.2f}")
# SwinIR has fixed 3x upscaling
swinir_scale = 3
# Only upscale if we need more than 1x scaling
if required_scale > 1:
# Convert PIL to cv2 format (BGR)
cv_img = np.array(bg_image)
cv_img = cv2.cvtColor(cv_img, cv2.COLOR_RGB2BGR)
# Calculate how many times we need to apply SwinIR
num_passes = 1
cumulative_scale = swinir_scale
# If we need more than 3x, we might need multiple passes
while cumulative_scale < required_scale and num_passes < 2:
num_passes += 1
cumulative_scale *= swinir_scale
logger.info(f"Using {num_passes} SwinIR pass(es) for {cumulative_scale}x upscaling")
# Apply SwinIR upscaling
upscaled_img = cv_img
for i in range(num_passes):
upscaled_img = upscale_tiled_bgr(upscaled_img, SWINIR_UPSAMPLER, DEVICE)
logger.info(f"Pass {i+1}: upscaled to {upscaled_img.shape[1]}x{upscaled_img.shape[0]}")
# Convert back to PIL
upscaled_img = cv2.cvtColor(upscaled_img, cv2.COLOR_BGR2RGB)
upscaled_pil = Image.fromarray(upscaled_img)
logger.info(f"Final SwinIR output: {upscaled_pil.size}")
else:
upscaled_pil = bg_image
# Now resize to exact target size
final_image = upscaled_pil.resize(target_size, Image.Resampling.LANCZOS)
logger.info(f"Final resized to: {final_image.size}")
return final_image
except Exception as e:
logger.error(f"Upscaling failed, using standard resize: {e}")
return bg_image.resize(target_size, Image.Resampling.LANCZOS)
def resize_and_composite(bg_image, ghost_image, target_size):
"""Intelligently upscale background image and composite ghost on top"""
# Use intelligent upscaling instead of simple resize
bg_resized = intelligent_upscale_and_resize(bg_image, target_size)
# Convert to RGBA for compositing
if bg_resized.mode != 'RGBA':
bg_resized = bg_resized.convert('RGBA')
# Composite ghost image on top
final_image = Image.alpha_composite(bg_resized, ghost_image)
# Convert back to RGB
final_image = final_image.convert('RGB')
logger.info(f"Intelligently upscaled background to {target_size} and composited ghost image")
return final_image
@spaces.GPU(duration=900, keep_on_idle=True) # 15β€―min window
def grounded_sam_inference(image_b64, category, gender):
"""
Enhanced inference function with mandatory background generation.
"""
global DINO_MODEL, SAM_PREDICTOR, FLUX_PIPELINE
try:
logger.info("=== STARTING GROUNDED SAM INFERENCE ===")
logger.info(f"Input category: {category}")
logger.info(f"Input gender: {gender}")
# Create auto-generated prompt based on category and gender
jewelry_location_map = {
"Rings": ("ring", "finger"),
"Bracelets": ("bracelet", "wrist"),
"Watches": ("watch", "wrist"),
"Earrings": ("earring", "ear"),
"Mixed Jewelry": ("jewelry", "body")
}
if category in jewelry_location_map:
jewelry_type, location = jewelry_location_map[category]
#bg_prompt = f"{gender} model showcasing a {jewelry_type} on their {location}; high-end studio lighting, editorial fashion photography"
bg_prompt = f"{jewelry_type} worn by a {gender} model"
else:
#bg_prompt = f"{gender} model showcasing jewelry; high-end studio lighting, editorial fashion photography"
bg_prompt = f"{jewelry_type} worn by a {gender} model"
logger.info(f"Auto-generated prompt: {bg_prompt}")
# Input validation
if not image_b64:
raise ValueError("Please upload an image.")
if category not in CATEGORY_CONFIG:
raise ValueError(f"Invalid category: {category}")
# Get category configuration
config = CATEGORY_CONFIG[category]
classes = config["classes"]
box_threshold = config["box_threshold"]
text_threshold = config["text_threshold"]
nms_threshold = config["nms_threshold"]
logger.info(f"Using classes: {classes}")
# Convert base64 to PIL Image and ensure 3:4 aspect ratio
logger.info("=== CONVERTING AND RESIZING IMAGE ===")
original_image = base64_to_image(image_b64)
if not original_image:
raise ValueError("Failed to decode input image.")
original_image = original_image.convert("RGB")
logger.info(f"Original image decoded: {original_image.size}")
# Convert to 3:4 aspect ratio
image_pil, padding_info = convert_to_3_4_aspect_ratio(original_image)
logger.info(f"Image converted to 3:4 ratio: {image_pil.size}")
# Initialize models if needed
if DINO_MODEL is None:
logger.info("=== LOADING GROUNDING DINO ===")
DINO_MODEL = load_grounding_dino(CONFIG_FILE, DINO_CKPT)
logger.info("GroundingDINO loaded successfully")
if SAM_PREDICTOR is None:
logger.info("=== LOADING SAM ===")
sam_model = build_sam(checkpoint=str(SAM_CKPT))
sam_model.to(DEVICE)
SAM_PREDICTOR = SamPredictor(sam_model)
logger.info("SAM loaded successfully")
# Transform image for GroundingDINO
img_tensor = transform_image(image_pil)
# Create text prompt
text_prompt = ". ".join(classes)
logger.info(f"Text prompt: {text_prompt}")
# ──────── GroundingDINO Detection ────────
logger.info("=== RUNNING GROUNDING DINO ===")
boxes, scores, phrases = get_grounding_output(
DINO_MODEL, img_tensor, text_prompt, box_threshold, text_threshold
)
if len(boxes) == 0:
logger.info("No detections found")
empty_mask = Image.new('L', image_pil.size, 0)
return image_to_base64(empty_mask), image_to_base64(image_pil), "", "❌ No objects detected"
logger.info(f"Initial detections: {len(boxes)}")
# Convert normalized boxes to pixel coordinates
W, H = image_pil.size
for i in range(boxes.size(0)):
boxes[i] = boxes[i] * torch.tensor([W, H, W, H])
boxes[i][:2] -= boxes[i][2:] / 2 # Convert center to top-left
boxes[i][2:] += boxes[i][:2] # Convert width/height to bottom-right
# ──────── Negative Word Filtering ────────
logger.info("=== FILTERING NEGATIVE WORDS ===")
keep_idxs = []
for i, phrase in enumerate(phrases):
phrase_lower = phrase.lower()
if any(neg in phrase_lower for neg in NEGATIVE_WORDS):
logger.info(f"Filtered out: {phrase}")
else:
keep_idxs.append(i)
if not keep_idxs:
logger.info("All detections filtered by negative words")
empty_mask = Image.new('L', image_pil.size, 0)
return image_to_base64(empty_mask), image_to_base64(image_pil), "", "❌ No valid objects after filtering"
boxes = boxes[keep_idxs]
scores = scores[keep_idxs]
phrases = [phrases[i] for i in keep_idxs]
logger.info(f"After filtering: {len(boxes)} detections")
# ──────── Non-Maximum Suppression ────────
logger.info("=== APPLYING NMS ===")
keep_nms = torchvision.ops.nms(boxes, scores, nms_threshold).tolist()
final_boxes = boxes[keep_nms]
final_phrases = [phrases[i] for i in keep_nms]
logger.info(f"After NMS: {len(final_boxes)} detections")
# ──────── SAM Segmentation ────────
logger.info("=== RUNNING SAM SEGMENTATION ===")
# Set image for SAM
np_image = np.array(image_pil)
SAM_PREDICTOR.set_image(np_image)
# Transform boxes for SAM
transformed_boxes = SAM_PREDICTOR.transform.apply_boxes_torch(
final_boxes, np_image.shape[:2]
).to(DEVICE)
# Run SAM prediction
masks, _, _ = SAM_PREDICTOR.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
)
logger.info(f"Generated {masks.shape[0]} masks")
# ──────── Create Results ────────
logger.info("=== CREATING RESULTS ===")
# Merge all masks
merged_mask = torch.any(masks.squeeze(1), dim=0).cpu().numpy().astype(np.uint8) * 255
mask_pil = Image.fromarray(merged_mask, mode='L')
# Create ghost image from the mask
ghost_image = create_ghost_image(image_pil, mask_pil)
# Create overlay image
overlay_pil = create_overlay_image(
image_pil, final_boxes, masks.squeeze(1), final_phrases
)
# ──────── Background Generation (Always) ────────
logger.info("=== GENERATING BACKGROUND ===")
try:
# Initialize FLUX pipeline if needed
if FLUX_PIPELINE is None:
FLUX_PIPELINE = get_flux_pipeline()
if FLUX_PIPELINE is not None:
# Process mask for background generation
processed_mask = process_mask(mask_pil)
# Composite image on white background
composited = segment_image_on_white_background(image_pil, processed_mask)
# Generate background with FLUX using auto-generated prompt
final_prompt = bg_prompt.strip() + ", realistic, HD"
flux_out = FLUX_PIPELINE(
prompt=final_prompt,
image=composited,
mask_image=processed_mask,
width=768,
height=1024,
guidance_scale=40,
num_inference_steps=40,
).images
bg_image = flux_out[0]
# Resize background to match input size and composite ghost
target_size = image_pil.size
final_composite = resize_and_composite(bg_image, ghost_image, target_size)
bg_image_b64 = image_to_base64(final_composite)
logger.info("Background generation and compositing completed successfully")
else:
logger.error("FLUX pipeline not available")
bg_image_b64 = ""
except Exception as e:
logger.error(f"Background generation failed: {e}")
logger.error(f"Traceback: {traceback.format_exc()}")
bg_image_b64 = ""
# Convert to base64
mask_b64 = image_to_base64(mask_pil)
overlay_b64 = image_to_base64(overlay_pil)
status = f"βœ… Found {len(final_boxes)} {category.lower()} objects with AI-upscaled professional background and ghost compositing completed"
logger.info("=== INFERENCE COMPLETED SUCCESSFULLY ===")
return mask_b64, overlay_b64, bg_image_b64, status
except Exception as e:
# Comprehensive error handling
error_msg = f"❌ Error: {str(e)}"
logger.error(f"=== INFERENCE FAILED ===")
logger.error(f"Error: {e}")
logger.error(f"Error type: {type(e)}")
logger.error(f"Full traceback:\n{traceback.format_exc()}")
# Return empty results
try:
if 'original_image' in locals() and original_image:
# Convert to 3:4 if we have original image
image_34, _ = convert_to_3_4_aspect_ratio(original_image)
empty_mask = Image.new('L', image_34.size, 0)
return image_to_base64(empty_mask), image_to_base64(image_34), "", error_msg
else:
return "", "", "", error_msg
except Exception as fallback_error:
logger.error(f"Fallback error handling failed: {fallback_error}")
return "", "", "", error_msg
# ───────── Enhanced API interface ─────────
backend = gr.Interface(
fn=grounded_sam_inference,
inputs=[
gr.Text(label="Image base64"),
gr.Text(label="Category"),
gr.Text(label="Gender"),
],
outputs=[
gr.Text(label="Mask base64"),
gr.Text(label="Overlay base64"),
gr.Text(label="Background Changed Image base64"),
gr.Text(label="Status"),
],
title="Enhanced Grounded SAM with Mandatory Background Generation",
description="Backend API for jewelry detection, segmentation, and automatic background generation",
allow_flagging="never",
)
if __name__ == "__main__":
backend.queue(
api_open=True,
max_size=10,
default_concurrency_limit=2
).launch(
show_api=False,
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
debug=True
)