File size: 5,545 Bytes
d686824
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from typing import List
from transformers import OwlViTProcessor, OwlViTForObjectDetection
import cv2
import numpy as np
from PIL import Image
import torch
import supervision as sv
from torch.cuda.amp import autocast
class owlInterface:
    def __init__(self):
        """
        Initialize the YOLO-World model with the given configuration and checkpoint.

        Args:
        """
        
     
        pass

class OWLInterface(owlInterface):
    # def __init__(self, model_name="google/owlvit-base-patch32"):
    def __init__(self, config_path: str, checkpoint_path: None, device: str = "cuda:0"):
        self.processor, self.model = self.load_model_and_tokenizer(config_path)
        self.device = device
        self.model = self.model.to(self.device)
        self.texts = ["couch", "table", "woman"]

    def load_model_and_tokenizer(self, model_name):
        processor = OwlViTProcessor.from_pretrained(model_name)
        model = OwlViTForObjectDetection.from_pretrained(model_name)
        return processor, model

    def forward_model(self, inputs):
        with torch.no_grad():
            outputs = self.model(**inputs)
        return outputs

    def inference(self, image_path, use_amp: bool = False):
        with Image.open(image_path).convert("RGB") as image:
            width, height = image.size
        inputs = self.processor(text=self.texts, images=image, return_tensors="pt").to(self.device)

        # Run model inference
        outputs = self.forward_model(inputs)

        # Post-process outputs
        target_size = torch.tensor([[height, width]])
        results = self.processor.post_process_grounded_object_detection(
            outputs=outputs, target_sizes=target_size)[0]
        detections = sv.Detections.from_transformers(transformers_results=results)
        return detections
    
    def inference_detector(self, images, use_amp: bool = False):
        # batch_images = []
        # for i in range(4): #@Jinhui why there are hard code and why four? bug!!
        #     for j in range(4):
        #         # Extract the smaller image from the grid
        #         small_image = image[i*120:(i+1)*120, j*160:(j+1)*160]
        #         batch_images.append(small_image)
        # # batch_images = np.array(batch_images)
        # images = [Image.fromarray(np.uint8(img)).convert("RGB") for img in batch_images]
        batch_images = np.array(images)
        inputs = self.processor(text= self.texts, images=batch_images[0], return_tensors="pt").to(self.device)
        height, width = batch_images[0].shape[:2]
        detections_inbatch = []
        with torch.no_grad():
            # Run model inference 
            outputs = self.forward_model(inputs)

            target_sizes = torch.tensor([[height, width] for i in batch_images]) 
            results = self.processor.post_process_grounded_object_detection(
                outputs=outputs, target_sizes=target_sizes, threshold=0.05)
        for result in results:
            detections = sv.Detections.from_transformers(transformers_results=result)
            detections_inbatch.append(detections)

        
        check = True
        if check:
            # save first image for checking
            bounding_box_annotator = sv.BoxAnnotator()
            annotated_image = bounding_box_annotator.annotate(batch_images[0] , detections_inbatch[0])

            output_image = Image.fromarray(annotated_image[:, :, ::-1])
            output_image.save("./annotated_image.png")
            self.detections_inbatch = detections_inbatch
        return detections_inbatch

    def bbox_visualization(self, images, detections_inbatch):
        # image =  Image.open(image_path).convert("RGB")
        # output_image.save(output_path)
        # detections = self.inference(images)
        # Annotate image
        # detections = self.inference(images)
        bounding_box_annotator = sv.BoxAnnotator()
        annotated_images = []
        for image, detections in zip(images,detections_inbatch):
            annotated_image = bounding_box_annotator.annotate(image, detections)
            # output_image = Image.fromarray(annotated_image[:, :, ::-1])
            annotated_images.append(annotated_image)
            
        return annotated_images
    def reparameterize_object_list(self, target_objects: List[str], cue_objects: List[str]):
        """
        Reparameterize the detect object list to be used by the OWL model.

        Args:
            target_objects (List[str]): List of target object names.
            cue_objects (List[str]): List of cue object names.
        """
        # Combine target objects and cue objects into the final text format
        combined_texts = target_objects + cue_objects

        # Format the text prompts for the YOLO model
        self.texts = [[obj.strip()] for obj in combined_texts] + [[' ']]

        # Reparameterize the YOLO model with the provided text prompts
        # self.model.reparameterize(self.texts)
        
def main():
    model_choice = 'owl_model'
    
    image_path = "/home/anabella/projects/MLLM/TSTAR/data/score/annotated_image.png"    
    output_path = "/home/anabella/projects/MLLM/TSTAR/data/score/annotated_image3.png"

    if model_choice == 'owl_model':
        model_name="google/owlvit-base-patch32"
        owl_interface = OWLInterface(
            config_path = model_name,
            checkpoint_path=None,
            device="cuda:0"
        )
        owl_interface.bbox_visualization(image_path, output_path)

if __name__ == "__main__":
    main()