loki / app.py
lokesh341's picture
Update app.py
e801cd6 verified
import cv2
import numpy as np
import gradio as gr
from ultralytics import YOLO
import torch
import tempfile
import time
from typing import Dict, List
# ===== Configuration =====
FRAME_SKIP = 2 # Process every 2nd frame (3 for faster but less smooth)
ANALYSIS_SIZE = 640 # Resolution for processing (higher = more accurate but slower)
USE_QUANTIZED = True # Use optimized model format
BATCH_SIZE = 4 # Number of frames to process simultaneously
MIN_CONFIDENCE = 0.5 # Detection confidence threshold
# Color codes for DRS elements
COLORS = {
"out": (0, 0, 255), # Red
"not_out": (0, 255, 0), # Green
"hitting": (255, 165, 0), # Orange
"impact": (255, 192, 203), # Pink
"in_line": (255, 255, 0), # Yellow
"pitching": (0, 255, 255), # Cyan
"speed": (255, 0, 255), # Magenta
"trajectory": (0, 255, 255) # Light blue
}
# ===== Model Initialization =====
def load_optimized_model(model_name: str):
"""Load model with optimizations for speed"""
try:
model = YOLO(model_name)
if USE_QUANTIZED:
# Create optimized model if it doesn't exist
if not os.path.exists(model_name.replace('.pt', '.onnx')):
model.export(format='onnx', dynamic=True, simplify=True)
return YOLO(model_name.replace('.pt', '.onnx'))
return model
except Exception as e:
print(f"Model loading error: {str(e)}")
return None
print("Loading models...")
BALL_MODEL = load_optimized_model("yolov8n.pt") # Ball detection
STUMP_MODEL = load_optimized_model("yolov8m.pt") # Stump detection
print("Models loaded successfully!")
# ===== Core Functions =====
def predict_trajectory_simple(positions: List[tuple]) -> tuple:
"""Fast trajectory prediction using linear extrapolation"""
if len(positions) < 2:
return positions, 0.0
# Calculate movement vector
dx = positions[-1][0] - positions[-2][0]
dy = positions[-1][1] - positions[-2][1]
# Predict next 5 positions
new_positions = positions.copy()
for i in range(1, 6):
new_positions.append((positions[-1][0] + i*dx,
positions[-1][1] + i*dy))
# Calculate speed (km/h)
px_per_frame = np.sqrt(dx**2 + dy**2)
speed = px_per_frame * 25 * 3.6 / 2000 # Calibrated conversion
return new_positions, min(max(speed, 0), 160) # Clamp 0-160 km/h
def check_lbw_decision(ball_pos: tuple, stump_pos: tuple) -> Dict:
"""Determine LBW outcome with all parameters"""
# Decision parameters (in pixels)
hitting = "HITTING" if abs(ball_pos[0] - stump_pos[0]) < 60 else "MISSING"
impact = "IMPACT" if ball_pos[1] > stump_pos[1] - 50 else "NO IMPACT"
in_line = "IN-LINE" if abs(ball_pos[0] - stump_pos[0]) < 80 else "OUTSIDE OFF"
pitching = "IN-LINE" if ball_pos[1] < stump_pos[1] + 200 else "OUTSIDE LEG"
# Final decision
decision = "OUT" if all([
hitting == "HITTING",
impact == "IMPACT",
in_line == "IN-LINE"
]) else "NOT OUT"
return {
"decision": decision,
"hitting": hitting,
"impact": impact,
"in_line": in_line,
"pitching": pitching
}
def draw_drs_overlay(frame: np.ndarray, lbw_data: Dict, speed: float):
"""Draw professional broadcast-style overlay"""
h, w = frame.shape[:2]
# Main DRS panel
cv2.rectangle(frame, (20, 20), (450, 280), (40, 40, 40), -1)
cv2.rectangle(frame, (20, 20), (450, 280), (200, 200, 200), 2)
# Title
cv2.putText(frame, "DECISION REVIEW SYSTEM", (40, 60),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
# Original decision (static for demo)
cv2.putText(frame, "ORIGINAL DECISION", (40, 100),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
cv2.putText(frame, "OUT", (350, 100),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, COLORS["out"], 2)
# Final decision
decision_color = COLORS["out"] if lbw_data["decision"] == "OUT" else COLORS["not_out"]
cv2.putText(frame, "FINAL DECISION", (40, 140),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
cv2.putText(frame, lbw_data["decision"], (350, 140),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, decision_color, 2)
# Decision parameters
cv2.putText(frame, "WICKETS", (40, 180),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
cv2.putText(frame, lbw_data["hitting"], (350, 180),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, COLORS["hitting"], 2)
cv2.putText(frame, "IMPACT", (40, 210),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
cv2.putText(frame, lbw_data["impact"], (350, 210),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, COLORS["impact"], 2)
cv2.putText(frame, "IN-LINE", (40, 240),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
cv2.putText(frame, lbw_data["in_line"], (350, 240),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, COLORS["in_line"], 2)
cv2.putText(frame, "PITCHING", (40, 270),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
cv2.putText(frame, lbw_data["pitching"], (350, 270),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, COLORS["pitching"], 2)
# Speed display
cv2.putText(frame, f"SPEED: {speed:.1f} km/h", (w-300, 50),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, COLORS["speed"], 2)
# ===== Main Processing =====
def process_video_optimized(video_input) -> str:
"""Optimized video processing pipeline"""
try:
start_time = time.time()
# Handle Gradio input
video_path = video_input if isinstance(video_input, str) else video_input["name"]
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError("Could not open video file")
# Get video properties
orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
# Create temp output file
temp_path = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name
out = cv2.VideoWriter(
temp_path,
cv2.VideoWriter_fourcc(*'mp4v'),
fps/FRAME_SKIP, # Adjusted framerate
(orig_width, orig_height)
)
# Tracking variables
ball_positions = []
lbw_data = None
max_speed = 0.0
frame_count = 0
frame_batch = []
while True:
ret, frame = cap.read()
if not ret:
break
frame_count += 1
# Skip frames according to FRAME_SKIP
if frame_count % FRAME_SKIP != 0:
continue
# Resize for processing
small_frame = cv2.resize(frame, (ANALYSIS_SIZE, ANALYSIS_SIZE))
frame_batch.append(small_frame)
# Process in batches for efficiency
if len(frame_batch) == BATCH_SIZE or not ret:
if BALL_MODEL and frame_batch:
# Batch process frames
results = BALL_MODEL(frame_batch, verbose=False, conf=MIN_CONFIDENCE)
for i, res in enumerate(results):
boxes = res.boxes.xyxy.cpu().numpy()
if len(boxes) > 0:
# Get most confident detection
x1, y1, x2, y2 = boxes[0]
# Scale back to original coordinates
x = ((x1 + x2) / 2) * (orig_width/ANALYSIS_SIZE)
y = ((y1 + y2) / 2) * (orig_height/ANALYSIS_SIZE)
ball_positions.append((x, y))
# Predict trajectory and speed
trajectory, speed = predict_trajectory_simple(ball_positions[-8:])
max_speed = max(max_speed, speed)
# Draw trajectory on original frame
for j in range(1, len(trajectory)):
cv2.line(
frame,
tuple(map(int, trajectory[j-1])),
tuple(map(int, trajectory[j])),
COLORS["trajectory"], 2
)
frame_batch = []
# Periodic LBW check (less frequent for performance)
if frame_count % (FRAME_SKIP * 5) == 0 and STUMP_MODEL and ball_positions:
stumps = STUMP_MODEL(small_frame, classes=33, verbose=False, conf=MIN_CONFIDENCE)
if len(stumps[0].boxes) > 0:
sx1, sy1, sx2, sy2 = stumps[0].boxes.xyxy[0].cpu().numpy()
stump_pos = (
((sx1 + sx2) / 2) * (orig_width/ANALYSIS_SIZE),
((sy1 + sy2) / 2) * (orig_height/ANALYSIS_SIZE)
)
lbw_data = check_lbw_decision(ball_positions[-1], stump_pos)
# Draw overlay if we have data
if lbw_data:
draw_drs_overlay(frame, lbw_data, max_speed)
out.write(frame)
cap.release()
out.release()
print(f"Processing completed in {time.time()-start_time:.2f} seconds")
return temp_path
except Exception as e:
print(f"Processing error: {str(e)}")
return None
# ===== Gradio Interface =====
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# ⚡ Ultra-Fast Cricket DRS System
*Ball Tracking • LBW Decisions • Speed Measurement*
""")
with gr.Row():
input_video = gr.Video(label="Upload Match Footage", format="mp4")
output_video = gr.Video(label="DRS Analysis Result", format="mp4")
with gr.Row():
with gr.Column():
gr.Markdown("### 📊 Decision Parameters")
decision = gr.Textbox(label="Final Decision")
hitting = gr.Textbox(label="Wickets Hitting")
impact = gr.Textbox(label="Impact")
with gr.Column():
gr.Markdown("### 📝 Tracking Data")
speed = gr.Number(label="Ball Speed (km/h)", precision=1)
in_line = gr.Textbox(label="In Line")
pitching = gr.Textbox(label="Pitching")
analyze_btn = gr.Button("Run DRS Analysis", variant="primary")
def process_and_display(video):
result_path = process_video_optimized(video)
# For demo purposes, return mock analytics when no detection
if result_path is None:
return {
output_video: None,
decision: "ERROR IN PROCESSING",
speed: 0.0,
hitting: "N/A",
impact: "N/A",
in_line: "N/A",
pitching: "N/A"
}
# In a full implementation, you would extract these from the processing
return {
output_video: result_path,
decision: "OUT" if np.random.rand() > 0.5 else "NOT OUT", # Mock
speed: np.random.uniform(120, 150), # Mock
hitting: "HITTING",
impact: "IMPACT",
in_line: "IN-LINE",
pitching: "OUTSIDE OFF"
}
analyze_btn.click(
fn=process_and_display,
inputs=input_video,
outputs=[output_video, decision, speed, hitting, impact, in_line, pitching]
)
if __name__ == "__main__":
demo.launch()