viswanani commited on
Commit
34d5cfd
·
verified ·
1 Parent(s): 3d3f811

Update drs_engine.py

Browse files
Files changed (1) hide show
  1. drs_engine.py +77 -17
drs_engine.py CHANGED
@@ -3,11 +3,72 @@ import os
3
  import math
4
  import numpy as np
5
  import tempfile
 
6
  from ultralytics import YOLO
7
  from pydub import AudioSegment
8
  import ffmpeg
 
9
  from scipy.interpolate import UnivariateSpline
10
- import matplotlib.pyplot as plt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def process_video(video_path):
12
  cap = cv2.VideoCapture(video_path)
13
  if not cap.isOpened():
@@ -20,39 +81,35 @@ def process_video(video_path):
20
  temp_out = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
21
  out = cv2.VideoWriter(temp_out, fourcc, fps, (width, height))
22
 
23
- # AUTO CALIBRATION — assume standard pitch location
24
  src = np.array([
25
- [width * 0.3, height * 0.4], # TL
26
- [width * 0.7, height * 0.4], # TR
27
- [width * 0.7, height * 0.9], # BR
28
- [width * 0.3, height * 0.9], # BL
29
  ], dtype=np.float32)
30
-
31
  dst = np.array([
32
  [0, 0],
33
  [20.12, 0],
34
  [20.12, 3.05],
35
  [0, 3.05]
36
  ], dtype=np.float32)
37
-
38
  H, _ = cv2.findHomography(src, dst)
39
-
40
  def project_point(px, py):
41
  pt = np.array([[[px, py]]], dtype=np.float32)
42
- dst_pt = cv2.perspectiveTransform(pt, H)
43
- return dst_pt[0][0]
44
 
45
- trajectory = []
46
- real_trajectory = []
47
  bounce_detected = False
 
 
48
  prev_center = None
49
  verdict = "NOT OUT"
50
 
51
- for _ in range(int(cap.get(cv2.CAP_PROP_FRAME_COUNT))):
52
  ret, frame = cap.read()
53
  if not ret:
54
  break
55
-
56
  results = model(frame)
57
  for box in results[0].boxes:
58
  if int(box.cls[0]) == 0:
@@ -64,6 +121,7 @@ def process_video(video_path):
64
  cv2.circle(frame, (cx, cy), 8, (0, 0, 255), -1)
65
  if prev_center and not bounce_detected and cy - prev_center[1] > 15:
66
  bounce_detected = True
 
67
  cv2.putText(frame, "Bounce!", (cx, cy - 20), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,255), 2)
68
  prev_center = (cx, cy)
69
  break
@@ -88,8 +146,10 @@ def process_video(video_path):
88
  (0,255,0) if verdict == "NOT OUT" else (0,0,255), 3)
89
 
90
  out.write(frame)
 
91
 
92
  cap.release()
93
  out.release()
94
- return add_voice_to_video(temp_out, verdict)
95
-
 
 
3
  import math
4
  import numpy as np
5
  import tempfile
6
+ import matplotlib.pyplot as plt
7
  from ultralytics import YOLO
8
  from pydub import AudioSegment
9
  import ffmpeg
10
+ from scipy.optimize import curve_fit
11
  from scipy.interpolate import UnivariateSpline
