Mahiruoshi commited on
Commit
0ec72e8
·
verified ·
1 Parent(s): 24016d8

Upload 47 files

Browse files
Files changed (5) hide show
  1. README.md +108 -12
  2. app.py +121 -70
  3. main.py +74 -12
  4. model.py +190 -42
  5. test.ipynb +74 -6
README.md CHANGED
@@ -1,12 +1,108 @@
1
- ---
2
- title: Mdpg4
3
- emoji: 🔥
4
- colorFrom: green
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 5.43.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Deploy in your labtop
2
+ The images with labels are now saved into results folder. Please collect them.
3
+ ```bash
4
+ # Afater cloning this branch
5
+ pip install -r requirements.txt
6
+ ```
7
+ # Inference Server
8
+ Start the server by
9
+
10
+ ```bash
11
+ python main.py
12
+ ```
13
+ Test script
14
+ ```bash
15
+ import requests
16
+
17
+ SERVER_URL = "http://localhost:7860"
18
+
19
+ image_file = "20230825_122540_jpg.rf.f0620856e7afdbd116ceffdfd512b03a.jpg"
20
+
21
+
22
+ with open(image_file, 'rb') as f:
23
+ files = {'file': f}
24
+ response = requests.post(f"{SERVER_URL}/image", files=files)
25
+
26
+ print(response.status_code)
27
+ print(response.json())
28
+
29
+ ```
30
+ Mapping Name to ID
31
+ ```bash
32
+ name_to_id = {
33
+ "NA": 'NA',
34
+ "Bullseye": 10,
35
+ "One": 11,
36
+ "Two": 12,
37
+ "Three": 13,
38
+ "Four": 14,
39
+ "Five": 15,
40
+ "Six": 16,
41
+ "Seven": 17,
42
+ "Eight": 18,
43
+ "Nine": 19,
44
+ "A": 20,
45
+ "B": 21,
46
+ "C": 22,
47
+ "D": 23,
48
+ "E": 24,
49
+ "F": 25,
50
+ "G": 26,
51
+ "H": 27,
52
+ "S": 28,
53
+ "T": 29,
54
+ "U": 30,
55
+ "V": 31,
56
+ "W": 32,
57
+ "X": 33,
58
+ "Y": 34,
59
+ "Z": 35,
60
+ "Up": 36,
61
+ "Down": 37,
62
+ "Right": 38,
63
+ "Left": 39,
64
+ "Up Arrow": 36,
65
+ "Down Arrow": 37,
66
+ "Right Arrow": 38,
67
+ "Left Arrow": 39,
68
+ "Stop": 40}
69
+ ```
70
+ # Training
71
+ ```bash
72
+ git clone https://github.com/ultralytics/yolov5 # clone repo
73
+ cd yolov5
74
+ pip install -qr requirements.txt # install dependencies
75
+ ```
76
+ Prepare dataset, pretrained model and config
77
+ ```bash
78
+ data.yaml
79
+ !cp "Week_8.pt" "best.pt"
80
+ ```
81
+ Train
82
+
83
+ # Demo Web
84
+ Now deployed In huggingface https://huggingface.co/spaces/Mahiruoshi/mdpg4
85
+ ## Test directly
86
+ ```
87
+ import requests
88
+
89
+ url = "https://mahiruoshi-mdpg4.hf.space/" # 你的 Space 地址
90
+ file_path = "20230825_122540_jpg.rf.f0620856e7afdbd116ceffdfd512b03a.jpg"
91
+
92
+ with open(file_path, "rb") as f:
93
+ files = {"file": f}
94
+ response = requests.post(url, files=files)
95
+
96
+ print("Status:", response.status_code)
97
+ try:
98
+ print("Response:", response.json())
99
+ except:
100
+ print("Response:", response.text)
101
+ ```
102
+
103
+ ```bash
104
+ # First time
105
+ python train.py --img 416 --batch 128 --epochs 150 --data E:/workspace/mdp/data.yaml --weights best.pt --cache
106
+
107
+ #python train.py --img 416 --batch 128 --epochs 150 --data E:/workspace/mdp/data.yaml --weights best.pt --cache --hyp hyp.yaml
108
+ ```
app.py CHANGED
@@ -1,70 +1,121 @@
1
- import time
2
- import os
3
- from flask import Flask, request, jsonify
4
- from flask_cors import CORS
5
-
6
- from model import *
7
-
8
- app = Flask(__name__)
9
- CORS(app)
10
- model = load_model()
11
-
12
- os.makedirs('uploads', exist_ok=True)
13
-
14
- @app.route('/', methods=['GET', 'POST'])
15
- def main_endpoint():
16
- if request.method == 'GET':
17
- return jsonify({
18
- "result": "ok",
19
- "service": "RPI Image Recognition API",
20
- "endpoints": {
21
- "GET /": "API status and documentation",
22
- "POST /": "Image prediction (upload 'file')",
23
- "GET /stitch": "Image stitching"
24
- },
25
- "model_loaded": model is not None
26
- })
27
-
28
- elif request.method == 'POST':
29
- if 'file' not in request.files:
30
- return jsonify({"error": "No file uploaded"}), 400
31
-
32
- file = request.files['file']
33
- if file.filename == '':
34
- return jsonify({"error": "No file selected"}), 400
35
-
36
- filename = file.filename
37
- file.save(os.path.join('uploads', filename))
38
-
39
- # filename format: "<timestamp>_<obstacle_id>_<signal>.jpeg"
40
- constituents = file.filename.split("_")
41
- obstacle_id = constituents[1] if len(constituents) > 1 else "unknown"
42
-
43
- ## Week 8 ##
44
- signal = constituents[2].strip(".jpg") if len(constituents) > 2 else "default"
45
- image_id = predict_image(filename, model, signal)
46
-
47
- ## Week 9 ##
48
- # We don't need to pass in the signal anymore
49
- #image_id = predict_image_week_9(filename,model)
50
-
51
- # Return the obstacle_id and image_id
52
- result = {
53
- "obstacle_id": obstacle_id,
54
- "image_id": image_id
55
- }
56
- return jsonify(result)
57
-
58
- @app.route('/stitch', methods=['GET'])
59
- def stitch():
60
- """
61
- This is the main endpoint for the stitching command. Stitches the images using two different functions, in effect creating two stitches, just for redundancy purposes
62
- """
63
- img = stitch_image()
64
- img.show()
65
- img2 = stitch_image_own()
66
- img2.show()
67
- return jsonify({"result": "ok"})
68
-
69
- if __name__ == '__main__':
70
- app.run(host='0.0.0.0', port=7860, debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import os
3
+ import uuid
4
+ import shutil
5
+ from flask import Flask, request, jsonify
6
+ from flask_cors import CORS
7
+ from model import *
8
+
9
+ app = Flask(__name__)
10
+ CORS(app)
11
+ model = load_model()
12
+ #model = None
13
+
14
+ @app.route('/status', methods=['GET'])
15
+ def status():
16
+ """
17
+ This is a health check endpoint to check if the server is running
18
+ :return: a json object with a key "result" and value "ok"
19
+ """
20
+ return jsonify({"result": "ok"})
21
+
22
+ @app.route('/image', methods=['POST'])
23
+ def image_predict():
24
+ """
25
+ This is the main endpoint for the image prediction algorithm
26
+ :return: a json object with a key "result" and value a dictionary with keys "obstacle_id" and "image_id"
27
+ """
28
+ file = request.files['file']
29
+ filename = file.filename
30
+
31
+ # Save to uploads folder first
32
+ file.save(os.path.join('uploads', filename))
33
+
34
+ # Try to parse filename format: "<timestamp>_<obstacle_id>_<signal>.jpeg"
35
+ # But be flexible with different formats
36
+ constituents = file.filename.split("_")
37
+
38
+ # Default values
39
+ obstacle_id = "unknown"
40
+ signal = "C" # Default to center
41
+
42
+ # Try to extract obstacle_id and signal if available
43
+ try:
44
+ if len(constituents) >= 2:
45
+ obstacle_id = constituents[1]
46
+ if len(constituents) >= 3:
47
+ # Remove file extension from signal
48
+ signal_part = constituents[2]
49
+ # Handle both .jpg and .png extensions
50
+ for ext in ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']:
51
+ if signal_part.endswith(ext):
52
+ signal = signal_part[:-len(ext)]
53
+ break
54
+ else:
55
+ signal = signal_part
56
+ except IndexError:
57
+ # Use default values if parsing fails
58
+ pass
59
+
60
+ ## Week 8 ##
61
+ # Check for optional preference parameter
62
+ prefer_close = request.form.get('prefer_close_objects', 'true').lower() == 'true'
63
+ detection_result = predict_image(filename, model, signal, prefer_close)
64
+
65
+ ## Week 9 ##
66
+ # We don't need to pass in the signal anymore
67
+ #detection_result = predict_image_week_9(filename,model)
68
+
69
+ # Extract image_id from detection result
70
+ image_id = detection_result["image_id"]
71
+
72
+ # Create results folder
73
+ results_folder = 'results'
74
+ if not os.path.exists(results_folder):
75
+ os.makedirs(results_folder)
76
+
77
+ # Generate UUID
78
+ unique_id = str(uuid.uuid4())
79
+
80
+ # Create new filename format: {UUID}_Label.png
81
+ new_filename = f"{unique_id}_{image_id}.png"
82
+
83
+ # Copy original image to results folder with new name
84
+ original_path = os.path.join('uploads', filename)
85
+ new_path = os.path.join(results_folder, new_filename)
86
+
87
+ try:
88
+ # Copy original file without any processing
89
+ shutil.copy2(original_path, new_path)
90
+ print(f"Original image saved to: {new_path}")
91
+ print(f"Annotated image saved to: {detection_result['marked_image_path']}")
92
+ except Exception as e:
93
+ print(f"Error saving original image: {e}")
94
+
95
+ # Return detailed detection information
96
+ result = {
97
+ "obstacle_id": obstacle_id,
98
+ "image_id": image_id,
99
+ "detection": {
100
+ "label": detection_result["label"],
101
+ "confidence": detection_result["confidence"],
102
+ "bbox_coordinates": detection_result["bbox"],
103
+ "original_image_path": new_path,
104
+ "annotated_image_path": detection_result["marked_image_path"]
105
+ }
106
+ }
107
+ return jsonify(result)
108
+
109
+ @app.route('/stitch', methods=['GET'])
110
+ def stitch():
111
+ """
112
+ This is the main endpoint for the stitching command. Stitches the images using two different functions, in effect creating two stitches, just for redundancy purposes
113
+ """
114
+ img = stitch_image()
115
+ img.show()
116
+ img2 = stitch_image_own()
117
+ img2.show()
118
+ return jsonify({"result": "ok"})
119
+
120
+ if __name__ == '__main__':
121
+ app.run(host='0.0.0.0', port=7860, debug=True)
main.py CHANGED
@@ -1,13 +1,16 @@
1
  import time
 
 
 
2
  from flask import Flask, request, jsonify
3
  from flask_cors import CORS
4
-
5
  from model import *
6
 
7
  app = Flask(__name__)
8
  CORS(app)
9
  model = load_model()
10
  #model = None
 
11
  @app.route('/status', methods=['GET'])
12
  def status():
13
  """
@@ -24,23 +27,82 @@ def image_predict():
24
  """
25
  file = request.files['file']
26
  filename = file.filename
 
 
27
  file.save(os.path.join('uploads', filename))
28
- # filename format: "<timestamp>_<obstacle_id>_<signal>.jpeg"
 
 
29
  constituents = file.filename.split("_")
30
- obstacle_id = constituents[1]
31
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  ## Week 8 ##
33
- signal = constituents[2].strip(".jpg")
34
- image_id = predict_image(filename, model, signal)
35
-
 
36
  ## Week 9 ##
37
  # We don't need to pass in the signal anymore
38
- #image_id = predict_image_week_9(filename,model)
39
-
40
- # Return the obstacle_id and image_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  result = {
42
  "obstacle_id": obstacle_id,
43
- "image_id": image_id
 
 
 
 
 
 
 
44
  }
