Mahiruoshi commited on
Commit
adddaea
·
verified ·
1 Parent(s): 6647b08

Update http_server_ir_task2.py

Browse files
Files changed (1) hide show
  1. http_server_ir_task2.py +209 -80
http_server_ir_task2.py CHANGED
@@ -1,91 +1,220 @@
1
- from flask import Flask, request, jsonify
2
  import os
 
3
  import importlib.util
4
- from datetime import datetime
5
- from predict_task2 import Predictor
6
- from id_mapping import mapping
7
- #from show_annotation import start_annotation_process
8
- #from multiprocessing import Process, Queue
9
- #import cv2
10
- from show_stitched import *
11
-
12
- app = Flask(__name__)
13
 
14
  config_dir = os.path.abspath(os.path.dirname(__file__))
15
  config_path = os.path.join(config_dir, 'PC_CONFIG.py')
16
  spec = importlib.util.spec_from_file_location("PC_CONFIG", config_path)
17
  PC_CONFIG = importlib.util.module_from_spec(spec)
18
  spec.loader.exec_module(PC_CONFIG)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- HOST = PC_CONFIG.HOST
21
- PORT = PC_CONFIG.IMAGE_REC_PORT
22
- UPLOAD_FOLDER = os.path.join(PC_CONFIG.FILE_DIRECTORY,"image-rec","images")
23
- app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
24
 
25
- def process_file(file_path, direction, task_type):
 
26
  predictor = Predictor()
27
- print("File received and saved successfully.")
28
- print(f"Direction received: {direction}")
29
- print(f"Task type received: {task_type}")
30
-
31
- startTime = datetime.now()
32
- class_name, results, detection_id = predictor.predict_id(file_path, task_type) # Perform prediction
33
- #show_annotation_queue.put((file_path, results, detection_id))
34
- class_id = str(mapping.get(class_name, -1))
35
- endTime = datetime.now()
36
- totalTime = (endTime - startTime).total_seconds()
37
- print(f"Predicted ID: {class_id}")
38
- print(f"Time taken for Predicting Image = {totalTime} s")
39
- return class_id
40
-
41
- @app.route('/status', methods=['GET'])
42
- def server_status():
43
- return jsonify({'status': 'OK'})
44
-
45
- @app.route('/upload', methods=['POST'])
46
- def upload_file():
47
- if 'file' not in request.files:
48
- return jsonify({'error': 'No file part'}), 400
49
- file = request.files['file']
50
- direction = request.form['direction']
51
- task_type = request.form['task_type']
52
- if file.filename == '':
53
- return jsonify({'error': 'No selected file'}), 400
54
- if file:
55
- filename = os.path.basename(file.filename)
56
- file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
57
- file.save(file_path)
58
- # image = cv2.imread(file_path)
59
- # cv2.imshow("Uploaded Image", image)
60
- # cv2.waitKey(0) # Wait until a key is pressed
61
- # cv2.destroyAllWindows() # Close the window
62
-
63
- # Process the file and predict
64
- class_id = process_file(file_path, direction, task_type)
65
- return jsonify({'message': 'File successfully uploaded', 'predicted_id': class_id}), 200
66
-
67
- @app.route('/display_stitched', methods=['POST'])
68
- def display_stitched():
69
- showAnnotatedStitched()
70
- return jsonify({'display_stitched': 'OK'})
71
-
72
- if __name__ == '__main__':
73
- # show_annotation_queue = Queue()
74
- # process = Process(target=start_annotation_process, args=(show_annotation_queue,))
75
- # process.start()
76
-
77
- print()
78
- print(f"UPLOAD FOLDER: {UPLOAD_FOLDER}")
79
- # Port 5000 if free
80
- '''
81
- try:
82
- app.run(host=HOST, port=PORT, debug=False)
83
- except:
84
- print('Unable to Connect to PC_CONFIG Host and Port. Switching to 0.0.0.0:4000.')
85
- app.run(host='0.0.0.0', port=4000, debug=True)
86
- '''
87
-
88
- # Run on Port 4000
89
- app.run(host='0.0.0.0', port=4000, debug=True)
90
-
91
- #process.join()
 
1
+ import cv2
2
  import os
3
+ import time
4
  import importlib.util
5
+ import supervision as sv
6
+ from ultralytics import YOLO
 
 
 
 
 
 
 
7
 
8
  config_dir = os.path.abspath(os.path.dirname(__file__))
