Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import torch | |
| import torchvision | |
| import gradio as gr | |
| import spaces | |
| from PIL import Image | |
| import traceback | |
| import warnings | |
| import cv2 | |
| import random | |
| from contextlib import nullcontext | |
| from utils import base64_to_image, logger, convert_to_3_4_aspect_ratio, image_to_base64, load_grounding_dino, DEVICE, OUTPUT_DIR, CONFIG_FILE, DINO_CKPT, transform_image, get_grounding_output, create_ghost_image, create_overlay_image, process_mask, load_swinir_x3, upscale_tiled_bgr, segment_image_on_white_background | |
| from segment_anything import SamPredictor, build_sam | |
| from huggingface_hub import hf_hub_download | |
| # FLUX imports for background generation | |
| from diffusers import FluxFillPipeline, FluxTransformer2DModel | |
| # SAM checkpoint with fallback | |
| try: | |
| SAM_CKPT = hf_hub_download("segments-arnaud/sam_vit_h", "sam_vit_h_4b8939.pth") | |
| except (FileNotFoundError, Exception): | |
| import requests | |
| SAM_CKPT = OUTPUT_DIR / "sam_vit_h_4b8939.pth" | |
| if not SAM_CKPT.exists(): | |
| logger.info("Downloading SAM checkpoint from Meta's CDN (~2.6 GB)") | |
| url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" | |
| with requests.get(url, stream=True) as r, open(SAM_CKPT, "wb") as f: | |
| for chunk in r.iter_content(chunk_size=1048576): | |
| if chunk: | |
| f.write(chunk) | |
| # SwinIR checkpoint path | |
| SWINIR_CKPT = "upscaler_model/net_g_82500.pth" | |
| # Category configurations for jewelry detection | |
| CATEGORY_CONFIG = { | |
| "Bracelets": { | |
| "classes": ["bracelet", "wrist band", "bangle"], | |
| "box_threshold": 0.4, | |
| "text_threshold": 0.4, | |
| "nms_threshold": 0.6 | |
| }, | |
| "Earrings": { | |
| "classes": ["earring", "earrings", "stud earring"], | |
| "box_threshold": 0.3, | |
| "text_threshold": 0.3, | |
| "nms_threshold": 0.65 | |
| }, | |
| "Watches": { | |
| "classes": ["watch", "wristwatch", "smartwatch"], | |
| "box_threshold": 0.3, | |
| "text_threshold": 0.3, | |
| "nms_threshold": 0.6 | |
| }, | |
| "Rings": { | |
| "classes": ["wedding ring", "finger ring"], | |
| "box_threshold": 0.25, | |
| "text_threshold": 0.25, | |
| "nms_threshold": 0.5 | |
| }, | |
| "Mixed Jewelry": { | |
| "classes": ["ring", "wedding ring", "bracelet", "wristwatch", "wrist band", "necklace", "earring", "stud earring", "jewelry"], | |
| "box_threshold": 0.25, | |
| "text_threshold": 0.25, | |
| "nms_threshold": 0.5 | |
| } | |
| } | |
| # Negative words to filter out | |
| NEGATIVE_WORDS = ["hand", "face", "arm", "mouth", "lips", "finger", "teeth", "eye", "nails", "fingernail", "mole"] | |
| # Global model caches | |
| DINO_MODEL = None | |
| SAM_PREDICTOR = None | |
| FLUX_PIPELINE = None | |
| SWINIR_UPSAMPLER = None | |
| # βββββββββ Model Initialization Functions βββββββββ | |
| def get_swinir_upsampler(): | |
| """Initialize SwinIR upsampler (without GPU decorator)""" | |
| global SWINIR_UPSAMPLER | |
| if SWINIR_UPSAMPLER is None: | |
| try: | |
| if not os.path.exists(SWINIR_CKPT): | |
| logger.error(f"SwinIR checkpoint not found at: {SWINIR_CKPT}") | |
| return None | |
| SWINIR_UPSAMPLER = load_swinir_x3(SWINIR_CKPT, DEVICE) | |
| logger.info("SwinIR upsampler loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load SwinIR upsampler: {e}") | |
| SWINIR_UPSAMPLER = None | |
| return SWINIR_UPSAMPLER | |
| def get_flux_pipeline(): | |
| """Initialize FLUX pipeline for background generation (without GPU decorator)""" | |
| global FLUX_PIPELINE | |
| if FLUX_PIPELINE is None: | |
| try: | |
| transformer = FluxTransformer2DModel.from_pretrained( | |
| "SnapwearAI/SAKS_Background_Model", | |
| subfolder="transformer", | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| FLUX_PIPELINE = FluxFillPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-Fill-dev", | |
| transformer=transformer, | |
| torch_dtype=torch.bfloat16 | |
| ).to("cuda") | |
| logger.info("FLUX pipeline loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load FLUX pipeline: {e}") | |
| FLUX_PIPELINE = None | |
| return FLUX_PIPELINE | |
| # Increased duration for the entire process | |
| def intelligent_upscale_and_resize(bg_image, target_size): | |
| """Use SwinIR to upscale background, then resize to exact target size""" | |
| global SWINIR_UPSAMPLER | |
| try: | |
| # Initialize upsampler if needed | |
| if SWINIR_UPSAMPLER is None: | |
| SWINIR_UPSAMPLER = get_swinir_upsampler() | |
| if SWINIR_UPSAMPLER is None: | |
| logger.warning("SwinIR upsampler not available, using standard resize") | |
| return bg_image.resize(target_size, Image.Resampling.LANCZOS) | |
| target_width, target_height = target_size | |
| current_width, current_height = bg_image.size | |
| # Calculate required scale factor | |
| width_scale = target_width / current_width | |
| height_scale = target_height / current_height | |
| required_scale = max(width_scale, height_scale) | |
| logger.info(f"Target size: {target_size}, Current size: {bg_image.size}") | |
| logger.info(f"Required scale: {required_scale:.2f}") | |
| # SwinIR has fixed 3x upscaling | |
| swinir_scale = 3 | |
| # Only upscale if we need more than 1x scaling | |
| if required_scale > 1: | |
| # Convert PIL to cv2 format (BGR) | |
| cv_img = np.array(bg_image) | |
| cv_img = cv2.cvtColor(cv_img, cv2.COLOR_RGB2BGR) | |
| # Calculate how many times we need to apply SwinIR | |
| num_passes = 1 | |
| cumulative_scale = swinir_scale | |
| # If we need more than 3x, we might need multiple passes | |
| while cumulative_scale < required_scale and num_passes < 2: | |
| num_passes += 1 | |
| cumulative_scale *= swinir_scale | |
| logger.info(f"Using {num_passes} SwinIR pass(es) for {cumulative_scale}x upscaling") | |
| # Apply SwinIR upscaling | |
| upscaled_img = cv_img | |
| for i in range(num_passes): | |
| upscaled_img = upscale_tiled_bgr(upscaled_img, SWINIR_UPSAMPLER, DEVICE) | |
| logger.info(f"Pass {i+1}: upscaled to {upscaled_img.shape[1]}x{upscaled_img.shape[0]}") | |
| # Convert back to PIL | |
| upscaled_img = cv2.cvtColor(upscaled_img, cv2.COLOR_BGR2RGB) | |
| upscaled_pil = Image.fromarray(upscaled_img) | |
| logger.info(f"Final SwinIR output: {upscaled_pil.size}") | |
| else: | |
| upscaled_pil = bg_image | |
| # Now resize to exact target size | |
| final_image = upscaled_pil.resize(target_size, Image.Resampling.LANCZOS) | |
| logger.info(f"Final resized to: {final_image.size}") | |
| return final_image | |
| except Exception as e: | |
| logger.error(f"Upscaling failed, using standard resize: {e}") | |
| return bg_image.resize(target_size, Image.Resampling.LANCZOS) | |
| def resize_and_composite(bg_image, ghost_image, target_size): | |
| """Intelligently upscale background image and composite ghost on top""" | |
| # Use intelligent upscaling instead of simple resize | |
| bg_resized = intelligent_upscale_and_resize(bg_image, target_size) | |
| # Convert to RGBA for compositing | |
| if bg_resized.mode != 'RGBA': | |
| bg_resized = bg_resized.convert('RGBA') | |
| # Composite ghost image on top | |
| final_image = Image.alpha_composite(bg_resized, ghost_image) | |
| # Convert back to RGB | |
| final_image = final_image.convert('RGB') | |
| logger.info(f"Intelligently upscaled background to {target_size} and composited ghost image") | |
| return final_image | |
| # 15β―min window | |
| def grounded_sam_inference(image_b64, category, gender): | |
| """ | |
| Enhanced inference function with mandatory background generation. | |
| """ | |
| global DINO_MODEL, SAM_PREDICTOR, FLUX_PIPELINE | |
| try: | |
| logger.info("=== STARTING GROUNDED SAM INFERENCE ===") | |
| logger.info(f"Input category: {category}") | |
| logger.info(f"Input gender: {gender}") | |
| # Create auto-generated prompt based on category and gender | |
| jewelry_location_map = { | |
| "Rings": ("ring", "finger"), | |
| "Bracelets": ("bracelet", "wrist"), | |
| "Watches": ("watch", "wrist"), | |
| "Earrings": ("earring", "ear"), | |
| "Mixed Jewelry": ("jewelry", "body") | |
| } | |
| if category in jewelry_location_map: | |
| jewelry_type, location = jewelry_location_map[category] | |
| #bg_prompt = f"{gender} model showcasing a {jewelry_type} on their {location}; high-end studio lighting, editorial fashion photography" | |
| bg_prompt = f"{jewelry_type} worn by a {gender} model" | |
| else: | |
| #bg_prompt = f"{gender} model showcasing jewelry; high-end studio lighting, editorial fashion photography" | |
| bg_prompt = f"{jewelry_type} worn by a {gender} model" | |
| logger.info(f"Auto-generated prompt: {bg_prompt}") | |
| # Input validation | |
| if not image_b64: | |
| raise ValueError("Please upload an image.") | |
| if category not in CATEGORY_CONFIG: | |
| raise ValueError(f"Invalid category: {category}") | |
| # Get category configuration | |
| config = CATEGORY_CONFIG[category] | |
| classes = config["classes"] | |
| box_threshold = config["box_threshold"] | |
| text_threshold = config["text_threshold"] | |
| nms_threshold = config["nms_threshold"] | |
| logger.info(f"Using classes: {classes}") | |
| # Convert base64 to PIL Image and ensure 3:4 aspect ratio | |
| logger.info("=== CONVERTING AND RESIZING IMAGE ===") | |
| original_image = base64_to_image(image_b64) | |
| if not original_image: | |
| raise ValueError("Failed to decode input image.") | |
| original_image = original_image.convert("RGB") | |
| logger.info(f"Original image decoded: {original_image.size}") | |
| # Convert to 3:4 aspect ratio | |
| image_pil, padding_info = convert_to_3_4_aspect_ratio(original_image) | |
| logger.info(f"Image converted to 3:4 ratio: {image_pil.size}") | |
| # Initialize models if needed | |
| if DINO_MODEL is None: | |
| logger.info("=== LOADING GROUNDING DINO ===") | |
| DINO_MODEL = load_grounding_dino(CONFIG_FILE, DINO_CKPT) | |
| logger.info("GroundingDINO loaded successfully") | |
| if SAM_PREDICTOR is None: | |
| logger.info("=== LOADING SAM ===") | |
| sam_model = build_sam(checkpoint=str(SAM_CKPT)) | |
| sam_model.to(DEVICE) | |
| SAM_PREDICTOR = SamPredictor(sam_model) | |
| logger.info("SAM loaded successfully") | |
| # Transform image for GroundingDINO | |
| img_tensor = transform_image(image_pil) | |
| # Create text prompt | |
| text_prompt = ". ".join(classes) | |
| logger.info(f"Text prompt: {text_prompt}") | |
| # ββββββββ GroundingDINO Detection ββββββββ | |
| logger.info("=== RUNNING GROUNDING DINO ===") | |
| boxes, scores, phrases = get_grounding_output( | |
| DINO_MODEL, img_tensor, text_prompt, box_threshold, text_threshold | |
| ) | |
| if len(boxes) == 0: | |
| logger.info("No detections found") | |
| empty_mask = Image.new('L', image_pil.size, 0) | |
| return image_to_base64(empty_mask), image_to_base64(image_pil), "", "β No objects detected" | |
| logger.info(f"Initial detections: {len(boxes)}") | |
| # Convert normalized boxes to pixel coordinates | |
| W, H = image_pil.size | |
| for i in range(boxes.size(0)): | |
| boxes[i] = boxes[i] * torch.tensor([W, H, W, H]) | |
| boxes[i][:2] -= boxes[i][2:] / 2 # Convert center to top-left | |
| boxes[i][2:] += boxes[i][:2] # Convert width/height to bottom-right | |
| # ββββββββ Negative Word Filtering ββββββββ | |
| logger.info("=== FILTERING NEGATIVE WORDS ===") | |
| keep_idxs = [] | |
| for i, phrase in enumerate(phrases): | |
| phrase_lower = phrase.lower() | |
| if any(neg in phrase_lower for neg in NEGATIVE_WORDS): | |
| logger.info(f"Filtered out: {phrase}") | |
| else: | |
| keep_idxs.append(i) | |
| if not keep_idxs: | |
| logger.info("All detections filtered by negative words") | |
| empty_mask = Image.new('L', image_pil.size, 0) | |
| return image_to_base64(empty_mask), image_to_base64(image_pil), "", "β No valid objects after filtering" | |
| boxes = boxes[keep_idxs] | |
| scores = scores[keep_idxs] | |
| phrases = [phrases[i] for i in keep_idxs] | |
| logger.info(f"After filtering: {len(boxes)} detections") | |
| # ββββββββ Non-Maximum Suppression ββββββββ | |
| logger.info("=== APPLYING NMS ===") | |
| keep_nms = torchvision.ops.nms(boxes, scores, nms_threshold).tolist() | |
| final_boxes = boxes[keep_nms] | |
| final_phrases = [phrases[i] for i in keep_nms] | |
| logger.info(f"After NMS: {len(final_boxes)} detections") | |
| # ββββββββ SAM Segmentation ββββββββ | |
| logger.info("=== RUNNING SAM SEGMENTATION ===") | |
| # Set image for SAM | |
| np_image = np.array(image_pil) | |
| SAM_PREDICTOR.set_image(np_image) | |
| # Transform boxes for SAM | |
| transformed_boxes = SAM_PREDICTOR.transform.apply_boxes_torch( | |
| final_boxes, np_image.shape[:2] | |
| ).to(DEVICE) | |
| # Run SAM prediction | |
| masks, _, _ = SAM_PREDICTOR.predict_torch( | |
| point_coords=None, | |
| point_labels=None, | |
| boxes=transformed_boxes, | |
| multimask_output=False, | |
| ) | |
| logger.info(f"Generated {masks.shape[0]} masks") | |
| # ββββββββ Create Results ββββββββ | |
| logger.info("=== CREATING RESULTS ===") | |
| # Merge all masks | |
| merged_mask = torch.any(masks.squeeze(1), dim=0).cpu().numpy().astype(np.uint8) * 255 | |
| mask_pil = Image.fromarray(merged_mask, mode='L') | |
| # Create ghost image from the mask | |
| ghost_image = create_ghost_image(image_pil, mask_pil) | |
| # Create overlay image | |
| overlay_pil = create_overlay_image( | |
| image_pil, final_boxes, masks.squeeze(1), final_phrases | |
| ) | |
| # ββββββββ Background Generation (Always) ββββββββ | |
| logger.info("=== GENERATING BACKGROUND ===") | |
| try: | |
| # Initialize FLUX pipeline if needed | |
| if FLUX_PIPELINE is None: | |
| FLUX_PIPELINE = get_flux_pipeline() | |
| if FLUX_PIPELINE is not None: | |
| # Process mask for background generation | |
| processed_mask = process_mask(mask_pil) | |
| # Composite image on white background | |
| composited = segment_image_on_white_background(image_pil, processed_mask) | |
| # Generate background with FLUX using auto-generated prompt | |
| final_prompt = bg_prompt.strip() + ", realistic, HD" | |
| flux_out = FLUX_PIPELINE( | |
| prompt=final_prompt, | |
| image=composited, | |
| mask_image=processed_mask, | |
| width=768, | |
| height=1024, | |
| guidance_scale=40, | |
| num_inference_steps=40, | |
| ).images | |
| bg_image = flux_out[0] | |
| # Resize background to match input size and composite ghost | |
| target_size = image_pil.size | |
| final_composite = resize_and_composite(bg_image, ghost_image, target_size) | |
| bg_image_b64 = image_to_base64(final_composite) | |
| logger.info("Background generation and compositing completed successfully") | |
| else: | |
| logger.error("FLUX pipeline not available") | |
| bg_image_b64 = "" | |
| except Exception as e: | |
| logger.error(f"Background generation failed: {e}") | |
| logger.error(f"Traceback: {traceback.format_exc()}") | |
| bg_image_b64 = "" | |
| # Convert to base64 | |
| mask_b64 = image_to_base64(mask_pil) | |
| overlay_b64 = image_to_base64(overlay_pil) | |
| status = f"β Found {len(final_boxes)} {category.lower()} objects with AI-upscaled professional background and ghost compositing completed" | |
| logger.info("=== INFERENCE COMPLETED SUCCESSFULLY ===") | |
| return mask_b64, overlay_b64, bg_image_b64, status | |
| except Exception as e: | |
| # Comprehensive error handling | |
| error_msg = f"β Error: {str(e)}" | |
| logger.error(f"=== INFERENCE FAILED ===") | |
| logger.error(f"Error: {e}") | |
| logger.error(f"Error type: {type(e)}") | |
| logger.error(f"Full traceback:\n{traceback.format_exc()}") | |
| # Return empty results | |
| try: | |
| if 'original_image' in locals() and original_image: | |
| # Convert to 3:4 if we have original image | |
| image_34, _ = convert_to_3_4_aspect_ratio(original_image) | |
| empty_mask = Image.new('L', image_34.size, 0) | |
| return image_to_base64(empty_mask), image_to_base64(image_34), "", error_msg | |
| else: | |
| return "", "", "", error_msg | |
| except Exception as fallback_error: | |
| logger.error(f"Fallback error handling failed: {fallback_error}") | |
| return "", "", "", error_msg | |
| # βββββββββ Enhanced API interface βββββββββ | |
| backend = gr.Interface( | |
| fn=grounded_sam_inference, | |
| inputs=[ | |
| gr.Text(label="Image base64"), | |
| gr.Text(label="Category"), | |
| gr.Text(label="Gender"), | |
| ], | |
| outputs=[ | |
| gr.Text(label="Mask base64"), | |
| gr.Text(label="Overlay base64"), | |
| gr.Text(label="Background Changed Image base64"), | |
| gr.Text(label="Status"), | |
| ], | |
| title="Enhanced Grounded SAM with Mandatory Background Generation", | |
| description="Backend API for jewelry detection, segmentation, and automatic background generation", | |
| allow_flagging="never", | |
| ) | |
| if __name__ == "__main__": | |
| backend.queue( | |
| api_open=True, | |
| max_size=10, | |
| default_concurrency_limit=2 | |
| ).launch( | |
| show_api=False, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True, | |
| debug=True | |
| ) |