45
  return jsonify(result)
46
 
@@ -56,4 +118,4 @@ def stitch():
56
  return jsonify({"result": "ok"})
57
 
58
  if __name__ == '__main__':
59
- app.run(host='0.0.0.0', port=5000, debug=True)
 
1
  import time
2
+ import os
3
+ import uuid
4
+ import shutil
5
  from flask import Flask, request, jsonify
6
  from flask_cors import CORS
 
7
  from model import *
8
 
9
  app = Flask(__name__)
10
  CORS(app)
11
  model = load_model()
12
  #model = None
13
+
14
  @app.route('/status', methods=['GET'])
15
  def status():
16
  """
 
27
  """
28
  file = request.files['file']
29
  filename = file.filename
30
+
31
+ # Save to uploads folder first
32
  file.save(os.path.join('uploads', filename))
33
+
34
+ # Try to parse filename format: "<timestamp>_<obstacle_id>_<signal>.jpeg"
35
+ # But be flexible with different formats
36
  constituents = file.filename.split("_")
37
+
38
+ # Default values
39
+ obstacle_id = "unknown"
40
+ signal = "C" # Default to center
41
+
42
+ # Try to extract obstacle_id and signal if available
43
+ try:
44
+ if len(constituents) >= 2:
45
+ obstacle_id = constituents[1]
46
+ if len(constituents) >= 3:
47
+ # Remove file extension from signal
48
+ signal_part = constituents[2]
49
+ # Handle both .jpg and .png extensions
50
+ for ext in ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']:
51
+ if signal_part.endswith(ext):
52
+ signal = signal_part[:-len(ext)]
53
+ break
54
+ else:
55
+ signal = signal_part
56
+ except IndexError:
57
+ # Use default values if parsing fails
58
+ pass
59
+
60
  ## Week 8 ##
