lokesh341 commited on
Commit
d7d180f
Β·
verified Β·
1 Parent(s): 8303862

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -54
app.py CHANGED
@@ -3,69 +3,72 @@ import numpy as np
3
  import gradio as gr
4
  from ultralytics import YOLO
5
  import torch
6
- from typing import List, Dict
7
  import os
 
 
8
 
9
- # Verify PyTorch and CUDA
10
- print(f"PyTorch version: {torch.__version__}")
11
- print(f"CUDA available: {torch.cuda.is_available()}")
12
-
13
- # Initialize models with safe loading
14
- def load_model(model_path: str, model_type: str = "ball"):
15
- """Safely load YOLO model with verification"""
16
  try:
17
- # For Hugging Face Spaces - use cached models
18
- if not os.path.exists(model_path):
19
- model = YOLO(f"yolov8{model_type}.pt") # Auto-download
20
- torch.save(model.model.state_dict(), model_path)
21
- else:
22
- model = YOLO(model_path)
23
-
24
- # Verify model
25
- if model.predict(np.zeros((640, 640, 3)), verbose=False)[0].boxes is not None:
26
- print(f"{model_type.upper()} model loaded successfully!")
27
  return model
28
- else:
29
- raise RuntimeError("Model verification failed")
30
  except Exception as e:
31
- raise RuntimeError(f"Model loading error: {str(e)}")
 
32
 
33
- # Load models (will auto-download if not present)
34
- try:
35
- BALL_MODEL = load_model("yolov8n.pt", "n")
36
- STUMP_MODEL = load_model("yolov8m.pt", "m")
37
- except Exception as e:
38
- print(f"Critical error: {str(e)}")
39
- # Fallback to CPU-only basic detection
40
- BALL_MODEL = None
41
 
42
  def process_video(video_path: str) -> Dict:
43
- """Robust video processing with error handling"""
44
  try:
 
 
 
 
 
 
 
 
45
  cap = cv2.VideoCapture(video_path)
46
  if not cap.isOpened():
47
  raise ValueError("Could not open video file")
48
 
 
49
  fps = cap.get(cv2.CAP_PROP_FPS)
50
- frames = []
 
 
 
 
51
  analytics = {
52
- "max_speed": 0,
53
  "events": [],
54
- "status": "Success"
 
 
55
  }
56
 
57
  prev_pos = None
 
58
 
59
- while cap.isOpened():
60
  ret, frame = cap.read()
61
  if not ret:
62
  break
63
 
 
 
 
64
  frame = cv2.resize(frame, (1280, 720))
65
 
66
- # Ball detection (only if model loaded)
67
  if BALL_MODEL:
68
- results = BALL_MODEL(frame, classes=32, verbose=False) # Class 32 = sports ball
69
  boxes = results[0].boxes.xyxy.cpu().numpy()
70
 
71
  if len(boxes) > 0:
@@ -74,55 +77,102 @@ def process_video(video_path: str) -> Dict:
74
 
75
  # Speed calculation
76
  if prev_pos:
77
- speed = np.sqrt((x - prev_pos[0])**2 + (y - prev_pos[1])**2) * fps * 3.6 / 100
 
78
  analytics["max_speed"] = max(analytics["max_speed"], speed)
 
 
 
79
  cv2.putText(frame, f"{speed:.1f} km/h", (x+15, y),
80
  cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 2)
81
 
82
  prev_pos = (x, y)
83
- cv2.circle(frame, (x, y), 10, (0, 255, 0), -1)
84
 
85
- frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
 
 
 
 
 
 
86
 
87
  cap.release()
88
- return {"frames": frames, "analytics": analytics}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  except Exception as e:
91
  return {
92
- "frames": [],
93
  "analytics": {
94
- "status": f"Error: {str(e)}",
95
- "max_speed": 0,
96
- "events": []
 
 
97
  }
98
  }
99
 
100
  # Gradio Interface
101
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
102
- gr.Markdown("# 🏏 Cricket Ball Tracker (Fixed Version)")
 
 
 
103
 
104
  with gr.Row():
105
- input_video = gr.Video(label="Input Video", format="mp4")
106
- output_video = gr.Video(label="Tracking Results")
 
 
 
 
 
 
 
 
107
 
108
- with gr.Accordion("Advanced Info", open=False):
109
- status = gr.Textbox(label="Processing Status")
110
- max_speed = gr.Number(label="Max Ball Speed (km/h)")
 
 
 
 
 
 
 
111
 
112
- analyze_btn = gr.Button("Analyze", variant="primary")
113
 
114
  def analyze_wrapper(video):
115
  result = process_video(video)
116
  return {
117
- output_video: result["frames"],
 
 
118
  status: result["analytics"]["status"],
119
- max_speed: result["analytics"]["max_speed"]
120
  }
121
 
122
  analyze_btn.click(
123
  fn=analyze_wrapper,
124
  inputs=input_video,
125
- outputs=[output_video, status, max_speed]
126
  )
127
 
128
  if __name__ == "__main__":
 
3
  import gradio as gr
4
  from ultralytics import YOLO
5
  import torch
 
6
  import os
7
+ import tempfile
8
+ from typing import Dict, List
9
 
10
+ # Initialize models safely
11
+ def load_model(model_name: str):
 
 
 
 
 
12
  try:
13
+ model = YOLO(model_name)
14
+ # Test model with dummy data
15
+ dummy_result = model(np.zeros((640, 640, 3)), verbose=False)
16
+ if dummy_result[0].boxes is not None:
17
+ print(f"βœ… {model_name} loaded successfully!")
 
 
 
 
 
