PanoVLM_dev0 / chat_with_sa2va.py
zhouyik's picture
Upload folder using huggingface_hub
4ee9c8f verified
import torch
from PIL import Image
import os
import numpy as np
from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
from types import MethodType
from detectron2.data import MetadataCatalog
from detectron2.utils.visualizer import ColorMode, Visualizer
from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
from detectron2.data.detection_utils import read_image
from detectron2.utils.visualizer import GenericMask
import matplotlib.colors as mplc
def draw_instance_predictions_cache(self, labels, np_masks, jittering: bool = True):
"""
Draw instance-level prediction results on an image.
Args:
predictions (Instances): the output of an instance detection/segmentation
model. Following fields will be used to draw:
"pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").
jittering: if True, in color mode SEGMENTATION, randomly jitter the colors per class
to distinguish instances from the same class
Returns:
output (VisImage): image object with visualizations.
"""
boxes = None
scores = None
classes = None
keypoints = None
masks = [GenericMask(x, self.output.height, self.output.width) for x in np_masks]
if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"):
colors = (
[self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes]
if jittering
else [
tuple(mplc.to_rgb([x / 255 for x in self.metadata.thing_colors[c]]))
for c in classes
]
)
alpha = 0.8
else:
colors = None
alpha = 0.5
self.overlay_instances(
masks=masks,
boxes=boxes,
labels=labels,
keypoints=keypoints,
assigned_colors=colors,
alpha=alpha,
)
return self.output
def visualize(image_path, cat_masks, out_path, tags):
if tags is None:
left_tags = [f'{i}' for i in range(len(cat_masks))]
else:
left_tags = tags
unique_tags = list(set(left_tags))
text_prompt = ','.join(unique_tags)
metadata = MetadataCatalog.get("__unused_ape_" + text_prompt)
metadata.thing_classes = unique_tags
metadata.stuff_classes = unique_tags
result_masks = cat_masks
input_image = read_image(image_path, format="BGR")
visualizer = Visualizer(input_image[:, :, ::-1], metadata, instance_mode=ColorMode.IMAGE)
visualizer.draw_instance_predictions = MethodType(draw_instance_predictions_cache, visualizer)
vis_output = visualizer.draw_instance_predictions(labels=left_tags, np_masks=result_masks)
output_image = vis_output.get_image()
output_image = Image.fromarray(output_image)
output_image.save(out_path)
path = "./work_dirs/hf_pano_vlm"
model = AutoModel.from_pretrained(
path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
use_flash_attn=True,
trust_remote_code=True).eval().cuda()
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
image_path = "./FRAME02_ORI.jpg"
image = Image.open(image_path)
width, height = image.size
from projects.llava_sam2.datasets.coco_category import COCO_CATEGORIES
coco_category_names = ""
for item in COCO_CATEGORIES:
class_name = item['name']
coco_category_names += f"<p>{class_name}</p> [CLS], "
coco_category_names = coco_category_names[:-2]
# question = f"<image>\nSegment from the class prompt: {coco_category_names}."
question = f"<image>\nSegment from the class prompt: <p>person</p> [CLS], <p>car</p> [CLS], <p>road</p> [CLS], <p>tree</p> [CLS], <p>building</p> [CLS], <p>ground</p> [CLS]."
m2f_processor = AutoImageProcessor.from_pretrained("./facebook/mask2former-swin-large-coco-panoptic", trust_remote_code=True,)
chat_outputs = model.predict_forward(text=question, image=image, tokenizer=tokenizer, m2f_processor=m2f_processor)
answer = chat_outputs['prediction']
masks = chat_outputs['prediction_masks']
m2f_outputs = chat_outputs['m2f_outputs']
label_id_to_text = m2f_outputs['label_id_to_text']
post_m2f_outputs = model.post_process_panoptic_segmentation(
m2f_outputs['class_queries_logits'],
m2f_outputs['masks_queries_logits'],
target_sizes=[(height, width)],
)
print(f"user: {question}")
print(f"assistant: {answer}")
segmentation = post_m2f_outputs[0]['segmentation']
segments_info = post_m2f_outputs[0]['segments_info']
pano_masks, pano_tags = [], []
for item in segments_info:
mask = segmentation == item['id']
pano_masks.append(mask.unsqueeze(0).cpu().numpy())
pano_tags.append(label_id_to_text[item['label_id']])
pano_masks = np.concatenate(pano_masks, axis=0)
visualize(image_path, pano_masks, "./visualize_test_4.jpg", pano_tags)