File size: 4,831 Bytes
4ee9c8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
from PIL import Image
import os
import numpy as np

from transformers import AutoModel, AutoTokenizer, AutoImageProcessor

from types import MethodType
from detectron2.data import MetadataCatalog
from detectron2.utils.visualizer import ColorMode, Visualizer

from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
from detectron2.data.detection_utils import read_image
from detectron2.utils.visualizer import GenericMask
import matplotlib.colors as mplc
def draw_instance_predictions_cache(self, labels, np_masks, jittering: bool = True):
    """
    Draw instance-level prediction results on an image.

    Args:
        predictions (Instances): the output of an instance detection/segmentation
            model. Following fields will be used to draw:
            "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").
        jittering: if True, in color mode SEGMENTATION, randomly jitter the colors per class
            to distinguish instances from the same class

    Returns:
        output (VisImage): image object with visualizations.
    """
    boxes = None
    scores = None
    classes = None
    keypoints = None

    masks = [GenericMask(x, self.output.height, self.output.width) for x in np_masks]


    if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"):
        colors = (
            [self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes]
            if jittering
            else [
                tuple(mplc.to_rgb([x / 255 for x in self.metadata.thing_colors[c]]))
                for c in classes
            ]
        )

        alpha = 0.8
    else:
        colors = None
        alpha = 0.5

    self.overlay_instances(
        masks=masks,
        boxes=boxes,
        labels=labels,
        keypoints=keypoints,
        assigned_colors=colors,
        alpha=alpha,
    )
    return self.output


def visualize(image_path, cat_masks, out_path, tags):
    if tags is None:
        left_tags = [f'{i}' for i in range(len(cat_masks))]
    else:
        left_tags = tags

    unique_tags = list(set(left_tags))
    text_prompt = ','.join(unique_tags)
    metadata = MetadataCatalog.get("__unused_ape_" + text_prompt)
    metadata.thing_classes = unique_tags
    metadata.stuff_classes = unique_tags

    result_masks = cat_masks
    input_image = read_image(image_path, format="BGR")
    visualizer = Visualizer(input_image[:, :, ::-1], metadata, instance_mode=ColorMode.IMAGE)
    visualizer.draw_instance_predictions = MethodType(draw_instance_predictions_cache, visualizer)
    vis_output = visualizer.draw_instance_predictions(labels=left_tags, np_masks=result_masks)
    output_image = vis_output.get_image()
    output_image = Image.fromarray(output_image)

    output_image.save(out_path)

path = "./work_dirs/hf_pano_vlm"
model = AutoModel.from_pretrained(
    path,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    use_flash_attn=True,
    trust_remote_code=True).eval().cuda()

tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)

image_path = "./FRAME02_ORI.jpg"
image = Image.open(image_path)
width, height = image.size

from projects.llava_sam2.datasets.coco_category import COCO_CATEGORIES
coco_category_names = ""
for item in COCO_CATEGORIES:
    class_name = item['name']
    coco_category_names += f"<p>{class_name}</p> [CLS], "
coco_category_names = coco_category_names[:-2]
# question = f"<image>\nSegment from the class prompt: {coco_category_names}."
question = f"<image>\nSegment from the class prompt: <p>person</p> [CLS], <p>car</p> [CLS], <p>road</p> [CLS], <p>tree</p> [CLS], <p>building</p> [CLS], <p>ground</p> [CLS]."

m2f_processor = AutoImageProcessor.from_pretrained("./facebook/mask2former-swin-large-coco-panoptic", trust_remote_code=True,)

chat_outputs = model.predict_forward(text=question, image=image, tokenizer=tokenizer, m2f_processor=m2f_processor)
answer = chat_outputs['prediction']
masks = chat_outputs['prediction_masks']

m2f_outputs = chat_outputs['m2f_outputs']

label_id_to_text = m2f_outputs['label_id_to_text']

post_m2f_outputs = model.post_process_panoptic_segmentation(
    m2f_outputs['class_queries_logits'],
    m2f_outputs['masks_queries_logits'],
    target_sizes=[(height, width)],
)

print(f"user: {question}")
print(f"assistant: {answer}")

segmentation = post_m2f_outputs[0]['segmentation']
segments_info = post_m2f_outputs[0]['segments_info']
pano_masks, pano_tags = [], []
for item in segments_info:
    mask = segmentation == item['id']
    pano_masks.append(mask.unsqueeze(0).cpu().numpy())
    pano_tags.append(label_id_to_text[item['label_id']])

pano_masks = np.concatenate(pano_masks, axis=0)

visualize(image_path, pano_masks, "./visualize_test_4.jpg", pano_tags)