61
+ # Check for optional preference parameter
62
+ prefer_close = request.form.get('prefer_close_objects', 'true').lower() == 'true'
63
+ detection_result = predict_image(filename, model, signal, prefer_close)
64
+
65
  ## Week 9 ##
66
  # We don't need to pass in the signal anymore
67
+ #detection_result = predict_image_week_9(filename,model)
68
+
69
+ # Extract image_id from detection result
70
+ image_id = detection_result["image_id"]
71
+
72
+ # Create results folder
73
+ results_folder = 'results'
74
+ if not os.path.exists(results_folder):
75
+ os.makedirs(results_folder)
76
+
77
+ # Generate UUID
78
+ unique_id = str(uuid.uuid4())
79
+
80
+ # Create new filename format: {UUID}_Label.png
81
+ new_filename = f"{unique_id}_{image_id}.png"
82
+
83
+ # Copy original image to results folder with new name
84
+ original_path = os.path.join('uploads', filename)
85
+ new_path = os.path.join(results_folder, new_filename)
86
+
87
+ try:
88
+ # Copy original file without any processing
89
+ shutil.copy2(original_path, new_path)
90
+ print(f"Original image saved to: {new_path}")
91
+ print(f"Annotated image saved to: {detection_result['marked_image_path']}")
92
+ except Exception as e:
93
+ print(f"Error saving original image: {e}")
94
+
95
+ # Return detailed detection information
96
  result = {
97
  "obstacle_id": obstacle_id,
98
+ "image_id": image_id,
99
+ "detection": {
100
+ "label": detection_result["label"],
101
+ "confidence": detection_result["confidence"],
102
+ "bbox_coordinates": detection_result["bbox"],
103
+ "original_image_path": new_path,
104
+ "annotated_image_path": detection_result["marked_image_path"]
105
+ }
106
  }
