Spaces:
Sleeping
Sleeping
| from PIL import Image | |
| import numpy as np | |
| from scipy import ndimage | |
| import torch | |
| import logging | |
| import base64 | |
| import io | |
| import numpy as np | |
| import gradio as gr | |
| import warnings | |
| from pathlib import Path | |
| from huggingface_hub import hf_hub_download | |
| from PIL import ImageDraw, ImageFont | |
| # Grounding DINO & Segment Anything imports | |
| import GroundingDINO.groundingdino.datasets.transforms as T | |
| from GroundingDINO.groundingdino.models import build_model | |
| from GroundingDINO.groundingdino.util.slconfig import SLConfig | |
| from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap | |
| # SwinIR imports for upscaling | |
| from basicsr.archs.swinir_arch import SwinIR | |
| from basicsr.utils import img2tensor, tensor2img | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| warnings.filterwarnings("ignore") | |
| # βββββββββ Configuration βββββββββ | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| OUTPUT_DIR = Path("outputs") | |
| OUTPUT_DIR.mkdir(exist_ok=True) | |
| # Model paths | |
| CONFIG_FILE = Path(__file__).parent / "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" | |
| DINO_CKPT = hf_hub_download("ShilongLiu/GroundingDINO", "groundingdino_swint_ogc.pth") | |
| def process_mask(image, threshold=50, invert=True): | |
| """ | |
| Processes the input image to convert it to a binary image with optional color inversion. | |
| :param image_path: Path to the input image. | |
| :param threshold: Threshold value for binary conversion (default is 50). | |
| :param invert: Boolean flag to invert the colors of the binary image (default is False). | |
| :return: Path to the processed binary image. | |
| """ | |
| # Convert the image to grayscale | |
| gray_image = image.convert("L") | |
| # Convert the grayscale image to a binary image | |
| binary_image = gray_image.point(lambda x: 0 if x < threshold else 255, '1') | |
| # Invert the colors if requested | |
| if invert: | |
| binary_image = binary_image.point(lambda x: 255 - x) | |
| return binary_image | |
| def dots_to_points(editor_value): | |
| """ | |
| Convert white-dot brush layer to a list of (x, y) float coordinates. | |
| Expect at least one layer with opaque white dots on transparent bg. | |
| """ | |
| bg = editor_value["background"] # PIL.Image | |
| layers = editor_value["layers"] | |
| if not layers: | |
| raise gr.Error("Draw at least one dot with the brush first!") | |
| # ββ take the *first* layer that has any opaque pixels -------------- | |
| for lyr in layers: | |
| layer_img = lyr if isinstance(lyr, Image.Image) else lyr["data"] | |
| alpha = np.array(layer_img.split()[-1]) # alpha channel | |
| if alpha.max() > 0: | |
| dot_layer = layer_img | |
| break | |
| else: | |
| raise gr.Error("No non-empty brush layer found.") | |
| # ββ binarise (opaque => 1) ---------------------------------------- | |
| bin_mask = (np.array(dot_layer.split()[-1]) > 0).astype(np.uint8) | |
| # ββ group contiguous blobs, take their centroids ------------------ | |
| labelled, n = ndimage.label(bin_mask) | |
| if n == 0: | |
| raise gr.Error("No dots detected on the brush layer.") | |
| centroids = ndimage.center_of_mass(bin_mask, labelled, | |
| range(1, n + 1)) # (y, x) | |
| # flip to (x, y) order for SAM | |
| point_coords = [(float(x), float(y)) for y, x in centroids] | |
| return bg.convert("RGB"), point_coords # PIL, list[(x,y)] | |
| # βββββββββ SwinIR Functions βββββββββ | |
| def load_swinir_x3(ckpt_path: str, device: str = "cuda"): | |
| """SwinIR-x3 network weights β ready model (half precision if GPU).""" | |
| net = SwinIR( | |
| upscale=3, img_size=192, window_size=8, | |
| depths=[6]*6, embed_dim=60, num_heads=[6]*6, | |
| mlp_ratio=2, upsampler="pixelshuffle", | |
| img_range=1.0, in_chans=3 | |
| ) | |
| sd = torch.load(ckpt_path, map_location="cpu") | |
| net.load_state_dict(sd.get("params", sd), strict=True) | |
| net = net.to(device).eval() | |
| if device.startswith("cuda"): | |
| net = net.half() # fp16 for speed / memory | |
| return net | |
| def upscale_tiled_bgr(img_bgr: np.ndarray, | |
| net: torch.nn.Module, | |
| device: str, | |
| tile: int = 192, | |
| pad: int = 16) -> np.ndarray: | |
| """Forward-chop & stitch (works on any PyTorch version).""" | |
| h, w = img_bgr.shape[:2] | |
| scale = 3 | |
| out = np.empty((h*scale, w*scale, 3), np.uint8) | |
| autocast_ctx = (torch.cuda.amp.autocast if device.startswith("cuda") | |
| else nullcontext) | |
| for y in range(0, h, tile): | |
| for x in range(0, w, tile): | |
| i0, j0 = max(0, y-pad), max(0, x-pad) | |
| i1, j1 = min(h, y+tile+pad), min(w, x+tile+pad) | |
| patch = img_bgr[i0:i1, j0:j1] | |
| patch = img2tensor(patch, bgr2rgb=True, float32=True) / 255.0 | |
| patch = patch.unsqueeze(0).to(device) | |
| if device.startswith("cuda"): | |
| patch = patch.half() | |
| with autocast_ctx(): | |
| sr = net(patch) | |
| sr = tensor2img(sr, rgb2bgr=True) # uint8 | |
| top = y * scale | |
| left = x * scale | |
| bottom = min(y+tile, h) * scale | |
| right = min(x+tile, w) * scale | |
| pt_top = (y - i0) * scale | |
| pt_left = (x - j0) * scale | |
| pt_bot = pt_top + (bottom - top) | |
| pt_rgt = pt_left + (right - left) | |
| out[top:bottom, left:right] = sr[pt_top:pt_bot, pt_left:pt_rgt] | |
| return out | |
| # βββββββββ Image Processing Utilities βββββββββ | |
| def convert_to_3_4_aspect_ratio(image): | |
| """Convert image to 3:4 aspect ratio without distortion or unnecessary cropping""" | |
| original_width, original_height = image.size | |
| target_ratio = 3 / 4 # width / height | |
| current_ratio = original_width / original_height | |
| if abs(current_ratio - target_ratio) < 0.01: # Already close to 3:4 | |
| return image, (0, 0, original_width, original_height) | |
| if current_ratio > target_ratio: | |
| # Image is wider than 3:4, add padding to height | |
| new_height = int(original_width / target_ratio) | |
| new_width = original_width | |
| else: | |
| # Image is taller than 3:4, add padding to width | |
| new_width = int(original_height * target_ratio) | |
| new_height = original_height | |
| # Create new image with white background | |
| new_image = Image.new('RGB', (new_width, new_height), (255, 255, 255)) | |
| # Calculate position to center the original image | |
| paste_x = (new_width - original_width) // 2 | |
| paste_y = (new_height - original_height) // 2 | |
| # Paste original image centered | |
| new_image.paste(image, (paste_x, paste_y)) | |
| logger.info(f"Converted image from {original_width}x{original_height} to {new_width}x{new_height} (3:4 ratio)") | |
| return new_image, (paste_x, paste_y, original_width, original_height) | |
| def convert_to_0_78_aspect_ratio(image): | |
| original_width, original_height = image.size | |
| target_ratio = 0.78 | |
| current_ratio = original_width / original_height | |
| if abs(current_ratio - target_ratio) < 0.01: | |
| return image, (0, 0, original_width, original_height) | |
| if current_ratio > target_ratio: | |
| new_height = int(original_width / target_ratio) | |
| new_width = original_width | |
| else: | |
| new_width = int(original_height * target_ratio) | |
| new_height = original_height | |
| new_image = Image.new('RGB', (new_width, new_height), (255, 255, 255)) | |
| paste_x = (new_width - original_width) // 2 | |
| paste_y = (new_height - original_height) // 2 | |
| new_image.paste(image, (paste_x, paste_y)) | |
| logger.info(f"Converted image from {original_width}x{original_height} to {new_width}x{new_height} (0.78 ratio)") | |
| return new_image, (paste_x, paste_y, original_width, original_height) | |
| def convert_to_0_729_aspect_ratio(image): | |
| original_width, original_height = image.size | |
| target_ratio = 0.729 | |
| current_ratio = original_width / original_height | |
| if abs(current_ratio - target_ratio) < 0.01: | |
| return image, (0, 0, original_width, original_height) | |
| if current_ratio > target_ratio: | |
| new_height = int(original_width / target_ratio) | |
| new_width = original_width | |
| else: | |
| new_width = int(original_height * target_ratio) | |
| new_height = original_height | |
| new_image = Image.new('RGB', (new_width, new_height), (255, 255, 255)) | |
| paste_x = (new_width - original_width) // 2 | |
| paste_y = (new_height - original_height) // 2 | |
| new_image.paste(image, (paste_x, paste_y)) | |
| logger.info(f"Converted image from {original_width}x{original_height} to {new_width}x{new_height} (0.78 ratio)") | |
| return new_image, (paste_x, paste_y, original_width, original_height) | |
| # def convert_to_864_1184(image): | |
| # original_width, original_height = image.size | |
| # target_width = 864 | |
| # target_height = 1184 | |
| # if original_width == target_width and original_height == target_height: | |
| # return image, (0, 0, original_width, original_height) | |
| # new_image = Image.new('RGB', (target_width, target_height), (255, 255, 255)) | |
| # paste_x = (target_width - original_width) // 2 | |
| # paste_y = (target_height - original_height) // 2 | |
| # new_image.paste(image, (paste_x, paste_y)) | |
| # return new_image, (paste_x, paste_y, original_width, original_height) | |
| def overlay_ghost_mask(mask_img, background_img): | |
| mask_img = mask_img.convert('RGBA') | |
| background_img = background_img.convert('RGBA') | |
| bg_width, bg_height = background_img.size | |
| mask_width, mask_height = mask_img.size | |
| if bg_width < mask_width or bg_height < mask_height: | |
| bg_ratio = bg_width / bg_height | |
| mask_ratio = mask_width / mask_height | |
| if mask_ratio > bg_ratio: | |
| new_bg_height = int(bg_width / mask_ratio) | |
| new_bg_width = bg_width | |
| else: | |
| new_bg_width = int(bg_height * mask_ratio) | |
| new_bg_height = bg_height | |
| new_background = Image.new('RGBA', (new_bg_width, new_bg_height), (255, 255, 255, 255)) | |
| paste_x = (new_bg_width - bg_width) // 2 | |
| paste_y = (new_bg_height - bg_height) // 2 | |
| new_background.paste(background_img, (paste_x, paste_y)) | |
| background_img = new_background | |
| bg_width, bg_height = new_bg_width, new_bg_height | |
| else: | |
| mask_ratio = mask_width / mask_height | |
| bg_ratio = bg_width / bg_height | |
| if bg_ratio > mask_ratio: | |
| new_mask_height = int(mask_width / bg_ratio) | |
| new_mask_width = mask_width | |
| else: | |
| new_mask_width = int(mask_height * bg_ratio) | |
| new_mask_height = mask_height | |
| new_mask = Image.new('RGBA', (new_mask_width, new_mask_height), (0, 0, 0, 0)) | |
| paste_x = (new_mask_width - mask_width) // 2 | |
| paste_y = (new_mask_height - mask_height) // 2 | |
| new_mask.paste(mask_img, (paste_x, paste_y)) | |
| mask_img = new_mask | |
| mask_width, mask_height = new_mask_width, new_mask_height | |
| bg_ratio = bg_width / bg_height | |
| mask_ratio = mask_width / mask_height | |
| if abs(mask_ratio - bg_ratio) < 0.01: | |
| mask_resized = mask_img.resize((bg_width, bg_height), Image.Resampling.LANCZOS) | |
| result = background_img.copy() | |
| result.paste(mask_resized, (0, 0), mask_resized) | |
| else: | |
| if mask_ratio > bg_ratio: | |
| new_mask_width = bg_width | |
| new_mask_height = int(bg_width / mask_ratio) | |
| else: | |
| new_mask_height = bg_height | |
| new_mask_width = int(bg_height * mask_ratio) | |
| mask_resized = mask_img.resize((new_mask_width, new_mask_height), Image.Resampling.LANCZOS) | |
| x_offset = (bg_width - new_mask_width) // 2 | |
| y_offset = (bg_height - new_mask_height) // 2 | |
| result = background_img.copy() | |
| result.paste(mask_resized, (x_offset, y_offset), mask_resized) | |
| return result | |
| def create_ghost_image(image, mask): | |
| """Create a ghost/transparent version of the masked area""" | |
| # Convert mask to RGBA for transparency | |
| if mask.mode != 'L': | |
| mask = mask.convert('L') | |
| # Convert image to RGBA | |
| if image.mode != 'RGBA': | |
| image_rgba = image.convert('RGBA') | |
| else: | |
| image_rgba = image.copy() | |
| # Create ghost image with transparency | |
| ghost_image = Image.new('RGBA', image.size, (0, 0, 0, 0)) | |
| # Apply mask with reduced opacity for ghost effect | |
| mask_array = np.array(mask) | |
| image_array = np.array(image_rgba) | |
| ghost_array = np.zeros_like(image_array) | |
| # Copy the masked area with reduced opacity | |
| ghost_alpha = (mask_array / 255.0 * 180).astype(np.uint8) # 70% opacity | |
| mask_pixels = mask_array > 128 | |
| ghost_array[mask_pixels] = image_array[mask_pixels] | |
| ghost_array[:, :, 3] = ghost_alpha # Set alpha channel | |
| ghost_image = Image.fromarray(ghost_array, 'RGBA') | |
| logger.info("Created ghost image from mask") | |
| return ghost_image | |
| # βββββββββ Helper Functions βββββββββ | |
| def numpy_to_pil(array): | |
| if array.dtype != np.uint8: | |
| if array.max() <= 1.0: | |
| array = (array * 255).astype(np.uint8) | |
| else: | |
| array = array.astype(np.uint8) | |
| return Image.fromarray(array) | |
| def base64_to_image(b64_str): | |
| """Convert base64 string to PIL Image.""" | |
| if not b64_str: | |
| logger.error("Empty base64 string provided") | |
| return None | |
| try: | |
| if b64_str.startswith('data:'): | |
| b64_str = b64_str.split(',', 1)[1] | |
| logger.info(f"Decoding base64 string of length: {len(b64_str)}") | |
| image_data = base64.b64decode(b64_str) | |
| image = Image.open(io.BytesIO(image_data)) | |
| logger.info(f"Successfully created PIL image: {image.size}, mode: {image.mode}") | |
| return image | |
| except Exception as e: | |
| logger.error(f"Failed to decode base64 to image: {e}") | |
| return None | |
| # def image_to_base64(image): | |
| # """Convert PIL Image to base64 string.""" | |
| # if image is None: | |
| # return "" | |
| # if image.mode != 'RGB': | |
| # image = image.convert('RGB') | |
| # buffer = io.BytesIO() | |
| # image.save(buffer, format="PNG", optimize=True) | |
| # buffer.seek(0) | |
| # return base64.b64encode(buffer.getvalue()).decode('utf-8') | |
| def image_to_base64(image): | |
| if image is None: | |
| return "" | |
| if image.mode in ('RGBA', 'LA') or 'transparency' in image.info: | |
| format_to_use = "PNG" | |
| else: | |
| image = image.convert('RGB') | |
| format_to_use = "PNG" | |
| buffer = io.BytesIO() | |
| image.save(buffer, format=format_to_use, optimize=True) | |
| buffer.seek(0) | |
| return base64.b64encode(buffer.getvalue()).decode('utf-8') | |
| def segment_image_on_white_background(image, mask): | |
| """Composite image onto white background using mask""" | |
| # Invert the mask for proper compositing | |
| inverted_mask = Image.eval(mask, lambda x: 255 - x) | |
| # Create a white background | |
| segmented_image_on_white = Image.new("RGB", image.size, (255, 255, 255)) | |
| # Paste the image onto the white background using the inverted mask | |
| segmented_image_on_white.paste(image, (0, 0), mask=inverted_mask) | |
| return segmented_image_on_white | |
| def create_overlay_image(image_pil, boxes, masks, phrases): | |
| """Create overlay image with detections and masks.""" | |
| overlay = image_pil.copy().convert("RGBA") | |
| draw = ImageDraw.Draw(overlay) | |
| colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), (0, 255, 255)] | |
| for i, (box, mask, phrase) in enumerate(zip(boxes, masks, phrases)): | |
| color = colors[i % len(colors)] | |
| # Draw bounding box and label | |
| draw_box_and_label(draw, box.int().tolist(), phrase, color) | |
| # Create mask overlay | |
| mask_layer = Image.new("RGBA", image_pil.size, (0, 0, 0, 0)) | |
| mask_draw = ImageDraw.Draw(mask_layer) | |
| # Draw mask with transparency | |
| mask_np = mask.cpu().numpy() | |
| for y, x in np.argwhere(mask_np): | |
| mask_draw.point((x, y), fill=(*color, 100)) # Semi-transparent | |
| overlay.alpha_composite(mask_layer) | |
| return overlay.convert("RGB") | |
| # βββββββββ SAM Helper Functions βββββββββ | |
| def transform_image(image_pil): | |
| """Transform image for GroundingDINO.""" | |
| transform = T.Compose([ | |
| T.RandomResize([800], max_size=1333), | |
| T.ToTensor(), | |
| T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ]) | |
| img, _ = transform(image_pil, None) | |
| return img | |
| def load_grounding_dino(config_path, ckpt_path): | |
| """Load GroundingDINO model.""" | |
| args = SLConfig.fromfile(str(config_path)) | |
| args.device = DEVICE | |
| model = build_model(args) | |
| checkpoint = torch.load(ckpt_path, map_location="cpu") | |
| model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) | |
| model.eval() | |
| return model | |
| def get_grounding_output(model, image, caption, box_threshold, text_threshold): | |
| """Get detection outputs from GroundingDINO.""" | |
| caption = caption.lower().strip() | |
| if not caption.endswith("."): | |
| caption += "." | |
| with torch.no_grad(): | |
| outputs = model(image[None], captions=[caption]) | |
| logits = outputs["pred_logits"].cpu().sigmoid()[0] | |
| boxes = outputs["pred_boxes"].cpu()[0] | |
| # Filter by box threshold | |
| mask = logits.max(1)[0] > box_threshold | |
| logits, boxes = logits[mask], boxes[mask] | |
| # Get phrases and scores | |
| tokenizer = model.tokenizer | |
| tokenized = tokenizer(caption) | |
| phrases, scores = [], [] | |
| for logit, box in zip(logits, boxes): | |
| phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer) | |
| phrases.append(phrase) | |
| scores.append(logit.max().item()) | |
| return boxes, torch.tensor(scores), phrases | |
| def draw_box_and_label(draw, box, label, color): | |
| """Draw bounding box and label.""" | |
| x1, y1, x2, y2 = box | |
| draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=3) | |
| # Draw label background and text | |
| if label: | |
| try: | |
| font = ImageFont.load_default() | |
| bbox = draw.textbbox((x1, y1), label, font=font) | |
| text_width = bbox[2] - bbox[0] | |
| text_height = bbox[3] - bbox[1] | |
| # Background rectangle for text | |
| draw.rectangle([(x1, y1-text_height-4), (x1+text_width+4, y1)], fill=color) | |
| draw.text((x1+2, y1-text_height-2), label, fill="white", font=font) | |
| except: | |
| # Fallback without font | |
| draw.text((x1, y1-15), label, fill=color) | |