File size: 9,163 Bytes
adddaea
689227c
adddaea
689227c
adddaea
 
689227c
 
 
 
 
 
adddaea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
689227c
 
adddaea
 
689227c
adddaea
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import cv2
import os
import time
import importlib.util
import supervision as sv
from ultralytics import YOLO

config_dir = os.path.abspath(os.path.dirname(__file__))
config_path = os.path.join(config_dir, 'PC_CONFIG.py')
spec = importlib.util.spec_from_file_location("PC_CONFIG", config_path)
PC_CONFIG = importlib.util.module_from_spec(spec)
spec.loader.exec_module(PC_CONFIG)
dir = str(os.path.join(PC_CONFIG.BASE_DIR, "weights", "best_task2.pt"))

class Predictor:
    def __init__(self):
        # Load a pre-trained yolov8n model
        print("dir:",dir)
        self.model = YOLO(dir) # replace model here
    #     self.print_class_ids()  # Print class IDs upon initialization

    # def print_class_ids(self):
    #     # Print all class names and their corresponding IDs
    #     for id, name in enumerate(self.model.names):
    #         print(f"ID: {id}, Name: {name}")

    # def predict_id(self, image_file_path, task_type):
    #     # Load the image
    #     image = cv2.imread(image_file_path)

    #     # Run inference on the image
    #     results = self.model(image)

    #     # Print results
    #     print(results)
    #     # Show annotation
    #     self.show_annotation(image, results)

    #     # Extract class name
    #     class_name, largest_size, detection_id = None, -1, None
    #     for result in results:  # Assuming 'results' is a list
    #         print(f"task_type is {task_type}")

    #         if task_type == "TASK_2":
    #             for prediction in result.predictions:
    #                 print(prediction)
    #                 class_name = prediction.class_name
    #                 detection_id = prediction.detection_id
    #                 if class_name != "Bullseye":
    #                     break
    #         else:
    #             for prediction in result.predictions:
    #                 print(prediction)
    #                 if largest_size == -1 or max(prediction.width, prediction.height) > largest_size:
    #                     largest_size = max(prediction.width, prediction.height)
    #                     class_name = prediction.class_name
    #                     detection_id = prediction.detection_id

    #     if class_name:
    #         print("class_name = " + class_name)
    #     else:
    #         print("class_name = None")

    #     return class_name, results, detection_id

    def predict_id(self, image_file_path, task_type):
        # Load the image
        image = cv2.imread(image_file_path)
        # Validation for image existence
        if image is None:
            print(f"Error: Could not read image at {image_file_path}")
            return None, None, None

        # Check the image size and resize if necessary
        if image.shape[0] != 640 or image.shape[1] != 640:
            image = cv2.resize(image, (640, 640))  # Resize to 640x640

        # Run inference on the image
        results = self.model(image)  # Directly pass the image

        # Print results
        print(results)

        # Show annotation (using YOLOv8's plotting capabilities)
        # results[0].show()

        # Extract class name, largest size, and detection ID
        class_name, largest_size, detection_id = None, -1, 0
        
        # Check if there are any detections
        if results[0].boxes is None or len(results[0].boxes) == 0:
            print("No detections found in the image")
            return class_name, results, detection_id
            
        boxes = results[0].boxes.xyxy  # Get bounding boxes (x1, y1, x2, y2)
        scores = results[0].boxes.conf  # Get confidence scores
        class_ids = results[0].boxes.cls  # Get class IDs

        # Store all detections with their priority
        detections_list = []
        
        for i in range(len(boxes)):
            detected_class = results[0].names[int(class_ids[i])]
            confidence = float(scores[i])
            yolo_class_id = int(class_ids[i])
            print(f"Processing detection {i}: {detected_class} (confidence: {confidence:.2f}, class_id: {yolo_class_id})")

            if task_type == "TASK_2":
                # Check by class name - only set Bullseye to lowest priority
                if detected_class.lower() != 'bullseye':
                    # All non-bullseye detections have equal priority (0), sorted by confidence
                    print(f"  Added to list: {detected_class} with normal priority")
                    detections_list.append({
                        'index': i,
                        'class_name': detected_class,
                        'confidence': confidence,
                        'priority': 0  # Equal priority for all non-bullseye detections
                    })
                else:
                    print(f"  Bullseye detected - adding with lowest priority")
                    detections_list.append({
                        'index': i,
                        'class_name': detected_class,
                        'confidence': confidence,
                        'priority': -10  # Lowest priority for bullseye
                    })
            else:
                # Determine the largest bounding box
                box_width = boxes[i][2] - boxes[i][0]
                box_height = boxes[i][3] - boxes[i][1]
                size = max(box_width, box_height)
                detection_id = i

                if largest_size == -1 or size > largest_size:
                    largest_size = size
                    class_name = detected_class
                    detection_id = i
        
        # For TASK_2, select detection based on priority, then confidence
        if task_type == "TASK_2" and detections_list:
            print(f"\nTotal detections found: {len(detections_list)}")
            # Sort by priority (descending), then by confidence (descending)
            detections_list.sort(key=lambda x: (x['priority'], x['confidence']), reverse=True)
            
            # Print sorted list for debugging
            print("Sorted detections:")
            for det in detections_list:
                print(f"  - {det['class_name']}: priority={det['priority']}, confidence={det['confidence']:.2f}")
            
            # Select the highest priority detection
            selected = detections_list[0]
            class_name = selected['class_name']
            detection_id = selected['index']
            print(f"\n✓ Selected detection: {class_name} (priority: {selected['priority']}, confidence: {selected['confidence']:.2f})")

        if class_name:
            print("class_name = " + class_name)
            timestamp = int(time.time())
            # Save the annotated image
            try:
                results[detection_id].save(f'../data/annotated_images/{class_name}_{timestamp}.jpg')
            except:
                print("error in saving photo!")
        else:
            print("class_name = None")

        return class_name, results, detection_id


    # def show_annotation(self, image, results):
    #     # Create supervision annotators
    #     bounding_box_annotator = sv.BoundingBoxAnnotator()
    #     label_annotator = sv.LabelAnnotator()

    #     # Process results from YOLOv8
    #     detections = []
    #     for result in results:
    #         for detection in result.boxes.data:  # Accessing YOLOv8's box data
    #             class_id = int(detection[5])  # Class ID
    #             x1, y1, x2, y2 = map(int, detection[:4])  # Bounding box coordinates
    #             score = float(detection[4])  # Confidence score

    #             # Add to detections
    #             detections.append({
    #                 "bbox": [x1, y1, x2, y2],
    #                 "confidence": score,
    #                 "class_id": class_id
    #             })

    #     # Convert detections to the expected format for supervision
    #     if detections:
    #         detections = sv.Detections(
    #             xyxy=[d["bbox"] for d in detections],
    #             confidence=[d["confidence"] for d in detections],
    #             class_id=[d["class_id"] for d in detections]
    #         )

    #         # Annotate the image with inference results
    #         annotated_image = bounding_box_annotator.annotate(scene=image, detections=detections)
    #         annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections)

    #         # Display the annotated image
    #         try:
    #             cv2.imshow("Annotated Image", annotated_image)
    #             cv2.waitKey(0)  # Wait indefinitely until a key is pressed
    #         except Exception as e:
    #             print(f"Error displaying image: {e}")
    #         finally:
    #             cv2.destroyAllWindows()  # Close all OpenCV windows
    #     else:
    #         print("No detections found.")


if __name__ == "__main__":
    # Example usage
    predictor = Predictor()
    # Specify the path to your image
    image_file_path = os.path.join(PC_CONFIG.FILE_DIRECTORY, "image-rec", "sample_images", "IMG_9325.jpg")
    # Predict and display the class name
    predictor.predict_id(image_file_path, "TASK_1")