AjaykumarPilla commited on
Commit
e7cb6e4
·
verified ·
1 Parent(s): 8b5e7c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -59
app.py CHANGED
@@ -5,76 +5,114 @@ import torch
5
  from transformers import DetrForObjectDetection, DetrImageProcessor
6
  import matplotlib.pyplot as plt
7
  from mpl_toolkits.mplot3d import Axes3D
 
 
 
 
 
8
 
9
  # Load pre-trained model for ball detection
10
- processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", cache_dir="/home/user/app/cache")
11
- model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", cache_dir="/home/user/app/cache")
 
 
 
 
 
12
 
13
  def process_drs_video(video_path):
14
- # Extract frames from video
15
- cap = cv2.VideoCapture(video_path)
16
- frames = []
17
- while cap.isOpened():
18
- ret, frame = cap.read()
19
- if not ret:
20
- break
21
- frames.append(frame)
22
- cap.release()
23
-
24
- # Detect ball in frames
25
- ball_positions = []
26
- for frame in frames:
27
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
28
- inputs = processor(images=frame_rgb, return_tensors="pt")
29
- outputs = model(**inputs)
30
- target_sizes = torch.tensor([frame_rgb.shape[:2]])
31
- results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
32
 
33
- for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
34
- if score > 0.9: # High confidence threshold
35
- x_center = (box[0] + box[2]) / 2
36
- y_center = (box[1] + box[3]) / 2
37
- # Simplified depth estimation (z-coordinate, placeholder)
38
- z = 100 # Replace with actual depth model or pitch mapping
39
- ball_positions.append([x_center.item(), y_center.item(), z])
40
  break
41
- else:
42
- ball_positions.append(None)
43
-
44
- ball_positions = [pos for pos in ball_positions if pos is not None]
45
- if not ball_positions:
46
- return "Error: No ball detected in video", None
47
 
