viswanani commited on
Commit
69c9ae2
·
verified ·
1 Parent(s): 5dbb53a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -86
app.py CHANGED
@@ -1,93 +1,24 @@
1
  import gradio as gr
2
- import cv2
3
- import numpy as np
4
- import tempfile
5
- import os
6
 
7
- # === CONFIG ===
8
- FRAME_SKIP = 3
9
- MAX_FRAMES = 200
10
- RESIZE_DIM = (640, 360)
11
 
12
- def track_ball_in_frame(frame):
13
- hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
14
-
15
- # 🎯 Color thresholds for red ball
16
- lower_red1 = np.array([0, 100, 100])
17
- upper_red1 = np.array([10, 255, 255])
18
- lower_red2 = np.array([160, 100, 100])
19
- upper_red2 = np.array([180, 255, 255])
20
 
21
- mask1 = cv2.inRange(hsv, lower_red1, upper_red1)
22
- mask2 = cv2.inRange(hsv, lower_red2, upper_red2)
23
- mask = mask1 | mask2
24
- mask = cv2.GaussianBlur(mask, (5, 5), 0)
25
-
26
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
27
-
28
- if contours:
29
- largest = max(contours, key=cv2.contourArea)
30
- if cv2.contourArea(largest) > 20:
31
- (x, y), radius = cv2.minEnclosingCircle(largest)
32
- center = (int(x), int(y))
33
- radius = int(radius)
34
- cv2.circle(frame, center, radius, (0, 255, 0), 2)
35
- cv2.putText(frame, "Ball", (center[0]+10, center[1]), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
36
-
37
- return frame
38
-
39
- def process_frame(frame, i):
40
- return track_ball_in_frame(frame)
41
-
42
- def analyze_video(video_path):
43
- cap = cv2.VideoCapture(video_path)
44
- frames = []
45
- frame_idx = 0
46
- processed_count = 0
47
-
48
- while cap.isOpened():
49
- ret, frame = cap.read()
50
- if not ret:
51
- break
52
-
53
- frame_idx += 1
54
- if frame_idx % FRAME_SKIP != 0:
55
- continue
56
-
57
- frame = cv2.resize(frame, RESIZE_DIM)
58
- processed_frame = process_frame(frame, frame_idx)
59
-
60
- frames.append(processed_frame)
61
- processed_count += 1
62
-
63
- if processed_count >= MAX_FRAMES:
64
- break
65
-
66
- cap.release()
67
-
68
- output_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
69
- height, width, _ = frames[0].shape
70
- out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 15, (width, height))
71
-
72
- for f in frames:
73
- out.write(f)
74
- out.release()
75
-
76
- return output_path, f"✔️ Processed {processed_count} frames from {frame_idx} total."
77
-
78
- # === Gradio UI ===
79
  with gr.Blocks() as demo:
80
- gr.Markdown("## 🏏 GullyDRS - LBW Predictor with Ball Tracking")
81
- gr.Markdown("Upload a short cricket delivery video (.mp4) to detect and track the ball!")
82
-
83
- with gr.Row():
84
- video_input = gr.File(file_types=[".mp4"], type="filepath", label="Upload Cricket Delivery (.mp4)")
85
- submit_btn = gr.Button("Track Ball")
86
-
87
- with gr.Row():
88
- video_output = gr.Video(label="Tracked Ball Output (.mp4)")
89
- status_output = gr.Textbox(label="Processing Status")
90
-
91
- submit_btn.click(fn=analyze_video, inputs=video_input, outputs=[video_output, status_output])
92
 
93
  demo.launch()
 
1
  import gradio as gr
2
+ from model.predictor import LBWPredictor
3
+ from utils.preprocess import clean_input
 
 
4
 
5
+ predictor = LBWPredictor("model/lbw_model.joblib")
 
 
 
6
 
7
+ def predict_interface(**kwargs):
8
+ features = clean_input(kwargs)
9
+ return predictor.predict(features)
 
 
 
 
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  with gr.Blocks() as demo:
12
+ gr.Markdown("# GullyDRS: LBW Predictor 🏏")
13
+ inputs = [
14
+ gr.Number(label="Ball Speed (km/h)", value=130),
15
+ gr.Number(label="Impact X", value=0.25),
16
+ gr.Number(label="Impact Y", value=0.5),
17
+ gr.Number(label="Stump Height", value=0.71)
18
+ ]
19
+ btn = gr.Button("Predict LBW")
20
+ output = gr.Textbox(label="Decision")
21
+
22
+ btn.click(predict_interface, inputs=inputs, outputs=output)
 
23
 
24
  demo.launch()