NV9523 commited on
Commit
90a85b5
·
verified ·
1 Parent(s): af0a2d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -41
app.py CHANGED
@@ -3,7 +3,10 @@ from ultralytics import YOLO
3
  import cv2
4
  import numpy as np
5
 
6
- model = YOLO("rix_reg.pt")
 
 
 
7
 
8
  def get_model_names():
9
  if hasattr(model, "names") and model.names is not None:
@@ -12,87 +15,98 @@ def get_model_names():
12
  return model.model.names
13
  return {}
14
 
 
 
 
15
  def count_objects(results, cumulative=None):
16
  names = get_model_names()
17
  counter = cumulative if cumulative is not None else {}
18
-
19
  for r in results:
20
  for cls_id in r.boxes.cls:
21
  cls_id = int(cls_id)
22
  label = str(names[cls_id])
23
  counter[label] = counter.get(label, 0) + 1
24
-
25
  counter["Total"] = sum(counter.get(k, 0) for k in counter if k != "Total")
26
  return counter
27
 
28
- # Initialize camera state
29
- def init_camera(camera_index):
30
- cap = cv2.VideoCapture(camera_index)
31
- return {"cap": cap, "cumulative": {}}
32
-
33
- # Release camera
34
- def stop_camera(state):
35
- cap = state.get("cap")
36
- if cap is not None:
37
- cap.release()
38
- return np.zeros((480,640,3), dtype=np.uint8), {}, state
39
-
40
- # Detect frame and update cumulative count
41
- def detect_camera(state):
42
- cap = state.get("cap")
43
- if cap is None or not cap.isOpened():
44
- return np.zeros((480,640,3), dtype=np.uint8), state.get("cumulative", {}), state
45
-
46
  ret, frame = cap.read()
47
  if not ret:
48
- return np.zeros((480,640,3), dtype=np.uint8), state.get("cumulative", {}), state
 
 
 
 
 
49
 
 
 
 
 
 
 
 
50
  results = model.predict(frame, imgsz=640)
51
  annotated = results[0].plot()
52
- cumulative = count_objects(results, state.get("cumulative", {}))
53
- state["cumulative"] = cumulative
54
 
55
- return annotated, cumulative, state
 
56
 
 
 
 
57
  with gr.Blocks(title="Rix Detection") as demo:
 
58
  gr.Markdown("## 🛠️ Object Counting Dashboard")
59
 
60
  with gr.Tabs():
61
 
62
- # Tab 1: Image
63
  with gr.Tab("Image Detection"):
64
  img_input = gr.Image(type="numpy", label="Upload Image")
65
  img_out = gr.Image(label="Result Image")
66
  dashboard1 = gr.JSON(label="Counts")
67
  btn1 = gr.Button("Detect")
68
- btn1.click(fn=lambda img: detect_image(img), inputs=img_input, outputs=[img_out, dashboard1])
69
 
70
- # Tab 2: Video
71
  with gr.Tab("Video Detection"):
72
  video_input = gr.Video(label="Upload Video")
73
  video_out = gr.Image(label="Demo Frame Result")
74
  dashboard2 = gr.JSON(label="Counts")
75
  btn2 = gr.Button("Detect Video")
76
- btn2.click(fn=lambda vid: detect_video(vid), inputs=video_input, outputs=[video_out, dashboard2])
77
 
78
- # Tab 3: Live Camera
79
  with gr.Tab("Live Camera"):
80
- camera_index = gr.Number(value=0, label="Camera Index")
 
 
81
  cam_out = gr.Image(label="Live Result")
82
  dashboard3 = gr.JSON(label="Cumulative Counts")
83
  state = gr.State()
84
 
85
- start_btn = gr.Button("Start Camera")
86
- stop_btn = gr.Button("Stop Camera")
87
 
88
- start_btn.click(fn=init_camera, inputs=camera_index, outputs=state)
89
  stop_btn.click(fn=stop_camera, inputs=state, outputs=[cam_out, dashboard3, state])
90
-
91
- cam_out.update(
92
- every=0.1,
93
- fn=detect_camera,
94
- inputs=state,
95
- outputs=[cam_out, dashboard3, state]
96
- )
97
 
98
  demo.launch()
 
3
  import cv2
4
  import numpy as np
5
 
6
+ # ======================================================
7
+ # Load YOLO model
8
+ # ======================================================
9
+ model = YOLO("rix_reg.pt") # change to your model
10
 
