viswanani commited on
Commit
98786f5
·
verified ·
1 Parent(s): ade3db7

Update drs_engine.py

Browse files
Files changed (1) hide show
  1. drs_engine.py +18 -78
drs_engine.py CHANGED
@@ -1,74 +1,3 @@
1
- import cv2
2
- import os
3
- import math
4
- import numpy as np
5
- import tempfile
6
- import matplotlib.pyplot as plt
7
- from ultralytics import YOLO
8
- from pydub import AudioSegment
9
- import ffmpeg
10
- from scipy.optimize import curve_fit
11
- from scipy.interpolate import UnivariateSpline
12
-
13
- model = YOLO("best.pt")
14
- FPS = 30
15
-
16
- def estimate_speed(p1, p2, fps):
17
- dist = math.hypot(p2[0] - p1[0], p2[1] - p1[1])
18
- meters_per_pixel = 0.03
19
- mps = dist * meters_per_pixel * fps
20
- return mps * 3.6
21
-
22
- def add_voice_to_video(video_path, verdict):
23
- audio_file = "out.mp3" if verdict == "OUT" else "not_out.mp3"
24
- audio_temp = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False).name
25
- AudioSegment.from_file(audio_file).export(audio_temp, format="mp3")
26
- final_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
27
- (
28
- ffmpeg
29
- .input(video_path)
30
- .output(final_output, audio=audio_temp, vcodec='copy', acodec='aac', strict='experimental')
31
- .run(overwrite_output=True)
32
- )
33
- return final_output
34
-
35
- def extend_trajectory_with_rotation(points, bounce_idx, final_x=20.12):
36
- x_vals = [pt[0] for pt in points]
37
- y_vals = [pt[1] for pt in points]
38
- if bounce_idx is None or bounce_idx >= len(points) - 2:
39
- return x_vals, y_vals
40
- x_pre = x_vals[:bounce_idx]
41
- y_pre = y_vals[:bounce_idx]
42
- def poly2(x, a, b, c): return a*x**2 + b*x + c
43
- try:
44
- popt, _ = curve_fit(poly2, x_pre, y_pre)
45
- x_post = np.linspace(x_vals[bounce_idx], final_x, 50)
46
- curve_shift = np.linspace(0, 0.05, len(x_post))
47
- y_post = poly2(x_post, *popt) + curve_shift
48
- return x_vals + list(x_post), y_vals + list(y_post)
49
- except:
50
- return x_vals, y_vals
51
-
52
- def draw_top_down_trajectory(points, bounce_frame_idx, output_path):
53
- if len(points) < 4:
54
- return None
55
- x_vals = [pt[0] for pt in points]
56
- y_vals = [pt[1] for pt in points]
57
- x_ext, y_ext = extend_trajectory_with_rotation(points, bounce_frame_idx)
58
- plt.figure(figsize=(10, 3))
59
- plt.plot(x_ext, y_ext, 'r-', label='Predicted Trajectory')
60
- plt.scatter(x_vals, y_vals, c='blue', s=10, label='Detected Points')
61
- plt.axvline(x=17.68, color='gray', linestyle='--', label='Stumps (17.68m)')
62
- plt.title("Top-Down Predicted Ball Path")
63
- plt.xlabel("Pitch Length (m)")
64
- plt.ylabel("Lateral Movement (m)")
65
- plt.grid(True)
66
- plt.legend()
67
- image_path = output_path.replace(".mp4", "_trajectory.png")
68
- plt.savefig(image_path)
69
- plt.close()
70
- return image_path
71
-
72
  def process_video(video_path):
73
  cap = cv2.VideoCapture(video_path)
74
  if not cap.isOpened():
@@ -81,7 +10,11 @@ def process_video(video_path):
81
  temp_out = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
82
  out = cv2.VideoWriter(temp_out, fourcc, fps, (width, height))
83
 
84
- # Auto homography
 
 
 
 
85
  src = np.array([
86
  [width * 0.3, height * 0.4],
87
  [width * 0.7, height * 0.4],
@@ -102,14 +35,14 @@ def process_video(video_path):
102
  trajectory, real_trajectory = [], []
103
  bounce_detected = False
104
  bounce_frame_idx = None
105
- frame_index = 0
106
  prev_center = None
107
  verdict = "NOT OUT"
108
 
109
- while True:
110
  ret, frame = cap.read()
111
  if not ret:
112
  break
 
113
  results = model(frame)
114
  for box in results[0].boxes:
115
  if int(box.cls[0]) == 0:
@@ -121,7 +54,7 @@ def process_video(video_path):
121
  cv2.circle(frame, (cx, cy), 8, (0, 0, 255), -1)
122
  if prev_center and not bounce_detected and cy - prev_center[1] > 15:
123
  bounce_detected = True
124
- bounce_frame_idx = frame_index
125
  cv2.putText(frame, "Bounce!", (cx, cy - 20), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,255), 2)
126
  prev_center = (cx, cy)
127
  break
@@ -146,11 +79,18 @@ def process_video(video_path):
146
  (0,255,0) if verdict == "NOT OUT" else (0,0,255), 3)
147
 
148
  out.write(frame)
149
- frame_index += 1
150
 
151
  cap.release()
152
  out.release()
 
 
153
  topdown_image = draw_top_down_trajectory(real_trajectory, bounce_frame_idx, temp_out)
154
- final_video = add_voice_to_video(temp_out, verdict)
155
- return final_video, topdown_image
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def process_video(video_path):
2
  cap = cv2.VideoCapture(video_path)
3
  if not cap.isOpened():
 
10
  temp_out = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
11
  out = cv2.VideoWriter(temp_out, fourcc, fps, (width, height))
12
 
13
+ # ⏱️ Trim to 4 seconds max
14
+ max_frames = int(fps * 4)
15
+ frame_count = 0
16
+
17
+ # Homography
18
  src = np.array([
19
  [width * 0.3, height * 0.4],
20
  [width * 0.7, height * 0.4],
 
35
  trajectory, real_trajectory = [], []
36
  bounce_detected = False
37
  bounce_frame_idx = None
 
38
  prev_center = None
39
  verdict = "NOT OUT"
40
 
41
+ while frame_count < max_frames:
42
  ret, frame = cap.read()
43
  if not ret:
44
  break
45
+
46
  results = model(frame)
47
  for box in results[0].boxes:
48
  if int(box.cls[0]) == 0:
 
54
  cv2.circle(frame, (cx, cy), 8, (0, 0, 255), -1)
55
  if prev_center and not bounce_detected and cy - prev_center[1] > 15:
56
  bounce_detected = True
57
+ bounce_frame_idx = frame_count
58
  cv2.putText(frame, "Bounce!", (cx, cy - 20), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,255), 2)
59
  prev_center = (cx, cy)
60
  break
 
79
  (0,255,0) if verdict == "NOT OUT" else (0,0,255), 3)
80
 
81
  out.write(frame)
82
+ frame_count += 1
83
 
84
  cap.release()
85
  out.release()
86
+
87
+ # Top-down chart
88
  topdown_image = draw_top_down_trajectory(real_trajectory, bounce_frame_idx, temp_out)
 
 
89
 
90
+ # Optional: voice-over
91
+ try:
92
+ final_video = add_voice_to_video(temp_out, verdict)
93
+ except:
94
+ final_video = temp_out
95
+
96
+ return final_video, topdown_image