Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| from transformers import Owlv2Processor, Owlv2ForObjectDetection | |
| import spaces | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import random | |
| from transformers import SamModel, SamProcessor | |
| def apply_colored_masks_on_image(image, masks): | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image.astype('uint8'), 'RGB') | |
| image_rgba = image.convert("RGBA") | |
| for i in range(masks.shape[0]): | |
| mask = masks[i].squeeze().cpu().numpy() | |
| mask_image = Image.fromarray((mask * 255).astype(np.uint8), 'L') | |
| color = tuple([random.randint(0, 255) for _ in range(3)] + [128]) | |
| colored_mask = Image.new("RGBA", image.size, color) | |
| colored_mask.putalpha(mask_image) | |
| image_rgba = Image.alpha_composite(image_rgba, colored_mask) | |
| return image_rgba | |
| # Use GPU if available | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| else: | |
| device = torch.device("cpu") | |
| model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device) | |
| processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble") | |
| model_sam = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) | |
| processor_sam = SamProcessor.from_pretrained("facebook/sam-vit-huge") | |
| def query_image(img, text_queries, score_threshold=0.5): | |
| text_queries = text_queries.split(",") | |
| size = max(img.shape[:2]) | |
| target_sizes = torch.Tensor([[size, size]]) | |
| inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| model_outputs = model(**inputs) | |
| model_outputs.logits = model_outputs.logits.cpu() | |
| model_outputs.pred_boxes = model_outputs.pred_boxes.cpu() | |
| results = processor.post_process_object_detection(outputs=model_outputs, target_sizes=target_sizes) | |
| boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"] | |
| img_pil = Image.fromarray(img.astype('uint8'), 'RGB') | |
| result_labels = [] | |
| result_boxes = [] | |
| for box, score, label in zip(boxes, scores, labels): | |
| if score >= score_threshold: | |
| box = [int(i) for i in box.tolist()] | |
| label_text = text_queries[label.item()] | |
| result_labels.append((box, label_text)) | |
| result_boxes.append(box) | |
| sam_image = generate_image_with_sam(np.array(img_pil), result_boxes) | |
| return sam_image,result_labels | |
| def generate_image_with_sam(img, input_boxes): | |
| img_pil = Image.fromarray(img.astype('uint8'), 'RGB') | |
| inputs = processor_sam(img_pil, return_tensors="pt").to(device) | |
| image_embeddings = model_sam.get_image_embeddings(inputs["pixel_values"]) | |
| inputs = processor_sam(img_pil, input_boxes=[input_boxes], return_tensors="pt").to(device) | |
| inputs["input_boxes"].shape | |
| inputs.pop("pixel_values", None) | |
| inputs.update({"image_embeddings": image_embeddings}) | |
| with torch.no_grad(): | |
| outputs = model_sam(**inputs, multimask_output=False) | |
| masks = processor_sam.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()) | |
| # scores = outputs.iou_scores | |
| SAM_image = apply_colored_masks_on_image(img_pil, masks[0]) | |
| return SAM_image | |
| description = """ | |
| Split anythings | |
| """ | |
| demo = gr.Interface( | |
| fn=query_image, | |
| inputs=[gr.Image(), gr.Textbox(label="Query Text"), gr.Slider(0, 1, value=0.1, label="Score Threshold")], | |
| outputs=gr.AnnotatedImage(), | |
| title="Zero-Shot Object Detection SV3", | |
| description="This interface demonstrates object detection using zero-shot object detection and SAM for image segmentation.", | |
| examples=[ | |
| ["images/dark_cell.png", "gray cells", 0.1], | |
| ["images/animals.png", "Rabbit,Squirrel,Parrot,Hedgehog,Turtle,Ladybug,Chick,Frog,Butterfly,Snail,Mouse", 0.35], | |
| ], | |
| ) | |
| demo.launch() | |