File size: 4,831 Bytes
4ee9c8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import torch
from PIL import Image
import os
import numpy as np
from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
from types import MethodType
from detectron2.data import MetadataCatalog
from detectron2.utils.visualizer import ColorMode, Visualizer
from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
from detectron2.data.detection_utils import read_image
from detectron2.utils.visualizer import GenericMask
import matplotlib.colors as mplc
def draw_instance_predictions_cache(self, labels, np_masks, jittering: bool = True):
"""
Draw instance-level prediction results on an image.
Args:
predictions (Instances): the output of an instance detection/segmentation
model. Following fields will be used to draw:
"pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").
jittering: if True, in color mode SEGMENTATION, randomly jitter the colors per class
to distinguish instances from the same class
Returns:
output (VisImage): image object with visualizations.
"""
boxes = None
scores = None
classes = None
keypoints = None
masks = [GenericMask(x, self.output.height, self.output.width) for x in np_masks]
if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"):
colors = (
[self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes]
if jittering
else [
tuple(mplc.to_rgb([x / 255 for x in self.metadata.thing_colors[c]]))
for c in classes
]
)
alpha = 0.8
else:
colors = None
alpha = 0.5
self.overlay_instances(
masks=masks,
boxes=boxes,
labels=labels,
keypoints=keypoints,
assigned_colors=colors,
alpha=alpha,
)
return self.output
def visualize(image_path, cat_masks, out_path, tags):
if tags is None:
left_tags = [f'{i}' for i in range(len(cat_masks))]
else:
left_tags = tags
unique_tags = list(set(left_tags))
text_prompt = ','.join(unique_tags)
metadata = MetadataCatalog.get("__unused_ape_" + text_prompt)
metadata.thing_classes = unique_tags
metadata.stuff_classes = unique_tags
result_masks = cat_masks
input_image = read_image(image_path, format="BGR")
visualizer = Visualizer(input_image[:, :, ::-1], metadata, instance_mode=ColorMode.IMAGE)
visualizer.draw_instance_predictions = MethodType(draw_instance_predictions_cache, visualizer)
vis_output = visualizer.draw_instance_predictions(labels=left_tags, np_masks=result_masks)
output_image = vis_output.get_image()
output_image = Image.fromarray(output_image)
output_image.save(out_path)
path = "./work_dirs/hf_pano_vlm"
model = AutoModel.from_pretrained(
path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
use_flash_attn=True,
trust_remote_code=True).eval().cuda()
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
image_path = "./FRAME02_ORI.jpg"
image = Image.open(image_path)
width, height = image.size
from projects.llava_sam2.datasets.coco_category import COCO_CATEGORIES
coco_category_names = ""
for item in COCO_CATEGORIES:
class_name = item['name']
coco_category_names += f"<p>{class_name}</p> [CLS], "
coco_category_names = coco_category_names[:-2]
# question = f"<image>\nSegment from the class prompt: {coco_category_names}."
question = f"<image>\nSegment from the class prompt: <p>person</p> [CLS], <p>car</p> [CLS], <p>road</p> [CLS], <p>tree</p> [CLS], <p>building</p> [CLS], <p>ground</p> [CLS]."
m2f_processor = AutoImageProcessor.from_pretrained("./facebook/mask2former-swin-large-coco-panoptic", trust_remote_code=True,)
chat_outputs = model.predict_forward(text=question, image=image, tokenizer=tokenizer, m2f_processor=m2f_processor)
answer = chat_outputs['prediction']
masks = chat_outputs['prediction_masks']
m2f_outputs = chat_outputs['m2f_outputs']
label_id_to_text = m2f_outputs['label_id_to_text']
post_m2f_outputs = model.post_process_panoptic_segmentation(
m2f_outputs['class_queries_logits'],
m2f_outputs['masks_queries_logits'],
target_sizes=[(height, width)],
)
print(f"user: {question}")
print(f"assistant: {answer}")
segmentation = post_m2f_outputs[0]['segmentation']
segments_info = post_m2f_outputs[0]['segments_info']
pano_masks, pano_tags = [], []
for item in segments_info:
mask = segmentation == item['id']
pano_masks.append(mask.unsqueeze(0).cpu().numpy())
pano_tags.append(label_id_to_text[item['label_id']])
pano_masks = np.concatenate(pano_masks, axis=0)
visualize(image_path, pano_masks, "./visualize_test_4.jpg", pano_tags)
|