Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| # Step 1: Search for best.pt in the training directory | |
| base_path = "yolov5/runs/train/" | |
| best_path = None | |
| # Search through the directory structure to find best.pt | |
| for root, dirs, files in os.walk(base_path): | |
| if "best.pt" in files: | |
| best_path = os.path.join(root, "best.pt") | |
| break | |
| # Step 2: If best.pt is not found, use pre-trained weights | |
| model = None # Ensure model is defined | |
| if best_path is None: | |
| print("Trained weights (best.pt) not found.") | |
| print("Using pre-trained YOLOv5 weights (yolov5s.pt) instead.") | |
| try: | |
| model = torch.hub.load('ultralytics/yolov5', 'yolov5s') # Load pre-trained weights | |
| except Exception as e: | |
| print(f"Error loading pre-trained YOLOv5 model: {e}") | |
| else: | |
| try: | |
| print(f"Model weights found at: {best_path}") | |
| model = torch.hub.load('ultralytics/yolov5', 'custom', path=best_path) | |
| except Exception as e: | |
| print(f"Error loading custom model: {e}") | |
| model = torch.hub.load('ultralytics/yolov5', 'yolov5s') # Fallback to pre-trained model | |
| # Ensure the model was loaded properly before proceeding | |
| if model is None: | |
| raise RuntimeError("Failed to load YOLOv5 model. Please check the weights or model path.") | |
| # Step 3: Define weapon classes to detect | |
| weapon_classes = ['bomb', 'gun', 'pistol', 'Automatic', 'Rifle', 'Bazooka', | |
| 'Handgun', 'Knife', 'Grenade Launcher', 'Shotgun', 'SMG', | |
| 'Sniper', 'Sword'] # Adjust based on your dataset | |
| def detect_weapons(image): | |
| try: | |
| results = model(image) | |
| except Exception as e: | |
| return f"Error during detection: {e}", None | |
| # Check available model class names | |
| model_classes = results.names | |
| print("Model class names:", model_classes) | |
| # Filter detections by confidence threshold (0.5 or higher) | |
| confidence_threshold = 0.5 | |
| filtered_results = results.pandas().xyxy[0][results.pandas().xyxy[0]['confidence'] >= confidence_threshold] | |
| # Get the detected classes with high confidence | |
| detected_classes = filtered_results['name'].unique() | |
| print("Detected classes:", detected_classes) | |
| # Check if any of the detected objects are weapons | |
| detected_threats = [weapon for weapon in weapon_classes if weapon in detected_classes] | |
| # Determine threat message based on weapons detected | |
| if detected_threats: | |
| threat_message = "Threat detected: Be careful" | |
| else: | |
| threat_message = "No threat detected. But all other features are good." | |
| # Create a string with the detected objects' names | |
| detected_objects = ', '.join(detected_classes) | |
| # Render the image with bounding boxes | |
| return f"{threat_message}\nDetected objects: {detected_objects}", Image.fromarray(results.render()[0]) | |
| # Step 4: Gradio Interface | |
| def inference(image): | |
| threat, detected_image = detect_weapons(image) | |
| return threat, detected_image | |
| iface = gr.Interface( | |
| fn=inference, | |
| inputs=gr.Image(type="numpy", label="Upload Image"), | |
| outputs=[ | |
| gr.Textbox(label="Threat Detection"), | |
| gr.Image(label="Detected Image"), | |
| ], | |
| title="Weapon Detection AI", | |
| description="Upload an image to detect weapons like bombs, guns, and pistols." | |
| ) | |
| # Step 5: Launch Gradio App | |
| iface.launch() | |
| v |