107
  return jsonify(result)
108
 
 
118
  return jsonify({"result": "ok"})
119
 
120
  if __name__ == '__main__':
121
+ app.run(host='0.0.0.0', port=5000, debug=True)
model.py CHANGED
@@ -36,7 +36,7 @@ def load_model():
36
 
37
  def draw_own_bbox(img,x1,y1,x2,y2,label,color=(36,255,12),text_color=(0,0,0)):
38
  """
39
- Draw bounding box on the image with text label and save both the raw and annotated image in the 'own_results' folder
40
 
41
  Inputs
42
  ------
@@ -58,7 +58,7 @@ def draw_own_bbox(img,x1,y1,x2,y2,label,color=(36,255,12),text_color=(0,0,0)):
58
 
59
  Returns
60
  -------
61
- None
62
 
63
  """
64
  name_to_id = {
@@ -109,9 +109,14 @@ def draw_own_bbox(img,x1,y1,x2,y2,label,color=(36,255,12),text_color=(0,0,0)):
109
  # Create a random string to be used as the suffix for the image name, just in case the same name is accidentally used
110
  rand = str(int(time.time()))
111
 
 
 
 
 
112
  # Save the raw image
113
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
114
- cv2.imwrite(f"own_results/raw_image_{label}_{rand}.jpg", img)
 
115
 
116
  # Draw the bounding box
117
  img = cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
@@ -121,12 +126,15 @@ def draw_own_bbox(img,x1,y1,x2,y2,label,color=(36,255,12),text_color=(0,0,0)):
121
  img = cv2.rectangle(img, (x1, y1 - 20), (x1 + w, y1), color, -1)
122
  img = cv2.putText(img, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, text_color, 1)
123
  # Save the annotated image
124
- cv2.imwrite(f"own_results/annotated_image_{label}_{rand}.jpg", img)
 
 
 
125
 
126
 
127
- def predict_image(image, model, signal):
128
  """
129
- Predict the image using the model and save the results in the 'runs' folder
130
 
131
  Inputs
132
  ------
@@ -135,22 +143,73 @@ def predict_image(image, model, signal):
135
  model: torch.hub.load - model to be used for prediction
136
 
137
  signal: str - signal to be used for filtering the predictions
 
 
 
138
 
139
  Returns
140
  -------
141
- str - predicted label
142
  """
143
- # Load the image
144
- img = Image.open(os.path.join('uploads', image))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  # Convert PIL image to cv2 format for better compatibility
147
  img_cv2 = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
148
 
149
- # Ensure image is in the right format and size for the model
150
- # Resize if necessary while maintaining aspect ratio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  height, width = img_cv2.shape[:2]
152
  if height != 640 or width != 640:
153
- img_cv2 = cv2.resize(img_cv2, (640, 640))
 
 
 
 
154
 
155
  # Convert back to PIL for model input and ensure it's writable
156
  img_array = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB)
@@ -166,32 +225,41 @@ def predict_image(image, model, signal):
166
  # Convert the results to a pandas dataframe and calculate the height and width of the bounding box and the area of the bounding box
167
  df_results = results.pandas().xyxy[0]
168
 
169
- # If no detections found, try with lower confidence threshold
170
- if len(df_results) == 0:
171
- print(f"No objects detected with default confidence, trying with lower threshold for image: {image}")
172
- # Set lower confidence threshold on the model
173
- original_conf = model.conf
174
- model.conf = 0.1 # Lower confidence threshold
 
 
 
175
  results = model(img)
176
- # results.save('runs') # Skip saving to avoid OpenCV error
177
  df_results = results.pandas().xyxy[0]
178
-
179
- # If still no detections, try with even lower threshold
180
- if len(df_results) == 0:
181
- model.conf = 0.01 # Even lower confidence threshold
182
- results = model(img)
183
- # results.save('runs') # Skip saving to avoid OpenCV error
184
- df_results = results.pandas().xyxy[0]
185
-
186
- # Restore original confidence threshold
187
- model.conf = original_conf
 
 
 
 
 
188
 
189
  df_results['bboxHt'] = df_results['ymax'] - df_results['ymin']
190
  df_results['bboxWt'] = df_results['xmax'] - df_results['xmin']
191
  df_results['bboxArea'] = df_results['bboxHt'] * df_results['bboxWt']
192
 
193
- # Label with largest bbox height will be last
194
- df_results = df_results.sort_values('bboxArea', ascending=False)
 
 
195
 
196
  # Filter out Bullseye
197
  pred_list = df_results
@@ -253,8 +321,43 @@ def predict_image(image, model, signal):
253
  pred_shortlist.sort(key=lambda x: x['bboxArea'])
254
  pred = pred_shortlist[-1]
255
 
