NV9523 commited on
Commit
1379c9a
·
verified ·
1 Parent(s): 90a85b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -37
app.py CHANGED
@@ -6,7 +6,7 @@ import numpy as np
6
  # ======================================================
7
  # Load YOLO model
8
  # ======================================================
9
- model = YOLO("rix_reg.pt") # change to your model
10
 
11
  def get_model_names():
12
  if hasattr(model, "names") and model.names is not None:
@@ -16,59 +16,71 @@ def get_model_names():
16
  return {}
17
 
18
  # ======================================================
19
- # Function to count all objects
20
  # ======================================================
21
- def count_objects(results, cumulative=None):
22
  names = get_model_names()
23
- counter = cumulative if cumulative is not None else {}
 
24
  for r in results:
25
  for cls_id in r.boxes.cls:
26
  cls_id = int(cls_id)
27
  label = str(names[cls_id])
28
- counter[label] = counter.get(label, 0) + 1
29
- counter["Total"] = sum(counter.get(k, 0) for k in counter if k != "Total")
 
 
 
 
 
 
30
  return counter
31
 
 
32
  # ======================================================
33
- # Image Detection
34
  # ======================================================
35
  def detect_image(img):
36
  results = model.predict(img, imgsz=640)
37
  annotated = results[0].plot()
 
38
  dashboard = count_objects(results)
39
  return annotated, dashboard
40
 
 
41
  # ======================================================
42
- # Video Detection
43
  # ======================================================
44
  def detect_video(video_path):
45
  cap = cv2.VideoCapture(video_path)
 
46
  ret, frame = cap.read()
47
  if not ret:
48
  return None, {"Error": "Cannot read video"}
 
 
49
  results = model.predict(frame, imgsz=640)
50
  annotated = results[0].plot()
 
51
  dashboard = count_objects(results)
52
  cap.release()
 
53
  return annotated, dashboard
54
 
 
55
  # ======================================================
56
- # Live Camera Detection with cumulative counting
57
  # ======================================================
58
- def init_camera():
59
- return {} # reset cumulative counts
60
-
61
- def detect_camera(frame, state):
62
  results = model.predict(frame, imgsz=640)
63
  annotated = results[0].plot()
64
- cumulative = count_objects(results, state)
65
- return annotated, cumulative, cumulative
66
 
67
- def stop_camera(state):
68
- return np.zeros((480,640,3), dtype=np.uint8), {}, state
 
69
 
70
  # ======================================================
71
- # GRADIO Interface
72
  # ======================================================
73
  with gr.Blocks(title="Rix Detection") as demo:
74
 
@@ -76,37 +88,44 @@ with gr.Blocks(title="Rix Detection") as demo:
76
 
77
  with gr.Tabs():
78
 
79
- # ==================== TAB 1: Image ====================
80
  with gr.Tab("Image Detection"):
81
  img_input = gr.Image(type="numpy", label="Upload Image")
82
  img_out = gr.Image(label="Result Image")
83
  dashboard1 = gr.JSON(label="Counts")
 
84
  btn1 = gr.Button("Detect")
85
- btn1.click(fn=detect_image, inputs=img_input, outputs=[img_out, dashboard1])
86
 
87
- # ==================== TAB 2: Video ====================
 
 
 
 
 
 
88
  with gr.Tab("Video Detection"):
89
  video_input = gr.Video(label="Upload Video")
90
  video_out = gr.Image(label="Demo Frame Result")
91
  dashboard2 = gr.JSON(label="Counts")
92
- btn2 = gr.Button("Detect Video")
93
- btn2.click(fn=detect_video, inputs=video_input, outputs=[video_out, dashboard2])
94
-
95
- # ==================== TAB 3: Live Camera ====================
96
- with gr.Tab("Live Camera"):
97
- gr.Markdown("### Select camera and start detection")
98
- cam_index = gr.Number(value=0, label="Camera Index (e.g., 0, 1)")
99
- cam_input = gr.Image(source="webcam", type="numpy", label="Camera Feed")
100
- cam_out = gr.Image(label="Live Result")
101
- dashboard3 = gr.JSON(label="Cumulative Counts")
102
- state = gr.State()
103
 
104
- start_btn = gr.Button("Start Detection")
105
- stop_btn = gr.Button("Stop Detection")
106
 
