Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import cv2 | |
| import numpy as np | |
| # 1. Download OpenCV Haarcascades for eye tracking | |
| import urllib.request | |
| import os | |
| cascade_path = 'haarcascade_eye.xml' | |
| if not os.path.exists(cascade_path): | |
| urllib.request.urlretrieve( | |
| 'https://raw.githubusercontent.com/opencv/opencv/master/data/haarcascades/haarcascade_eye.xml', | |
| cascade_path | |
| ) | |
| eye_cascade = cv2.CascadeClassifier(cascade_path) | |
| # 2. Re-initialize and load the model (Using your weights) | |
| model = models.mobilenet_v2(weights=None) | |
| # Ensure this matches exactly how you defined it in the Masterpiece training step | |
| model.classifier = nn.Sequential( | |
| nn.Dropout(p=0.5), | |
| nn.Linear(model.last_channel, 2) | |
| ) | |
| model.load_state_dict(torch.load('ddobj_model.pth', map_location=torch.device('cpu'))) | |
| model.eval() | |
| # 3. Transforms (Grayscale is key to matching the MRL dataset!) | |
| transform = transforms.Compose([ | |
| transforms.Grayscale(num_output_channels=3), # Convert to 3-channel grayscale | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # 4. The Smart Prediction Function | |
| def predict_drowsiness(image): | |
| # Convert Gradio image to OpenCV format | |
| img_cv = np.array(image) | |
| gray = cv2.cvtColor(img_cv, cv2.COLOR_RGB2GRAY) | |
| # Detect eyes in the image | |
| eyes = eye_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30)) | |
| if len(eyes) == 0: | |
| return "ERROR: Could not detect any eyes in the image. Please upload a clear face photo.", None | |
| # Take the first detected eye (largest/clearest) | |
| (x, y, w, h) = eyes[0] | |
| # Crop the eye from the original image | |
| eye_crop = img_cv[y:y+h, x:x+w] | |
| # Convert the cropped eye back to PIL for PyTorch | |
| eye_pil = Image.fromarray(eye_crop) | |
| input_tensor = transform(eye_pil).unsqueeze(0) | |
| # Run the model | |
| with torch.no_grad(): | |
| outputs = model(input_tensor) | |
| _, predicted = torch.max(outputs, 1) | |
| classes = ["DROWSY ALERT! 🚨 (Eyes Closed)", "NOT DROWSY ✅ (Eyes Open)"] | |
| result = classes[predicted.item()] | |
| # Return the prediction AND show the user the exact crop the model looked at | |
| return result, eye_pil | |
| # 5. Build the UI | |
| interface = gr.Interface( | |
| fn=predict_drowsiness, | |
| inputs=gr.Image(label="Upload Full Face Photo"), | |
| outputs=[ | |
| gr.Textbox(label="DDobj System Status"), | |
| gr.Image(label="What the AI saw (Eye Crop)") | |
| ], | |
| title="DDobj: Driver Drowsiness Detection", | |
| description="Upload a photo. The system will automatically locate the eyes, isolate them, and analyze them for fatigue.", | |
| theme="default" | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch() |