11
  def get_model_names():
12
  if hasattr(model, "names") and model.names is not None:
 
15
  return model.model.names
16
  return {}
17
 
18
+ # ======================================================
19
+ # Function to count all objects
20
+ # ======================================================
21
  def count_objects(results, cumulative=None):
22
  names = get_model_names()
23
  counter = cumulative if cumulative is not None else {}
 
24
  for r in results:
25
  for cls_id in r.boxes.cls:
26
  cls_id = int(cls_id)
27
  label = str(names[cls_id])
28
  counter[label] = counter.get(label, 0) + 1
 
29
  counter["Total"] = sum(counter.get(k, 0) for k in counter if k != "Total")
30
  return counter
31
 
32
+ # ======================================================
33
+ # Image Detection
34
+ # ======================================================
35
+ def detect_image(img):
36
+ results = model.predict(img, imgsz=640)
37
+ annotated = results[0].plot()
38
+ dashboard = count_objects(results)
39
+ return annotated, dashboard
40
+
41
+ # ======================================================
42
+ # Video Detection
43
+ # ======================================================
44
+ def detect_video(video_path):
45
+ cap = cv2.VideoCapture(video_path)
 
 
 
 
46
  ret, frame = cap.read()
47
  if not ret:
48
+ return None, {"Error": "Cannot read video"}
49
+ results = model.predict(frame, imgsz=640)
50
+ annotated = results[0].plot()
51
+ dashboard = count_objects(results)
52
+ cap.release()
53
+ return annotated, dashboard
54
 
55
+ # ======================================================
56
+ # Live Camera Detection with cumulative counting
57
+ # ======================================================
58
+ def init_camera():
59
+ return {} # reset cumulative counts
60
+
61
+ def detect_camera(frame, state):
62
  results = model.predict(frame, imgsz=640)
63
  annotated = results[0].plot()
64
+ cumulative = count_objects(results, state)
65
+ return annotated, cumulative, cumulative
66
 
67
+ def stop_camera(state):
68
+ return np.zeros((480,640,3), dtype=np.uint8), {}, state
69
 
70
+ # ======================================================
71
+ # GRADIO Interface
72
+ # ======================================================
73
  with gr.Blocks(title="Rix Detection") as demo:
74
+
75
  gr.Markdown("## 🛠️ Object Counting Dashboard")
76
 
77
  with gr.Tabs():
78
 
79
+ # ==================== TAB 1: Image ====================
80
  with gr.Tab("Image Detection"):
81
  img_input = gr.Image(type="numpy", label="Upload Image")
82
  img_out = gr.Image(label="Result Image")
83
  dashboard1 = gr.JSON(label="Counts")
84
  btn1 = gr.Button("Detect")
85
+ btn1.click(fn=detect_image, inputs=img_input, outputs=[img_out, dashboard1])
86
 
87
+ # ==================== TAB 2: Video ====================
88
  with gr.Tab("Video Detection"):
89
  video_input = gr.Video(label="Upload Video")
90
  video_out = gr.Image(label="Demo Frame Result")
91
  dashboard2 = gr.JSON(label="Counts")
92
  btn2 = gr.Button("Detect Video")
93
+ btn2.click(fn=detect_video, inputs=video_input, outputs=[video_out, dashboard2])
94
 
95
+ # ==================== TAB 3: Live Camera ====================
96
  with gr.Tab("Live Camera"):
97
+ gr.Markdown("### Select camera and start detection")
98
+ cam_index = gr.Number(value=0, label="Camera Index (e.g., 0, 1)")
99
+ cam_input = gr.Image(source="webcam", type="numpy", label="Camera Feed")
100
  cam_out = gr.Image(label="Live Result")
101
  dashboard3 = gr.JSON(label="Cumulative Counts")
102
  state = gr.State()
103
 
104
+ start_btn = gr.Button("Start Detection")
105
+ stop_btn = gr.Button("Stop Detection")
106
 
107
+ start_btn.click(fn=init_camera, inputs=None, outputs=state)
108
  stop_btn.click(fn=stop_camera, inputs=state, outputs=[cam_out, dashboard3, state])
109
+
110
+ cam_input.stream(fn=detect_camera, inputs=[cam_input, state], outputs=[cam_out, dashboard3, state])
 
 
 
 
 
111
 
112
  demo.launch()