ElodieA commited on
Commit
4dff9c4
·
verified ·
1 Parent(s): d4d3a7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -15
app.py CHANGED
@@ -5,6 +5,17 @@ from ultralytics import YOLO
5
  import pandas as pd
6
  import matplotlib.pyplot as plt
7
 
 
 
 
 
 
 
 
 
 
 
 
8
  def process_video(video_file):
9
  # Define colors for each class (8 classes)
10
  colors = [
@@ -18,9 +29,6 @@ def process_video(video_file):
18
  (0, 128, 128) # Class 7 - Teal
19
  ]
20
 
21
- # Define class names (example names, replace with actual class names if available)
22
- class_names = ['Hymenoptera', 'Mantodea', 'Odonata', 'Orthoptera', 'Coleoptera', 'Lepidoptera', 'Hemiptera']
23
-
24
  # Load the YOLOv8 model
25
  model = YOLO("insect_detection4.pt")
26
 
@@ -28,7 +36,7 @@ def process_video(video_file):
28
  cap = cv2.VideoCapture(video_file)
29
 
30
  # Prepare DataFrame for storing detection data
31
- columns = ["frame", "id", "class", "x", "y", "w", "h"]
32
  df = pd.DataFrame(columns=columns)
33
 
34
  frame_id = 0
@@ -53,13 +61,13 @@ def process_video(video_file):
53
  for i, box in enumerate(boxes):
54
  class_id = int(class_ids[i])
55
  confidence = confidences[i]
56
- unique_id = int(box.id[0]) # Ensure the ID is an integer
57
 
58
  # Append detection data to DataFrame
