Arizal Firdaus Bagus Pratama commited on
Commit
e1cbdac
·
verified ·
1 Parent(s): c1c1fef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -25
app.py CHANGED
@@ -9,10 +9,11 @@ import gradio as gr
9
  import os
10
 
11
  # Import the Sort class from the local 'sort.py' file
 
12
  from sort import Sort
13
 
14
  # --- LOAD MODELS AND TRACKER ONCE (PENTING!) ---
15
- # This part runs only once when the app starts, so we don't reload the model for every user.
16
  print("Loading model and processor...")
17
  model_checkpoint = "facebook/detr-resnet-50"
18
  image_processor = AutoImageProcessor.from_pretrained(model_checkpoint)
@@ -26,7 +27,7 @@ print("Model loaded successfully.")
26
  # ---------------------------------------------------
27
 
28
  def iou(boxA, boxB):
29
- # Standard IoU calculation
30
  xA = max(boxA[0], boxB[0])
31
  yA = max(boxA[1], boxB[1])
32
  xB = min(boxA[2], boxB[2])
@@ -37,33 +38,41 @@ def iou(boxA, boxB):
37
  iou_score = interArea / float(boxAArea + boxBArea - interArea)
38
  return iou_score
39
 
40
- # --- THE MAIN PROCESSING FUNCTION ---
41
  def process_video(input_video_path):
42
- # Initialize tracker and counters for each new video
43
- tracker = Sort(min_hits=1, iou_threshold=0.3)
44
  total_counts = {'person': 0, 'bicycle': 0, 'car': 0, 'motorcycle': 0}
45
  counted_ids = set()
46
 
47
- # Define the output path for the processed video
48
  output_video_path = "output.mp4"
49
 
50
  cap = cv2.VideoCapture(input_video_path)
51
  if not cap.isOpened():
52
  raise gr.Error(f"Could not open video file.")
53
 
54
- frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
55
- frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
56
  fps = int(cap.get(cv2.CAP_PROP_FPS))
57
 
58
- # Use 'mp4v' codec which is widely compatible
59
- out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
60
-
 
 
 
 
 
61
  while True:
62
  ret, frame = cap.read()
63
  if not ret:
64
  break
65
 
66
- # --- (Logic from our notebook goes here) ---
 
 
 
 
 
67
  pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
68
  inputs = image_processor(images=pil_image, return_tensors="pt").to(device)
69
  with torch.no_grad():
@@ -71,6 +80,7 @@ def process_video(input_video_path):
71
  target_sizes = torch.tensor([pil_image.size[::-1]])
72
  results = image_processor.post_process_object_detection(outputs, threshold=0.6, target_sizes=target_sizes)[0]
73
 
 
74
  detections_for_sort = []
75
  original_detections = []
76
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
@@ -80,12 +90,15 @@ def process_video(input_video_path):
80
  detections_for_sort.append([box_list[0], box_list[1], box_list[2], box_list[3], score.item()])
81
  original_detections.append({'box': box_list, 'label': label_name})
82
 
 
83
  tracked_objects_raw = []
84
  if len(detections_for_sort) > 0:
85
  tracked_objects_raw = tracker.update(np.array(detections_for_sort))
86
 
 
87
  for obj in tracked_objects_raw:
88
  x1, y1, x2, y2, obj_id = [int(val) for val in obj]
 
89
  best_iou = 0
90
  best_label = None
91
  for det in original_detections:
@@ -94,6 +107,7 @@ def process_video(input_video_path):
94
  best_iou = iou_score
95
  best_label = det['label']
96
 
 
97
  if best_label and obj_id not in counted_ids:
98
  total_counts[best_label] += 1
99
  counted_ids.add(obj_id)
@@ -102,6 +116,7 @@ def process_video(input_video_path):
102
  cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
