Spaces:
Runtime error
Runtime error
| from typing import List | |
| from transformers import OwlViTProcessor, OwlViTForObjectDetection | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import supervision as sv | |
| from torch.cuda.amp import autocast | |
| class owlInterface: | |
| def __init__(self): | |
| """ | |
| Initialize the YOLO-World model with the given configuration and checkpoint. | |
| Args: | |
| """ | |
| pass | |
| class OWLInterface(owlInterface): | |
| # def __init__(self, model_name="google/owlvit-base-patch32"): | |
| def __init__(self, config_path: str, checkpoint_path: None, device: str = "cuda:0"): | |
| self.processor, self.model = self.load_model_and_tokenizer(config_path) | |
| self.device = device | |
| self.model = self.model.to(self.device) | |
| self.texts = ["couch", "table", "woman"] | |
| def load_model_and_tokenizer(self, model_name): | |
| processor = OwlViTProcessor.from_pretrained(model_name) | |
| model = OwlViTForObjectDetection.from_pretrained(model_name) | |
| return processor, model | |
| def forward_model(self, inputs): | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| return outputs | |
| def inference(self, image_path, use_amp: bool = False): | |
| with Image.open(image_path).convert("RGB") as image: | |
| width, height = image.size | |
| inputs = self.processor(text=self.texts, images=image, return_tensors="pt").to(self.device) | |
| # Run model inference | |
| outputs = self.forward_model(inputs) | |
| # Post-process outputs | |
| target_size = torch.tensor([[height, width]]) | |
| results = self.processor.post_process_grounded_object_detection( | |
| outputs=outputs, target_sizes=target_size)[0] | |
| detections = sv.Detections.from_transformers(transformers_results=results) | |
| return detections | |
| def inference_detector(self, images, use_amp: bool = False): | |
| # batch_images = [] | |
| # for i in range(4): #@Jinhui why there are hard code and why four? bug!! | |
| # for j in range(4): | |
| # # Extract the smaller image from the grid | |
| # small_image = image[i*120:(i+1)*120, j*160:(j+1)*160] | |
| # batch_images.append(small_image) | |
| # # batch_images = np.array(batch_images) | |
| # images = [Image.fromarray(np.uint8(img)).convert("RGB") for img in batch_images] | |
| batch_images = np.array(images) | |
| inputs = self.processor(text= self.texts, images=batch_images[0], return_tensors="pt").to(self.device) | |
| height, width = batch_images[0].shape[:2] | |
| detections_inbatch = [] | |
| with torch.no_grad(): | |
| # Run model inference | |
| outputs = self.forward_model(inputs) | |
| target_sizes = torch.tensor([[height, width] for i in batch_images]) | |
| results = self.processor.post_process_grounded_object_detection( | |
| outputs=outputs, target_sizes=target_sizes, threshold=0.05) | |
| for result in results: | |
| detections = sv.Detections.from_transformers(transformers_results=result) | |
| detections_inbatch.append(detections) | |
| check = True | |
| if check: | |
| # save first image for checking | |
| bounding_box_annotator = sv.BoxAnnotator() | |
| annotated_image = bounding_box_annotator.annotate(batch_images[0] , detections_inbatch[0]) | |
| output_image = Image.fromarray(annotated_image[:, :, ::-1]) | |
| output_image.save("./annotated_image.png") | |
| self.detections_inbatch = detections_inbatch | |
| return detections_inbatch | |
| def bbox_visualization(self, images, detections_inbatch): | |
| # image = Image.open(image_path).convert("RGB") | |
| # output_image.save(output_path) | |
| # detections = self.inference(images) | |
| # Annotate image | |
| # detections = self.inference(images) | |
| bounding_box_annotator = sv.BoxAnnotator() | |
| annotated_images = [] | |
| for image, detections in zip(images,detections_inbatch): | |
| annotated_image = bounding_box_annotator.annotate(image, detections) | |
| # output_image = Image.fromarray(annotated_image[:, :, ::-1]) | |
| annotated_images.append(annotated_image) | |
| return annotated_images | |
| def reparameterize_object_list(self, target_objects: List[str], cue_objects: List[str]): | |
| """ | |
| Reparameterize the detect object list to be used by the OWL model. | |
| Args: | |
| target_objects (List[str]): List of target object names. | |
| cue_objects (List[str]): List of cue object names. | |
| """ | |
| # Combine target objects and cue objects into the final text format | |
| combined_texts = target_objects + cue_objects | |
| # Format the text prompts for the YOLO model | |
| self.texts = [[obj.strip()] for obj in combined_texts] + [[' ']] | |
| # Reparameterize the YOLO model with the provided text prompts | |
| # self.model.reparameterize(self.texts) | |
| def main(): | |
| model_choice = 'owl_model' | |
| image_path = "/home/anabella/projects/MLLM/TSTAR/data/score/annotated_image.png" | |
| output_path = "/home/anabella/projects/MLLM/TSTAR/data/score/annotated_image3.png" | |
| if model_choice == 'owl_model': | |
| model_name="google/owlvit-base-patch32" | |
| owl_interface = OWLInterface( | |
| config_path = model_name, | |
| checkpoint_path=None, | |
| device="cuda:0" | |
| ) | |
| owl_interface.bbox_visualization(image_path, output_path) | |
| if __name__ == "__main__": | |
| main() |