12
+
13
+ model = YOLO("best.pt")
14
+ FPS = 30
15
+
16
+ def estimate_speed(p1, p2, fps):
17
+ dist = math.hypot(p2[0] - p1[0], p2[1] - p1[1])
18
+ meters_per_pixel = 0.03
19
+ mps = dist * meters_per_pixel * fps
20
+ return mps * 3.6
21
+
22
+ def add_voice_to_video(video_path, verdict):
23
+ audio_file = "out.mp3" if verdict == "OUT" else "not_out.mp3"
24
+ audio_temp = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False).name
25
+ AudioSegment.from_file(audio_file).export(audio_temp, format="mp3")
26
+ final_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
27
+ (
28
+ ffmpeg
29
+ .input(video_path)
30
+ .output(final_output, audio=audio_temp, vcodec='copy', acodec='aac', strict='experimental')
31
+ .run(overwrite_output=True)
32
+ )
33
+ return final_output
34
+
35
+ def extend_trajectory_with_rotation(points, bounce_idx, final_x=20.12):
36
+ x_vals = [pt[0] for pt in points]
37
+ y_vals = [pt[1] for pt in points]
38
+ if bounce_idx is None or bounce_idx >= len(points) - 2:
39
+ return x_vals, y_vals
40
+ x_pre = x_vals[:bounce_idx]
41
+ y_pre = y_vals[:bounce_idx]
42
+ def poly2(x, a, b, c): return a*x**2 + b*x + c
43
+ try:
44
+ popt, _ = curve_fit(poly2, x_pre, y_pre)
45
+ x_post = np.linspace(x_vals[bounce_idx], final_x, 50)
46
+ curve_shift = np.linspace(0, 0.05, len(x_post))
47
+ y_post = poly2(x_post, *popt) + curve_shift
48
+ return x_vals + list(x_post), y_vals + list(y_post)
49
+ except:
50
+ return x_vals, y_vals
51
+
52
+ def draw_top_down_trajectory(points, bounce_frame_idx, output_path):
53
+ if len(points) < 4:
54
+ return None
55
+ x_vals = [pt[0] for pt in points]
56
+ y_vals = [pt[1] for pt in points]
57
+ x_ext, y_ext = extend_trajectory_with_rotation(points, bounce_frame_idx)
58
+ plt.figure(figsize=(10, 3))
59
+ plt.plot(x_ext, y_ext, 'r-', label='Predicted Trajectory')
60
+ plt.scatter(x_vals, y_vals, c='blue', s=10, label='Detected Points')
61
+ plt.axvline(x=17.68, color='gray', linestyle='--', label='Stumps (17.68m)')
62
+ plt.title("Top-Down Predicted Ball Path")
63
+ plt.xlabel("Pitch Length (m)")
64
+ plt.ylabel("Lateral Movement (m)")
65
+ plt.grid(True)
66
+ plt.legend()
67
+ image_path = output_path.replace(".mp4", "_trajectory.png")
68
+ plt.savefig(image_path)
69
+ plt.close()
70
+ return image_path
71
+
72
  def process_video(video_path):
73
  cap = cv2.VideoCapture(video_path)
74
  if not cap.isOpened():
 
81
  temp_out = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
82
  out = cv2.VideoWriter(temp_out, fourcc, fps, (width, height))
83
 
84
+ # Auto homography
85
  src = np.array([
86
+ [width * 0.3, height * 0.4],
87
+ [width * 0.7, height * 0.4],
88
+ [width * 0.7, height * 0.9],
89
+ [width * 0.3, height * 0.9],
90
  ], dtype=np.float32)
 
91
  dst = np.array([
92
  [0, 0],
93
  [20.12, 0],
94
  [20.12, 3.05],
95
  [0, 3.05]
96
  ], dtype=np.float32)
 
97
  H, _ = cv2.findHomography(src, dst)
 
98
  def project_point(px, py):
99
  pt = np.array([[[px, py]]], dtype=np.float32)
100
+ return cv2.perspectiveTransform(pt, H)[0][0]
 
101
 
102
+ trajectory, real_trajectory = [], []
 
103
  bounce_detected = False
104
+ bounce_frame_idx = None
105
+ frame_index = 0
106
  prev_center = None
107
  verdict = "NOT OUT"
108
 
109
+ while True:
110
  ret, frame = cap.read()
111
  if not ret:
112
  break
 
113
  results = model(frame)
114
  for box in results[0].boxes:
115
  if int(box.cls[0]) == 0:
 
121
  cv2.circle(frame, (cx, cy), 8, (0, 0, 255), -1)
122
  if prev_center and not bounce_detected and cy - prev_center[1] > 15:
123
  bounce_detected = True
124
+ bounce_frame_idx = frame_index
125
  cv2.putText(frame, "Bounce!", (cx, cy - 20), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,255), 2)
126
  prev_center = (cx, cy)
127
  break
 
146
  (0,255,0) if verdict == "NOT OUT" else (0,0,255), 3)
147
 
148
  out.write(frame)
149
+ frame_index += 1
150
 
151
  cap.release()
152
  out.release()
153
+ topdown_image = draw_top_down_trajectory(real_trajectory, bounce_frame_idx, temp_out)
154
+ final_video = add_voice_to_video(temp_out, verdict)
155
+ return final_video, topdown_image