18
  return model
19
+ raise RuntimeError("Model test failed")
 
20
  except Exception as e:
21
+ print(f"❌ Model load error: {str(e)}")
22
+ return None
23
 
24
+ BALL_MODEL = load_model("yolov8n.pt")
 
 
 
 
 
 
 
25
 
26
  def process_video(video_path: str) -> Dict:
27
+ """Robust video processing with full error handling"""
28
  try:
29
+ # Handle Gradio's temporary file path
30
+ if isinstance(video_path, dict):
31
+ video_path = video_path["name"]
32
+
33
+ # Verify file exists
34
+ if not os.path.exists(video_path):
35
+ raise FileNotFoundError(f"Video file not found: {video_path}")
36
+
37
  cap = cv2.VideoCapture(video_path)
38
  if not cap.isOpened():
39
  raise ValueError("Could not open video file")
40
 
41
+ # Get video properties
42
  fps = cap.get(cv2.CAP_PROP_FPS)
43
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
44
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
45
+
46
+ # Prepare output
47
+ output_frames = []
48
  analytics = {
49
+ "max_speed": 0.0,
50
  "events": [],
51
+ "status": "Processing...",
52
+ "fps": fps,
53
+ "resolution": f"{width}x{height}"
54
  }
55
 
56
  prev_pos = None
57
+ frame_count = 0
58
 
59
+ while True:
60
  ret, frame = cap.read()
61
  if not ret:
62
  break
63
 
64
+ frame_count += 1
65
+
66
+ # Resize for consistent processing (optional)
67
  frame = cv2.resize(frame, (1280, 720))
68
 
69
+ # Ball detection
70
  if BALL_MODEL:
71
+ results = BALL_MODEL(frame, classes=32, verbose=False)
72
  boxes = results[0].boxes.xyxy.cpu().numpy()
73
 
74
  if len(boxes) > 0:
 
77
 
78
  # Speed calculation
79
  if prev_pos:
80
+ px_per_meter = 100 # Calibration needed
81
+ speed = np.sqrt((x - prev_pos[0])**2 + (y - prev_pos[1])**2) * fps * 3.6 / px_per_meter
82
  analytics["max_speed"] = max(analytics["max_speed"], speed)
83
+
84
+ # Visualize
85
+ cv2.circle(frame, (x, y), 10, (0, 255, 0), -1)
86
  cv2.putText(frame, f"{speed:.1f} km/h", (x+15, y),
87
  cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 2)
88
 
89
  prev_pos = (x, y)
 
90
 
91
+ # Convert frame to RGB for Gradio
92
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
93
+ output_frames.append(frame_rgb)
94
+
95
+ # Update status every 10 frames
96
+ if frame_count % 10 == 0:
97
+ analytics["status"] = f"Processed {frame_count} frames"
98
 
99
  cap.release()
100
+
101
+ # Handle empty output
102
+ if not output_frames:
103
+ raise ValueError("No frames processed - check video format")
104
+
105
+ # Save output as temporary video file
106
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
107
+ out = cv2.VideoWriter(tmp.name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (1280, 720))
108
+ for frame in output_frames:
109
+ out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
110
+ out.release()
111
+
112
+ analytics["status"] = "βœ… Processing complete"
113
+ return {
114
+ "output_video": tmp.name,
115
+ "analytics": analytics
116
+ }
117
 
118
  except Exception as e:
119
  return {
120
+ "output_video": None,
121
  "analytics": {
122
+ "status": f"❌ Error: {str(e)}",
123
+ "max_speed": 0.0,
124
+ "events": [],
125
+ "fps": 0,
126
+ "resolution": "0x0"
127
  }
128
  }
129
 
130
  # Gradio Interface
131
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
132
+ gr.Markdown("""
133
+ # 🏏 Professional Cricket Tracker
134
+ *Ball Tracking β€’ Speed Analysis β€’ Event Detection*
135
+ """)
136
 
137
  with gr.Row():
138
+ with gr.Column():
139
+ input_video = gr.Video(label="Input Match Footage", format="mp4")
140
+ gr.Examples(
141
+ examples=["sample.mp4"],
142
+ inputs=input_video,
143
+ label="Try Sample Video"
144
+ )
145
+
146
+ with gr.Column():
147
+ output_video = gr.Video(label="Tracking Results", format="mp4")
148
 
149
+ with gr.Row():
150
+ with gr.Column():
151
+ gr.Markdown("### πŸ“Š Match Analytics")
152
+ max_speed = gr.Number(label="Max Ball Speed (km/h)", precision=1)
153
+ resolution = gr.Textbox(label="Video Resolution")
154
+
155
+ with gr.Column():
156
+ gr.Markdown("### πŸ“ Processing Info")
157
+ status = gr.Textbox(label="Status")
158
+ fps = gr.Number(label="Video FPS")
159
 
160
+ analyze_btn = gr.Button("Analyze Video", variant="primary")
161
 
162
  def analyze_wrapper(video):
163
  result = process_video(video)
164
  return {
165
+ output_video: result["output_video"],
166
+ max_speed: result["analytics"]["max_speed"],
167
+ resolution: result["analytics"]["resolution"],
168
  status: result["analytics"]["status"],
169
+ fps: result["analytics"]["fps"]
170
  }
171
 
172
  analyze_btn.click(
173
  fn=analyze_wrapper,
174
  inputs=input_video,
175
+ outputs=[output_video, max_speed, resolution, status, fps]
176
  )
177
 
178
  if __name__ == "__main__":