|
|
import torch |
|
|
from PIL import Image |
|
|
import os |
|
|
import numpy as np |
|
|
|
|
|
from transformers import AutoModel, AutoTokenizer, AutoImageProcessor |
|
|
|
|
|
from types import MethodType |
|
|
from detectron2.data import MetadataCatalog |
|
|
from detectron2.utils.visualizer import ColorMode, Visualizer |
|
|
|
|
|
from transformers import AutoModel, AutoTokenizer, AutoImageProcessor |
|
|
from detectron2.data.detection_utils import read_image |
|
|
from detectron2.utils.visualizer import GenericMask |
|
|
import matplotlib.colors as mplc |
|
|
def draw_instance_predictions_cache(self, labels, np_masks, jittering: bool = True): |
|
|
""" |
|
|
Draw instance-level prediction results on an image. |
|
|
|
|
|
Args: |
|
|
predictions (Instances): the output of an instance detection/segmentation |
|
|
model. Following fields will be used to draw: |
|
|
"pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle"). |
|
|
jittering: if True, in color mode SEGMENTATION, randomly jitter the colors per class |
|
|
to distinguish instances from the same class |
|
|
|
|
|
Returns: |
|
|
output (VisImage): image object with visualizations. |
|
|
""" |
|
|
boxes = None |
|
|
scores = None |
|
|
classes = None |
|
|
keypoints = None |
|
|
|
|
|
masks = [GenericMask(x, self.output.height, self.output.width) for x in np_masks] |
|
|
|
|
|
|
|
|
if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"): |
|
|
colors = ( |
|
|
[self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes] |
|
|
if jittering |
|
|
else [ |
|
|
tuple(mplc.to_rgb([x / 255 for x in self.metadata.thing_colors[c]])) |
|
|
for c in classes |
|
|
] |
|
|
) |
|
|
|
|
|
alpha = 0.8 |
|
|
else: |
|
|
colors = None |
|
|
alpha = 0.5 |
|
|
|
|
|
self.overlay_instances( |
|
|
masks=masks, |
|
|
boxes=boxes, |
|
|
labels=labels, |
|
|
keypoints=keypoints, |
|
|
assigned_colors=colors, |
|
|
alpha=alpha, |
|
|
) |
|
|
return self.output |
|
|
|
|
|
|
|
|
def visualize(image_path, cat_masks, out_path, tags): |
|
|
if tags is None: |
|
|
left_tags = [f'{i}' for i in range(len(cat_masks))] |
|
|
else: |
|
|
left_tags = tags |
|
|
|
|
|
unique_tags = list(set(left_tags)) |
|
|
text_prompt = ','.join(unique_tags) |
|
|
metadata = MetadataCatalog.get("__unused_ape_" + text_prompt) |
|
|
metadata.thing_classes = unique_tags |
|
|
metadata.stuff_classes = unique_tags |
|
|
|
|
|
result_masks = cat_masks |
|
|
input_image = read_image(image_path, format="BGR") |
|
|
visualizer = Visualizer(input_image[:, :, ::-1], metadata, instance_mode=ColorMode.IMAGE) |
|
|
visualizer.draw_instance_predictions = MethodType(draw_instance_predictions_cache, visualizer) |
|
|
vis_output = visualizer.draw_instance_predictions(labels=left_tags, np_masks=result_masks) |
|
|
output_image = vis_output.get_image() |
|
|
output_image = Image.fromarray(output_image) |
|
|
|
|
|
output_image.save(out_path) |
|
|
|
|
|
path = "./work_dirs/hf_pano_vlm" |
|
|
model = AutoModel.from_pretrained( |
|
|
path, |
|
|
torch_dtype=torch.bfloat16, |
|
|
low_cpu_mem_usage=True, |
|
|
use_flash_attn=True, |
|
|
trust_remote_code=True).eval().cuda() |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False) |
|
|
|
|
|
image_path = "./FRAME02_ORI.jpg" |
|
|
image = Image.open(image_path) |
|
|
width, height = image.size |
|
|
|
|
|
from projects.llava_sam2.datasets.coco_category import COCO_CATEGORIES |
|
|
coco_category_names = "" |
|
|
for item in COCO_CATEGORIES: |
|
|
class_name = item['name'] |
|
|
coco_category_names += f"<p>{class_name}</p> [CLS], " |
|
|
coco_category_names = coco_category_names[:-2] |
|
|
|
|
|
question = f"<image>\nSegment from the class prompt: <p>person</p> [CLS], <p>car</p> [CLS], <p>road</p> [CLS], <p>tree</p> [CLS], <p>building</p> [CLS], <p>ground</p> [CLS]." |
|
|
|
|
|
m2f_processor = AutoImageProcessor.from_pretrained("./facebook/mask2former-swin-large-coco-panoptic", trust_remote_code=True,) |
|
|
|
|
|
chat_outputs = model.predict_forward(text=question, image=image, tokenizer=tokenizer, m2f_processor=m2f_processor) |
|
|
answer = chat_outputs['prediction'] |
|
|
masks = chat_outputs['prediction_masks'] |
|
|
|
|
|
m2f_outputs = chat_outputs['m2f_outputs'] |
|
|
|
|
|
label_id_to_text = m2f_outputs['label_id_to_text'] |
|
|
|
|
|
post_m2f_outputs = model.post_process_panoptic_segmentation( |
|
|
m2f_outputs['class_queries_logits'], |
|
|
m2f_outputs['masks_queries_logits'], |
|
|
target_sizes=[(height, width)], |
|
|
) |
|
|
|
|
|
print(f"user: {question}") |
|
|
print(f"assistant: {answer}") |
|
|
|
|
|
segmentation = post_m2f_outputs[0]['segmentation'] |
|
|
segments_info = post_m2f_outputs[0]['segments_info'] |
|
|
pano_masks, pano_tags = [], [] |
|
|
for item in segments_info: |
|
|
mask = segmentation == item['id'] |
|
|
pano_masks.append(mask.unsqueeze(0).cpu().numpy()) |
|
|
pano_tags.append(label_id_to_text[item['label_id']]) |
|
|
|
|
|
pano_masks = np.concatenate(pano_masks, axis=0) |
|
|
|
|
|
visualize(image_path, pano_masks, "./visualize_test_4.jpg", pano_tags) |
|
|
|
|
|
|
|
|
|