NV9523 commited on
Commit
e1d2b3a
·
verified ·
1 Parent(s): 10b14b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -59
app.py CHANGED
@@ -3,10 +3,7 @@ from ultralytics import YOLO
3
  import cv2
4
  import numpy as np
5
 
6
- # ======================================================
7
- # Load YOLO model
8
- # ======================================================
9
- model = YOLO("rix_reg.pt") # change to your model
10
 
11
  def get_model_names():
12
  if hasattr(model, "names") and model.names is not None:
@@ -15,117 +12,89 @@ def get_model_names():
15
  return model.model.names
16
  return {}
17
 
18
- # ======================================================
19
- # Function to count all objects in the model
20
- # ======================================================
21
- def count_objects(results):
22
  names = get_model_names()
23
- counter = {}
24
 
25
  for r in results:
26
  for cls_id in r.boxes.cls:
27
  cls_id = int(cls_id)
28
  label = str(names[cls_id])
29
-
30
- # increment count
31
  if label not in counter:
32
  counter[label] = 1
33
  else:
34
  counter[label] += 1
35
 
36
- counter["Total"] = sum(counter.get(k, 0) for k in counter)
37
  return counter
38
 
39
-
40
- # ======================================================
41
- # Tab 1 - Image processing
42
- # ======================================================
43
  def detect_image(img):
44
  results = model.predict(img, imgsz=640)
45
  annotated = results[0].plot()
46
-
47
  dashboard = count_objects(results)
48
  return annotated, dashboard
49
 
50
-
51
- # ======================================================
52
- # Tab 2 - Video processing
53
- # ======================================================
54
  def detect_video(video_path):
55
  cap = cv2.VideoCapture(video_path)
56
-
57
  ret, frame = cap.read()
58
  if not ret:
59
  return None, {"Error": "Cannot read video"}
60
-
61
- # demo first frame
62
  results = model.predict(frame, imgsz=640)
63
  annotated = results[0].plot()
64
-
65
  dashboard = count_objects(results)
66
  cap.release()
67
-
68
  return annotated, dashboard
69
 
 
 
 
 
 
70
 
71
- # ======================================================
72
- # Tab 3 - Live camera
73
- # ======================================================
74
- def detect_camera(frame):
75
  results = model.predict(frame, imgsz=640)
76
  annotated = results[0].plot()
77
 
78
- dashboard = count_objects(results)
79
- return annotated, dashboard
 
80
 
 
81
 
82
- # ======================================================
83
- # GRADIO interface
84
- # ======================================================
85
  with gr.Blocks(title="Rix Detection") as demo:
86
-
87
  gr.Markdown("## 🛠️ Object Counting Dashboard")
88
 
89
  with gr.Tabs():
90
 
91
- # ==================== TAB 1 ====================
92
  with gr.Tab("Image Detection"):
93
  img_input = gr.Image(type="numpy", label="Upload Image")
94
  img_out = gr.Image(label="Result Image")
95
  dashboard1 = gr.JSON(label="Counts")
96
-
97
  btn1 = gr.Button("Detect")
 
98
 
99
- btn1.click(
100
- fn=detect_image,
101
- inputs=img_input,
102
- outputs=[img_out, dashboard1]
103
- )
104
-
105
- # ==================== TAB 2 ====================
106
  with gr.Tab("Video Detection"):
107
  video_input = gr.Video(label="Upload Video")
108
  video_out = gr.Image(label="Demo Frame Result")
109
  dashboard2 = gr.JSON(label="Counts")
110
-
111
  btn2 = gr.Button("Detect Video")
 
112
 
113
- btn2.click(
114
- fn=detect_video,
115
- inputs=video_input,
116
- outputs=[video_out, dashboard2]
117
- )
118
-
119
- # ==================== TAB 3 ====================
120
  with gr.Tab("Live Camera"):
121
  cam_input = gr.Image(sources=["webcam"], type="numpy", label="Camera")
122
  cam_out = gr.Image(label="Real-time Result")
123
- dashboard3 = gr.JSON(label="Counts")
 
 
 
 
124
 
125
- cam_input.stream(
126
- fn=detect_camera,
127
- inputs=cam_input,
128
- outputs=[cam_out, dashboard3]
129
- )
130
 
