Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import AutoImageProcessor, AutoModelForObjectDetection | |
| from PIL import Image | |
| import cv2 | |
| import numpy as np | |
| import time | |
| import gradio as gr | |
| # Device setup (GPU or CPU) | |
| device = 'cpu' | |
| if torch.cuda.is_available(): | |
| device = torch.device('cuda') | |
| elif torch.backends.mps.is_available(): | |
| device = torch.device('mps') | |
| # Load pre-trained model and image processor from Hugging Face | |
| ckpt = 'yainage90/fashion-object-detection' | |
| image_processor = AutoImageProcessor.from_pretrained(ckpt) | |
| model = AutoModelForObjectDetection.from_pretrained(ckpt).to(device) | |
| def detect_objects(frame): | |
| """Detect objects in the video frame.""" | |
| # Convert the frame to PIL image | |
| image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| # Prepare inputs for the model | |
| with torch.no_grad(): | |
| inputs = image_processor(images=[image], return_tensors="pt") | |
| outputs = model(**inputs.to(device)) | |
| target_sizes = torch.tensor([[image.size[1], image.size[0]]]) | |
| results = image_processor.post_process_object_detection(outputs, threshold=0.4, target_sizes=target_sizes)[0] | |
| # Extract the detected items | |
| items = [] | |
| for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | |
| score = score.item() | |
| label = label.item() | |
| box = [i.item() for i in box] | |
| print(f"{model.config.id2label[label]}: {round(score, 3)} at {box}") | |
| items.append((score, label, box)) | |
| return items | |
| def process_image(image): | |
| """Process the image uploaded via Gradio and return the result.""" | |
| # Convert the image to numpy array | |
| frame = np.array(image) | |
| # Detect objects (e.g., helmets) in the frame | |
| items = detect_objects(frame) | |
| # Check if helmet is detected (you can adapt this based on your model's labels) | |
| helmet_detected = False | |
| for score, label, box in items: | |
| if model.config.id2label[label] == "helmet": # Replace "helmet" with the actual class name in your model | |
| helmet_detected = True | |
| # If no helmet detected, show a traffic violation notification | |
| if not helmet_detected: | |
| violation_message = "Serious Traffic Violation: Rider not wearing a helmet!" | |
| else: | |
| violation_message = "Helmet detected: No violation." | |
| # Save the image with detected items | |
| if items: # If objects are detected, save the data | |
| save_data(frame, items) | |
| return {"items_detected": items, "violation_message": violation_message} | |
| def save_data(frame, items): | |
| """Save image and extract plate number.""" | |
| filename = f"helmet_violation_{int(time.time())}.jpg" | |
| cv2.imwrite(filename, frame) | |
| # Here, you'd extract plate numbers or process further | |
| plate_number = extract_plate_number(frame) | |
| save_to_database(filename, plate_number, items) | |
| def extract_plate_number(frame): | |
| """Extract license plate number (simplified).""" | |
| plate_number = "XYZ 1234" # Replace with an actual license plate recognition method | |
| return plate_number | |
| def save_to_database(image_filename, plate_number, items): | |
| """Save the data (for simplicity, we just print it here).""" | |
| print(f"Plate Number: {plate_number}, Image saved as {image_filename}") | |
| print("Detected items:", items) | |
| # Define the Gradio interface using updated syntax | |
| interface = gr.Interface(fn=process_image, | |
| inputs=gr.Image(type="pil"), | |
| outputs=[gr.JSON(), gr.Textbox()], | |
| live=True) | |
| # Launch the Gradio app | |
| interface.launch(debug=True) | |