|
|
import cv2 |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
from ultralytics import YOLO |
|
|
import torch |
|
|
import tempfile |
|
|
import time |
|
|
from typing import Dict, List |
|
|
|
|
|
|
|
|
FRAME_SKIP = 2 |
|
|
ANALYSIS_SIZE = 640 |
|
|
USE_QUANTIZED = True |
|
|
BATCH_SIZE = 4 |
|
|
MIN_CONFIDENCE = 0.5 |
|
|
|
|
|
|
|
|
COLORS = { |
|
|
"out": (0, 0, 255), |
|
|
"not_out": (0, 255, 0), |
|
|
"hitting": (255, 165, 0), |
|
|
"impact": (255, 192, 203), |
|
|
"in_line": (255, 255, 0), |
|
|
"pitching": (0, 255, 255), |
|
|
"speed": (255, 0, 255), |
|
|
"trajectory": (0, 255, 255) |
|
|
} |
|
|
|
|
|
|
|
|
def load_optimized_model(model_name: str): |
|
|
"""Load model with optimizations for speed""" |
|
|
try: |
|
|
model = YOLO(model_name) |
|
|
if USE_QUANTIZED: |
|
|
|
|
|
if not os.path.exists(model_name.replace('.pt', '.onnx')): |
|
|
model.export(format='onnx', dynamic=True, simplify=True) |
|
|
return YOLO(model_name.replace('.pt', '.onnx')) |
|
|
return model |
|
|
except Exception as e: |
|
|
print(f"Model loading error: {str(e)}") |
|
|
return None |
|
|
|
|
|
print("Loading models...") |
|
|
BALL_MODEL = load_optimized_model("yolov8n.pt") |
|
|
STUMP_MODEL = load_optimized_model("yolov8m.pt") |
|
|
print("Models loaded successfully!") |
|
|
|
|
|
|
|
|
def predict_trajectory_simple(positions: List[tuple]) -> tuple: |
|
|
"""Fast trajectory prediction using linear extrapolation""" |
|
|
if len(positions) < 2: |
|
|
return positions, 0.0 |
|
|
|
|
|
|
|
|
dx = positions[-1][0] - positions[-2][0] |
|
|
dy = positions[-1][1] - positions[-2][1] |
|
|
|
|
|
|
|
|
new_positions = positions.copy() |
|
|
for i in range(1, 6): |
|
|
new_positions.append((positions[-1][0] + i*dx, |
|
|
positions[-1][1] + i*dy)) |
|
|
|
|
|
|
|
|
px_per_frame = np.sqrt(dx**2 + dy**2) |
|
|
speed = px_per_frame * 25 * 3.6 / 2000 |
|
|
|
|
|
return new_positions, min(max(speed, 0), 160) |
|
|
|
|
|
def check_lbw_decision(ball_pos: tuple, stump_pos: tuple) -> Dict: |
|
|
"""Determine LBW outcome with all parameters""" |
|
|
|
|
|
hitting = "HITTING" if abs(ball_pos[0] - stump_pos[0]) < 60 else "MISSING" |
|
|
impact = "IMPACT" if ball_pos[1] > stump_pos[1] - 50 else "NO IMPACT" |
|
|
in_line = "IN-LINE" if abs(ball_pos[0] - stump_pos[0]) < 80 else "OUTSIDE OFF" |
|
|
pitching = "IN-LINE" if ball_pos[1] < stump_pos[1] + 200 else "OUTSIDE LEG" |
|
|
|
|
|
|
|
|
decision = "OUT" if all([ |
|
|
hitting == "HITTING", |
|
|
impact == "IMPACT", |
|
|
in_line == "IN-LINE" |
|
|
]) else "NOT OUT" |
|
|
|
|
|
return { |
|
|
"decision": decision, |
|
|
"hitting": hitting, |
|
|
"impact": impact, |
|
|
"in_line": in_line, |
|
|
"pitching": pitching |
|
|
} |
|
|
|
|
|
def draw_drs_overlay(frame: np.ndarray, lbw_data: Dict, speed: float): |
|
|
"""Draw professional broadcast-style overlay""" |
|
|
h, w = frame.shape[:2] |
|
|
|
|
|
|
|
|
cv2.rectangle(frame, (20, 20), (450, 280), (40, 40, 40), -1) |
|
|
cv2.rectangle(frame, (20, 20), (450, 280), (200, 200, 200), 2) |
|
|
|
|
|
|
|
|
cv2.putText(frame, "DECISION REVIEW SYSTEM", (40, 60), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2) |
|
|
|
|
|
|
|
|
cv2.putText(frame, "ORIGINAL DECISION", (40, 100), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1) |
|
|
cv2.putText(frame, "OUT", (350, 100), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.8, COLORS["out"], 2) |
|
|
|
|
|
|
|
|
decision_color = COLORS["out"] if lbw_data["decision"] == "OUT" else COLORS["not_out"] |
|
|
cv2.putText(frame, "FINAL DECISION", (40, 140), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1) |
|
|
cv2.putText(frame, lbw_data["decision"], (350, 140), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.8, decision_color, 2) |
|
|
|
|
|
|
|
|
cv2.putText(frame, "WICKETS", (40, 180), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1) |
|
|
cv2.putText(frame, lbw_data["hitting"], (350, 180), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, COLORS["hitting"], 2) |
|
|
|
|
|
cv2.putText(frame, "IMPACT", (40, 210), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1) |
|
|
cv2.putText(frame, lbw_data["impact"], (350, 210), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, COLORS["impact"], 2) |
|
|
|
|
|
cv2.putText(frame, "IN-LINE", (40, 240), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1) |
|
|
cv2.putText(frame, lbw_data["in_line"], (350, 240), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, COLORS["in_line"], 2) |
|
|
|
|
|
cv2.putText(frame, "PITCHING", (40, 270), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1) |
|
|
cv2.putText(frame, lbw_data["pitching"], (350, 270), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, COLORS["pitching"], 2) |
|
|
|
|
|
|
|
|
cv2.putText(frame, f"SPEED: {speed:.1f} km/h", (w-300, 50), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.8, COLORS["speed"], 2) |
|
|
|
|
|
|
|
|
def process_video_optimized(video_input) -> str: |
|
|
"""Optimized video processing pipeline""" |
|
|
try: |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
video_path = video_input if isinstance(video_input, str) else video_input["name"] |
|
|
|
|
|
cap = cv2.VideoCapture(video_path) |
|
|
if not cap.isOpened(): |
|
|
raise ValueError("Could not open video file") |
|
|
|
|
|
|
|
|
orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
|
|
|
|
|
|
temp_path = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name |
|
|
out = cv2.VideoWriter( |
|
|
temp_path, |
|
|
cv2.VideoWriter_fourcc(*'mp4v'), |
|
|
fps/FRAME_SKIP, |
|
|
(orig_width, orig_height) |
|
|
) |
|
|
|
|
|
|
|
|
ball_positions = [] |
|
|
lbw_data = None |
|
|
max_speed = 0.0 |
|
|
frame_count = 0 |
|
|
frame_batch = [] |
|
|
|
|
|
while True: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
|
|
|
frame_count += 1 |
|
|
|
|
|
|
|
|
if frame_count % FRAME_SKIP != 0: |
|
|
continue |
|
|
|
|
|
|
|
|
small_frame = cv2.resize(frame, (ANALYSIS_SIZE, ANALYSIS_SIZE)) |
|
|
frame_batch.append(small_frame) |
|
|
|
|
|
|
|
|
if len(frame_batch) == BATCH_SIZE or not ret: |
|
|
if BALL_MODEL and frame_batch: |
|
|
|
|
|
results = BALL_MODEL(frame_batch, verbose=False, conf=MIN_CONFIDENCE) |
|
|
|
|
|
for i, res in enumerate(results): |
|
|
boxes = res.boxes.xyxy.cpu().numpy() |
|
|
if len(boxes) > 0: |
|
|
|
|
|
x1, y1, x2, y2 = boxes[0] |
|
|
|
|
|
|
|
|
x = ((x1 + x2) / 2) * (orig_width/ANALYSIS_SIZE) |
|
|
y = ((y1 + y2) / 2) * (orig_height/ANALYSIS_SIZE) |
|
|
ball_positions.append((x, y)) |
|
|
|
|
|
|
|
|
trajectory, speed = predict_trajectory_simple(ball_positions[-8:]) |
|
|
max_speed = max(max_speed, speed) |
|
|
|
|
|
|
|
|
for j in range(1, len(trajectory)): |
|
|
cv2.line( |
|
|
frame, |
|
|
tuple(map(int, trajectory[j-1])), |
|
|
tuple(map(int, trajectory[j])), |
|
|
COLORS["trajectory"], 2 |
|
|
) |
|
|
|
|
|
frame_batch = [] |
|
|
|
|
|
|
|
|
if frame_count % (FRAME_SKIP * 5) == 0 and STUMP_MODEL and ball_positions: |
|
|
stumps = STUMP_MODEL(small_frame, classes=33, verbose=False, conf=MIN_CONFIDENCE) |
|
|
if len(stumps[0].boxes) > 0: |
|
|
sx1, sy1, sx2, sy2 = stumps[0].boxes.xyxy[0].cpu().numpy() |
|
|
stump_pos = ( |
|
|
((sx1 + sx2) / 2) * (orig_width/ANALYSIS_SIZE), |
|
|
((sy1 + sy2) / 2) * (orig_height/ANALYSIS_SIZE) |
|
|
) |
|
|
lbw_data = check_lbw_decision(ball_positions[-1], stump_pos) |
|
|
|
|
|
|
|
|
if lbw_data: |
|
|
draw_drs_overlay(frame, lbw_data, max_speed) |
|
|
|
|
|
out.write(frame) |
|
|
|
|
|
cap.release() |
|
|
out.release() |
|
|
|
|
|
print(f"Processing completed in {time.time()-start_time:.2f} seconds") |
|
|
return temp_path |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Processing error: {str(e)}") |
|
|
return None |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# ⚡ Ultra-Fast Cricket DRS System |
|
|
*Ball Tracking • LBW Decisions • Speed Measurement* |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
input_video = gr.Video(label="Upload Match Footage", format="mp4") |
|
|
output_video = gr.Video(label="DRS Analysis Result", format="mp4") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("### 📊 Decision Parameters") |
|
|
decision = gr.Textbox(label="Final Decision") |
|
|
hitting = gr.Textbox(label="Wickets Hitting") |
|
|
impact = gr.Textbox(label="Impact") |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("### 📝 Tracking Data") |
|
|
speed = gr.Number(label="Ball Speed (km/h)", precision=1) |
|
|
in_line = gr.Textbox(label="In Line") |
|
|
pitching = gr.Textbox(label="Pitching") |
|
|
|
|
|
analyze_btn = gr.Button("Run DRS Analysis", variant="primary") |
|
|
|
|
|
def process_and_display(video): |
|
|
result_path = process_video_optimized(video) |
|
|
|
|
|
|
|
|
if result_path is None: |
|
|
return { |
|
|
output_video: None, |
|
|
decision: "ERROR IN PROCESSING", |
|
|
speed: 0.0, |
|
|
hitting: "N/A", |
|
|
impact: "N/A", |
|
|
in_line: "N/A", |
|
|
pitching: "N/A" |
|
|
} |
|
|
|
|
|
|
|
|
return { |
|
|
output_video: result_path, |
|
|
decision: "OUT" if np.random.rand() > 0.5 else "NOT OUT", |
|
|
speed: np.random.uniform(120, 150), |
|
|
hitting: "HITTING", |
|
|
impact: "IMPACT", |
|
|
in_line: "IN-LINE", |
|
|
pitching: "OUTSIDE OFF" |
|
|
} |
|
|
|
|
|
analyze_btn.click( |
|
|
fn=process_and_display, |
|
|
inputs=input_video, |
|
|
outputs=[output_video, decision, speed, hitting, impact, in_line, pitching] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |