import torch from PIL import Image import os import numpy as np from transformers import AutoModel, AutoTokenizer, AutoImageProcessor from types import MethodType from detectron2.data import MetadataCatalog from detectron2.utils.visualizer import ColorMode, Visualizer from transformers import AutoModel, AutoTokenizer, AutoImageProcessor from detectron2.data.detection_utils import read_image from detectron2.utils.visualizer import GenericMask import matplotlib.colors as mplc def draw_instance_predictions_cache(self, labels, np_masks, jittering: bool = True): """ Draw instance-level prediction results on an image. Args: predictions (Instances): the output of an instance detection/segmentation model. Following fields will be used to draw: "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle"). jittering: if True, in color mode SEGMENTATION, randomly jitter the colors per class to distinguish instances from the same class Returns: output (VisImage): image object with visualizations. """ boxes = None scores = None classes = None keypoints = None masks = [GenericMask(x, self.output.height, self.output.width) for x in np_masks] if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"): colors = ( [self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes] if jittering else [ tuple(mplc.to_rgb([x / 255 for x in self.metadata.thing_colors[c]])) for c in classes ] ) alpha = 0.8 else: colors = None alpha = 0.5 self.overlay_instances( masks=masks, boxes=boxes, labels=labels, keypoints=keypoints, assigned_colors=colors, alpha=alpha, ) return self.output def visualize(image_path, cat_masks, out_path, tags): if tags is None: left_tags = [f'{i}' for i in range(len(cat_masks))] else: left_tags = tags unique_tags = list(set(left_tags)) text_prompt = ','.join(unique_tags) metadata = MetadataCatalog.get("__unused_ape_" + text_prompt) metadata.thing_classes = unique_tags metadata.stuff_classes = unique_tags result_masks = cat_masks input_image = read_image(image_path, format="BGR") visualizer = Visualizer(input_image[:, :, ::-1], metadata, instance_mode=ColorMode.IMAGE) visualizer.draw_instance_predictions = MethodType(draw_instance_predictions_cache, visualizer) vis_output = visualizer.draw_instance_predictions(labels=left_tags, np_masks=result_masks) output_image = vis_output.get_image() output_image = Image.fromarray(output_image) output_image.save(out_path) path = "./work_dirs/hf_pano_vlm" model = AutoModel.from_pretrained( path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, use_flash_attn=True, trust_remote_code=True).eval().cuda() tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False) image_path = "./FRAME02_ORI.jpg" image = Image.open(image_path) width, height = image.size from projects.llava_sam2.datasets.coco_category import COCO_CATEGORIES coco_category_names = "" for item in COCO_CATEGORIES: class_name = item['name'] coco_category_names += f"

{class_name}

[CLS], " coco_category_names = coco_category_names[:-2] # question = f"\nSegment from the class prompt: {coco_category_names}." question = f"\nSegment from the class prompt:

person

[CLS],

car

[CLS],

road

[CLS],

tree

[CLS],

building

[CLS],

ground

[CLS]." m2f_processor = AutoImageProcessor.from_pretrained("./facebook/mask2former-swin-large-coco-panoptic", trust_remote_code=True,) chat_outputs = model.predict_forward(text=question, image=image, tokenizer=tokenizer, m2f_processor=m2f_processor) answer = chat_outputs['prediction'] masks = chat_outputs['prediction_masks'] m2f_outputs = chat_outputs['m2f_outputs'] label_id_to_text = m2f_outputs['label_id_to_text'] post_m2f_outputs = model.post_process_panoptic_segmentation( m2f_outputs['class_queries_logits'], m2f_outputs['masks_queries_logits'], target_sizes=[(height, width)], ) print(f"user: {question}") print(f"assistant: {answer}") segmentation = post_m2f_outputs[0]['segmentation'] segments_info = post_m2f_outputs[0]['segments_info'] pano_masks, pano_tags = [], [] for item in segments_info: mask = segmentation == item['id'] pano_masks.append(mask.unsqueeze(0).cpu().numpy()) pano_tags.append(label_id_to_text[item['label_id']]) pano_masks = np.concatenate(pano_masks, axis=0) visualize(image_path, pano_masks, "./visualize_test_4.jpg", pano_tags)