Image Classification
ultralytics
yolo
drowsiness-detection
computer-vision
mosesb commited on
Commit
aa46824
·
verified ·
1 Parent(s): 9222b90

Upload folder using huggingface_hub

Browse files
Files changed (14) hide show
  1. .gitattributes +4 -0
  2. README.md +97 -0
  3. app.py +49 -0
  4. args.yaml +105 -0
  5. best.pt +3 -0
  6. last.pt +3 -0
  7. output.jpg +3 -0
  8. output_augmentation.jpg +3 -0
  9. output_confusion_matrix.png +3 -0
  10. output_grad_cam.jpg +0 -0
  11. results.csv +31 -0
  12. results.png +3 -0
  13. sample_1.jpg +0 -0
  14. sample_2.jpg +0 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ output.jpg filter=lfs diff=lfs merge=lfs -text
37
+ output_augmentation.jpg filter=lfs diff=lfs merge=lfs -text
38
+ output_confusion_matrix.png filter=lfs diff=lfs merge=lfs -text
39
+ results.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: ultralytics
4
+ tags:
5
+ - image-classification
6
+ - yolo
7
+ - ultralytics
8
+ - drowsiness-detection
9
+ - computer-vision
10
+ widget:
11
+ - modelId: mosesb/drowsiness-detection-yolo-cls
12
+ title: Drowsiness Detection With YOLO CLS
13
+ url: https://huggingface.co/spaces/mosesb/drowsiness-detection-yolo-cls/resolve/main/output.jpg
14
+ datasets:
15
+ - ismailnasri20/driver-drowsiness-dataset-ddd
16
+ - yasharjebraeily/drowsy-detection-dataset
17
+ metrics:
18
+ - accuracy
19
+ - f1
20
+ ---
21
+
22
+ # YOLOv11 Model for Drowsiness Detection
23
+
24
+ This repository contains a YOLO classification model fine-tuned to detect driver drowsiness from images. The model classifies input images into two categories: `Drowsy` and `Non Drowsy` (Awake).
25
+
26
+ This model was trained using the `ultralytics` framework and demonstrates high performance on an unseen test set, making it a reliable tool for safety applications.
27
+
28
+ ## Model Details
29
+ * **Base Model:** `yolo11x-cls` (from the Ultralytics v8 ecosystem)
30
+ * **Fine-tuned on:** A combined dataset for driver drowsiness detection.
31
+ * **Classes:** `Drowsy`, `Non Drowsy`
32
+ * **Framework:** PyTorch, Ultralytics
33
+
34
+ ## How to Get Started
35
+
36
+ You can easily use this model with the `ultralytics` library.
37
+
38
+ ```python
39
+ # Install ultralytics
40
+ !pip install ultralytics
41
+
42
+ from ultralytics import YOLO
43
+
44
+ # Load the model from the Hugging Face Hub
45
+ model = YOLO('your-username/your-repo-name')
46
+
47
+ # Run inference on an image
48
+ image_path = 'path/to/your/image.jpg'
49
+ results = model.predict(image_path)
50
+
51
+ # Print the top prediction
52
+ probs = results[0].probs
53
+ top1_class_index = probs.top1
54
+ top1_confidence = probs.top1conf
55
+ class_name = model.names[top1_class_index]
56
+
57
+ print(f"Prediction: {class_name} with confidence {top1_confidence:.4f}")
58
+ ```
59
+
60
+ ## Training Procedure
61
+
62
+ The model was fine-tuned on a large dataset of driver images. The training process involved:
63
+ - **Data Augmentation:** Standard augmentations like random flips, color jitter (HSV), and scaling were applied.
64
+ - **Transfer Learning:** The model was initialized with weights pretrained on a large-scale dataset, enabling rapid convergence.
65
+
66
+ ### Key Hyperparameters
67
+ - **Image Size:** 224x224
68
+ - **Batch Size:** 185 (auto-tuned)
69
+ - **Optimizer:** SGD with momentum
70
+
71
+ ![Training Results](results.png)
72
+
73
+ ## Evaluation
74
+
75
+ The model was evaluated on a completely **unseen test set** to ensure a fair assessment of its generalization capabilities.
76
+
77
+ ### Key Performance Metrics
78
+ | Metric | Value | Description |
79
+ | :----: | :----: | :------------------------------------------------- |
80
+ | **Accuracy** | 99.80% | Overall correctness on the test set. |
81
+ | **APCER** | 0.00% | Rate of 'Drowsy' drivers missed (False Negatives). |
82
+ | **BPCER** | 0.41% | Rate of 'Non Drowsy' drivers flagged (False Positives). |
83
+ | **ACER** | 0.21% | Average of APCER and BPCER. |
84
+
85
+ *APCER (Attack Presentation Classification Error Rate) is the most critical safety metric.*
86
+
87
+ ![Confusion Matrix](output_confusion_matrix.png)
88
+
89
+ ### Model Explainability (Grad-CAM)
90
+ To ensure the model is focusing on relevant facial features, Grad-CAM was used. The heatmaps confirm that the model's predictions are primarily based on the eye and mouth regions, as expected.
91
+
92
+ ![Grad-CAM](output_grad_cam.jpg)
93
+
94
+ ## Intended Use and Limitations
95
+ This model is intended as a proof-of-concept for driver safety systems. It should not be used as the sole mechanism for preventing accidents. Real-world performance may vary based on lighting conditions, camera angles, occlusions (e.g., sunglasses), and individual differences.
96
+
97
+ *This model card is based on the training notebook [`yolov11_drowsiness.ipynb`](https://github.com/mosesab/YOLOV11-Drowsiness-Detection/blob/main/yolov11_drowsiness.ipynb).*
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from ultralytics import YOLO
3
+ import torch
4
+
5
+ # Load the fine-tuned YOLOv11x model
6
+ # The model will be in the same directory in the HF Space
7
+ model = YOLO('best.pt')
8
+
9
+ # Define the prediction function
10
+ def predict_drowsiness(image):
11
+ """
12
+ Takes a PIL image, runs inference, and returns a dictionary of class probabilities.
13
+ """
14
+ # Run prediction
15
+ results = model.predict(image, verbose=False)
16
+
17
+ # Get the class names from the model
18
+ names_dict = results[0].names
19
+
20
+ # Get the probabilities
21
+ probs = results[0].probs.data.cpu().numpy()
22
+
23
+ # Create a dictionary of {class_name: probability}
24
+ return {names_dict[i]: prob for i, prob in enumerate(probs)}
25
+
26
+ # --- Gradio Interface ---
27
+ # Define the title and description for the demo
28
+ title = "YOLOv11 Drowsiness Detection"
29
+ description = """
30
+ This demo showcases a fine-tuned YOLO classification model for detecting driver drowsiness.
31
+ Upload an image of a driver, and the model will predict whether the person is 'Drowsy' or 'Non Drowsy' (Awake).
32
+ This model was trained as detailed in the notebook below and achieves high accuracy on the test set.
33
+ Training Notebook Repo: https://github.com/mosesab/YOLOV11-Drowsiness-Detection/blob/main/yolov11_drowsiness.ipynb
34
+ """
35
+ article = "Driver fatigue is a major cause of accidents. This model analyzes facial images to predict the likelihood of drowsiness in real time."
36
+
37
+ # Create the Gradio interface
38
+ iface = gr.Interface(
39
+ fn=predict_drowsiness,
40
+ inputs=gr.Image(type="pil", label="Upload Driver Image"),
41
+ outputs=gr.Label(num_top_classes=2, label="Prediction"),
42
+ title=title,
43
+ description=description,
44
+ article=article,
45
+ examples=[ "sample_1.jpg", "sample_2.jpg" ]
46
+ )
47
+
48
+ # Launch the app
49
+ iface.launch()
args.yaml ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ task: classify
2
+ mode: train
3
+ model: yolo11x-cls.pt
4
+ data: /workspace/dataset
5
+ epochs: 300
6
+ time: null
7
+ patience: 20
8
+ batch: -1
9
+ imgsz: 224
10
+ save: true
11
+ save_period: -1
12
+ cache: true
13
+ device: null
14
+ workers: 8
15
+ project: drowsiness_training
16
+ name: yolo_cls_run
17
+ exist_ok: true
18
+ pretrained: true
19
+ optimizer: auto
20
+ verbose: true
21
+ seed: 0
22
+ deterministic: true
23
+ single_cls: false
24
+ rect: false
25
+ cos_lr: false
26
+ close_mosaic: 10
27
+ resume: false
28
+ amp: true
29
+ fraction: 1.0
30
+ profile: false
31
+ freeze: null
32
+ multi_scale: false
33
+ overlap_mask: true
34
+ mask_ratio: 4
35
+ dropout: 0.0
36
+ val: true
37
+ split: val
38
+ save_json: false
39
+ conf: null
40
+ iou: 0.7
41
+ max_det: 300
42
+ half: false
43
+ dnn: false
44
+ plots: true
45
+ source: null
46
+ vid_stride: 1
47
+ stream_buffer: false
48
+ visualize: false
49
+ augment: false
50
+ agnostic_nms: false
51
+ classes: null
52
+ retina_masks: false
53
+ embed: null
54
+ show: false
55
+ save_frames: false
56
+ save_txt: false
57
+ save_conf: false
58
+ save_crop: false
59
+ show_labels: true
60
+ show_conf: true
61
+ show_boxes: true
62
+ line_width: null
63
+ format: torchscript
64
+ keras: false
65
+ optimize: false
66
+ int8: false
67
+ dynamic: false
68
+ simplify: true
69
+ opset: null
70
+ workspace: null
71
+ nms: false
72
+ lr0: 0.01
73
+ lrf: 0.01
74
+ momentum: 0.937
75
+ weight_decay: 0.0005
76
+ warmup_epochs: 3.0
77
+ warmup_momentum: 0.8
78
+ warmup_bias_lr: 0.1
79
+ box: 7.5
80
+ cls: 0.5
81
+ dfl: 1.5
82
+ pose: 12.0
83
+ kobj: 1.0
84
+ nbs: 64
85
+ hsv_h: 0.015
86
+ hsv_s: 0.7
87
+ hsv_v: 0.4
88
+ degrees: 0.0
89
+ translate: 0.1
90
+ scale: 0.5
91
+ shear: 0.0
92
+ perspective: 0.0
93
+ flipud: 0.0
94
+ fliplr: 0.5
95
+ bgr: 0.0
96
+ mosaic: 1.0
97
+ mixup: 0.0
98
+ cutmix: 0.0
99
+ copy_paste: 0.0
100
+ copy_paste_mode: flip
101
+ auto_augment: randaugment
102
+ erasing: 0.4
103
+ cfg: null
104
+ tracker: botsort.yaml
105
+ save_dir: drowsiness_training/yolo_cls_run
best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fed29ef7a18c011b440ce470554b2eb6c54fbb7c417b7ebf8245896f1a45032
3
+ size 57006897
last.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f844b288a66917e1e12d9ae7773a8f3dfcf168ee1c2daf458b816a47240a3fab
3
+ size 57008241
output.jpg ADDED

Git LFS Details

  • SHA256: 7b9b8427f0b0ef025fc7ecfece35a9c93ae642e110a3590d5fd152d4825fab68
  • Pointer size: 131 Bytes
  • Size of remote file: 112 kB
output_augmentation.jpg ADDED

Git LFS Details

  • SHA256: aed79762df50a5df4f1ac1bea87d1b1133f917a57562cf8ae137243b1a2e59dc
  • Pointer size: 131 Bytes
  • Size of remote file: 101 kB
output_confusion_matrix.png ADDED

Git LFS Details

  • SHA256: fcf8ff481f41218f486ed8036476c037bad72fbbcc5c6df8cac698a290e419ce
  • Pointer size: 131 Bytes
  • Size of remote file: 124 kB
output_grad_cam.jpg ADDED
results.csv ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ epoch,time,train/loss,metrics/accuracy_top1,metrics/accuracy_top5,val/loss,lr/pg0,lr/pg1,lr/pg2
2
+ 1,36.4885,0.14788,0.98671,1,0.03707,0.00332807,0.00332807,0.00332807
3
+ 1,78.2575,0.4195,0.97034,1,0.09415,0.00331811,0.00331811,0.00331811
4
+ 2,154.248,0.07064,0.99314,1,0.02335,0.0066295,0.0066295,0.0066295
5
+ 3,229.881,0.02835,0.99678,1,0.00745,0.00991888,0.00991888,0.00991888
6
+ 4,305.269,0.01964,0.99846,1,0.00526,0.009901,0.009901,0.009901
7
+ 5,380.811,0.01352,0.99874,1,0.00512,0.009868,0.009868,0.009868
8
+ 6,456.478,0.01245,0.99832,1,0.00642,0.009835,0.009835,0.009835
9
+ 7,531.746,0.00972,0.9993,1,0.00297,0.009802,0.009802,0.009802
10
+ 8,607.184,0.01014,0.99888,1,0.00312,0.009769,0.009769,0.009769
11
+ 9,682.386,0.00689,0.99958,1,0.00191,0.009736,0.009736,0.009736
12
+ 10,757.863,0.00809,0.9986,1,0.00412,0.009703,0.009703,0.009703
13
+ 11,833.026,0.00768,0.99846,1,0.00442,0.00967,0.00967,0.00967
14
+ 12,908.255,0.00767,0.99888,1,0.00294,0.009637,0.009637,0.009637
15
+ 13,983.448,0.00702,0.9979,1,0.00534,0.009604,0.009604,0.009604
16
+ 14,1058.57,0.0091,0.99818,1,0.00499,0.009571,0.009571,0.009571
17
+ 15,1133.77,0.00801,0.99902,1,0.00208,0.009538,0.009538,0.009538
18
+ 16,1208.97,0.00812,0.99622,1,0.01176,0.009505,0.009505,0.009505
19
+ 17,1284.23,0.01047,0.99818,1,0.0056,0.009472,0.009472,0.009472
20
+ 18,1359.47,0.01251,0.9986,1,0.0035,0.009439,0.009439,0.009439
21
+ 19,1434.67,0.01087,0.99874,1,0.00518,0.009406,0.009406,0.009406
22
+ 20,1509.67,0.01463,0.99482,1,0.01421,0.009373,0.009373,0.009373
23
+ 21,1584.82,0.01296,0.99818,1,0.00519,0.00934,0.00934,0.00934
24
+ 22,1659.91,0.0123,0.99902,1,0.0033,0.009307,0.009307,0.009307
25
+ 23,1735.13,0.0137,0.99804,1,0.00708,0.009274,0.009274,0.009274
26
+ 24,1810.33,0.01558,0.9986,1,0.00429,0.009241,0.009241,0.009241
27
+ 25,1885.4,0.0152,0.9986,1,0.00395,0.009208,0.009208,0.009208
28
+ 26,1961.98,0.02328,0.99944,1,0.00325,0.009175,0.009175,0.009175
29
+ 27,2040.69,0.01656,0.99902,1,0.00282,0.009142,0.009142,0.009142
30
+ 28,2115.7,0.01849,0.99888,1,0.00395,0.009109,0.009109,0.009109
31
+ 29,2190.86,0.02094,0.99804,1,0.00647,0.009076,0.009076,0.009076
results.png ADDED

Git LFS Details

  • SHA256: 2e5c685506ce9c470c1673f21dd8d113fc9d0b4ffceb26cd665b977d2bdb2258
  • Pointer size: 131 Bytes
  • Size of remote file: 115 kB
sample_1.jpg ADDED
sample_2.jpg ADDED