48
- trajectory = np.array(ball_positions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- # Simplified LBW decision logic (without edge detection)
51
- # Assumptions: Ball pitches in line, no bat contact (due to no audio), check stump impact
52
- pitching_in_line = trajectory[0][0] > 200 and trajectory[0][0] < 520 # Placeholder: Adjust based on pitch dimensions
53
- impact_in_line = trajectory[-1][0] > 200 and trajectory[-1][0] < 520 # Check impact near stumps
54
- hits_stumps = trajectory[-1][1] < 300 # Simplified: Ball low enough to hit stumps
55
- decision = "Out" if pitching_in_line and impact_in_line and hits_stumps else "Not Out"
56
 
57
- # 3D Trajectory Plot
58
- fig = plt.figure(figsize=(10, 5))
59
- ax = fig.add_subplot(121, projection='3d')
60
- ax.plot(trajectory[:, 0], trajectory[:, 1], trajectory[:, 2], 'r-')
61
- ax.set_xlabel("X (Pitch Width)")
62
- ax.set_ylabel("Y (Pitch Length)")
63
- ax.set_zlabel("Z (Height)")
64
- ax.set_title("3D Ball Trajectory")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- # Pitch Map Plot
67
- ax2 = fig.add_subplot(122)
68
- ax2.scatter(trajectory[:, 0], trajectory[:, 1], c='blue', marker='o')
69
- ax2.set_xlim(0, 720) # Adjust based on video resolution or pitch dimensions
70
- ax2.set_ylim(0, 1280)
71
- ax2.set_xlabel("Pitch Width")
72
- ax2.set_ylabel("Pitch Length")
73
- ax2.set_title("Pitch Map")
74
- plt.savefig("drs_output.png")
75
- plt.close()
76
 
77
- return decision, "drs_output.png"
 
 
78
 
79
  # Gradio interface
80
  iface = gr.Interface(
 
5
  from transformers import DetrForObjectDetection, DetrImageProcessor
6
  import matplotlib.pyplot as plt
7
  from mpl_toolkits.mplot3d import Axes3D
8
+ import logging
9
+
10
+ # Set up logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
 
14
  # Load pre-trained model for ball detection
15
+ try:
16
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", cache_dir="/home/user/app/cache")
17
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", cache_dir="/home/user/app/cache")
18
+ logger.info("Model and processor loaded successfully")
19
+ except Exception as e:
20
+ logger.error(f"Error loading model: {str(e)}")
21
+ raise
22
 
23
  def process_drs_video(video_path):
24
+ try:
25
+ # Extract frames from video
26
+ cap = cv2.VideoCapture(video_path)
27
+ if not cap.isOpened():
28
+ logger.error("Failed to open video file")
29
+ return "Error: Could not open video file", None
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ frames = []
32
+ frame_count = 0
33
+ while cap.isOpened() and frame_count < 100: # Limit to 100 frames to avoid memory issues
34
+ ret, frame = cap.read()
35
+ if not ret:
 
 
36
  break
37
+ # Resize frame to reduce memory usage
38
+ frame = cv2.resize(frame, (640, 360))
39
+ frames.append(frame)
40
+ frame_count += 1
41
+ cap.release()
42
+ logger.info(f"Extracted {len(frames)} frames from video")
43
 
44
+ if not frames:
45
+ logger.error("No frames extracted from video")
46
+ return "Error: No frames extracted from video", None
47
+
48
+ # Detect ball in frames
49
+ ball_positions = []
50
+ for i, frame in enumerate(frames[::2]): # Process every 2nd frame to reduce load
51
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
52
+ inputs = processor(images=frame_rgb, return_tensors="pt")
53
+ with torch.no_grad(): # Disable gradients for inference
54
+ outputs = model(**inputs)
55
+
56
+ target_sizes = torch.tensor([frame_rgb.shape[:2]])
57
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.7)[0] # Lowered threshold
58
+
59
+ ball_found = False
60
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
61
+ if score > 0.7: # Adjusted confidence threshold
62
+ x_center = (box[0] + box[2]) / 2
63
+ y_center = (box[1] + box[3]) / 2
64
+ z = 100 # Placeholder for depth (replace with depth model if available)
65
+ ball_positions.append([x_center.item(), y_center.item(), z])
66
+ ball_found = True
67
+ logger.info(f"Ball detected in frame {i}: score={score:.2f}, box={box.tolist()}")
68
+ break
69
+ if not ball_found:
70
+ ball_positions.append(None)
71
+ logger.warning(f"No ball detected in frame {i}")
72
+
73
+ ball_positions = [pos for pos in ball_positions if pos is not None]
74
+ if not ball_positions:
75
+ logger.error("No ball detected in any frame")
76
+ return "Error: No ball detected in video", None
77
 
78
+ trajectory = np.array(ball_positions)
79
+ logger.info(f"Trajectory shape: {trajectory.shape}")
 
 
 
 
80
 
81
+ # Simplified LBW decision logic
82
+ pitching_in_line = trajectory[0][0] > 100 and trajectory[0][0] < 540 # Adjusted for 640x360 resolution
83
+ impact_in_line = trajectory[-1][0] > 100 and trajectory[-1][0] < 540
84
+ hits_stumps = trajectory[-1][1] < 200 # Adjusted for lower resolution
85
+ decision = "Out" if pitching_in_line and impact_in_line and hits_stumps else "Not Out"
86
+ logger.info(f"LBW Decision: {decision}, Pitching: {pitching_in_line}, Impact: {impact_in_line}, Stumps: {hits_stumps}")
87
+
88
+ # 3D Trajectory Plot
89
+ fig = plt.figure(figsize=(10, 5))
90
+ ax = fig.add_subplot(121, projection='3d')
91
+ ax.plot(trajectory[:, 0], trajectory[:, 1], trajectory[:, 2], 'r-')
92
+ ax.set_xlabel("X (Pitch Width)")
93
+ ax.set_ylabel("Y (Pitch Length)")
94
+ ax.set_zlabel("Z (Height)")
95
+ ax.set_title("3D Ball Trajectory")
96
+
97
+ # Pitch Map Plot
98
+ ax2 = fig.add_subplot(122)
99
+ ax2.scatter(trajectory[:, 0], trajectory[:, 1], c='blue', marker='o')
100
+ ax2.set_xlim(0, 640) # Match video resolution
101
+ ax2.set_ylim(0, 360)
102
+ ax2.set_xlabel("Pitch Width")
103
+ ax2.set_ylabel("Pitch Length")
104
+ ax2.set_title("Pitch Map")
105
+
106
+ output_path = "drs_output.png"
107
+ plt.savefig(output_path)
108
+ plt.close()
109
+ logger.info(f"Output saved to {output_path}")
110
 
111
+ return decision, output_path
 
 
 
 
 
 
 
 
 
112
 
113
+ except Exception as e:
114
+ logger.error(f"Error processing video: {str(e)}")
115
+ return f"Error: {str(e)}", None
116
 
117
  # Gradio interface
118
  iface = gr.Interface(