131
  demo.launch()
 
3
  import cv2
4
  import numpy as np
5
 
6
+ model = YOLO("rix_reg.pt")
 
 
 
7
 
8
  def get_model_names():
9
  if hasattr(model, "names") and model.names is not None:
 
12
  return model.model.names
13
  return {}
14
 
15
+ # Function to count objects in a single frame
16
+ def count_objects(results, cumulative=None):
 
 
17
  names = get_model_names()
18
+ counter = cumulative if cumulative is not None else {}
19
 
20
  for r in results:
21
  for cls_id in r.boxes.cls:
22
  cls_id = int(cls_id)
23
  label = str(names[cls_id])
 
 
24
  if label not in counter:
25
  counter[label] = 1
26
  else:
27
  counter[label] += 1
28
 
29
+ counter["Total"] = sum(counter.get(k, 0) for k in counter if k != "Total")
30
  return counter
31
 
32
+ # Image detection (unchanged)
 
 
 
33
  def detect_image(img):
34
  results = model.predict(img, imgsz=640)
35
  annotated = results[0].plot()
 
36
  dashboard = count_objects(results)
37
  return annotated, dashboard
38
 
39
+ # Video detection (unchanged)
 
 
 
40
  def detect_video(video_path):
41
  cap = cv2.VideoCapture(video_path)
 
42
  ret, frame = cap.read()
43
  if not ret:
44
  return None, {"Error": "Cannot read video"}
 
 
45
  results = model.predict(frame, imgsz=640)
46
  annotated = results[0].plot()
 
47
  dashboard = count_objects(results)
48
  cap.release()
 
49
  return annotated, dashboard
50
 
51
+ # ====================
52
+ # Live camera detection with cumulative counting
53
+ # ====================
54
+ def start_camera(_):
55
+ return np.zeros((480, 640, 3), dtype=np.uint8), {}
56
 
57
+ def detect_camera(frame, state):
 
 
 
58
  results = model.predict(frame, imgsz=640)
59
  annotated = results[0].plot()
60
 
61
+ cumulative = state.get("cumulative", {})
62
+ cumulative = count_objects(results, cumulative)
63
+ state["cumulative"] = cumulative
64
 
65
+ return annotated, cumulative, state
66
 
 
 
 
67
  with gr.Blocks(title="Rix Detection") as demo:
 
68
  gr.Markdown("## 🛠️ Object Counting Dashboard")
69
 
70
  with gr.Tabs():
71
 
72
+ # Tab 1: Image
73
  with gr.Tab("Image Detection"):
74
  img_input = gr.Image(type="numpy", label="Upload Image")
75
  img_out = gr.Image(label="Result Image")
76
  dashboard1 = gr.JSON(label="Counts")
 
77
  btn1 = gr.Button("Detect")
78
+ btn1.click(fn=detect_image, inputs=img_input, outputs=[img_out, dashboard1])
79
 
80
+ # Tab 2: Video
 
 
 
 
 
 
81
  with gr.Tab("Video Detection"):
82
  video_input = gr.Video(label="Upload Video")
83
  video_out = gr.Image(label="Demo Frame Result")
84
  dashboard2 = gr.JSON(label="Counts")
 
85
  btn2 = gr.Button("Detect Video")
86
+ btn2.click(fn=detect_video, inputs=video_input, outputs=[video_out, dashboard2])
87
 
88
+ # Tab 3: Live Camera
 
 
 
 
 
 
89
  with gr.Tab("Live Camera"):
90
  cam_input = gr.Image(sources=["webcam"], type="numpy", label="Camera")
91
  cam_out = gr.Image(label="Real-time Result")
92
+ dashboard3 = gr.JSON(label="Cumulative Counts")
93
+ state = gr.State()
94
+
95
+ start_btn = gr.Button("Start Detection")
96
+ start_btn.click(fn=start_camera, inputs=None, outputs=[cam_out, dashboard3])
97
 
98
+ cam_input.stream(fn=detect_camera, inputs=[cam_input, state], outputs=[cam_out, dashboard3, state])
 
 
 
 
99
 
100
  demo.launch()