107
- start_btn.click(fn=init_camera, inputs=None, outputs=state)
108
- stop_btn.click(fn=stop_camera, inputs=state, outputs=[cam_out, dashboard3, state])
 
 
 
109
 
110
- cam_input.stream(fn=detect_camera, inputs=[cam_input, state], outputs=[cam_out, dashboard3, state])
 
 
 
 
 
 
 
 
 
 
111
 
112
  demo.launch()
 
6
  # ======================================================
7
  # Load YOLO model
8
  # ======================================================
9
+ model = YOLO("rix_reg.pt") # change to your model
10
 
11
  def get_model_names():
12
  if hasattr(model, "names") and model.names is not None:
 
16
  return {}
17
 
18
  # ======================================================
19
+ # Function to count all objects in the model
20
  # ======================================================
21
+ def count_objects(results):
22
  names = get_model_names()
23
+ counter = {}
24
+
25
  for r in results:
26
  for cls_id in r.boxes.cls:
27
  cls_id = int(cls_id)
28
  label = str(names[cls_id])
29
+
30
+ # increment count
31
+ if label not in counter:
32
+ counter[label] = 1
33
+ else:
34
+ counter[label] += 1
35
+
36
+ counter["Total"] = sum(counter.get(k, 0) for k in counter)
37
  return counter
38
 
39
+
40
  # ======================================================
41
+ # Tab 1 - Image processing
42
  # ======================================================
43
  def detect_image(img):
44
  results = model.predict(img, imgsz=640)
45
  annotated = results[0].plot()
46
+
47
  dashboard = count_objects(results)
48
  return annotated, dashboard
49
 
50
+
51
  # ======================================================
52
+ # Tab 2 - Video processing
53
  # ======================================================
54
  def detect_video(video_path):
55
  cap = cv2.VideoCapture(video_path)
56
+
57
  ret, frame = cap.read()
58
  if not ret:
59
  return None, {"Error": "Cannot read video"}
60
+
61
+ # demo first frame
62
  results = model.predict(frame, imgsz=640)
63
  annotated = results[0].plot()
64
+
65
  dashboard = count_objects(results)
66
  cap.release()
67
+
68
  return annotated, dashboard
69
 
70
+
71
  # ======================================================
72
+ # Tab 3 - Live camera
73
  # ======================================================
74
+ def detect_camera(frame):
 
 
 
75
  results = model.predict(frame, imgsz=640)
76
  annotated = results[0].plot()
 
 
77
 
78
+ dashboard = count_objects(results)
79
+ return annotated, dashboard
80
+
81
 
82
  # ======================================================
83
+ # GRADIO interface
84
  # ======================================================
85
  with gr.Blocks(title="Rix Detection") as demo:
86
 
 
88
 
89
  with gr.Tabs():
90
 
91
+ # ==================== TAB 1 ====================
92
  with gr.Tab("Image Detection"):
93
  img_input = gr.Image(type="numpy", label="Upload Image")
94
  img_out = gr.Image(label="Result Image")
95
  dashboard1 = gr.JSON(label="Counts")
96
+
97
  btn1 = gr.Button("Detect")
 
98
 
99
+ btn1.click(
100
+ fn=detect_image,
101
+ inputs=img_input,
102
+ outputs=[img_out, dashboard1]
103
+ )
104
+
105
+ # ==================== TAB 2 ====================
106
  with gr.Tab("Video Detection"):
107
  video_input = gr.Video(label="Upload Video")
108
  video_out = gr.Image(label="Demo Frame Result")
109
  dashboard2 = gr.JSON(label="Counts")
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ btn2 = gr.Button("Detect Video")
 
112
 
113
+ btn2.click(
114
+ fn=detect_video,
115
+ inputs=video_input,
116
+ outputs=[video_out, dashboard2]
117
+ )
118
 
119
+ # ==================== TAB 3 ====================
120
+ with gr.Tab("Live Camera"):
121
+ cam_input = gr.Image(sources=["webcam"], type="numpy", label="Camera")
122
+ cam_out = gr.Image(label="Real-time Result")
123
+ dashboard3 = gr.JSON(label="Counts")
124
+
125
+ cam_input.stream(
126
+ fn=detect_camera,
127
+ inputs=cam_input,
128
+ outputs=[cam_out, dashboard3]
129
+ )
130
 
131
  demo.launch()