Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,76 +5,114 @@ import torch
|
|
| 5 |
from transformers import DetrForObjectDetection, DetrImageProcessor
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
from mpl_toolkits.mplot3d import Axes3D
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
# Load pre-trained model for ball detection
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
def process_drs_video(video_path):
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
break
|
| 21 |
-
frames.append(frame)
|
| 22 |
-
cap.release()
|
| 23 |
-
|
| 24 |
-
# Detect ball in frames
|
| 25 |
-
ball_positions = []
|
| 26 |
-
for frame in frames:
|
| 27 |
-
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 28 |
-
inputs = processor(images=frame_rgb, return_tensors="pt")
|
| 29 |
-
outputs = model(**inputs)
|
| 30 |
-
target_sizes = torch.tensor([frame_rgb.shape[:2]])
|
| 31 |
-
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
z = 100 # Replace with actual depth model or pitch mapping
|
| 39 |
-
ball_positions.append([x_center.item(), y_center.item(), z])
|
| 40 |
break
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
pitching_in_line = trajectory[0][0] > 200 and trajectory[0][0] < 520 # Placeholder: Adjust based on pitch dimensions
|
| 53 |
-
impact_in_line = trajectory[-1][0] > 200 and trajectory[-1][0] < 520 # Check impact near stumps
|
| 54 |
-
hits_stumps = trajectory[-1][1] < 300 # Simplified: Ball low enough to hit stumps
|
| 55 |
-
decision = "Out" if pitching_in_line and impact_in_line and hits_stumps else "Not Out"
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
-
|
| 67 |
-
ax2 = fig.add_subplot(122)
|
| 68 |
-
ax2.scatter(trajectory[:, 0], trajectory[:, 1], c='blue', marker='o')
|
| 69 |
-
ax2.set_xlim(0, 720) # Adjust based on video resolution or pitch dimensions
|
| 70 |
-
ax2.set_ylim(0, 1280)
|
| 71 |
-
ax2.set_xlabel("Pitch Width")
|
| 72 |
-
ax2.set_ylabel("Pitch Length")
|
| 73 |
-
ax2.set_title("Pitch Map")
|
| 74 |
-
plt.savefig("drs_output.png")
|
| 75 |
-
plt.close()
|
| 76 |
|
| 77 |
-
|
|
|
|
|
|
|
| 78 |
|
| 79 |
# Gradio interface
|
| 80 |
iface = gr.Interface(
|
|
|
|
| 5 |
from transformers import DetrForObjectDetection, DetrImageProcessor
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
from mpl_toolkits.mplot3d import Axes3D
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
# Set up logging
|
| 11 |
+
logging.basicConfig(level=logging.INFO)
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
|
| 14 |
# Load pre-trained model for ball detection
|
| 15 |
+
try:
|
| 16 |
+
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", cache_dir="/home/user/app/cache")
|
| 17 |
+
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", cache_dir="/home/user/app/cache")
|
| 18 |
+
logger.info("Model and processor loaded successfully")
|
| 19 |
+
except Exception as e:
|
| 20 |
+
logger.error(f"Error loading model: {str(e)}")
|
| 21 |
+
raise
|
| 22 |
|
| 23 |
def process_drs_video(video_path):
|
| 24 |
+
try:
|
| 25 |
+
# Extract frames from video
|
| 26 |
+
cap = cv2.VideoCapture(video_path)
|
| 27 |
+
if not cap.isOpened():
|
| 28 |
+
logger.error("Failed to open video file")
|
| 29 |
+
return "Error: Could not open video file", None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
+
frames = []
|
| 32 |
+
frame_count = 0
|
| 33 |
+
while cap.isOpened() and frame_count < 100: # Limit to 100 frames to avoid memory issues
|
| 34 |
+
ret, frame = cap.read()
|
| 35 |
+
if not ret:
|
|
|
|
|
|
|
| 36 |
break
|
| 37 |
+
# Resize frame to reduce memory usage
|
| 38 |
+
frame = cv2.resize(frame, (640, 360))
|
| 39 |
+
frames.append(frame)
|
| 40 |
+
frame_count += 1
|
| 41 |
+
cap.release()
|
| 42 |
+
logger.info(f"Extracted {len(frames)} frames from video")
|
| 43 |
|
| 44 |
+
if not frames:
|
| 45 |
+
logger.error("No frames extracted from video")
|
| 46 |
+
return "Error: No frames extracted from video", None
|
| 47 |
+
|
| 48 |
+
# Detect ball in frames
|
| 49 |
+
ball_positions = []
|
| 50 |
+
for i, frame in enumerate(frames[::2]): # Process every 2nd frame to reduce load
|
| 51 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 52 |
+
inputs = processor(images=frame_rgb, return_tensors="pt")
|
| 53 |
+
with torch.no_grad(): # Disable gradients for inference
|
| 54 |
+
outputs = model(**inputs)
|
| 55 |
+
|
| 56 |
+
target_sizes = torch.tensor([frame_rgb.shape[:2]])
|
| 57 |
+
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.7)[0] # Lowered threshold
|
| 58 |
+
|
| 59 |
+
ball_found = False
|
| 60 |
+
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
| 61 |
+
if score > 0.7: # Adjusted confidence threshold
|
| 62 |
+
x_center = (box[0] + box[2]) / 2
|
| 63 |
+
y_center = (box[1] + box[3]) / 2
|
| 64 |
+
z = 100 # Placeholder for depth (replace with depth model if available)
|
| 65 |
+
ball_positions.append([x_center.item(), y_center.item(), z])
|
| 66 |
+
ball_found = True
|
| 67 |
+
logger.info(f"Ball detected in frame {i}: score={score:.2f}, box={box.tolist()}")
|
| 68 |
+
break
|
| 69 |
+
if not ball_found:
|
| 70 |
+
ball_positions.append(None)
|
| 71 |
+
logger.warning(f"No ball detected in frame {i}")
|
| 72 |
+
|
| 73 |
+
ball_positions = [pos for pos in ball_positions if pos is not None]
|
| 74 |
+
if not ball_positions:
|
| 75 |
+
logger.error("No ball detected in any frame")
|
| 76 |
+
return "Error: No ball detected in video", None
|
| 77 |
|
| 78 |
+
trajectory = np.array(ball_positions)
|
| 79 |
+
logger.info(f"Trajectory shape: {trajectory.shape}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
+
# Simplified LBW decision logic
|
| 82 |
+
pitching_in_line = trajectory[0][0] > 100 and trajectory[0][0] < 540 # Adjusted for 640x360 resolution
|
| 83 |
+
impact_in_line = trajectory[-1][0] > 100 and trajectory[-1][0] < 540
|
| 84 |
+
hits_stumps = trajectory[-1][1] < 200 # Adjusted for lower resolution
|
| 85 |
+
decision = "Out" if pitching_in_line and impact_in_line and hits_stumps else "Not Out"
|
| 86 |
+
logger.info(f"LBW Decision: {decision}, Pitching: {pitching_in_line}, Impact: {impact_in_line}, Stumps: {hits_stumps}")
|
| 87 |
+
|
| 88 |
+
# 3D Trajectory Plot
|
| 89 |
+
fig = plt.figure(figsize=(10, 5))
|
| 90 |
+
ax = fig.add_subplot(121, projection='3d')
|
| 91 |
+
ax.plot(trajectory[:, 0], trajectory[:, 1], trajectory[:, 2], 'r-')
|
| 92 |
+
ax.set_xlabel("X (Pitch Width)")
|
| 93 |
+
ax.set_ylabel("Y (Pitch Length)")
|
| 94 |
+
ax.set_zlabel("Z (Height)")
|
| 95 |
+
ax.set_title("3D Ball Trajectory")
|
| 96 |
+
|
| 97 |
+
# Pitch Map Plot
|
| 98 |
+
ax2 = fig.add_subplot(122)
|
| 99 |
+
ax2.scatter(trajectory[:, 0], trajectory[:, 1], c='blue', marker='o')
|
| 100 |
+
ax2.set_xlim(0, 640) # Match video resolution
|
| 101 |
+
ax2.set_ylim(0, 360)
|
| 102 |
+
ax2.set_xlabel("Pitch Width")
|
| 103 |
+
ax2.set_ylabel("Pitch Length")
|
| 104 |
+
ax2.set_title("Pitch Map")
|
| 105 |
+
|
| 106 |
+
output_path = "drs_output.png"
|
| 107 |
+
plt.savefig(output_path)
|
| 108 |
+
plt.close()
|
| 109 |
+
logger.info(f"Output saved to {output_path}")
|
| 110 |
|
| 111 |
+
return decision, output_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
+
except Exception as e:
|
| 114 |
+
logger.error(f"Error processing video: {str(e)}")
|
| 115 |
+
return f"Error: {str(e)}", None
|
| 116 |
|
| 117 |
# Gradio interface
|
| 118 |
iface = gr.Interface(
|