lokesh341 commited on
Commit
80dceed
·
verified ·
1 Parent(s): ea0d6a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -97
app.py CHANGED
@@ -5,49 +5,65 @@ from ultralytics import YOLO
5
  import torch
6
  from scipy.interpolate import interp1d
7
  import tempfile
8
- from typing import Dict, Tuple
 
9
 
10
- # Initialize models
11
- BALL_MODEL = YOLO('yolov8n.pt') # Auto-downloads if not present
12
- STUMP_MODEL = YOLO('yolov8m.pt')
 
 
 
 
 
 
 
 
 
13
 
14
- # Pitch constants (in pixels)
15
- PITCH_LENGTH = 2000 # From crease to crease
16
- STUMPS_WIDTH = 71 # Standard cricket stump width (9 inches)
 
 
 
 
17
 
18
  def predict_trajectory(positions: list) -> Tuple[list, float]:
19
- """Predict ball trajectory using cubic interpolation"""
20
  if len(positions) < 3:
21
  return positions, 0.0
22
 
23
  x = [p[0] for p in positions]
24
  y = [p[1] for p in positions]
25
-
26
- # Cubic spline interpolation
27
  t = np.linspace(0, 1, len(positions))
28
- fx = interp1d(t, x, kind='cubic', fill_value="extrapolate")
29
- fy = interp1d(t, y, kind='cubic', fill_value="extrapolate")
30
-
31
- # Predict next 10 frames
32
- new_t = np.linspace(0, 1.5, len(positions)+10)
33
- new_x = fx(new_t)
34
- new_y = fy(new_t)
35
 
36
- # Calculate speed (pixels/frame to km/h)
37
- dx = new_x[-1] - new_x[-2]
38
- dy = new_y[-1] - new_y[-2]
39
- speed = np.sqrt(dx**2 + dy**2) * 25 * 3.6 / PITCH_LENGTH # Convert to km/h
40
-
41
- return list(zip(new_x, new_y)), speed
 
 
 
 
 
 
 
 
42
 
43
- def check_lbw(ball_pos: tuple, stump_pos: tuple, impact: tuple) -> Dict:
44
- """LBW decision system"""
45
- # Simplified decision logic
46
- hitting = "Hitting" if abs(ball_pos[0] - stump_pos[0]) < STUMPS_WIDTH else "Missing"
47
- in_line = "In Line" if impact[0] < stump_pos[0] + STUMPS_WIDTH else "Not in Line"
48
- pitching = "In Line" if impact[1] < stump_pos[1] + 100 else "Outside"
49
-
50
- decision = "Out" if all([hitting == "Hitting", in_line == "In Line", pitching == "In Line"]) else "Not Out"
 
 
 
51
 
52
  return {
53
  "decision": decision,
@@ -57,79 +73,127 @@ def check_lbw(ball_pos: tuple, stump_pos: tuple, impact: tuple) -> Dict:
57
  "pitching": pitching
58
  }
59
 