9
  config_path = os.path.join(config_dir, 'PC_CONFIG.py')
10
  spec = importlib.util.spec_from_file_location("PC_CONFIG", config_path)
11
  PC_CONFIG = importlib.util.module_from_spec(spec)
12
  spec.loader.exec_module(PC_CONFIG)
13
+ dir = str(os.path.join(PC_CONFIG.BASE_DIR, "weights", "best_task2.pt"))
14
+
15
+ class Predictor:
16
+ def __init__(self):
17
+ # Load a pre-trained yolov8n model
18
+ print("dir:",dir)
19
+ self.model = YOLO(dir) # replace model here
20
+ # self.print_class_ids() # Print class IDs upon initialization
21
+
22
+ # def print_class_ids(self):
23
+ # # Print all class names and their corresponding IDs
24
+ # for id, name in enumerate(self.model.names):
25
+ # print(f"ID: {id}, Name: {name}")
26
+
27
+ # def predict_id(self, image_file_path, task_type):
28
+ # # Load the image
29
+ # image = cv2.imread(image_file_path)
30
+
31
+ # # Run inference on the image
32
+ # results = self.model(image)
33
+
34
+ # # Print results
35
+ # print(results)
36
+ # # Show annotation
37
+ # self.show_annotation(image, results)
38
+
39
+ # # Extract class name
40
+ # class_name, largest_size, detection_id = None, -1, None
41
+ # for result in results: # Assuming 'results' is a list
42
+ # print(f"task_type is {task_type}")
43
+
44
+ # if task_type == "TASK_2":
45
+ # for prediction in result.predictions:
46
+ # print(prediction)
47
+ # class_name = prediction.class_name
48
+ # detection_id = prediction.detection_id
49
+ # if class_name != "Bullseye":
50
+ # break
51
+ # else:
52
+ # for prediction in result.predictions:
53
+ # print(prediction)
54
+ # if largest_size == -1 or max(prediction.width, prediction.height) > largest_size:
55
+ # largest_size = max(prediction.width, prediction.height)
56
+ # class_name = prediction.class_name
57
+ # detection_id = prediction.detection_id
58
+
59
+ # if class_name:
60
+ # print("class_name = " + class_name)
61
+ # else:
62
+ # print("class_name = None")
63
+
64
+ # return class_name, results, detection_id
65
+
66
+ def predict_id(self, image_file_path, task_type):
67
+ # Load the image
68
+ image = cv2.imread(image_file_path)
69
+ # Validation for image existence
70
+ if image is None:
71
+ print(f"Error: Could not read image at {image_file_path}")
72
+ return None, None, None
73
+
74
+ # Check the image size and resize if necessary
75
+ if image.shape[0] != 640 or image.shape[1] != 640:
76
+ image = cv2.resize(image, (640, 640)) # Resize to 640x640
77
+
78
+ # Run inference on the image
79
+ results = self.model(image) # Directly pass the image
80
+
81
+ # Print results
82
+ print(results)
83
+
84
+ # Show annotation (using YOLOv8's plotting capabilities)
85
+ # results[0].show()
86
+
87
+ # Extract class name, largest size, and detection ID
88
+ class_name, largest_size, detection_id = None, -1, 0
89
+
90
+ # Check if there are any detections
91
+ if results[0].boxes is None or len(results[0].boxes) == 0:
92
+ print("No detections found in the image")
93
+ return class_name, results, detection_id
94
+
95
+ boxes = results[0].boxes.xyxy # Get bounding boxes (x1, y1, x2, y2)
96
+ scores = results[0].boxes.conf # Get confidence scores
97
+ class_ids = results[0].boxes.cls # Get class IDs
98
+
99
+ # Store all detections with their priority
100
+ detections_list = []
101
+
102
+ for i in range(len(boxes)):
103
+ detected_class = results[0].names[int(class_ids[i])]
104
+ confidence = float(scores[i])
105
+ yolo_class_id = int(class_ids[i])
106
+ print(f"Processing detection {i}: {detected_class} (confidence: {confidence:.2f}, class_id: {yolo_class_id})")
107
+
108
+ if task_type == "TASK_2":
109
+ # Check by class name - only set Bullseye to lowest priority
110
+ if detected_class.lower() != 'bullseye':
111
+ # All non-bullseye detections have equal priority (0), sorted by confidence
112
+ print(f" Added to list: {detected_class} with normal priority")
113
+ detections_list.append({
114
+ 'index': i,
115
+ 'class_name': detected_class,
116
+ 'confidence': confidence,
117
+ 'priority': 0 # Equal priority for all non-bullseye detections
118
+ })
119
+ else:
120
+ print(f" Bullseye detected - adding with lowest priority")
121
+ detections_list.append({
122
+ 'index': i,
123
+ 'class_name': detected_class,
124
+ 'confidence': confidence,
125
+ 'priority': -10 # Lowest priority for bullseye
126
+ })
127
+ else:
128
+ # Determine the largest bounding box
129
+ box_width = boxes[i][2] - boxes[i][0]
130
+ box_height = boxes[i][3] - boxes[i][1]
131
+ size = max(box_width, box_height)
132
+ detection_id = i
133
+
134
+ if largest_size == -1 or size > largest_size:
135
+ largest_size = size
136
+ class_name = detected_class
137
+ detection_id = i
138
+
139
+ # For TASK_2, select detection based on priority, then confidence
140
+ if task_type == "TASK_2" and detections_list:
141
+ print(f"\nTotal detections found: {len(detections_list)}")
142
+ # Sort by priority (descending), then by confidence (descending)
143
+ detections_list.sort(key=lambda x: (x['priority'], x['confidence']), reverse=True)
144
+
145
+ # Print sorted list for debugging
146
+ print("Sorted detections:")
147
+ for det in detections_list:
148
+ print(f" - {det['class_name']}: priority={det['priority']}, confidence={det['confidence']:.2f}")
149
+
150
+ # Select the highest priority detection
151
+ selected = detections_list[0]
152
+ class_name = selected['class_name']
153
+ detection_id = selected['index']
154
+ print(f"\n✓ Selected detection: {class_name} (priority: {selected['priority']}, confidence: {selected['confidence']:.2f})")
155
+
156
+ if class_name:
157
+ print("class_name = " + class_name)
158
+ timestamp = int(time.time())
159
+ # Save the annotated image
160
+ try:
161
+ results[detection_id].save(f'../data/annotated_images/{class_name}_{timestamp}.jpg')
162
+ except:
163
+ print("error in saving photo!")
164
+ else:
165
+ print("class_name = None")
166
+
167
+ return class_name, results, detection_id
168
+
169
+
170
+ # def show_annotation(self, image, results):
171
+ # # Create supervision annotators
172
+ # bounding_box_annotator = sv.BoundingBoxAnnotator()
173
+ # label_annotator = sv.LabelAnnotator()
174
+
175
+ # # Process results from YOLOv8
176
+ # detections = []
177
+ # for result in results:
178
+ # for detection in result.boxes.data: # Accessing YOLOv8's box data
179
+ # class_id = int(detection[5]) # Class ID
180
+ # x1, y1, x2, y2 = map(int, detection[:4]) # Bounding box coordinates
181
+ # score = float(detection[4]) # Confidence score
182
+
183
+ # # Add to detections
184
+ # detections.append({
185
+ # "bbox": [x1, y1, x2, y2],
186
+ # "confidence": score,
187
+ # "class_id": class_id
188
+ # })
189
+
190
+ # # Convert detections to the expected format for supervision
191
+ # if detections:
192
+ # detections = sv.Detections(
193
+ # xyxy=[d["bbox"] for d in detections],
194
+ # confidence=[d["confidence"] for d in detections],
195
+ # class_id=[d["class_id"] for d in detections]
196
+ # )
197
+
198
+ # # Annotate the image with inference results
199
+ # annotated_image = bounding_box_annotator.annotate(scene=image, detections=detections)
200
+ # annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections)
201
+
202
+ # # Display the annotated image
203
+ # try:
204
+ # cv2.imshow("Annotated Image", annotated_image)
205
+ # cv2.waitKey(0) # Wait indefinitely until a key is pressed
206
+ # except Exception as e:
207
+ # print(f"Error displaying image: {e}")
208
+ # finally:
209
+ # cv2.destroyAllWindows() # Close all OpenCV windows
210
+ # else:
211
+ # print("No detections found.")
212
 
 
 
 
 
213
 
214
+ if __name__ == "__main__":
215
+ # Example usage
216
  predictor = Predictor()
217
+ # Specify the path to your image
218
+ image_file_path = os.path.join(PC_CONFIG.FILE_DIRECTORY, "image-rec", "sample_images", "IMG_9325.jpg")
219
+ # Predict and display the class name
220
+ predictor.predict_id(image_file_path, "TASK_1")