NV9523 commited on
Commit
af0a2d0
·
verified ·
1 Parent(s): e1d2b3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -38
app.py CHANGED
@@ -12,7 +12,6 @@ def get_model_names():
12
  return model.model.names
13
  return {}
14
 
15
- # Function to count objects in a single frame
16
  def count_objects(results, cumulative=None):
17
  names = get_model_names()
18
  counter = cumulative if cumulative is not None else {}
@@ -21,45 +20,36 @@ def count_objects(results, cumulative=None):
21
  for cls_id in r.boxes.cls:
22
  cls_id = int(cls_id)
23
  label = str(names[cls_id])
24
- if label not in counter:
25
- counter[label] = 1
26
- else:
27
- counter[label] += 1
28
 
29
  counter["Total"] = sum(counter.get(k, 0) for k in counter if k != "Total")
30
  return counter
31
 
32
- # Image detection (unchanged)
33
- def detect_image(img):
34
- results = model.predict(img, imgsz=640)
35
- annotated = results[0].plot()
36
- dashboard = count_objects(results)
37
- return annotated, dashboard
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- # Video detection (unchanged)
40
- def detect_video(video_path):
41
- cap = cv2.VideoCapture(video_path)
42
  ret, frame = cap.read()
43
  if not ret:
44
- return None, {"Error": "Cannot read video"}
45
- results = model.predict(frame, imgsz=640)
46
- annotated = results[0].plot()
47
- dashboard = count_objects(results)
48
- cap.release()
49
- return annotated, dashboard
50
-
51
- # ====================
52
- # Live camera detection with cumulative counting
53
- # ====================
54
- def start_camera(_):
55
- return np.zeros((480, 640, 3), dtype=np.uint8), {}
56
 
57
- def detect_camera(frame, state):
58
  results = model.predict(frame, imgsz=640)
59
  annotated = results[0].plot()
60
-
61
- cumulative = state.get("cumulative", {})
62
- cumulative = count_objects(results, cumulative)
63
  state["cumulative"] = cumulative
64
 
65
  return annotated, cumulative, state
@@ -75,7 +65,7 @@ with gr.Blocks(title="Rix Detection") as demo:
75
  img_out = gr.Image(label="Result Image")
76
  dashboard1 = gr.JSON(label="Counts")
77
  btn1 = gr.Button("Detect")
78
- btn1.click(fn=detect_image, inputs=img_input, outputs=[img_out, dashboard1])
79
 
80
  # Tab 2: Video
81
  with gr.Tab("Video Detection"):
@@ -83,18 +73,26 @@ with gr.Blocks(title="Rix Detection") as demo:
83
  video_out = gr.Image(label="Demo Frame Result")
84
  dashboard2 = gr.JSON(label="Counts")
85
  btn2 = gr.Button("Detect Video")
86
- btn2.click(fn=detect_video, inputs=video_input, outputs=[video_out, dashboard2])
87
 
88
  # Tab 3: Live Camera
89
  with gr.Tab("Live Camera"):
90
- cam_input = gr.Image(sources=["webcam"], type="numpy", label="Camera")
91
- cam_out = gr.Image(label="Real-time Result")
92
  dashboard3 = gr.JSON(label="Cumulative Counts")
93
  state = gr.State()
94
 
95
- start_btn = gr.Button("Start Detection")
96
- start_btn.click(fn=start_camera, inputs=None, outputs=[cam_out, dashboard3])
97
-
98
- cam_input.stream(fn=detect_camera, inputs=[cam_input, state], outputs=[cam_out, dashboard3, state])
 
 
 
 
 
 
 
 
99
 
100
  demo.launch()
 
12
  return model.model.names
13
  return {}
14
 
 
15
  def count_objects(results, cumulative=None):
16
  names = get_model_names()
17
  counter = cumulative if cumulative is not None else {}
 
20
  for cls_id in r.boxes.cls:
21
  cls_id = int(cls_id)
22
  label = str(names[cls_id])
23
+ counter[label] = counter.get(label, 0) + 1
 
 
 
24
 
25
  counter["Total"] = sum(counter.get(k, 0) for k in counter if k != "Total")
26
  return counter
27
 
28
+ # Initialize camera state
29
+ def init_camera(camera_index):
30
+ cap = cv2.VideoCapture(camera_index)
31
+ return {"cap": cap, "cumulative": {}}
32
+
33
+ # Release camera
34
+ def stop_camera(state):
35
+ cap = state.get("cap")
36
+ if cap is not None:
37
+ cap.release()
38
+ return np.zeros((480,640,3), dtype=np.uint8), {}, state
39
+
40
+ # Detect frame and update cumulative count
41
+ def detect_camera(state):
42
+ cap = state.get("cap")
43
+ if cap is None or not cap.isOpened():
44
+ return np.zeros((480,640,3), dtype=np.uint8), state.get("cumulative", {}), state
45
 
 
 
 
46
  ret, frame = cap.read()
47
  if not ret:
48
+ return np.zeros((480,640,3), dtype=np.uint8), state.get("cumulative", {}), state
 
 
 
 
 
 
 
 
 
 
 
49
 
 
50
  results = model.predict(frame, imgsz=640)
51
  annotated = results[0].plot()
52
+ cumulative = count_objects(results, state.get("cumulative", {}))
 
 
53
  state["cumulative"] = cumulative
54
 
55
  return annotated, cumulative, state
 
65
  img_out = gr.Image(label="Result Image")
66
  dashboard1 = gr.JSON(label="Counts")
67
  btn1 = gr.Button("Detect")
68
+ btn1.click(fn=lambda img: detect_image(img), inputs=img_input, outputs=[img_out, dashboard1])
69
 
70
  # Tab 2: Video
71
  with gr.Tab("Video Detection"):
 
73
  video_out = gr.Image(label="Demo Frame Result")
74
  dashboard2 = gr.JSON(label="Counts")
75
  btn2 = gr.Button("Detect Video")
76
+ btn2.click(fn=lambda vid: detect_video(vid), inputs=video_input, outputs=[video_out, dashboard2])
77
 
78
  # Tab 3: Live Camera
79
  with gr.Tab("Live Camera"):
80
+ camera_index = gr.Number(value=0, label="Camera Index")
81
+ cam_out = gr.Image(label="Live Result")
82
  dashboard3 = gr.JSON(label="Cumulative Counts")
83
  state = gr.State()
84
 
85
+ start_btn = gr.Button("Start Camera")
86
+ stop_btn = gr.Button("Stop Camera")
87
+
88
+ start_btn.click(fn=init_camera, inputs=camera_index, outputs=state)
89
+ stop_btn.click(fn=stop_camera, inputs=state, outputs=[cam_out, dashboard3, state])
90
+
91
+ cam_out.update(
92
+ every=0.1,
93
+ fn=detect_camera,
94
+ inputs=state,
95
+ outputs=[cam_out, dashboard3, state]
96
+ )
97
 
98
  demo.launch()