from typing import List from transformers import OwlViTProcessor, OwlViTForObjectDetection import cv2 import numpy as np from PIL import Image import torch import supervision as sv from torch.cuda.amp import autocast class owlInterface: def __init__(self): """ Initialize the YOLO-World model with the given configuration and checkpoint. Args: """ pass class OWLInterface(owlInterface): # def __init__(self, model_name="google/owlvit-base-patch32"): def __init__(self, config_path: str, checkpoint_path: None, device: str = "cuda:0"): self.processor, self.model = self.load_model_and_tokenizer(config_path) self.device = device self.model = self.model.to(self.device) self.texts = ["couch", "table", "woman"] def load_model_and_tokenizer(self, model_name): processor = OwlViTProcessor.from_pretrained(model_name) model = OwlViTForObjectDetection.from_pretrained(model_name) return processor, model def forward_model(self, inputs): with torch.no_grad(): outputs = self.model(**inputs) return outputs def inference(self, image_path, use_amp: bool = False): with Image.open(image_path).convert("RGB") as image: width, height = image.size inputs = self.processor(text=self.texts, images=image, return_tensors="pt").to(self.device) # Run model inference outputs = self.forward_model(inputs) # Post-process outputs target_size = torch.tensor([[height, width]]) results = self.processor.post_process_grounded_object_detection( outputs=outputs, target_sizes=target_size)[0] detections = sv.Detections.from_transformers(transformers_results=results) return detections def inference_detector(self, images, use_amp: bool = False): # batch_images = [] # for i in range(4): #@Jinhui why there are hard code and why four? bug!! # for j in range(4): # # Extract the smaller image from the grid # small_image = image[i*120:(i+1)*120, j*160:(j+1)*160] # batch_images.append(small_image) # # batch_images = np.array(batch_images) # images = [Image.fromarray(np.uint8(img)).convert("RGB") for img in batch_images] batch_images = np.array(images) inputs = self.processor(text= self.texts, images=batch_images[0], return_tensors="pt").to(self.device) height, width = batch_images[0].shape[:2] detections_inbatch = [] with torch.no_grad(): # Run model inference outputs = self.forward_model(inputs) target_sizes = torch.tensor([[height, width] for i in batch_images]) results = self.processor.post_process_grounded_object_detection( outputs=outputs, target_sizes=target_sizes, threshold=0.05) for result in results: detections = sv.Detections.from_transformers(transformers_results=result) detections_inbatch.append(detections) check = True if check: # save first image for checking bounding_box_annotator = sv.BoxAnnotator() annotated_image = bounding_box_annotator.annotate(batch_images[0] , detections_inbatch[0]) output_image = Image.fromarray(annotated_image[:, :, ::-1]) output_image.save("./annotated_image.png") self.detections_inbatch = detections_inbatch return detections_inbatch def bbox_visualization(self, images, detections_inbatch): # image = Image.open(image_path).convert("RGB") # output_image.save(output_path) # detections = self.inference(images) # Annotate image # detections = self.inference(images) bounding_box_annotator = sv.BoxAnnotator() annotated_images = [] for image, detections in zip(images,detections_inbatch): annotated_image = bounding_box_annotator.annotate(image, detections) # output_image = Image.fromarray(annotated_image[:, :, ::-1]) annotated_images.append(annotated_image) return annotated_images def reparameterize_object_list(self, target_objects: List[str], cue_objects: List[str]): """ Reparameterize the detect object list to be used by the OWL model. Args: target_objects (List[str]): List of target object names. cue_objects (List[str]): List of cue object names. """ # Combine target objects and cue objects into the final text format combined_texts = target_objects + cue_objects # Format the text prompts for the YOLO model self.texts = [[obj.strip()] for obj in combined_texts] + [[' ']] # Reparameterize the YOLO model with the provided text prompts # self.model.reparameterize(self.texts) def main(): model_choice = 'owl_model' image_path = "/home/anabella/projects/MLLM/TSTAR/data/score/annotated_image.png" output_path = "/home/anabella/projects/MLLM/TSTAR/data/score/annotated_image3.png" if model_choice == 'owl_model': model_name="google/owlvit-base-patch32" owl_interface = OWLInterface( config_path = model_name, checkpoint_path=None, device="cuda:0" ) owl_interface.bbox_visualization(image_path, output_path) if __name__ == "__main__": main()