Spaces:
Configuration error
Configuration error
| import argparse | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from PIL import Image as PILImage | |
| import numpy as np | |
| import clip | |
| import uuid | |
| from transformers import pipeline | |
| from dotenv import load_dotenv | |
| import os | |
| import cv2 | |
| # Qdrant imports | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http.models import PointStruct, VectorParams, Distance | |
| # SAM imports | |
| from sam2.build_sam import build_sam2 | |
| from sam2.sam2_image_predictor import SAM2ImagePredictor | |
| # grounding dino imports | |
| import groundingdino | |
| print(groundingdino.__file__) | |
| import groundingdino.datasets.transforms as T | |
| from groundingdino.util.inference import load_model, predict, load_image | |
| from groundingdino.config import GroundingDINO_SwinT_OGC | |
| from groundingdino.util.inference import load_model | |
| from torchvision.ops import box_convert | |
| from groundingdino.datasets.transforms import Compose, RandomResize, ToTensor, Normalize | |
| #SEEM imports | |
| #from modeling.BaseModel import BaseModel | |
| #from modeling import build_model | |
| #from utils.distributed import init_distributed | |
| #from utils.arguments import load_opt_from_config_files | |
| from torchvision import transforms | |
| import torch.nn.functional as F | |
| import boto3 | |
| from neo4j import GraphDatabase | |
| load_dotenv() # Loads variables from .env | |
| # Global variable for the SEEM model. | |
| seem_model = None | |
| # ------------------ Custom Gradio ImageMask Component ------------------ | |
| class ImageMask(gr.components.ImageEditor): | |
| """ | |
| Sets: source="canvas", tool="sketch" | |
| """ | |
| is_template = True | |
| def __init__(self, **kwargs): | |
| super().__init__(interactive=True, **kwargs) | |
| def preprocess(self, x): | |
| return super().preprocess(x) | |
| def load_seem_model(): | |
| """ | |
| Load the real SEEM model. This assumes you have installed the SEEM package. | |
| Adjust the import and model identifier as needed. | |
| """ | |
| global seem_model | |
| cfg = parse_option() | |
| opt = load_opt_from_config_files([cfg.conf_files]) | |
| opt = init_distributed(opt) | |
| pretrained_pth = os.path.join("seem_focall_v0.pt") | |
| seem_model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth) | |
| seem_model.eval().cuda() # set the model to evaluation mode | |
| # Pre-compute text embeddings for segmentation classes to avoid missing attribute | |
| try: | |
| from utils.constants import COCO_PANOPTIC_CLASSES | |
| class_list = [name.replace('-other','').replace('-merged','') for name in COCO_PANOPTIC_CLASSES] + ["background"] | |
| with torch.no_grad(): | |
| lang_encoder = seem_model.model.sem_seg_head.predictor.lang_encoder | |
| lang_encoder.get_text_embeddings(class_list, is_eval=True) | |
| print("Text embeddings for COCO classes loaded.") | |
| except Exception as e: | |
| print(f"Warning: failed to load class text embeddings: {e}") | |
| #with torch.no_grad(): | |
| # seem_model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(COCO_PANOPTIC_CLASSES + ["background"], is_eval=True) | |
| # Load the pretrained model (replace 'seem_pretrained_model' with the proper identifier/path) | |
| print("SEEM model loaded.") | |
| def parse_option(): | |
| parser = argparse.ArgumentParser('SEEM Demo', add_help=False) | |
| parser.add_argument('--conf_files', default="configs/focall_unicl_lang_demo.yaml", metavar="FILE", help='path to config file', ) | |
| cfg = parser.parse_args() | |
| return cfg | |
| # Load the CLIP model and preprocessing function. | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| clip_model, preprocess = clip.load("ViT-B/32", device=device) | |
| # Initialize an image captioning pipeline. | |
| captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") | |
| # Define the embedding dimensionality. | |
| embedding_dim = 512 | |
| print("hpst: " + os.getenv("QRANDT_HOST")) | |
| # Set up Qdrant client and collection. | |
| qdrant_client = QdrantClient( | |
| url=os.getenv("QRANDT_HOST"), | |
| api_key=os.getenv("QDRANT_API"), | |
| ) | |
| COLLECTION_NAME = "object_collection" | |
| if not qdrant_client.collection_exists(COLLECTION_NAME): | |
| qdrant_client.create_collection( | |
| collection_name=COLLECTION_NAME, | |
| vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE) | |
| ) | |
| else: | |
| qdrant_client.get_collection(COLLECTION_NAME) | |
| # Initialize SAM (Segment Anything Model) for segmentation. | |
| sam_checkpoint = "./checkpoints/sam2.1_hiera_small.pt" # Update this path to your SAM checkpoint. | |
| sam_model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml" | |
| predictor = SAM2ImagePredictor(build_sam2(sam_model_cfg, sam_checkpoint)) | |
| # … after you build your SAM predictor, load Grounding DINO: | |
| from groundingdino.util.slconfig import SLConfig | |
| grounding_config_file = "./configs/GroundingDINO_SwinT_OGC.py" | |
| grounding_config = SLConfig.fromfile(grounding_config_file) | |
| #grounding_config.merge_from_file("./configs/GroundingDINO_SwinT_OGC.py") | |
| grounding_checkpoint = "./checkpoints/groundingdino_swint_ogc.pth" | |
| grounding_model = load_model(grounding_config_file, grounding_checkpoint, device="cuda") | |
| #grounding_model = build_grounding_model(grounding_config) | |
| #ckpt = torch.load(grounding_checkpoint, map_location=device) | |
| #grounding_model.load_state_dict(ckpt["model"], strict=False) | |
| #grounding_model.to(device).eval() | |
| # Invoke at startup | |
| #load_seem_model() | |
| # 2) grab creds from .env | |
| aws_key = os.getenv("S3_ACCESS_KEY") | |
| aws_secret = os.getenv("S3_SECRET_KEY") | |
| aws_region = os.getenv("S3_REGION", "us-east-1") | |
| session = boto3.Session( | |
| aws_access_key_id=aws_key, | |
| aws_secret_access_key=aws_secret, | |
| region_name=aws_region, | |
| ) | |
| s3 = session.client("s3") | |
| s3_bucket = 'object-mem' | |
| NEO4J_URI = os.getenv("NEO4J_URI") | |
| NEO4J_USER = os.getenv("NEO4J_USER") | |
| NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD") | |
| neo4j_driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD)) | |
| HOUSE_ID='c8c5fdea-7138-44ea-9f02-7fdcd47ff8cf' | |
| # Shared preprocessing | |
| resize_transform = transforms.Compose([ | |
| transforms.Resize(512, interpolation=Image.BICUBIC) | |
| ]) | |
| # ------------------------------ | |
| # Helper functions | |
| # ------------------------------ | |
| def resize_image(image, max_width=800): | |
| """ | |
| Resizes a numpy array image (RGB) to a maximum width of 800px, preserving aspect ratio. | |
| """ | |
| if image is None: | |
| return None | |
| from PIL import Image | |
| pil_img = Image.fromarray(image) | |
| width, height = pil_img.size | |
| if width > max_width: | |
| new_height = int(height * (max_width / width)) | |
| resized_img = pil_img.resize((max_width, new_height), Image.LANCZOS) | |
| return np.array(resized_img) | |
| else: | |
| return gr.skip() | |
| def generate_description_vllm(pil_image): | |
| """ | |
| Generate a default caption for the image using the captioning model. | |
| """ | |
| output = captioner(pil_image) | |
| return output[0]['generated_text'] | |
| # ---------------- New apply_seem Function ---------------- | |
| def apply_seem(editor_output, | |
| background_mode: str = "remove", | |
| crop_result: bool = True) -> np.ndarray: | |
| """ | |
| 1) Extract the user’s sketch from ImageEditor layers, | |
| 2) Run exactly one spatial-only SEEM inference, | |
| 3) Upsample and threshold the chosen mask, | |
| 4) Composite (remove or blur), and | |
| 5) Optionally crop. | |
| """ | |
| if seem_model is None: | |
| load_seem_model() | |
| # --- 1) pull RGB + sketch mask --- | |
| if isinstance(editor_output, dict): | |
| bg = editor_output.get('background') | |
| if bg is None: | |
| return None | |
| image = bg[..., :3] | |
| stroke_mask = np.zeros(image.shape[:2], dtype=np.uint8) | |
| for layer in editor_output.get('layers', []): | |
| stroke_mask |= (layer[..., 3] > 0).astype(np.uint8) | |
| else: | |
| arr = editor_output | |
| if arr.shape[2] == 4: | |
| image = arr[..., :3] | |
| stroke_mask = (arr[..., 3] > 0).astype(np.uint8) | |
| else: | |
| image = arr | |
| stroke_mask = np.zeros(arr.shape[:2], dtype=np.uint8) | |
| # if no sketch, bail out | |
| if stroke_mask.sum() == 0: | |
| return image | |
| # --- 2) resize & to‐tensor --- | |
| pil = Image.fromarray(image) | |
| pil_r = pil #resize_transform(pil) | |
| img_np = np.asarray(pil_r) | |
| h, w = img_np.shape[:2] | |
| # dilate the stroke so it’s “seen” by SEEM | |
| stroke_small = cv2.resize(stroke_mask, (w, h), interpolation=cv2.INTER_NEAREST) | |
| kernel = np.ones((15,15), dtype=np.uint8) | |
| stroke_small = cv2.dilate(stroke_small, kernel, iterations=1) | |
| img_t = torch.from_numpy(img_np).permute(2,0,1).unsqueeze(0).float()/255.0 | |
| img_t = img_t.cuda() | |
| stroke_t = torch.from_numpy(stroke_small[None,None]).bool().cuda() | |
| # --- 3) single-pass spatial inference --- | |
| ts = seem_model.model.task_switch | |
| ts['spatial'] = True | |
| ts['visual'] = False | |
| ts['grounding']= False | |
| ts['audio'] = False | |
| data = { | |
| 'image': img_t[0], # [3,H,W] | |
| 'height': h, | |
| 'width': w, | |
| 'stroke': stroke_t, # [1,1,H,W] | |
| 'spatial_query_pos_mask': [stroke_t[0]] | |
| } | |
| with torch.no_grad(): | |
| results, _, _ = seem_model.model.evaluate_demo([data]) | |
| # --- 4) pick & upsample mask --- | |
| v_emb = results['pred_maskembs'] # [1,M,D] | |
| s_emb = results['pred_pspatials'] # [1,1,D] (N=1 for a single stroke mask) | |
| pred_ms = results['pred_masks'] # [1,M,H',W'] | |
| sim = v_emb @ s_emb.transpose(1,2) # [1,M,1] | |
| idx = sim[0,:,0].argmax().item() | |
| mask_lo = torch.sigmoid(pred_ms[0,idx]) # logits→[0,1] | |
| mask_up = F.interpolate(mask_lo[None,None], (h,w), mode='bilinear')[0,0].cpu().numpy() > 0.5 | |
| masks = [] | |
| num_masks = pred_ms.shape[1] | |
| for i in range(min(num_masks, 5)): # show up to 5 proposals | |
| m = pred_ms[0, i] | |
| up = F.interpolate(m[None,None], (h, w), mode='bilinear')[0,0].cpu().numpy() > 0 | |
| vis = (up * 255).astype(np.uint8) | |
| masks.append(PILImage.fromarray(vis)) | |
| # create horizontal montage | |
| widths, heights = zip(*(im.size for im in masks)) | |
| total_width = sum(widths) | |
| max_height = max(heights) | |
| montage = PILImage.new('L', (total_width, max_height)) | |
| x_offset = 0 | |
| for im in masks: | |
| montage.paste(im, (x_offset, 0)) | |
| x_offset += im.width | |
| return montage | |
| # --- 5) composite & crop back to original --- | |
| orig_h, orig_w = image.shape[:2] | |
| mask_full = cv2.resize(mask_up.astype(np.uint8), (orig_w,orig_h), | |
| interpolation=cv2.INTER_NEAREST).astype(bool) | |
| mask_3c = np.stack([mask_full]*3, axis=-1).astype(np.float32) | |
| if background_mode == 'extreme_blur': | |
| blur = cv2.GaussianBlur(image, (101,101), 0) | |
| out = image*mask_3c + blur*(1-mask_3c) | |
| else: | |
| bg = np.full_like(image, 255) | |
| out = image*mask_3c + bg*(1-mask_3c) | |
| out = out.astype(np.uint8) | |
| if crop_result: | |
| ys, xs = np.where(mask_full) | |
| if ys.size: | |
| out = out[ys.min():ys.max()+1, xs.min():xs.max()+1] | |
| return out | |
| def apply_sam(editor_output, background_mode="remove", crop_result=True) -> np.ndarray: | |
| """ | |
| Uses SAM to generate a segmentation mask based on the sketch (stroke_mask), | |
| then either removes or extremely blurs the background. Optionally crops to | |
| the foreground bbox. | |
| Parameters: | |
| editor_output: either a dict with 'background' and 'layers' or an HxWx3/4 array | |
| background_mode: "remove" or "extreme_blur" | |
| crop_result: whether to crop output to fg bbox | |
| Returns: | |
| HxWx3 uint8 array | |
| """ | |
| # --- 1) pull RGB + sketch mask --- | |
| if isinstance(editor_output, dict): | |
| bg = editor_output.get('background') | |
| if bg is None: | |
| return None | |
| image = bg[..., :3] | |
| stroke_mask = np.zeros(image.shape[:2], dtype=np.uint8) | |
| for layer in editor_output.get('layers', []): | |
| stroke_mask |= (layer[..., 3] > 0).astype(np.uint8) | |
| else: | |
| arr = editor_output | |
| if arr.shape[2] == 4: | |
| image = arr[..., :3] | |
| stroke_mask = (arr[..., 3] > 0).astype(np.uint8) | |
| else: | |
| image = arr | |
| stroke_mask = np.zeros(arr.shape[:2], dtype=np.uint8) | |
| # if no sketch, just return original | |
| if stroke_mask.sum() == 0: | |
| return image | |
| # preprocess & set image | |
| image = resize_image(image) | |
| predictor.set_image(image) | |
| # downscale stroke mask to predictor size | |
| h, w = image.shape[:2] | |
| stroke_small = cv2.resize(stroke_mask, (w, h), interpolation=cv2.INTER_NEAREST) | |
| ys, xs = np.nonzero(stroke_small) | |
| if len(xs) == 0: | |
| raise ValueError("stroke_mask provided but contains no nonzero pixels") | |
| point_coords = np.stack([xs, ys], axis=1) | |
| point_labels = np.ones(len(point_coords), dtype=int) | |
| #mask_input = stroke_small.astype(np.float32)[None, ...] # shape (1, H, W) | |
| coords = np.stack([xs, ys], axis=1) | |
| # sample up to N points | |
| N = min(10, len(coords)) | |
| if N == 0: | |
| raise ValueError("No stroke pixels found") | |
| idxs = np.linspace(0, len(coords)-1, num=N, dtype=int) | |
| point_coords = coords[idxs] | |
| point_labels = np.ones(N, dtype=int) | |
| # now actually predict using the strokes | |
| with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): | |
| masks, scores, logits = predictor.predict( | |
| point_coords=point_coords, | |
| point_labels=point_labels, | |
| box=None, | |
| multimask_output=False | |
| ) | |
| # pick the highest-score mask and binarize | |
| best_idx = int(np.argmax(scores)) | |
| mask = masks[best_idx] > 0.5 | |
| mask_3c = np.repeat(mask[:, :, None], 3, axis=2).astype(np.float32) | |
| # composite | |
| if background_mode == "extreme_blur": | |
| blurred = cv2.GaussianBlur(image, (101, 101), 0) | |
| output = image.astype(np.float32) * mask_3c + blurred * (1 - mask_3c) | |
| else: # "remove" | |
| white = np.full_like(image, 255, dtype=np.uint8).astype(np.float32) | |
| output = image.astype(np.float32) * mask_3c + white * (1 - mask_3c) | |
| output = output.astype(np.uint8) | |
| # optional crop | |
| if crop_result: | |
| ys, xs = np.where(mask) | |
| if xs.size and ys.size: | |
| x0, x1 = xs.min(), xs.max() | |
| y0, y1 = ys.min(), ys.max() | |
| output = output[y0:y1+1, x0:x1+1] | |
| return output | |
| def apply_grounded_sam(editor_output, prompt: str, | |
| box_threshold=0.3, text_threshold=0.25, crop_result=True) -> np.ndarray: | |
| # 1) pull RGB out | |
| if isinstance(editor_output, dict): | |
| bg = editor_output.get('background') | |
| if bg is None: | |
| return None | |
| image = bg[..., :3] | |
| stroke_mask = np.zeros(image.shape[:2], dtype=np.uint8) | |
| for layer in editor_output.get('layers', []): | |
| stroke_mask |= (layer[..., 3] > 0).astype(np.uint8) | |
| else: | |
| arr = editor_output | |
| if arr.shape[2] == 4: | |
| image = arr[..., :3] | |
| stroke_mask = (arr[..., 3] > 0).astype(np.uint8) | |
| else: | |
| image = arr | |
| stroke_mask = np.zeros(arr.shape[:2], dtype=np.uint8) | |
| pil = Image.fromarray(image) | |
| h, w = pil.height, pil.width | |
| transform = Compose([ | |
| RandomResize([800], max_size=1333), | |
| ToTensor(), | |
| Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]) | |
| ]) | |
| # Given your PIL image: | |
| orig_np = np.array(pil) # H,W,3 | |
| img_t, _ = transform(pil, None) # returns tensor[C,H,W] | |
| img_t = img_t.to(device) # move to GPU if needed | |
| # 3) run DINO’s predict API – it will tokenize, forward, and post‐process for you :contentReference[oaicite:1]{index=1} | |
| boxes, scores, phrases = predict( | |
| model=grounding_model, | |
| image=img_t, | |
| caption=prompt, | |
| box_threshold=box_threshold, | |
| text_threshold=text_threshold, | |
| device=device | |
| ) | |
| if boxes.numel() == 0: | |
| return image # no detections → return original | |
| # 4) convert normalized cxcywh → absolute xyxy pixels :contentReference[oaicite:2]{index=2} | |
| # (boxes is tensor of shape [N,4] with values in [0,1]) | |
| boxes_abs = boxes * torch.tensor([w, h, w, h], device=boxes.device) | |
| xyxy = box_convert(boxes=boxes_abs, in_fmt="cxcywh", out_fmt="xyxy") | |
| sam_boxes = xyxy.cpu().numpy() # shape [N,4] in pixel coords | |
| point_coords = None | |
| point_labels = None | |
| if stroke_mask.sum() > 0: | |
| ys, xs = np.nonzero(stroke_mask) | |
| point_coords = np.stack([xs, ys], axis=1) | |
| point_labels = np.ones(len(point_coords), dtype=int) | |
| #mask_input = stroke_small.astype(np.float32)[None, ...] # shape (1, H, W) | |
| coords = np.stack([xs, ys], axis=1) | |
| # sample up to N points | |
| N = min(10, len(coords)) | |
| if N == 0: | |
| raise ValueError("No stroke pixels found") | |
| idxs = np.linspace(0, len(coords)-1, num=N, dtype=int) | |
| point_coords = coords[idxs] | |
| point_labels = np.ones(N, dtype=int) | |
| # -> shape (1,P,2) and (1,P) | |
| point_coords = point_coords[None, ...] # (1, P, 2) | |
| point_labels = point_labels[None, ...] # (1, P) | |
| # now tile to (B,P,2) and (B,P) | |
| box_count = boxes.shape[0] | |
| point_coords = np.tile(point_coords, (box_count, 1, 1)) # (B, P, 2) | |
| point_labels = np.tile(point_labels, (box_count, 1)) # (B, P) | |
| # 5) feed those boxes into SAM2 | |
| predictor.set_image(image) | |
| masks, scores_sam, _ = predictor.predict( | |
| point_coords=point_coords, | |
| point_labels=point_labels, | |
| box=sam_boxes, | |
| multimask_output=False | |
| ) | |
| # 6) pick the best SAM proposal, composite & crop | |
| best = int(np.argmax(scores_sam)) | |
| # 1) pick the best mask and remove any leading batch‐dim | |
| mask = masks[best] > 0.5 # masks[best] should give you shape (H, W) | |
| # if you still see a leading 1, just squeeze it: | |
| if mask.ndim == 3 and mask.shape[0] == 1: | |
| mask = mask[0] # -> now (H, W) | |
| # expand it into a 3-channel float mask of shape (H, W, 3) | |
| mask_3c = np.repeat(mask[..., None], 3, axis=2).astype(np.float32) | |
| # numpy will automatically broadcast the 1→3 in the last dim when you multiply | |
| print("img:", image.shape) | |
| print("mask :", mask.shape) | |
| print("mask_3c :", mask_3c.shape) | |
| img_f = image.astype(np.float32) | |
| one_c = 1.0 - mask_3c | |
| if background_mode == "extreme_blur": | |
| blurred = cv2.GaussianBlur(image, (101, 101), 0).astype(np.float32) | |
| output_f = img_f * mask_3c + blurred * one_c | |
| elif background_mode == "highlight": | |
| alpha = 0.5 | |
| overlay_color = np.array([255, 0, 0], dtype=np.float32) # pure red | |
| output_f = img_f.copy() | |
| # img_f[mask] is (N,3); blend each pixel with red | |
| output_f[mask] = (1 - alpha) * img_f[mask] + alpha * overlay_color | |
| else: #remove | |
| white = np.full_like(img_f, 255, dtype=np.float32) | |
| output_f = img_f * mask_3c + white * one_c | |
| output = output_f.astype(np.uint8) | |
| if crop_result: | |
| ys, xs = np.where(mask) | |
| if xs.size and ys.size: | |
| x0, x1 = xs.min(), xs.max() | |
| y0, y1 = ys.min(), ys.max() | |
| output = output[y0:y1+1, x0:x1+1] | |
| return output | |
| def update_preview(image, background_mode, click_points): | |
| """ | |
| Returns a preview image. | |
| If background_mode is not "None", processes the image with SAM using the provided click points. | |
| """ | |
| if image is None: | |
| return None | |
| if background_mode != "None": | |
| mode = background_mode.lower().replace(" ", "_") | |
| processed_image = apply_seem(image, click_points, mode=mode) | |
| else: | |
| processed_image = image | |
| return processed_image | |
| def update_caption(image, background_mode, click_points): | |
| """ | |
| Updates the description textbox by generating a caption from the processed image. | |
| """ | |
| if image is None: | |
| return gr.update(value="") | |
| processed_image = image | |
| pil_image = Image.fromarray(processed_image) | |
| caption = generate_description_vllm(pil_image) | |
| return gr.update(value=caption) | |
| def add_item(image, description, object_id, background_mode, click_points): | |
| """ | |
| Processes the image for memorization: | |
| - Resizes it. | |
| - Optionally applies SAM processing (background removal or extreme blur) based on background_mode. | |
| - Generates a caption if needed. | |
| - Computes the CLIP embedding and stores it in Qdrant. | |
| """ | |
| pil_image = Image.fromarray(image) | |
| #apply clip embeddings | |
| image_features = embed_image(pil_image) | |
| #generate id's | |
| if not object_id or object_id.strip() == "": | |
| object_id = str(uuid.uuid4()) | |
| view_id = str(uuid.uuid4()) | |
| #upload original full-res to S3 | |
| key = f"object_collection/{object_id}/{view_id}.png" | |
| image_url = upload_to_s3(pil_image, s3_bucket, key) | |
| store_in_qdrant(view_id, vector=image_features.tolist(), object_id=object_id, house_id=HOUSE_ID, image_url=image_url) | |
| store_in_neo4j(object_id, HOUSE_ID, description, object_id) | |
| return f"Item added under object ID: {object_id}\nDescription: {description}" | |
| def query_item(query_image, background_mode, click_points, k=5): | |
| """ | |
| Processes the query image: | |
| - Resizes it. | |
| - Optionally applies SAM processing based on background_mode and click points. | |
| - Computes the CLIP embedding and queries Qdrant. | |
| - Returns matching objects. | |
| """ | |
| pil_query = Image.fromarray(query_image) | |
| query_features = embed_image(pil_query) | |
| search_results = qdrant_client.search( | |
| collection_name=COLLECTION_NAME, | |
| query_vector=query_features.tolist(), | |
| limit=k | |
| ) | |
| object_scores = {} | |
| object_views = {} | |
| for result in search_results: | |
| obj_id = result.payload.get("object_id") | |
| score = result.score | |
| if obj_id in object_scores: | |
| object_scores[obj_id] = max(object_scores[obj_id], score) | |
| object_views[obj_id].append(result.payload.get("description")) | |
| else: | |
| object_scores[obj_id] = score | |
| object_views[obj_id] = [result.payload.get("description")] | |
| all_scores = np.array(list(object_scores.values())) | |
| exp_scores = np.exp(all_scores) | |
| probabilities = exp_scores / np.sum(exp_scores) if np.sum(exp_scores) > 0 else np.zeros_like(exp_scores) | |
| results = [] | |
| for i, (obj_id, score) in enumerate(object_scores.items()): | |
| results.append({ | |
| "object_id": obj_id, | |
| "aggregated_similarity": float(score), | |
| "probability": float(probabilities[i]), | |
| "descriptions": object_views[obj_id] | |
| }) | |
| return results | |
| def update_click_points_str(event: gr.SelectData): | |
| """ | |
| Callback to update click points. | |
| Receives the event from the image select event (with keys "x" and "y"), appends the new coordinate | |
| to the global list, and returns the updated state and a formatted string. | |
| """ | |
| global click_points_global | |
| if event is None: | |
| return click_points_global, "" | |
| # Here we use event.index to get the (x,y) coordinates. | |
| x = event.index[0] | |
| y = event.index[1] | |
| if x is not None and y is not None: | |
| click_points_global.append([x, y]) | |
| points_str = ";".join([f"{pt[0]},{pt[1]}" for pt in click_points_global]) | |
| return click_points_global, points_str | |
| def clear_click_points(): | |
| """ | |
| Clears the global list of click points. | |
| """ | |
| global click_points_global | |
| click_points_global = [] | |
| return click_points_global, "" | |
| def embed_image(pil_image : Image): | |
| image = preprocess(pil_image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| embedding = clip_model.encode_image(image) | |
| image_features = embedding[0].cpu().numpy() | |
| norm = np.linalg.norm(image_features) | |
| if norm > 0: | |
| image_features = image_features / norm | |
| return image_features | |
| def upload_to_s3(pil_image, bucket: str, key: str) -> str: | |
| """ | |
| Save a PIL image to S3 under `key` and return the public URL. | |
| """ | |
| # 1) write into an in-memory buffer | |
| from io import BytesIO | |
| buf = BytesIO() | |
| pil_image.save(buf, format="PNG") | |
| buf.seek(0) | |
| # 2) upload | |
| s3.upload_fileobj(buf, bucket, key, ExtraArgs={"ContentType": "image/png"}) | |
| # 3) build URL | |
| region = boto3.session.Session().region_name | |
| return f"https://{bucket}.s3.{region}.amazonaws.com/{key}" | |
| def store_in_qdrant(view_id, vector, object_id, house_id, image_url : str): | |
| payload = {"object_id": object_id, "image_url": image_url, "house_id": house_id,} | |
| point = PointStruct(id=view_id, vector=vector, payload=payload) | |
| qdrant_client.upsert(collection_name=COLLECTION_NAME, points=[point]) | |
| return view_id | |
| def store_in_neo4j(object_id, house_id, description, qdrant_object_id): | |
| with neo4j_driver.session() as session: | |
| session.run(""" | |
| MERGE (h:House {house_id: $house_id}) | |
| MERGE (o:Object {object_id: $object_id}) | |
| SET o.description = $description, | |
| o.qdrant_object_id = $qdrant_object_id | |
| MERGE (h)-[:CONTAINS]->(o) | |
| """, { | |
| "object_id": object_id, | |
| "house_id": house_id, | |
| "description": description, | |
| "qdrant_object_id": qdrant_object_id | |
| }) | |
| # ------------------------------ | |
| # Gradio Interface | |
| # ------------------------------ | |
| # Preview function for both tabs | |
| # Preview function for both tabs | |
| def preview_fn(editor_output, mode): | |
| # If no input yet, skip preview | |
| if editor_output is None or (isinstance(editor_output, dict) and 'background' not in editor_output): | |
| return None | |
| return apply_sam(editor_output, mode) | |
| with gr.Blocks() as demo: | |
| with gr.Tab("Add Item"): | |
| image_input = gr.ImageEditor(label="Upload & Sketch", type="numpy") | |
| seg_prompt_input = gr.Textbox(label="Segmentation Prompt", placeholder="e.g. ‘red apple’") | |
| description_input = gr.Textbox(label="Description", lines=3) | |
| object_id_input = gr.Textbox(label="Object ID (optional)") | |
| background_mode = gr.Radio(choices=["remove","extreme_blur"], value="remove") | |
| preview_button = gr.Button("Preview") | |
| preview_output = gr.Image(label="Preview Processed Image", type="numpy") | |
| submit_button = gr.Button("Submit") | |
| output_text = gr.Textbox(label="Result") | |
| # Only trigger preview on upload | |
| #image_input.upload(fn=preview_fn, | |
| # inputs=[image_input, background_mode], | |
| # outputs=[preview_output]) | |
| # User can manually re-trigger preview via a button if mode changes | |
| preview_button.click( | |
| fn=lambda img,mode,prompt: ( | |
| apply_grounded_sam(img, prompt) | |
| if prompt else | |
| apply_sam(img, mode) | |
| ), | |
| inputs=[image_input, background_mode, seg_prompt_input], | |
| outputs=[preview_output] | |
| ) | |
| submit_button.click(fn=add_item, | |
| inputs=[preview_output, description_input, object_id_input, background_mode, image_input], | |
| outputs=[output_text]) | |
| with gr.Tab("Query Item"): | |
| query_input = gr.ImageEditor(label="Query & Sketch", type="numpy") | |
| query_prompt = gr.Textbox(label="Segmentation Prompt", placeholder="optional text-based mask") | |
| query_mode = gr.Radio(choices=["remove","extreme_blur"], value="remove") | |
| query_preview= gr.Image(label="Query Preview", type="numpy") | |
| k_slider = gr.Slider(1,10,1, label="Results k") | |
| query_button = gr.Button("Search") | |
| query_output = gr.JSON(label="Query Results") | |
| # Only trigger preview on upload | |
| query_input.upload( | |
| fn=lambda img,mode,prompt: ( | |
| apply_grounded_sam(img, prompt) | |
| if prompt else | |
| apply_sam(img, mode) | |
| ), | |
| inputs=[query_input, query_mode, query_prompt], | |
| outputs=[query_preview] | |
| ) | |
| # Manual preview refresh | |
| query_preview_button = gr.Button("Refresh Preview") | |
| query_preview_button.click(fn=preview_fn, | |
| inputs=[query_input, query_mode], | |
| outputs=[query_preview]) | |
| query_button.click(fn=query_item, | |
| inputs=[query_preview, query_mode, query_input, k_slider], | |
| outputs=[query_output]) | |
| demo.launch() | |