59
  new_row = pd.DataFrame({
60
  "frame": [frame_id],
61
- "id": [unique_id],
62
- "class": [int(box.cls[0])],
63
  "x": [box.xywh[0][0]],
64
  "y": [box.xywh[0][1]],
65
  "w": [box.xywh[0][2]],
@@ -68,7 +76,7 @@ def process_video(video_file):
68
  df = pd.concat([df, new_row], ignore_index=True)
69
 
70
  # Crop and save the image of the insect
71
- if unique_id not in unique_insect_crops:
72
  x_center, y_center, width, height = box.xywh[0]
73
  x1 = int(x_center - width / 2)
74
  y1 = int(y_center - height / 2)
@@ -77,7 +85,7 @@ def process_video(video_file):
77
  insect_crop = frame[y1:y2, x1:x2]
78
  crop_path = tempfile.mktemp(suffix=".png")
79
  cv2.imwrite(crop_path, insect_crop)
80
- unique_insect_crops[unique_id] = crop_path
81
 
82
  else:
83
  break
@@ -94,9 +102,10 @@ def process_video(video_file):
94
 
95
  # Create the plot from the CSV data
96
  plt.figure(figsize=(10, 6))
97
- for unique_id in df_from_csv['id'].unique():
98
- frames = df_from_csv[df_from_csv['id'] == unique_id]['frame']
99
- plt.plot(frames, [unique_id] * len(frames), 'o-', label=f'Insect {unique_id}')
 
100
 
101
  plt.xlabel('Frame')
102
  plt.ylabel('Insect ID')
@@ -109,17 +118,19 @@ def process_video(video_file):
109
  plt.savefig(plot_path)
110
  plt.close()
111
 
112
- gallery_items = [(crop_path, f'Insect {unique_id}') for unique_id, crop_path in unique_insect_crops.items()]
113
 
114
  return plot_path, gallery_items, csv_path
115
 
116
  # Create a Gradio interface
117
- inputs = gr.Video(label="Input Insect Trap Video")
 
 
118
  outputs = [
119
  gr.Image(label="Insect Detection Plot"),
120
  gr.Gallery(label="Unique Insect Crops"), # Added a gallery to display insect crops with labels
121
  gr.File(label="Download CSV")
122
  ]
123
 
124
- gr.Interface(fn=process_video, inputs=inputs, outputs=outputs).launch()
125
 
 
5
  import pandas as pd
6
  import matplotlib.pyplot as plt
7
 
8
+ # Define the label mapping
9
+ label_mapping = {
10
+ 0: 'Hymenoptera',
11
+ 1: 'Mantodea',
12
+ 2: 'Odonata',
13
+ 3: 'Orthoptera',
14
+ 4: 'Coleoptera',
15
+ 5: 'Lepidoptera',
16
+ 6: 'Hemiptera'
17
+ }
18
+
19
  def process_video(video_file):
20
  # Define colors for each class (8 classes)
21
  colors = [
 
29
  (0, 128, 128) # Class 7 - Teal
30
  ]
31
 
 
 
 
32
  # Load the YOLOv8 model
33
  model = YOLO("insect_detection4.pt")
34
 
 
36
  cap = cv2.VideoCapture(video_file)
37
 
38
  # Prepare DataFrame for storing detection data
39
+ columns = ["frame", "insect_id", "class", "x", "y", "w", "h"]
40
  df = pd.DataFrame(columns=columns)
41
 
42
  frame_id = 0
 
61
  for i, box in enumerate(boxes):
62
  class_id = int(class_ids[i])
63
  confidence = confidences[i]
64
+ insect_id = int(box.id[0]) # Ensure the ID is an integer
65
 
66
  # Append detection data to DataFrame
67
  new_row = pd.DataFrame({
68
  "frame": [frame_id],
69
+ "insect_id": [insect_id],
70
+ "class": [class_id],
71
  "x": [box.xywh[0][0]],
72
  "y": [box.xywh[0][1]],
73
  "w": [box.xywh[0][2]],
 
76
  df = pd.concat([df, new_row], ignore_index=True)
77
 
78
  # Crop and save the image of the insect
79
+ if insect_id not in unique_insect_crops:
80
  x_center, y_center, width, height = box.xywh[0]
81
  x1 = int(x_center - width / 2)
82
  y1 = int(y_center - height / 2)
 
85
  insect_crop = frame[y1:y2, x1:x2]
86
  crop_path = tempfile.mktemp(suffix=".png")
87
  cv2.imwrite(crop_path, insect_crop)
88
+ unique_insect_crops[insect_id] = (crop_path, label_mapping[class_id])
89
 
90
  else:
91
  break
 
102
 
103
  # Create the plot from the CSV data
104
  plt.figure(figsize=(10, 6))
105
+ for insect_id in df_from_csv['insect_id'].unique():
106
+ frames = df_from_csv[df_from_csv['insect_id'] == insect_id]['frame']
107
+ insect_class = label_mapping[df_from_csv[df_from_csv['insect_id'] == insect_id]['class'].values[0]]
108
+ plt.plot(frames, [insect_id] * len(frames), 'o-', label=f'{insect_class} {insect_id}')
109
 
110
  plt.xlabel('Frame')
111
  plt.ylabel('Insect ID')
 
118
  plt.savefig(plot_path)
119
  plt.close()
120
 
121
+ gallery_items = [(crop_path, f'{label} {insect_id}') for insect_id, (crop_path, label) in unique_insect_crops.items()]
122
 
123
  return plot_path, gallery_items, csv_path
124
 
125
  # Create a Gradio interface
126
+ example_video = "path_to_example_video.mp4" # Replace with the actual path to your example video
127
+
128
+ inputs = gr.Video(label="Input Insect Trap Video", value=example_video)
129
  outputs = [
130
  gr.Image(label="Insect Detection Plot"),
131
  gr.Gallery(label="Unique Insect Crops"), # Added a gallery to display insect crops with labels
132
  gr.File(label="Download CSV")
133
  ]
134
 
135
+ gr.Interface(fn=process_video, inputs=inputs, outputs=outputs, examples=[example_video]).launch()
136