103
  cv2.putText(frame, f'{best_label} ID: {obj_id}', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
104
 
 
105
  y_offset = 30
106
  for obj_name, count in total_counts.items():
107
  text = f'Total {obj_name.capitalize()}: {count}'
@@ -114,40 +129,32 @@ def process_video(input_video_path):
114
  cap.release()
115
  out.release()
116
 
117
- # Return the path to the processed video
118
  return output_video_path
119
 
120
- # --- GRADIO INTERFACE ---
121
-
122
- # Build the layout with gr.Blocks
123
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
124
- # 1. Title and Description (no change)
125
  gr.Markdown("<h1>Real-Time Object Tracking & Counting with DETR and SORT</h1>")
126
  gr.Markdown("Upload a video to see object detection and tracking in action. This demo uses Facebook's DETR model for detection and the SORT algorithm to assign unique IDs and count objects.")
127
 
128
- # 2. Main Row for Input and Output
129
  with gr.Row():
130
- # We set a fixed width and height for both video components
131
- # to prevent the layout from shifting.
132
  input_video = gr.Video(label="Input Video", width=640, height=360)
133
  output_video = gr.Video(label="Processed Video", width=640, height=360)
134
 
135
- # 3. Submit Button (no change)
136
  submit_button = gr.Button("Submit", variant="primary")
137
 
138
- # 4. Examples (no change)
139
  gr.Examples(
140
  examples=[['5402016-hd_1920_1080_30fps.mp4']],
141
  inputs=input_video,
142
  label="Click an example to run"
143
  )
144
 
145
- # 5. Link button to function (no change)
146
  submit_button.click(
147
  fn=process_video,
148
  inputs=input_video,
149
  outputs=output_video
150
  )
151
 
152
- # Launch the app
153
- demo.launch()
 
9
  import os
10
 
11
  # Import the Sort class from the local 'sort.py' file
12
+ # Pastikan file 'sort.py' ada di direktori yang sama dengan app.py
13
  from sort import Sort
14
 
15
  # --- LOAD MODELS AND TRACKER ONCE (PENTING!) ---
16
+ # Bagian ini hanya berjalan sekali saat aplikasi dimulai.
17
  print("Loading model and processor...")
18
  model_checkpoint = "facebook/detr-resnet-50"
19
  image_processor = AutoImageProcessor.from_pretrained(model_checkpoint)
 
27
  # ---------------------------------------------------
28
 
29
  def iou(boxA, boxB):
30
+ # Fungsi untuk menghitung Intersection over Union (IoU)
31
  xA = max(boxA[0], boxB[0])
32
  yA = max(boxA[1], boxB[1])
33
  xB = min(boxA[2], boxB[2])
 
38
  iou_score = interArea / float(boxAArea + boxBArea - interArea)
39
  return iou_score
40
 
41
+ # --- FUNGSI PEMROSESAN UTAMA ---
42
  def process_video(input_video_path):
43
+ # Inisialisasi tracker dan penghitung untuk setiap video baru
44
+ tracker = Sort(min_hits=3, iou_threshold=0.3)
45
  total_counts = {'person': 0, 'bicycle': 0, 'car': 0, 'motorcycle': 0}
46
  counted_ids = set()
47
 
48
+ # Tentukan path output untuk video yang diproses
49
  output_video_path = "output.mp4"
50
 
51
  cap = cv2.VideoCapture(input_video_path)
52
  if not cap.isOpened():
53
  raise gr.Error(f"Could not open video file.")
54
 
 
 
55
  fps = int(cap.get(cv2.CAP_PROP_FPS))
56
 
57
+ # --- OPTIMISASI: Atur resolusi baru yang lebih kecil ---
58
+ new_width = 960
59
+ new_height = 540
60
+
61
+ # Gunakan codec 'mp4v' yang kompatibel dan resolusi baru
62
+ out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (new_width, new_height))
63
+
64
+ frame_number = 0
65
  while True:
66
  ret, frame = cap.read()
67
  if not ret:
68
  break
69
 
70
+ frame_number += 1
71
+
72
+ # --- OPTIMISASI: Ubah ukuran setiap frame sebelum dideteksi ---
73
+ frame = cv2.resize(frame, (new_width, new_height))
74
+
75
+ # 1. Deteksi objek dengan DETR
76
  pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
77
  inputs = image_processor(images=pil_image, return_tensors="pt").to(device)
78
  with torch.no_grad():
 
80
  target_sizes = torch.tensor([pil_image.size[::-1]])
81
  results = image_processor.post_process_object_detection(outputs, threshold=0.6, target_sizes=target_sizes)[0]
82
 
83
+ # 2. Format deteksi untuk SORT
84
  detections_for_sort = []
85
  original_detections = []
86
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
 
90
  detections_for_sort.append([box_list[0], box_list[1], box_list[2], box_list[3], score.item()])
91
  original_detections.append({'box': box_list, 'label': label_name})
92
 
93
+ # 3. Update tracker
94
  tracked_objects_raw = []
95
  if len(detections_for_sort) > 0:
96
  tracked_objects_raw = tracker.update(np.array(detections_for_sort))
97
 
98
+ # 4. Logika Penghitungan & Visualisasi
99
  for obj in tracked_objects_raw:
100
  x1, y1, x2, y2, obj_id = [int(val) for val in obj]
101
+
102
  best_iou = 0
103
  best_label = None
104
  for det in original_detections:
 
107
  best_iou = iou_score
108
  best_label = det['label']
109
 
110
+ # Hitung objek jika ID-nya baru
111
  if best_label and obj_id not in counted_ids:
112
  total_counts[best_label] += 1
113
  counted_ids.add(obj_id)
 
116
  cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
117
  cv2.putText(frame, f'{best_label} ID: {obj_id}', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
118
 
119
+ # Tampilkan total hitungan kumulatif
120
  y_offset = 30
121
  for obj_name, count in total_counts.items():
122
  text = f'Total {obj_name.capitalize()}: {count}'
 
129
  cap.release()
130
  out.release()
131
 
132
+ print(f"Video processing finished. Total frames: {frame_number}")
133
  return output_video_path
134
 
135
+ # --- ANTARMUKA GRADIO (Dengan Layout Stabil) ---
 
 
136
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
 
137
  gr.Markdown("<h1>Real-Time Object Tracking & Counting with DETR and SORT</h1>")
138
  gr.Markdown("Upload a video to see object detection and tracking in action. This demo uses Facebook's DETR model for detection and the SORT algorithm to assign unique IDs and count objects.")
139
 
 
140
  with gr.Row():
141
+ # Atur ukuran video yang tetap untuk mencegah layout "melompat"
 
142
  input_video = gr.Video(label="Input Video", width=640, height=360)
143
  output_video = gr.Video(label="Processed Video", width=640, height=360)
144
 
 
145
  submit_button = gr.Button("Submit", variant="primary")
146
 
 
147
  gr.Examples(
148
  examples=[['5402016-hd_1920_1080_30fps.mp4']],
149
  inputs=input_video,
150
  label="Click an example to run"
151
  )
152
 
 
153
  submit_button.click(
154
  fn=process_video,
155
  inputs=input_video,
156
  outputs=output_video
157
  )
158
 
159
+ # Jalankan aplikasi
160
+ demo.launch()