256
- # Draw the bounding box on the image
257
- draw_own_bbox(np.array(img), pred['xmin'], pred['ymin'], pred['xmax'], pred['ymax'], pred['name'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
  name_to_id = {
260
  "NA": 'NA',
@@ -296,8 +399,23 @@ def predict_image(image, model, signal):
296
  }
297
  # Convert prediction to ID
298
  image_id = str(name_to_id[pred['name']])
299
- print(f"Final result: {image_id}")
300
- return image_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
  def predict_image_week_9(image, model):
303
  # Load the image
@@ -327,7 +445,9 @@ def predict_image_week_9(image, model):
327
 
328
  # Draw the bounding box on the image
329
  if not isinstance(pred,str):
330
- draw_own_bbox(np.array(img), pred['xmin'], pred['ymin'], pred['xmax'], pred['ymax'], pred['name'])
 
 
331
 
332
  # Dictionary is shorter as only two symbols, left and right are needed
333
  name_to_id = {
@@ -338,12 +458,36 @@ def predict_image_week_9(image, model):
338
  "Right Arrow": 38,
339
  "Left Arrow": 39,
340
  }
341
- # Return the image id
342
  if not isinstance(pred,str):
343
  image_id = str(name_to_id[pred['name']])
 
 
 
 
 
 
 
 
 
 
 
 
344
  else:
345
  image_id = 'NA'
346
- return image_id
 
 
 
 
 
 
 
 
 
 
 
 
347
 
348
 
349
  def stitch_image():
@@ -382,14 +526,18 @@ def stitch_image():
382
 
383
  def stitch_image_own():
384
  """
385
- Stitches the images in the folder together and saves it into own_results folder
386
 
387
- Basically similar to stitch_image() but with different folder names and slightly different drawing of bounding boxes and text
388
  """
389
- imgFolder = 'own_results'
390
  stitchedPath = os.path.join(imgFolder, f'stitched-{int(time.time())}.jpeg')
391
 
392
- imgPaths = glob.glob(os.path.join(imgFolder+"/annotated_image_*.jpg"))
 
 
 
 
393
  imgTimestamps = [imgPath.split("_")[-1][:-4] for imgPath in imgPaths]
394
 
395
  sortedByTimeStampImages = sorted(zip(imgPaths, imgTimestamps), key=lambda x: x[1])
 
36
 
37
  def draw_own_bbox(img,x1,y1,x2,y2,label,color=(36,255,12),text_color=(0,0,0)):
38
  """
39
+ Draw bounding box on the image with text label and save both the raw and annotated image in the 'results' folder
40
 
41
  Inputs
42
  ------
 
58
 
59
  Returns
60
  -------
61
+ str - path to the annotated image file
62
 
63
  """
64
  name_to_id = {
 
109
  # Create a random string to be used as the suffix for the image name, just in case the same name is accidentally used
110
  rand = str(int(time.time()))
111
 
112
+ # Create results folder if it doesn't exist
113
+ if not os.path.exists("results"):
114
+ os.makedirs("results")
115
+
116
  # Save the raw image
117
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
118
+ raw_image_path = f"results/raw_image_{label}_{rand}.jpg"
119
+ cv2.imwrite(raw_image_path, img)
120
 
121
  # Draw the bounding box
122
  img = cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
 
126
  img = cv2.rectangle(img, (x1, y1 - 20), (x1 + w, y1), color, -1)
127
  img = cv2.putText(img, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, text_color, 1)
128
  # Save the annotated image
129
+ annotated_image_path = f"results/annotated_image_{label}_{rand}.jpg"
130
+ cv2.imwrite(annotated_image_path, img)
131
+
132
+ return annotated_image_path
133
 
134
 
135
+ def predict_image(image, model, signal, prefer_close_objects=True):
136
  """
137
+ Predict the image using the model and save the results in the 'results' folder
138
 
139
  Inputs
140
  ------
 
143
  model: torch.hub.load - model to be used for prediction
144
 
145
  signal: str - signal to be used for filtering the predictions
146
+
147
+ prefer_close_objects: bool - if True, prioritize larger objects (closer),
148
+ if False, prioritize smaller objects (farther)
149
 
150
  Returns
151
  -------
152
+ dict - detection result with image_id, label, confidence, bbox, and marked_image_path
153
  """
154
+ # Load the image (supports both PNG and JPG)
155
+ img_path = os.path.join('uploads', image)
156
+ try:
157
+ img = Image.open(img_path)
158
+ # Convert to RGB if it's RGBA (PNG with transparency) or other modes
159
+ if img.mode != 'RGB':
160
+ img = img.convert('RGB')
161
+ except Exception as e:
162
+ print(f"Error loading image {image}: {e}")
163
+ # Return default result if image loading fails
164
+ return {
165
+ "image_id": "NA",
166
+ "label": "NA",
167
+ "confidence": 0.0,
168
+ "bbox": {"x1": 0.0, "y1": 0.0, "x2": 0.0, "y2": 0.0},
169
+ "marked_image_path": None
170
+ }
171
+
172
+ # Store original image dimensions for later coordinate scaling
173
+ original_width, original_height = img.size
174
 
175
  # Convert PIL image to cv2 format for better compatibility
176
  img_cv2 = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
177
 
178
+ # Resize to model input size while maintaining aspect ratio
179
+ def resize_with_aspect_ratio(image, target_size=640):
180
+ """Resize image to target size while maintaining aspect ratio using padding"""
181
+ height, width = image.shape[:2]
182
+
183
+ # Calculate scaling factor
184
+ scale = min(target_size / width, target_size / height)
185
+
186
+ # Calculate new dimensions
187
+ new_width = int(width * scale)
188
+ new_height = int(height * scale)
189
+
190
+ # Resize image
191
+ resized = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
192
+
193
+ # Create a square canvas with padding
194
+ canvas = np.zeros((target_size, target_size, 3), dtype=np.uint8)
195
+
196
+ # Calculate padding offsets to center the image
197
+ y_offset = (target_size - new_height) // 2
198
+ x_offset = (target_size - new_width) // 2
199
+
200
+ # Place the resized image on the canvas
201
+ canvas[y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized
202
+
203
+ return canvas, scale, x_offset, y_offset
204
+
205
+ # Apply proper aspect ratio preserving resize
206
  height, width = img_cv2.shape[:2]
207
  if height != 640 or width != 640:
208
+ img_cv2, scale_factor, x_offset, y_offset = resize_with_aspect_ratio(img_cv2, 640)
209
+ else:
210
+ scale_factor = 1.0
211
+ x_offset = 0
212
+ y_offset = 0
213
 
214
  # Convert back to PIL for model input and ensure it's writable
215
  img_array = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB)
 
225
  # Convert the results to a pandas dataframe and calculate the height and width of the bounding box and the area of the bounding box
226
  df_results = results.pandas().xyxy[0]
227
 
228
+ # Try progressively lower confidence thresholds to ensure we get some detection
229
+ original_conf = model.conf
230
+ confidence_thresholds = [original_conf, 0.5, 0.3, 0.1, 0.05, 0.01]
231
+
232
+ for conf_threshold in confidence_thresholds:
233
+ if len(df_results) > 0:
234
+ break
235
+ print(f"No objects detected with confidence {conf_threshold}, trying lower threshold for image: {image}")
236
+ model.conf = conf_threshold
237
  results = model(img)
 
238
  df_results = results.pandas().xyxy[0]
239
+
240
+ # If still no detections with extremely low threshold, create a default detection
241
+ if len(df_results) == 0:
242
+ print(f"No detections found even with lowest threshold. Creating default detection.")
243
+ # Create a default bounding box in the center of the image
244
+ default_detection = {
245
+ 'xmin': 160, 'ymin': 160, 'xmax': 480, 'ymax': 480,
246
+ 'confidence': 0.01, 'name': 'One' # Default to 'One' as fallback
247
+ }
248
+ # Convert to DataFrame format
249
+ import pandas as pd
250
+ df_results = pd.DataFrame([default_detection])
251
+
252
+ # Restore original confidence threshold
253
+ model.conf = original_conf
254
 
255
  df_results['bboxHt'] = df_results['ymax'] - df_results['ymin']
256
  df_results['bboxWt'] = df_results['xmax'] - df_results['xmin']
257
  df_results['bboxArea'] = df_results['bboxHt'] * df_results['bboxWt']
258
 
259
+ # Sort by area based on preference for close or far objects
260
+ # prefer_close_objects=True: larger area first (closer objects)
261
+ # prefer_close_objects=False: smaller area first (farther objects)
262
+ df_results = df_results.sort_values('bboxArea', ascending=not prefer_close_objects)
263
 
264
  # Filter out Bullseye
265
  pred_list = df_results
 
321
  pred_shortlist.sort(key=lambda x: x['bboxArea'])
322
  pred = pred_shortlist[-1]
323
 
324
+ # Convert bounding box coordinates back to original image scale
325
+ def convert_bbox_to_original(bbox, scale_factor, x_offset, y_offset, original_width, original_height):
326
+ """Convert bounding box coordinates from model input size back to original image size"""
327
+ # Remove padding offsets
328
+ x1 = bbox['xmin'] - x_offset
329
+ y1 = bbox['ymin'] - y_offset
330
+ x2 = bbox['xmax'] - x_offset
331
+ y2 = bbox['ymax'] - y_offset
332
+
333
+ # Scale back to original size
334
+ x1 = x1 / scale_factor
335
+ y1 = y1 / scale_factor
336
+ x2 = x2 / scale_factor
337
+ y2 = y2 / scale_factor
338
+
339
+ # Clamp to original image bounds
340
+ x1 = max(0, min(x1, original_width))
341
+ y1 = max(0, min(y1, original_height))
342
+ x2 = max(0, min(x2, original_width))
343
+ y2 = max(0, min(y2, original_height))
344
+
345
+ return {
346
+ 'xmin': x1, 'ymin': y1, 'xmax': x2, 'ymax': y2,
347
+ 'confidence': bbox['confidence'], 'name': bbox['name']
348
+ }
349
+
350
+ # Convert coordinates to original image scale
351
+ original_pred = convert_bbox_to_original(pred, scale_factor, x_offset, y_offset, original_width, original_height)
352
+
353
+ # Load original image for annotation (not the resized version)
354
+ original_img = Image.open(os.path.join('uploads', image))
355
+ if original_img.mode != 'RGB':
356
+ original_img = original_img.convert('RGB')
357
+
358
+ # Draw the bounding box on the original image and get the marked image path
359
+ marked_image_path = draw_own_bbox(np.array(original_img), original_pred['xmin'], original_pred['ymin'],
360
+ original_pred['xmax'], original_pred['ymax'], original_pred['name'])
361
 
362
  name_to_id = {
363
  "NA": 'NA',
 
399
  }
400
  # Convert prediction to ID
401
  image_id = str(name_to_id[pred['name']])
402
+
403
+ # Prepare detailed detection result using original image coordinates
404
+ detection_result = {
405
+ "image_id": image_id,
406
+ "label": original_pred['name'],
407
+ "confidence": float(original_pred['confidence']),
408
+ "bbox": {
409
+ "x1": float(original_pred['xmin']),
410
+ "y1": float(original_pred['ymin']),
411
+ "x2": float(original_pred['xmax']),
412
+ "y2": float(original_pred['ymax'])
413
+ },
414
+ "marked_image_path": marked_image_path
415
+ }
416
+
417
+ print(f"Final result: {image_id} with bbox coordinates")
418
+ return detection_result
419
 
420
  def predict_image_week_9(image, model):
421
  # Load the image
 
445
 
446
  # Draw the bounding box on the image
447
  if not isinstance(pred,str):
448
+ marked_image_path = draw_own_bbox(np.array(img), pred['xmin'], pred['ymin'], pred['xmax'], pred['ymax'], pred['name'])
449
+ else:
450
+ marked_image_path = None
451
 
452
  # Dictionary is shorter as only two symbols, left and right are needed
453
  name_to_id = {
 
458
  "Right Arrow": 38,
459
  "Left Arrow": 39,
460
  }
461
+ # Return the image id and detailed information
462
  if not isinstance(pred,str):
463
  image_id = str(name_to_id[pred['name']])
464
+ detection_result = {
465
+ "image_id": image_id,
466
+ "label": pred['name'],
467
+ "confidence": float(pred['confidence']),
468
+ "bbox": {
469
+ "x1": float(pred['xmin']),
470
+ "y1": float(pred['ymin']),
471
+ "x2": float(pred['xmax']),
472
+ "y2": float(pred['ymax'])
473
+ },
474
+ "marked_image_path": marked_image_path
475
+ }
476
  else:
477
  image_id = 'NA'
478
+ detection_result = {
479
+ "image_id": image_id,
480
+ "label": "NA",
481
+ "confidence": 0.0,
482
+ "bbox": {
483
+ "x1": 0.0,
484
+ "y1": 0.0,
485
+ "x2": 0.0,
486
+ "y2": 0.0
487
+ },
488
+ "marked_image_path": None
489
+ }
490
+ return detection_result
491
 
492
 
493
  def stitch_image():
 
526
 
527
  def stitch_image_own():
528
  """
529
+ Stitches the images in the folder together and saves it into results folder
530
 
531
+ Similar to stitch_image() but works with annotated images from results folder
532
  """
533
+ imgFolder = 'results'
534
  stitchedPath = os.path.join(imgFolder, f'stitched-{int(time.time())}.jpeg')
535
 
536
+ imgPaths = glob.glob(os.path.join(imgFolder, "annotated_image_*.jpg"))
537
+ if not imgPaths:
538
+ print("No annotated images found for stitching")
539
+ return None
540
+
541
  imgTimestamps = [imgPath.split("_")[-1][:-4] for imgPath in imgPaths]
542
 
543
  sortedByTimeStampImages = sorted(zip(imgPaths, imgTimestamps), key=lambda x: x[1])
test.ipynb CHANGED
@@ -50,38 +50,106 @@
50
  {
51
  "cell_type": "code",
52
  "execution_count": null,
53
- "id": "970d2d85",
54
  "metadata": {},
55
  "outputs": [
56
  {
57
  "name": "stdout",
58
  "output_type": "stream",
59
  "text": [
 
60
  "200\n",
61
- "{'image_id': '39', 'obstacle_id': '122540'}\n"
62
  ]
63
  }
64
  ],
65
  "source": [
 
66
  "import requests\n",
67
  "\n",
68
  "SERVER_URL = \"http://localhost:5000\"\n",
 
69
  "\n",
70
- "image_file = \"20230825_122540_jpg.rf.f0620856e7afdbd116ceffdfd512b03a.jpg\"\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  "\n",
73
  "with open(image_file, 'rb') as f:\n",
74
  " files = {'file': f}\n",
 
75
  " response = requests.post(f\"{SERVER_URL}/image\", files=files)\n",
76
  "\n",
 
77
  "print(response.status_code)\n",
78
- "print(response.json())\n"
79
  ]
80
  }
81
  ],
82
  "metadata": {
83
  "kernelspec": {
84
- "display_name": "bert-vits2",
85
  "language": "python",
86
  "name": "python3"
87
  },
@@ -95,7 +163,7 @@
95
  "name": "python",
96
  "nbconvert_exporter": "python",
97
  "pygments_lexer": "ipython3",
98
- "version": "3.11.7"
99
  }
100
  },
101
  "nbformat": 4,
 
50
  {
51
  "cell_type": "code",
52
  "execution_count": null,
53
+ "id": "a89ceef6",
54
  "metadata": {},
55
  "outputs": [
56
  {
57
  "name": "stdout",
58
  "output_type": "stream",
59
  "text": [
60
+ "优先近距离物体:\n",
61
  "200\n",
62
+ "{'detection': {'annotated_image_path': 'results/annotated_image_Up-36_1757948296.jpg', 'bbox_coordinates': {'x1': 545.3320312499999, 'x2': 560.7070312499999, 'y1': 15.254882812499998, 'y2': 33.75292968749999}, 'confidence': 0.01422882080078125, 'label': 'Up', 'original_image_path': 'results\\\\3483d55f-887a-4364-8d0b-6910faa6a585_36.png'}, 'image_id': '36', 'obstacle_id': 'unknown'}\n"
63
  ]
64
  }
65
  ],
66
  "source": [
67
+ "# 选项1: 优先检测较近的物体(默认行为,面积较大的物体)\n",
68
  "import requests\n",
69
  "\n",
70
  "SERVER_URL = \"http://localhost:5000\"\n",
71
+ "image_file = \"Screenshot 2025-09-15 225930.png\"\n",
72
  "\n",
73
+ "with open(image_file, 'rb') as f:\n",
74
+ " files = {'file': f}\n",
75
+ " data = {'prefer_close_objects': 'true'} # 优先近距离物体\n",
76
+ " response = requests.post(f\"{SERVER_URL}/image\", files=files, data=data)\n",
77
+ "\n",
78
+ "print(\"优先近距离物体:\")\n",
79
+ "print(response.status_code)\n",
80
+ "print(response.json())"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": 16,
86
+ "id": "21f15172",
87
+ "metadata": {},
88
+ "outputs": [
89
+ {
90
+ "name": "stdout",
91
+ "output_type": "stream",
92
+ "text": [
93
+ "优先远距离物体:\n",
94
+ "200\n",
95
+ "{'detection': {'annotated_image_path': 'results/annotated_image_Up-36_1757948327.jpg', 'bbox_coordinates': {'x1': 545.3320312499999, 'x2': 560.7070312499999, 'y1': 15.254882812499998, 'y2': 33.75292968749999}, 'confidence': 0.01422882080078125, 'label': 'Up', 'original_image_path': 'results\\\\e7dcd5cf-db24-4821-9fe6-5e16412ba51c_36.png'}, 'image_id': '36', 'obstacle_id': 'unknown'}\n"
96
+ ]
97
+ }
98
+ ],
99
+ "source": [
100
+ "# 选项2: 优先检测较远的物体(面积较小的物体)\n",
101
+ "import requests\n",
102
+ "\n",
103
+ "SERVER_URL = \"http://localhost:5000\"\n",
104
+ "image_file = \"b.png\"\n",
105
  "\n",
106
+ "with open(image_file, 'rb') as f:\n",
107
+ " files = {'file': f}\n",
108
+ " data = {'prefer_close_objects': 'false'} # 优先远距离物体\n",
109
+ " response = requests.post(f\"{SERVER_URL}/image\", files=files, data=data)\n",
110
+ "\n",
111
+ "print(\"优先远距离物体:\")\n",
112
+ "print(response.status_code)\n",
113
+ "print(response.json())"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": 15,
119
+ "id": "6b29a73a",
120
+ "metadata": {},
121
+ "outputs": [
122
+ {
123
+ "name": "stdout",
124
+ "output_type": "stream",
125
+ "text": [
126
+ "默认行为(优先近距离物体):\n",
127
+ "200\n",
128
+ "{'detection': {'annotated_image_path': 'results/annotated_image_Up-36_1757948317.jpg', 'bbox_coordinates': {'x1': 545.3320312499999, 'x2': 560.7070312499999, 'y1': 15.254882812499998, 'y2': 33.75292968749999}, 'confidence': 0.01422882080078125, 'label': 'Up', 'original_image_path': 'results\\\\d0387124-696d-4233-90db-fe511ed62828_36.png'}, 'image_id': '36', 'obstacle_id': 'unknown'}\n"
129
+ ]
130
+ }
131
+ ],
132
+ "source": [
133
+ "# 选项3: 不指定参数(使用默认行为,等同于 prefer_close_objects=true)\n",
134
+ "import requests\n",
135
+ "\n",
136
+ "SERVER_URL = \"http://localhost:5000\"\n",
137
+ "image_file = \"b.png\"\n",
138
  "\n",
139
  "with open(image_file, 'rb') as f:\n",
140
  " files = {'file': f}\n",
141
+ " # 不添加 data 参数,使用默认行为(优先近距离物体)\n",
142
  " response = requests.post(f\"{SERVER_URL}/image\", files=files)\n",
143
  "\n",
144
+ "print(\"默认行为(优先近距离物体):\")\n",
145
  "print(response.status_code)\n",
146
+ "print(response.json())"
147
  ]
148
  }
149
  ],
150
  "metadata": {
151
  "kernelspec": {
152
+ "display_name": "chatbot",
153
  "language": "python",
154
  "name": "python3"
155
  },
 
163
  "name": "python",
164
  "nbconvert_exporter": "python",
165
  "pygments_lexer": "ipython3",
166
+ "version": "3.8.16"
167
  }
168
  },
169
  "nbformat": 4,