Shantanukadam commited on
Commit
410dce1
·
verified ·
1 Parent(s): 6518693

Upload 4 files

Browse files
Files changed (4) hide show
  1. finetune.py +359 -0
  2. infer.py +205 -0
  3. paths.txt +8 -0
  4. requirements.txt +10 -0
finetune.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import shutil
4
+ import xml.etree.ElementTree as ET
5
+ from pathlib import Path
6
+ import numpy as np
7
+ import cv2
8
+ import matplotlib.pyplot as plt
9
+ from tqdm import tqdm
10
+ from sklearn.model_selection import train_test_split
11
+ import yaml
12
+ from ultralytics import YOLO
13
+ import torch
14
+
15
+ # Configuration
16
+ PROJECT_DIR = Path("gun_detection_project")
17
+ TRAIN_DIR = Path("Train") # Directory containing training data
18
+ TEST_DIR = Path("test") # Directory containing test data
19
+ LABELS = ["weapon"] # Single class for all weapons
20
+ TRAIN_VAL_SPLIT = 0.9 # 90% training, 10% validation from training data
21
+
22
+ def create_project_structure():
23
+ """Create project directory structure"""
24
+ # Create main directories
25
+ dirs = [
26
+ PROJECT_DIR,
27
+ PROJECT_DIR / "data",
28
+ PROJECT_DIR / "data" / "images" / "train",
29
+ PROJECT_DIR / "data" / "images" / "val",
30
+ PROJECT_DIR / "data" / "images" / "test",
31
+ PROJECT_DIR / "data" / "labels" / "train",
32
+ PROJECT_DIR / "data" / "labels" / "val",
33
+ PROJECT_DIR / "data" / "labels" / "test",
34
+ PROJECT_DIR / "weights",
35
+ PROJECT_DIR / "results"
36
+ ]
37
+
38
+ for dir_path in dirs:
39
+ dir_path.mkdir(parents=True, exist_ok=True)
40
+ print(f"Created directory: {dir_path}")
41
+
42
+ return True
43
+
44
+ def convert_bbox_to_yolo(size, box):
45
+ """Convert VOC bbox to YOLO format"""
46
+ dw = 1.0 / size[0]
47
+ dh = 1.0 / size[1]
48
+
49
+ # VOC format: xmin, ymin, xmax, ymax
50
+ # YOLO format: x_center, y_center, width, height (normalized)
51
+ x = (box[0] + box[2]) / 2.0
52
+ y = (box[1] + box[3]) / 2.0
53
+ w = box[2] - box[0]
54
+ h = box[3] - box[1]
55
+
56
+ # Normalize
57
+ x = x * dw
58
+ w = w * dw
59
+ y = y * dh
60
+ h = h * dh
61
+
62
+ return x, y, w, h
63
+
64
+ def convert_annotation(xml_file, output_path, class_mapping):
65
+ """Convert XML annotation to YOLO txt format"""
66
+ try:
67
+ tree = ET.parse(xml_file)
68
+ root = tree.getroot()
69
+
70
+ size = root.find('size')
71
+ width = int(size.find('width').text)
72
+ height = int(size.find('height').text)
73
+
74
+ with open(output_path, 'w') as out_file:
75
+ for obj in root.iter('object'):
76
+ cls = obj.find('name').text.lower()
77
+
78
+ # Map any weapon-related class to our single "weapon" class
79
+ if cls in ["weapon", "gun", "pistol", "rifle", "firearm", "handgun"]:
80
+ cls_id = 0 # Always use index 0 for the single "weapon" class
81
+ else:
82
+ print(f"Warning: Unknown class '{cls}' in {xml_file}, skipping object")
83
+ continue
84
+
85
+ xmlbox = obj.find('bndbox')
86
+ b = (
87
+ float(xmlbox.find('xmin').text),
88
+ float(xmlbox.find('ymin').text),
89
+ float(xmlbox.find('xmax').text),
90
+ float(xmlbox.find('ymax').text)
91
+ )
92
+
93
+ # Convert to YOLO format
94
+ bb = convert_bbox_to_yolo((width, height), b)
95
+
96
+ # Write to output file
97
+ out_file.write(f"{cls_id} {bb[0]:.6f} {bb[1]:.6f} {bb[2]:.6f} {bb[3]:.6f}\n")
98
+
99
+ return True
100
+ except Exception as e:
101
+ print(f"Error processing {xml_file}: {e}")
102
+ return False
103
+
104
+ def prepare_dataset():
105
+ """Prepare the dataset by converting annotations and organizing files"""
106
+ # Process training data (with train/val split)
107
+ train_files = process_directory(TRAIN_DIR, ["train", "val"])
108
+
109
+ # Process test data (directly to test set)
110
+ test_files = process_directory(TEST_DIR, ["test"])
111
+
112
+ # Print dataset summary
113
+ total_files = sum(train_files.values()) + sum(test_files.values())
114
+ print(f"Total dataset files: {total_files}")
115
+ print(f"Training files: {train_files.get('train', 0)}")
116
+ print(f"Validation files: {train_files.get('val', 0)}")
117
+ print(f"Test files: {test_files.get('test', 0)}")
118
+
119
+ # Create data.yaml config file
120
+ create_data_yaml()
121
+
122
+ return train_files["train"], train_files["val"], test_files["test"]
123
+
124
+ def process_directory(source_dir, splits):
125
+ """Process a directory (train or test) and distribute files to specified splits"""
126
+ # Get all XML files in this directory
127
+ annotation_files = list(Path(source_dir / "Annotations").glob("*.xml"))
128
+ print(f"Found {len(annotation_files)} annotation files in {source_dir}")
129
+
130
+ # Extract image filenames from annotations
131
+ image_files = []
132
+ for xml_file in annotation_files:
133
+ tree = ET.parse(xml_file)
134
+ root = tree.getroot()
135
+ filename = root.find('filename').text
136
+
137
+ # Handle the case where XML filename might not match the actual image filename
138
+ img_file = Path(source_dir / "JPEGImages" / filename)
139
+ if not img_file.exists():
140
+ # Try matching by base name without extension
141
+ potential_matches = list(Path(source_dir / "JPEGImages").glob(f"{Path(filename).stem}.*"))
142
+ if potential_matches:
143
+ img_file = potential_matches[0]
144
+ else:
145
+ # Try using the XML filename with .jpg extension
146
+ img_file = Path(source_dir / "JPEGImages" / f"{xml_file.stem}.jpg")
147
+
148
+ if img_file.exists():
149
+ image_files.append((xml_file, img_file))
150
+ else:
151
+ print(f"Warning: No matching image found for {xml_file.name}")
152
+
153
+ print(f"Successfully matched {len(image_files)} annotation-image pairs in {source_dir}")
154
+
155
+ # Handle splits appropriately
156
+ file_pairs_by_split = {}
157
+
158
+ if "test" in splits and len(splits) == 1:
159
+ # If this is test directory, all goes to test split
160
+ file_pairs_by_split["test"] = image_files
161
+ else:
162
+ # If training directory, split into train/val
163
+ train_pairs, val_pairs = train_test_split(
164
+ image_files, train_size=TRAIN_VAL_SPLIT, random_state=42
165
+ )
166
+ file_pairs_by_split["train"] = train_pairs
167
+ file_pairs_by_split["val"] = val_pairs
168
+
169
+ # Process each split
170
+ counts = {}
171
+ for split_name, file_pairs in file_pairs_by_split.items():
172
+ process_dataset_split(file_pairs, split_name)
173
+ counts[split_name] = len(file_pairs)
174
+
175
+ return counts
176
+
177
+ def process_dataset_split(file_pairs, split_name):
178
+ """Process and copy files for a specific dataset split"""
179
+ class_mapping = LABELS
180
+ images_dir = PROJECT_DIR / "data" / "images" / split_name
181
+ labels_dir = PROJECT_DIR / "data" / "labels" / split_name
182
+
183
+ print(f"Processing {len(file_pairs)} files for {split_name} set")
184
+
185
+ for xml_file, img_file in tqdm(file_pairs):
186
+ # Copy image
187
+ dest_img = images_dir / img_file.name
188
+ shutil.copy(img_file, dest_img)
189
+
190
+ # Convert and save annotation
191
+ yolo_label = labels_dir / f"{xml_file.stem}.txt"
192
+ convert_annotation(xml_file, yolo_label, class_mapping)
193
+
194
+ def create_data_yaml():
195
+ """Create the data.yaml configuration file for YOLOv8"""
196
+ # Use absolute paths instead of relative paths
197
+ data = {
198
+ 'path': str(PROJECT_DIR.absolute() / "data"), # Make path absolute
199
+ 'train': str((PROJECT_DIR.absolute() / "data" / "images" / "train")), # Absolute path to train
200
+ 'val': str((PROJECT_DIR.absolute() / "data" / "images" / "val")), # Absolute path to val
201
+ 'test': str((PROJECT_DIR.absolute() / "data" / "images" / "test")), # Absolute path to test
202
+ 'names': {i: name for i, name in enumerate(LABELS)},
203
+ 'nc': len(LABELS)
204
+ }
205
+
206
+ with open(PROJECT_DIR / "data" / "dataset.yaml", 'w') as f:
207
+ yaml.dump(data, f, default_flow_style=False)
208
+
209
+ print(f"Created dataset configuration at {PROJECT_DIR / 'data' / 'dataset.yaml'}")
210
+
211
+ def train_model():
212
+ """Train the YOLOv8 model with optimal settings"""
213
+ print("Starting model training...")
214
+
215
+ # Load YOLOv8 model
216
+ model = YOLO('yolov8m.pt') # Medium size for balance of speed and accuracy
217
+
218
+ # Train the model with optimal hyperparameters
219
+ results = model.train(
220
+ data=str(PROJECT_DIR / "data" / "dataset.yaml"),
221
+ epochs=100,
222
+ patience=10, # Early stopping
223
+ batch=8,
224
+ imgsz=640,
225
+ pretrained=True,
226
+ optimizer='AdamW', # AdamW optimizer works well for detection tasks
227
+ lr0=0.001,
228
+ lrf=0.01,
229
+ weight_decay=0.0005, # L2 regularization to prevent overfitting
230
+ warmup_epochs=3,
231
+ cos_lr=True, # Cosine learning rate schedule
232
+ box=7.5, # Box loss gain
233
+ cls=0.5, # Class loss gain
234
+ dfl=1.5, # Distribution focal loss gain
235
+ val=True,
236
+ plots=True,
237
+ save=True,
238
+ save_period=10, # Save checkpoints every 10 epochs
239
+ project=str(PROJECT_DIR / "results"),
240
+ name='gun_detection',
241
+ exist_ok=True,
242
+ cache=False, # Cache images for faster training
243
+ device=0 if torch.cuda.is_available() else 'cpu',
244
+ amp=True, # Mixed precision for faster training
245
+ augment=True, # Use default augmentation
246
+ mixup=0.1, # Mix up augmentation
247
+ mosaic=1.0, # Mosaic augmentation
248
+ degrees=0.3, # Rotation augmentation (small for gun detection)
249
+ translate=0.1, # Translation augmentation
250
+ scale=0.5, # Scale augmentation
251
+ shear=0.0, # Shear augmentation (minimal for gun detection)
252
+ perspective=0.0, # Perspective augmentation (minimal for gun detection)
253
+ flipud=0.0, # No vertical flip for gun detection
254
+ fliplr=0.5, # Horizontal flip
255
+ hsv_h=0.015, # HSV hue augmentation
256
+ hsv_s=0.7, # HSV saturation augmentation
257
+ hsv_v=0.4, # HSV value augmentation
258
+ )
259
+
260
+ return results
261
+
262
+ def export_model(format='onnx'):
263
+ """Export the trained model to various formats"""
264
+ # Get the best model
265
+ best_model_path = list(Path(PROJECT_DIR / "results" / "gun_detection").glob('*.pt'))[0]
266
+ model = YOLO(best_model_path)
267
+
268
+ # Export the model
269
+ model.export(format=format)
270
+ print(f"Model exported to {format.upper()} format")
271
+
272
+ def run_inference(image_path):
273
+ """Run inference on a single image"""
274
+ # Get the best model
275
+ best_model_path = list(Path(PROJECT_DIR / "results" / "gun_detection").glob('*.pt'))[0]
276
+ model = YOLO(best_model_path)
277
+
278
+ # Run inference
279
+ results = model(image_path, conf=0.25)
280
+
281
+ # Plot results
282
+ for result in results:
283
+ boxes = result.boxes
284
+ print(f"Detected {len(boxes)} guns")
285
+
286
+ # Plot the image with detections
287
+ img = cv2.imread(image_path)
288
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
289
+
290
+ for box in boxes:
291
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
292
+ conf = float(box.conf[0])
293
+
294
+ # Draw bounding box
295
+ cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
296
+ cv2.putText(img, f"Gun: {conf:.2f}", (x1, y1 - 10),
297
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
298
+
299
+ plt.figure(figsize=(10, 8))
300
+ plt.imshow(img)
301
+ plt.title("Gun Detection Results")
302
+ plt.axis("off")
303
+ plt.savefig(PROJECT_DIR / "results" / "inference_example.png")
304
+ plt.close()
305
+
306
+ return results
307
+
308
+ def verify_dataset():
309
+ """Verify that label files contain data"""
310
+ empty_files = 0
311
+ total_files = 0
312
+
313
+ for split in ["train", "val", "test"]:
314
+ label_dir = PROJECT_DIR / "data" / "labels" / split
315
+ if not label_dir.exists():
316
+ continue
317
+
318
+ label_files = list(label_dir.glob("*.txt"))
319
+ total_files += len(label_files)
320
+
321
+ for label_file in label_files:
322
+ if label_file.stat().st_size == 0:
323
+ empty_files += 1
324
+
325
+ if empty_files > 0:
326
+ print(f"⚠️ WARNING: Found {empty_files}/{total_files} empty label files!")
327
+ print("Training will continue, treating empty label files as images without annotations.")
328
+ return True
329
+ else:
330
+ print(f"✅ All {total_files} label files contain data")
331
+ return True
332
+
333
+ def main():
334
+ """Main execution function"""
335
+ # Create project structure
336
+ create_project_structure()
337
+
338
+ # Prepare dataset
339
+ #train_count, val_count, test_count = prepare_dataset()
340
+ #print(f"Dataset prepared: {train_count} training samples, {val_count} validation samples, {test_count} test samples")
341
+
342
+ # Verify dataset before training
343
+ verify_dataset() # Warn but do not abort on empty label files
344
+
345
+ # Train model
346
+ results = train_model()
347
+ print("Training completed!")
348
+
349
+ # Export model
350
+ export_model(format='onnx')
351
+
352
+ # Optional: Run inference on a test image
353
+ test_images = list(Path(PROJECT_DIR / "data" / "images" / "test").glob("*.jpg"))
354
+ if test_images:
355
+ run_inference(str(test_images[0]))
356
+ print(f"Inference example saved to {PROJECT_DIR / 'results' / 'inference_example.png'}")
357
+
358
+ if __name__ == "__main__":
359
+ main()
infer.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ import cv2
4
+ import numpy as np
5
+ from ultralytics import YOLO
6
+ import time
7
+
8
+ def run_inference_on_image(model_path, image_path, conf_threshold=0.5, save_path=None):
9
+ """Run inference on a single image"""
10
+ # Load model
11
+ model = YOLO(model_path)
12
+
13
+ # Run inference
14
+ start_time = time.time()
15
+ results = model(image_path, conf=conf_threshold)
16
+ inference_time = time.time() - start_time
17
+
18
+ # Process results
19
+ img = cv2.imread(image_path)
20
+
21
+ # Draw results on image
22
+ for result in results:
23
+ boxes = result.boxes
24
+ print(f"Detected {len(boxes)} guns in {inference_time:.4f} seconds")
25
+
26
+ for box in boxes:
27
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
28
+ conf = float(box.conf[0])
29
+
30
+ # Draw bounding box
31
+ cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
32
+ cv2.putText(img, f"Gun: {conf:.2f}", (x1, y1 - 10),
33
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
34
+
35
+ # Save or display the result
36
+ if save_path:
37
+ cv2.imwrite(save_path, img)
38
+ print(f"Result saved to {save_path}")
39
+ else:
40
+ cv2.imshow("Gun Detection Result", img)
41
+ cv2.waitKey(0)
42
+ cv2.destroyAllWindows()
43
+
44
+ def run_inference_on_video(model_path, video_path, conf_threshold=0.55, save_path=None):
45
+ """Run inference on a video file"""
46
+ # Load model
47
+ model = YOLO(model_path)
48
+
49
+ # Open video
50
+ cap = cv2.VideoCapture(video_path)
51
+ if not cap.isOpened():
52
+ print(f"Error: Could not open video {video_path}")
53
+ return
54
+
55
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
56
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
57
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
58
+
59
+ # Create video writer if save_path is provided
60
+ if save_path:
61
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
62
+ writer = cv2.VideoWriter(save_path, fourcc, fps, (width, height))
63
+
64
+ # Process frames
65
+ frame_count = 0
66
+ total_time = 0
67
+
68
+ while cap.isOpened():
69
+ ret, frame = cap.read()
70
+ if not ret:
71
+ break
72
+
73
+ # Start timing
74
+ start_time = time.time()
75
+
76
+ # Convert BGR to RGB and normalize
77
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
78
+
79
+ # Run inference
80
+ results = model(frame_rgb, conf=conf_threshold)
81
+
82
+ # Calculate inference time
83
+ inference_time = time.time() - start_time
84
+ total_time += inference_time
85
+ frame_count += 1
86
+
87
+ # Draw results on frame
88
+ annotated_frame = frame.copy()
89
+ for result in results:
90
+ for box in result.boxes:
91
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
92
+ conf = float(box.conf[0])
93
+
94
+ # Filter out low-confidence detections
95
+ if conf < conf_threshold:
96
+ continue
97
+
98
+ # Draw bounding box and label
99
+ cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
100
+ cv2.putText(annotated_frame, f"Weapon: {conf:.2f}", (x1, y1 - 10),
101
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
102
+
103
+ # Add FPS info
104
+ fps_text = f"FPS: {1/inference_time:.1f}"
105
+ cv2.putText(annotated_frame, fps_text, (20, 40),
106
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
107
+
108
+ # Save or display the frame
109
+ if save_path:
110
+ writer.write(annotated_frame)
111
+ else:
112
+ cv2.imshow("Gun Detection", annotated_frame)
113
+ if cv2.waitKey(1) & 0xFF == ord('q'):
114
+ break
115
+
116
+ # Release resources
117
+ cap.release()
118
+ if save_path:
119
+ writer.release()
120
+ cv2.destroyAllWindows()
121
+
122
+ # Print statistics
123
+ avg_fps = frame_count / total_time if total_time > 0 else 0
124
+ print(f"Processed {frame_count} frames in {total_time:.2f} seconds ({avg_fps:.2f} FPS)")
125
+
126
+ def run_inference_on_webcam(model_path, camera_id=0, conf_threshold=0.55):
127
+ """Run inference on webcam"""
128
+ # Load model
129
+ model = YOLO(model_path)
130
+
131
+ # Open webcam
132
+ cap = cv2.VideoCapture(camera_id)
133
+ if not cap.isOpened():
134
+ print(f"Error: Could not open webcam {camera_id}")
135
+ return
136
+
137
+ # Process frames
138
+ while cap.isOpened():
139
+ ret, frame = cap.read()
140
+ if not ret:
141
+ break
142
+
143
+ # Start timing
144
+ start_time = time.time()
145
+
146
+ # Convert BGR to RGB and normalize
147
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
148
+
149
+ # Run inference
150
+ results = model(frame_rgb, conf=conf_threshold)
151
+
152
+ # Calculate inference time
153
+ inference_time = time.time() - start_time
154
+
155
+ # Draw results on frame
156
+ annotated_frame = frame.copy()
157
+ for result in results:
158
+ for box in result.boxes:
159
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
160
+ conf = float(box.conf[0])
161
+
162
+ # Filter out low-confidence detections
163
+ if conf < conf_threshold:
164
+ continue
165
+
166
+ # Draw bounding box and label
167
+ cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
168
+ cv2.putText(annotated_frame, f"Weapon: {conf:.2f}", (x1, y1 - 10),
169
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
170
+
171
+ # Add FPS info
172
+ fps_text = f"FPS: {1/inference_time:.1f}"
173
+ cv2.putText(annotated_frame, fps_text, (20, 40),
174
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
175
+
176
+ # Display the frame
177
+ cv2.imshow("Gun Detection (Press 'q' to quit)", annotated_frame)
178
+ if cv2.waitKey(1) & 0xFF == ord('q'):
179
+ break
180
+
181
+ # Release resources
182
+ cap.release()
183
+ cv2.destroyAllWindows()
184
+
185
+ def main():
186
+ # Parse command-line arguments
187
+ parser = argparse.ArgumentParser(description="Run inference with YOLOv8 gun detection model")
188
+ parser.add_argument("--model", type=str, required=True, help="Path to the trained model")
189
+ parser.add_argument("--source", type=str, required=True,
190
+ help="Path to image, video file or 'webcam' for live detection")
191
+ parser.add_argument("--conf", type=float, default=0.5, help="Confidence threshold")
192
+ parser.add_argument("--output", type=str, default=None, help="Path to save results")
193
+
194
+ args = parser.parse_args()
195
+
196
+ # Run inference based on source type
197
+ if args.source.lower() == "webcam":
198
+ run_inference_on_webcam(args.model, camera_id=0, conf_threshold=args.conf)
199
+ elif args.source.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
200
+ run_inference_on_video(args.model, args.source, conf_threshold=args.conf, save_path=args.output)
201
+ else:
202
+ run_inference_on_image(args.model, args.source, conf_threshold=args.conf, save_path=args.output)
203
+
204
+ if __name__ == "__main__":
205
+ main()
paths.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # For image inference
2
+ python infer.py --model gun_detection_project/results/gun_detection/weights/best.pt --source path/to/image.jpg
3
+
4
+ # For video inference
5
+ python infer.py --model gun_detection_project/results/gun_detection/weights/best.pt --source path/to/video.mp4 --output results.mp4
6
+
7
+ # For webcam
8
+ python infer.py --model gun_detection_project/results/gun_detection/weights/best.pt --source webcam
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ultralytics>=8.0.0
2
+ torch>=1.7.0
3
+ torchvision>=0.8.1
4
+ numpy>=1.18.5
5
+ opencv-python>=4.1.2
6
+ matplotlib>=3.2.2
7
+ PyYAML>=5.3.1
8
+ tqdm>=4.41.0
9
+ scikit-learn>=0.24.2
10
+ Pillow>=7.1.2