lokesh341 commited on
Commit
ea0d6a2
Β·
verified Β·
1 Parent(s): d7d180f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -137
app.py CHANGED
@@ -3,177 +3,169 @@ import numpy as np
3
  import gradio as gr
4
  from ultralytics import YOLO
5
  import torch
6
- import os
7
  import tempfile
8
- from typing import Dict, List
9
 
10
- # Initialize models safely
11
- def load_model(model_name: str):
12
- try:
13
- model = YOLO(model_name)
14
- # Test model with dummy data
15
- dummy_result = model(np.zeros((640, 640, 3)), verbose=False)
16
- if dummy_result[0].boxes is not None:
17
- print(f"βœ… {model_name} loaded successfully!")
18
- return model
19
- raise RuntimeError("Model test failed")
20
- except Exception as e:
21
- print(f"❌ Model load error: {str(e)}")
22
- return None
23
 
24
- BALL_MODEL = load_model("yolov8n.pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def process_video(video_path: str) -> Dict:
27
- """Robust video processing with full error handling"""
28
- try:
29
- # Handle Gradio's temporary file path
30
- if isinstance(video_path, dict):
31
- video_path = video_path["name"]
32
-
33
- # Verify file exists
34
- if not os.path.exists(video_path):
35
- raise FileNotFoundError(f"Video file not found: {video_path}")
36
-
37
- cap = cv2.VideoCapture(video_path)
38
- if not cap.isOpened():
39
- raise ValueError("Could not open video file")
40
-
41
- # Get video properties
42
- fps = cap.get(cv2.CAP_PROP_FPS)
43
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
44
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
45
-
46
- # Prepare output
47
- output_frames = []
48
- analytics = {
49
- "max_speed": 0.0,
50
- "events": [],
51
- "status": "Processing...",
52
- "fps": fps,
53
- "resolution": f"{width}x{height}"
54
- }
55
 
56
- prev_pos = None
57
- frame_count = 0
 
58
 
59
- while True:
60
- ret, frame = cap.read()
61
- if not ret:
62
- break
63
 
64
- frame_count += 1
 
 
65
 
66
- # Resize for consistent processing (optional)
67
- frame = cv2.resize(frame, (1280, 720))
 
68
 
69
- # Ball detection
70
- if BALL_MODEL:
71
- results = BALL_MODEL(frame, classes=32, verbose=False)
72
- boxes = results[0].boxes.xyxy.cpu().numpy()
73
-
74
- if len(boxes) > 0:
75
- x1, y1, x2, y2 = map(int, boxes[0])
76
- x, y = (x1 + x2) // 2, (y1 + y2) // 2
77
-
78
- # Speed calculation
79
- if prev_pos:
80
- px_per_meter = 100 # Calibration needed
81
- speed = np.sqrt((x - prev_pos[0])**2 + (y - prev_pos[1])**2) * fps * 3.6 / px_per_meter
82
- analytics["max_speed"] = max(analytics["max_speed"], speed)
83
-
84
- # Visualize
85
- cv2.circle(frame, (x, y), 10, (0, 255, 0), -1)
86
- cv2.putText(frame, f"{speed:.1f} km/h", (x+15, y),
87
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 2)
88
-
89
- prev_pos = (x, y)
90
-
91
- # Convert frame to RGB for Gradio
92
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
93
- output_frames.append(frame_rgb)
94
-
95
- # Update status every 10 frames
96
- if frame_count % 10 == 0:
97
- analytics["status"] = f"Processed {frame_count} frames"
98
 
99
- cap.release()
 
 
 
 
 
 
 
100
 
101
- # Handle empty output
102
- if not output_frames:
103
- raise ValueError("No frames processed - check video format")
104
 
105
- # Save output as temporary video file
106
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
107
- out = cv2.VideoWriter(tmp.name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (1280, 720))
108
- for frame in output_frames:
109
- out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
110
- out.release()
111
-
112
- analytics["status"] = "βœ… Processing complete"
113
- return {
114
- "output_video": tmp.name,
115
- "analytics": analytics
116
- }
117
 
118
- except Exception as e:
119
- return {
120
- "output_video": None,
121
- "analytics": {
122
- "status": f"❌ Error: {str(e)}",
123
- "max_speed": 0.0,
124
- "events": [],
125
- "fps": 0,
126
- "resolution": "0x0"
127
- }
128
- }
129
 
130
  # Gradio Interface
131
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
132
- gr.Markdown("""
133
- # 🏏 Professional Cricket Tracker
134
- *Ball Tracking β€’ Speed Analysis β€’ Event Detection*
135
- """)
136
 
137
  with gr.Row():
138
- with gr.Column():
139
- input_video = gr.Video(label="Input Match Footage", format="mp4")
140
- gr.Examples(
141
- examples=["sample.mp4"],
142
- inputs=input_video,
143
- label="Try Sample Video"
144
- )
145
-
146
- with gr.Column():
147
- output_video = gr.Video(label="Tracking Results", format="mp4")
148
 
149
  with gr.Row():
150
  with gr.Column():
151
- gr.Markdown("### πŸ“Š Match Analytics")
152
- max_speed = gr.Number(label="Max Ball Speed (km/h)", precision=1)
153
- resolution = gr.Textbox(label="Video Resolution")
 
154
 
155
  with gr.Column():
156
- gr.Markdown("### πŸ“ Processing Info")
157
- status = gr.Textbox(label="Status")
158
- fps = gr.Number(label="Video FPS")
 
159
 
160
- analyze_btn = gr.Button("Analyze Video", variant="primary")
161
 
162
  def analyze_wrapper(video):
163
  result = process_video(video)
164
  return {
165
  output_video: result["output_video"],
166
- max_speed: result["analytics"]["max_speed"],
167
- resolution: result["analytics"]["resolution"],
168
- status: result["analytics"]["status"],
169
- fps: result["analytics"]["fps"]
 
 
170
  }
171
 
172
  analyze_btn.click(
173
  fn=analyze_wrapper,
174
  inputs=input_video,
175
- outputs=[output_video, max_speed, resolution, status, fps]
176
  )
177
 
178
- if __name__ == "__main__":
179
- demo.launch(debug=True)
 
3
  import gradio as gr
4
  from ultralytics import YOLO
5
  import torch
6
+ from scipy.interpolate import interp1d
7
  import tempfile
8
+ from typing import Dict, Tuple
9
 
10
+ # Initialize models
11
+ BALL_MODEL = YOLO('yolov8n.pt') # Auto-downloads if not present
12
+ STUMP_MODEL = YOLO('yolov8m.pt')
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # Pitch constants (in pixels)
15
+ PITCH_LENGTH = 2000 # From crease to crease
16
+ STUMPS_WIDTH = 71 # Standard cricket stump width (9 inches)
17
+
18
+ def predict_trajectory(positions: list) -> Tuple[list, float]:
19
+ """Predict ball trajectory using cubic interpolation"""
20
+ if len(positions) < 3:
21
+ return positions, 0.0
22
+
23
+ x = [p[0] for p in positions]
24
+ y = [p[1] for p in positions]
25
+
26
+ # Cubic spline interpolation
27
+ t = np.linspace(0, 1, len(positions))
28
+ fx = interp1d(t, x, kind='cubic', fill_value="extrapolate")
29
+ fy = interp1d(t, y, kind='cubic', fill_value="extrapolate")
30
+
31
+ # Predict next 10 frames
32
+ new_t = np.linspace(0, 1.5, len(positions)+10)
33
+ new_x = fx(new_t)
34
+ new_y = fy(new_t)
35
+
36
+ # Calculate speed (pixels/frame to km/h)
37
+ dx = new_x[-1] - new_x[-2]
38
+ dy = new_y[-1] - new_y[-2]
39
+ speed = np.sqrt(dx**2 + dy**2) * 25 * 3.6 / PITCH_LENGTH # Convert to km/h
40
+
41
+ return list(zip(new_x, new_y)), speed
42
+
43
+ def check_lbw(ball_pos: tuple, stump_pos: tuple, impact: tuple) -> Dict:
44
+ """LBW decision system"""
45
+ # Simplified decision logic
46
+ hitting = "Hitting" if abs(ball_pos[0] - stump_pos[0]) < STUMPS_WIDTH else "Missing"
47
+ in_line = "In Line" if impact[0] < stump_pos[0] + STUMPS_WIDTH else "Not in Line"
48
+ pitching = "In Line" if impact[1] < stump_pos[1] + 100 else "Outside"
49
+
50
+ decision = "Out" if all([hitting == "Hitting", in_line == "In Line", pitching == "In Line"]) else "Not Out"
51
+
52
+ return {
53
+ "decision": decision,
54
+ "hitting": hitting,
55
+ "impact": "Impact" if decision == "Out" else "No Impact",
56
+ "in_line": in_line,
57
+ "pitching": pitching
58
+ }
59
 
60
  def process_video(video_path: str) -> Dict:
61
+ """Main processing function"""
62
+ cap = cv2.VideoCapture(video_path)
63
+ fps = cap.get(cv2.CAP_PROP_FPS)
64
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
65
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
66
+
67
+ # Video writer setup
68
+ temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
69
+ out = cv2.VideoWriter(temp_file.name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
70
+
71
+ ball_positions = []
72
+ lbw_data = None
73
+ max_speed = 0.0
74
+
75
+ while cap.isOpened():
76
+ ret, frame = cap.read()
77
+ if not ret:
78
+ break
 
 
 
 
 
 
 
 
 
 
79
 
80
+ # Ball detection
81
+ ball_results = BALL_MODEL(frame, classes=32, verbose=False)
82
+ boxes = ball_results[0].boxes.xyxy.cpu().numpy()
83
 
84
+ if len(boxes) > 0:
85
+ x1, y1, x2, y2 = boxes[0]
86
+ x, y = (x1 + x2) // 2, (y1 + y2) // 2
87
+ ball_positions.append((x, y))
88
 
89
+ # Predict trajectory
90
+ trajectory, speed = predict_trajectory(ball_positions[-10:])
91
+ max_speed = max(max_speed, speed)
92
 
93
+ # Draw trajectory
94
+ for i in range(1, len(trajectory)):
95
+ cv2.line(frame, tuple(map(int, trajectory[i-1])), tuple(map(int, trajectory[i])), (0, 255, 255), 2)
96
 
97
+ # LBW check (every 5 frames)
98
+ if len(ball_positions) % 5 == 0:
99
+ stump_results = STUMP_MODEL(frame, classes=33, verbose=False)
100
+ if len(stump_results[0].boxes) > 0:
101
+ sx1, sy1, sx2, sy2 = stump_results[0].boxes.xyxy[0].cpu().numpy()
102
+ stump_pos = ((sx1 + sx2) // 2, (sy1 + sy2) // 2)
103
+ lbw_data = check_lbw((x, y), stump_pos, ball_positions[-1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ # Draw DRS overlay
106
+ if lbw_data:
107
+ cv2.putText(frame, f"Final Decision: {lbw_data['decision']}", (50, 50),
108
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255) if lbw_data['decision'] == "Out" else (0, 255, 0), 3)
109
+ cv2.putText(frame, f"Hitting: {lbw_data['hitting']}", (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
110
+ cv2.putText(frame, f"Impact: {lbw_data['impact']}", (50, 140), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
111
+ cv2.putText(frame, f"In Line: {lbw_data['in_line']}", (50, 180), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
112
+ cv2.putText(frame, f"Pitching: {lbw_data['pitching']}", (50, 220), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
113
 
114
+ # Draw speed
115
+ cv2.putText(frame, f"Speed: {max_speed:.1f} km/h", (width-300, 50),
116
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 2)
117
 
118
+ out.write(frame)
 
 
 
 
 
 
 
 
 
 
 
119
 
120
+ cap.release()
121
+ out.release()
122
+
123
+ return {
124
+ "output_video": temp_file.name,
125
+ "lbw_decision": lbw_data["decision"] if lbw_data else "No Decision",
126
+ "max_speed": max_speed,
127
+ "analytics": lbw_data if lbw_data else {}
128
+ }
 
 
129
 
130
  # Gradio Interface
131
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
132
+ gr.Markdown("# 🏏 Professional DRS System")
 
 
 
133
 
134
  with gr.Row():
135
+ input_video = gr.Video(label="Input Match Footage", format="mp4")
136
+ output_video = gr.Video(label="DRS Analysis", format="mp4")
 
 
 
 
 
 
 
 
137
 
138
  with gr.Row():
139
  with gr.Column():
140
+ gr.Markdown("### πŸ“Š Decision Review System")
141
+ decision = gr.Textbox(label="Final Decision")
142
+ hitting = gr.Textbox(label="Hitting")
143
+ impact = gr.Textbox(label="Impact")
144
 
145
  with gr.Column():
146
+ gr.Markdown("### πŸ“ Ball Tracking")
147
+ max_speed = gr.Number(label="Max Speed (km/h)", precision=1)
148
+ in_line = gr.Textbox(label="In Line")
149
+ pitching = gr.Textbox(label="Pitching")
150
 
151
+ analyze_btn = gr.Button("Run DRS Analysis", variant="primary")
152
 
153
  def analyze_wrapper(video):
154
  result = process_video(video)
155
  return {
156
  output_video: result["output_video"],
157
+ decision: result["lbw_decision"],
158
+ max_speed: result["max_speed"],
159
+ hitting: result["analytics"].get("hitting", "N/A"),
160
+ impact: result["analytics"].get("impact", "N/A"),
161
+ in_line: result["analytics"].get("in_line", "N/A"),
162
+ pitching: result["analytics"].get("pitching", "N/A")
163
  }
164
 
165
  analyze_btn.click(
166
  fn=analyze_wrapper,
167
  inputs=input_video,
168
+ outputs=[output_video, decision, max_speed, hitting, impact, in_line, pitching]
169
  )
170
 
171
+ demo.launch()