TSTAR / TStar /interface_owl.py
ZihanWang314's picture
Upload folder using huggingface_hub
d686824 verified
from typing import List
from transformers import OwlViTProcessor, OwlViTForObjectDetection
import cv2
import numpy as np
from PIL import Image
import torch
import supervision as sv
from torch.cuda.amp import autocast
class owlInterface:
def __init__(self):
"""
Initialize the YOLO-World model with the given configuration and checkpoint.
Args:
"""
pass
class OWLInterface(owlInterface):
# def __init__(self, model_name="google/owlvit-base-patch32"):
def __init__(self, config_path: str, checkpoint_path: None, device: str = "cuda:0"):
self.processor, self.model = self.load_model_and_tokenizer(config_path)
self.device = device
self.model = self.model.to(self.device)
self.texts = ["couch", "table", "woman"]
def load_model_and_tokenizer(self, model_name):
processor = OwlViTProcessor.from_pretrained(model_name)
model = OwlViTForObjectDetection.from_pretrained(model_name)
return processor, model
def forward_model(self, inputs):
with torch.no_grad():
outputs = self.model(**inputs)
return outputs
def inference(self, image_path, use_amp: bool = False):
with Image.open(image_path).convert("RGB") as image:
width, height = image.size
inputs = self.processor(text=self.texts, images=image, return_tensors="pt").to(self.device)
# Run model inference
outputs = self.forward_model(inputs)
# Post-process outputs
target_size = torch.tensor([[height, width]])
results = self.processor.post_process_grounded_object_detection(
outputs=outputs, target_sizes=target_size)[0]
detections = sv.Detections.from_transformers(transformers_results=results)
return detections
def inference_detector(self, images, use_amp: bool = False):
# batch_images = []
# for i in range(4): #@Jinhui why there are hard code and why four? bug!!
# for j in range(4):
# # Extract the smaller image from the grid
# small_image = image[i*120:(i+1)*120, j*160:(j+1)*160]
# batch_images.append(small_image)
# # batch_images = np.array(batch_images)
# images = [Image.fromarray(np.uint8(img)).convert("RGB") for img in batch_images]
batch_images = np.array(images)
inputs = self.processor(text= self.texts, images=batch_images[0], return_tensors="pt").to(self.device)
height, width = batch_images[0].shape[:2]
detections_inbatch = []
with torch.no_grad():
# Run model inference
outputs = self.forward_model(inputs)
target_sizes = torch.tensor([[height, width] for i in batch_images])
results = self.processor.post_process_grounded_object_detection(
outputs=outputs, target_sizes=target_sizes, threshold=0.05)
for result in results:
detections = sv.Detections.from_transformers(transformers_results=result)
detections_inbatch.append(detections)
check = True
if check:
# save first image for checking
bounding_box_annotator = sv.BoxAnnotator()
annotated_image = bounding_box_annotator.annotate(batch_images[0] , detections_inbatch[0])
output_image = Image.fromarray(annotated_image[:, :, ::-1])
output_image.save("./annotated_image.png")
self.detections_inbatch = detections_inbatch
return detections_inbatch
def bbox_visualization(self, images, detections_inbatch):
# image = Image.open(image_path).convert("RGB")
# output_image.save(output_path)
# detections = self.inference(images)
# Annotate image
# detections = self.inference(images)
bounding_box_annotator = sv.BoxAnnotator()
annotated_images = []
for image, detections in zip(images,detections_inbatch):
annotated_image = bounding_box_annotator.annotate(image, detections)
# output_image = Image.fromarray(annotated_image[:, :, ::-1])
annotated_images.append(annotated_image)
return annotated_images
def reparameterize_object_list(self, target_objects: List[str], cue_objects: List[str]):
"""
Reparameterize the detect object list to be used by the OWL model.
Args:
target_objects (List[str]): List of target object names.
cue_objects (List[str]): List of cue object names.
"""
# Combine target objects and cue objects into the final text format
combined_texts = target_objects + cue_objects
# Format the text prompts for the YOLO model
self.texts = [[obj.strip()] for obj in combined_texts] + [[' ']]
# Reparameterize the YOLO model with the provided text prompts
# self.model.reparameterize(self.texts)
def main():
model_choice = 'owl_model'
image_path = "/home/anabella/projects/MLLM/TSTAR/data/score/annotated_image.png"
output_path = "/home/anabella/projects/MLLM/TSTAR/data/score/annotated_image3.png"
if model_choice == 'owl_model':
model_name="google/owlvit-base-patch32"
owl_interface = OWLInterface(
config_path = model_name,
checkpoint_path=None,
device="cuda:0"
)
owl_interface.bbox_visualization(image_path, output_path)
if __name__ == "__main__":
main()