60
- def process_video(video_path: str) -> Dict:
61
- """Main processing function"""
62
- cap = cv2.VideoCapture(video_path)
63
- fps = cap.get(cv2.CAP_PROP_FPS)
64
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
65
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
66
-
67
- # Video writer setup
68
- temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
69
- out = cv2.VideoWriter(temp_file.name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
70
-
71
- ball_positions = []
72
- lbw_data = None
73
- max_speed = 0.0
74
-
75
- while cap.isOpened():
76
- ret, frame = cap.read()
77
- if not ret:
78
- break
79
 
80
- # Ball detection
81
- ball_results = BALL_MODEL(frame, classes=32, verbose=False)
82
- boxes = ball_results[0].boxes.xyxy.cpu().numpy()
 
83
 
84
- if len(boxes) > 0:
85
- x1, y1, x2, y2 = boxes[0]
86
- x, y = (x1 + x2) // 2, (y1 + y2) // 2
87
- ball_positions.append((x, y))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- # Predict trajectory
90
- trajectory, speed = predict_trajectory(ball_positions[-10:])
91
- max_speed = max(max_speed, speed)
92
 
93
- # Draw trajectory
94
- for i in range(1, len(trajectory)):
95
- cv2.line(frame, tuple(map(int, trajectory[i-1])), tuple(map(int, trajectory[i])), (0, 255, 255), 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- # LBW check (every 5 frames)
98
- if len(ball_positions) % 5 == 0:
99
- stump_results = STUMP_MODEL(frame, classes=33, verbose=False)
100
- if len(stump_results[0].boxes) > 0:
101
- sx1, sy1, sx2, sy2 = stump_results[0].boxes.xyxy[0].cpu().numpy()
102
- stump_pos = ((sx1 + sx2) // 2, (sy1 + sy2) // 2)
103
- lbw_data = check_lbw((x, y), stump_pos, ball_positions[-1])
104
-
105
- # Draw DRS overlay
106
- if lbw_data:
107
- cv2.putText(frame, f"Final Decision: {lbw_data['decision']}", (50, 50),
108
- cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255) if lbw_data['decision'] == "Out" else (0, 255, 0), 3)
109
- cv2.putText(frame, f"Hitting: {lbw_data['hitting']}", (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
110
- cv2.putText(frame, f"Impact: {lbw_data['impact']}", (50, 140), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
111
- cv2.putText(frame, f"In Line: {lbw_data['in_line']}", (50, 180), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
112
- cv2.putText(frame, f"Pitching: {lbw_data['pitching']}", (50, 220), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
 
 
 
 
 
 
 
 
113
 
114
- # Draw speed
115
- cv2.putText(frame, f"Speed: {max_speed:.1f} km/h", (width-300, 50),
116
- cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 2)
117
 
118
- out.write(frame)
119
-
120
- cap.release()
121
- out.release()
 
 
122
 
123
- return {
124
- "output_video": temp_file.name,
125
- "lbw_decision": lbw_data["decision"] if lbw_data else "No Decision",
126
- "max_speed": max_speed,
127
- "analytics": lbw_data if lbw_data else {}
128
- }
 
129
 
130
  # Gradio Interface
131
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
132
- gr.Markdown("# 🏏 Professional DRS System")
133
 
134
  with gr.Row():
135
  input_video = gr.Video(label="Input Match Footage", format="mp4")
@@ -137,7 +201,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
137
 
138
  with gr.Row():
139
  with gr.Column():
140
- gr.Markdown("### 📊 Decision Review System")
141
  decision = gr.Textbox(label="Final Decision")
142
  hitting = gr.Textbox(label="Hitting")
143
  impact = gr.Textbox(label="Impact")
@@ -150,7 +214,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
150
 
151
  analyze_btn = gr.Button("Run DRS Analysis", variant="primary")
152
 
153
- def analyze_wrapper(video):
154
  result = process_video(video)
155
  return {
156
  output_video: result["output_video"],
@@ -163,9 +227,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
163
  }
164
 
165
  analyze_btn.click(
166
- fn=analyze_wrapper,
167
  inputs=input_video,
168
  outputs=[output_video, decision, max_speed, hitting, impact, in_line, pitching]
169
  )
170
 
171
- demo.launch()
 
 
5
  import torch
6
  from scipy.interpolate import interp1d
7
  import tempfile
8
+ import os
9
+ from typing import Dict, Tuple, Optional
10
 
11
+ # Initialize models with error handling
12
+ def safe_load_model(model_name: str):
13
+ try:
14
+ model = YOLO(model_name)
15
+ # Verify model works
16
+ dummy = model(np.zeros((640,640,3)), verbose=False)
17
+ if dummy[0].boxes is not None:
18
+ print(f"✅ {model_name} loaded successfully")
19
+ return model
20
+ except Exception as e:
21
+ print(f"❌ Error loading {model_name}: {str(e)}")
22
+ return None
23
 
24
+ BALL_MODEL = safe_load_model("yolov8n.pt") # Ball detection
25
+ STUMP_MODEL = safe_load_model("yolov8m.pt") # Stump detection
26
+
27
+ # Constants (adjust based on your video)
28
+ PITCH_LENGTH_PX = 1800 # Pixels between creases
29
+ STUMPS_WIDTH_PX = 60 # Width of stumps in pixels
30
+ FRAME_SKIP = 2 # Process every 2nd frame for speed
31
 
32
  def predict_trajectory(positions: list) -> Tuple[list, float]:
33
+ """Predict ball path with cubic spline interpolation"""
34
  if len(positions) < 3:
35
  return positions, 0.0
36
 
37
  x = [p[0] for p in positions]
38
  y = [p[1] for p in positions]
 
 
39
  t = np.linspace(0, 1, len(positions))
 
 
 
 
 
 
 
40
 
41
+ try:
42
+ fx = interp1d(t, x, kind='cubic', fill_value="extrapolate")
43
+ fy = interp1d(t, y, kind='cubic', fill_value="extrapolate")
44
+ new_t = np.linspace(0, 1.5, len(positions)+10)
45
+ new_x = fx(new_t)
46
+ new_y = fy(new_t)
47
+
48
+ # Calculate speed (km/h)
49
+ dx = new_x[-1] - new_x[-2]
50
+ dy = new_y[-1] - new_y[-2]
51
+ speed = np.sqrt(dx**2 + dy**2) * 25 * 3.6 / PITCH_LENGTH_PX
52
+ return list(zip(new_x, new_y)), speed
53
+ except:
54
+ return positions, 0.0
55
 
56
+ def check_lbw(ball_pos: tuple, stump_pos: tuple) -> Dict:
57
+ """Make LBW decision with all parameters"""
58
+ hitting = "Hitting" if abs(ball_pos[0] - stump_pos[0]) < STUMPS_WIDTH_PX else "Missing"
59
+ in_line = "In Line" if ball_pos[0] < stump_pos[0] + STUMPS_WIDTH_PX else "Not in Line"
60
+ pitching = "In Line" if ball_pos[1] < stump_pos[1] + 100 else "Outside"
61
+
62
+ decision = "Out" if all([
63
+ hitting == "Hitting",
64
+ in_line == "In Line",
65
+ pitching == "In Line"
66
+ ]) else "Not Out"
67
 
68
  return {
69
  "decision": decision,
 
73
  "pitching": pitching
74
  }
75
 
76
+ def process_video(video_input) -> Dict:
77
+ """Main processing function with full error handling"""
78
+ try:
79
+ # Handle Gradio file input
80
+ if isinstance(video_input, dict):
81
+ video_path = video_input["name"]
82
+ else:
83
+ video_path = video_input
84
+
85
+ if not os.path.exists(video_path):
86
+ raise FileNotFoundError(f"Video file not found: {video_path}")
87
+
88
+ cap = cv2.VideoCapture(video_path)
89
+ if not cap.isOpened():
90
+ raise ValueError("Could not open video file")
 
 
 
 
91
 
92
+ # Get video properties
93
+ fps = cap.get(cv2.CAP_PROP_FPS)
94
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
95
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
96
 
97
+ # Create temp output file
98
+ temp_path = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name
99
+ out = cv2.VideoWriter(
100
+ temp_path,
101
+ cv2.VideoWriter_fourcc(*'mp4v'),
102
+ fps/FRAME_SKIP, # Adjusted for frame skipping
103
+ (width, height)
104
+ )
105
+
106
+ ball_positions = []
107
+ lbw_data = None
108
+ max_speed = 0.0
109
+ frame_count = 0
110
+
111
+ while True:
112
+ ret, frame = cap.read()
113
+ if not ret:
114
+ break
115
 
116
+ frame_count += 1
117
+ if frame_count % FRAME_SKIP != 0:
118
+ continue
119
 
120
+ # Ball detection
121
+ if BALL_MODEL:
122
+ results = BALL_MODEL(frame, classes=32, verbose=False)
123
+ boxes = results[0].boxes.xyxy.cpu().numpy()
124
+
125
+ if len(boxes) > 0:
126
+ x1, y1, x2, y2 = boxes[0]
127
+ x, y = (x1 + x2) // 2, (y1 + y2) // 2
128
+ ball_positions.append((x, y))
129
+
130
+ # Predict trajectory and speed
131
+ trajectory, speed = predict_trajectory(ball_positions[-10:])
132
+ max_speed = max(max_speed, speed)
133
+
134
+ # Draw trajectory
135
+ for i in range(1, len(trajectory)):
136
+ cv2.line(
137
+ frame,
138
+ tuple(map(int, trajectory[i-1])),
139
+ tuple(map(int, trajectory[i])),
140
+ (0, 255, 255), 2
141
+ )
142
+
143
+ # LBW check (every 5 processed frames)
144
+ if len(ball_positions) % 5 == 0 and STUMP_MODEL:
145
+ stump_results = STUMP_MODEL(frame, classes=33, verbose=False)
146
+ if len(stump_results[0].boxes) > 0:
147
+ sx1, sy1, sx2, sy2 = stump_results[0].boxes.xyxy[0].cpu().numpy()
148
+ stump_pos = ((sx1 + sx2) // 2, (sy1 + sy2) // 2)
149
+ lbw_data = check_lbw((x, y), stump_pos)
150
 
151
+ # Draw DRS overlay
152
+ if lbw_data:
153
+ cv2.putText(
154
+ frame, f"Final Decision: {lbw_data['decision']}",
155
+ (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1,
156
+ (0, 0, 255) if lbw_data['decision'] == "Out" else (0, 255, 0), 3
157
+ )
158
+ cv2.putText(frame, f"Hitting: {lbw_data['hitting']}", (50, 100),
159
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
160
+ cv2.putText(frame, f"Impact: {lbw_data['impact']}", (50, 140),
161
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
162
+ cv2.putText(frame, f"In Line: {lbw_data['in_line']}", (50, 180),
163
+ cv2.FERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
164
+ cv2.putText(frame, f"Pitching: {lbw_data['pitching']}", (50, 220),
165
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
166
+
167
+ # Draw speed
168
+ cv2.putText(
169
+ frame, f"Speed: {max_speed:.1f} km/h",
170
+ (width-300, 50), cv2.FONT_HERSHEY_SIMPLEX,
171
+ 1, (255, 255, 0), 2
172
+ )
173
+
174
+ out.write(frame)
175
 
176
+ cap.release()
177
+ out.release()
 
178
 
179
+ return {
180
+ "output_video": temp_path,
181
+ "lbw_decision": lbw_data["decision"] if lbw_data else "No Decision",
182
+ "max_speed": max_speed,
183
+ "analytics": lbw_data if lbw_data else {}
184
+ }
185
 
186
+ except Exception as e:
187
+ return {
188
+ "output_video": None,
189
+ "lbw_decision": f"Error: {str(e)}",
190
+ "max_speed": 0.0,
191
+ "analytics": {}
192
+ }
193
 
194
  # Gradio Interface
195
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
196
+ gr.Markdown("# 🏏 Professional Cricket DRS System")
197
 
198
  with gr.Row():
199
  input_video = gr.Video(label="Input Match Footage", format="mp4")
 
201
 
202
  with gr.Row():
203
  with gr.Column():
204
+ gr.Markdown("### 📊 Decision Review")
205
  decision = gr.Textbox(label="Final Decision")
206
  hitting = gr.Textbox(label="Hitting")
207
  impact = gr.Textbox(label="Impact")
 
214
 
215
  analyze_btn = gr.Button("Run DRS Analysis", variant="primary")
216
 
217
+ def wrapper(video):
218
  result = process_video(video)
219
  return {
220
  output_video: result["output_video"],
 
227
  }
228
 
229
  analyze_btn.click(
230
+ fn=wrapper,
231
  inputs=input_video,
232
  outputs=[output_video, decision, max_speed, hitting, impact, in_line, pitching]
233
  )
234
 
235
+ if __name__ == "__main__":
236
+ demo.launch(debug=True)