Tian Wang commited on
Commit
9843ba6
·
1 Parent(s): dadb116

Add detector and classifier models with training artifacts

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - object-detection
5
+ - image-classification
6
+ - yolo
7
+ - set-game
8
+ - card-game
9
+ - computer-vision
10
+ ---
11
+
12
+ # Set Solver Models
13
+
14
+ Trained models for the [Set card game](https://www.setgame.com/) solver.
15
+
16
+ ![Example detection](./example-detection.png)
17
+
18
+ **Live demo**: [huggingface.co/spaces/wangtianthu/set-solver](https://huggingface.co/spaces/wangtianthu/set-solver)
19
+
20
+ ## Models
21
+
22
+ ### Detector — YOLOv11n
23
+
24
+ Detects individual Set cards on a board image.
25
+
26
+ | Metric | Value |
27
+ |--------|-------|
28
+ | mAP50 | 99.5% |
29
+ | mAP50-95 | 97.4% |
30
+ | Architecture | YOLOv11n |
31
+ | Input size | 640x640 |
32
+ | Epochs | 10 |
33
+ | Training data | 4000 synthetic board images |
34
+
35
+ **Files**: `detector/weights/best.pt` (PyTorch), `detector/weights/best.onnx` (ONNX)
36
+
37
+ ### Classifier — MobileNetV3
38
+
39
+ Classifies each card's 4 attributes: shape, color, number, and fill.
40
+
41
+ | Metric | Value |
42
+ |--------|-------|
43
+ | Overall accuracy | 99.9% |
44
+ | Number accuracy | 100% |
45
+ | Color accuracy | 100% |
46
+ | Shape accuracy | 99.9% |
47
+ | Fill accuracy | 99.8% |
48
+ | Architecture | MobileNetV3-Small |
49
+ | Input size | 224x224 |
50
+ | Training data | ~9500 cropped card images (81 classes) |
51
+
52
+ **File**: `classifier/classifier_best.pt`
53
+
54
+ ## Usage
55
+
56
+ ```python
57
+ from ultralytics import YOLO
58
+ from PIL import Image
59
+
60
+ # Load detector
61
+ detector = YOLO("detector/weights/best.pt")
62
+ results = detector("board_photo.jpg", conf=0.25)
63
+
64
+ # Load classifier
65
+ import torch
66
+ from src.train.classifier import SetCardClassifier
67
+
68
+ classifier = SetCardClassifier(pretrained=False)
69
+ checkpoint = torch.load("classifier/classifier_best.pt", map_location="cpu")
70
+ classifier.load_state_dict(checkpoint["model_state_dict"])
71
+ classifier.eval()
72
+ ```
73
+
74
+ ## Training
75
+
76
+ Both models were trained on synthetic data generated by a custom board generator that produces realistic Set game layouts with varied backgrounds, perspective transforms, and noise objects.
77
+
78
+ Source code: [github.com/wangtian24/set-solver](https://github.com/wangtian24/set-solver)
classifier/classifier_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0c464367eccfcfd6599377c9af35f72cd23c524b01eda7e9a11ccb1e3ba3f6d
3
+ size 11465795
classifier/training_results.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "test_loss": 0.009692762911799945,
3
+ "test_accuracy": {
4
+ "number": 1.0,
5
+ "color": 1.0,
6
+ "shape": 0.9992082343626286,
7
+ "fill": 0.997624703087886
8
+ },
9
+ "avg_test_accuracy": 0.9992082343626286,
10
+ "train_size": 9479,
11
+ "val_size": 1895,
12
+ "test_size": 1263
13
+ }
detector/args.yaml ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ task: detect
2
+ mode: train
3
+ model: yolo11n.pt
4
+ data: /Users/wangtian/workspace/set-solver/data/synthetic/dataset.yaml
5
+ epochs: 10
6
+ time: null
7
+ patience: 20
8
+ batch: 16
9
+ imgsz: 640
10
+ save: true
11
+ save_period: -1
12
+ cache: false
13
+ device: mps
14
+ workers: 8
15
+ project: /Users/wangtian/workspace/set-solver/weights
16
+ name: detector
17
+ exist_ok: true
18
+ pretrained: true
19
+ optimizer: auto
20
+ verbose: true
21
+ seed: 0
22
+ deterministic: true
23
+ single_cls: false
24
+ rect: false
25
+ cos_lr: false
26
+ close_mosaic: 10
27
+ resume: false
28
+ amp: true
29
+ fraction: 1.0
30
+ profile: false
31
+ freeze: null
32
+ multi_scale: 0.0
33
+ compile: false
34
+ overlap_mask: true
35
+ mask_ratio: 4
36
+ dropout: 0.0
37
+ val: true
38
+ split: val
39
+ save_json: false
40
+ conf: null
41
+ iou: 0.7
42
+ max_det: 300
43
+ half: false
44
+ dnn: false
45
+ plots: true
46
+ end2end: null
47
+ source: null
48
+ vid_stride: 1
49
+ stream_buffer: false
50
+ visualize: false
51
+ augment: false
52
+ agnostic_nms: false
53
+ classes: null
54
+ retina_masks: false
55
+ embed: null
56
+ show: false
57
+ save_frames: false
58
+ save_txt: false
59
+ save_conf: false
60
+ save_crop: false
61
+ show_labels: true
62
+ show_conf: true
63
+ show_boxes: true
64
+ line_width: null
65
+ format: torchscript
66
+ keras: false
67
+ optimize: false
68
+ int8: false
69
+ dynamic: false
70
+ simplify: true
71
+ opset: null
72
+ workspace: null
73
+ nms: false
74
+ lr0: 0.01
75
+ lrf: 0.01
76
+ momentum: 0.937
77
+ weight_decay: 0.0005
78
+ warmup_epochs: 3.0
79
+ warmup_momentum: 0.8
80
+ warmup_bias_lr: 0.1
81
+ box: 7.5
82
+ cls: 0.5
83
+ dfl: 1.5
84
+ pose: 12.0
85
+ kobj: 1.0
86
+ rle: 1.0
87
+ angle: 1.0
88
+ nbs: 64
89
+ hsv_h: 0.015
90
+ hsv_s: 0.7
91
+ hsv_v: 0.4
92
+ degrees: 0.0
93
+ translate: 0.1
94
+ scale: 0.5
95
+ shear: 0.0
96
+ perspective: 0.0
97
+ flipud: 0.0
98
+ fliplr: 0.5
99
+ bgr: 0.0
100
+ mosaic: 1.0
101
+ mixup: 0.0
102
+ cutmix: 0.0
103
+ copy_paste: 0.0
104
+ copy_paste_mode: flip
105
+ auto_augment: randaugment
106
+ erasing: 0.4
107
+ cfg: null
108
+ tracker: botsort.yaml
109
+ save_dir: /Users/wangtian/workspace/set-solver/weights/detector
detector/plots/BoxF1_curve.png ADDED

Git LFS Details

  • SHA256: 22ac76b663bfbc5c70c62254d0838e9551680613f046c10d9a3905712caa78a3
  • Pointer size: 130 Bytes
  • Size of remote file: 86.4 kB
detector/plots/BoxPR_curve.png ADDED

Git LFS Details

  • SHA256: df0dc0b651b5b88c715752280eaead0c02aad42774f657354029dd78d305161e
  • Pointer size: 130 Bytes
  • Size of remote file: 68.3 kB
detector/plots/BoxP_curve.png ADDED

Git LFS Details

  • SHA256: 7f3592367666105d034a2ce12f0ea69a7d1d3e8294fd4b322a2f58eabc30fef8
  • Pointer size: 130 Bytes
  • Size of remote file: 76.3 kB
detector/plots/BoxR_curve.png ADDED

Git LFS Details

  • SHA256: 4f631fabe01c94ae37e1f06bce7c88ce9b36f640e9d62a5123cd0bc56abee686
  • Pointer size: 130 Bytes
  • Size of remote file: 80.1 kB
detector/plots/confusion_matrix.png ADDED

Git LFS Details

  • SHA256: 8b72dfb16b3589e5f98b21b849bbe6eb7afda4d01aa4ef7dde0f2e661cee2a22
  • Pointer size: 130 Bytes
  • Size of remote file: 90.9 kB
detector/plots/confusion_matrix_normalized.png ADDED

Git LFS Details

  • SHA256: fb754284bd2eb4bbbaa38f906767e6f5c5bf17d40f4bb796162b9c0b982cb5b0
  • Pointer size: 130 Bytes
  • Size of remote file: 85 kB
detector/plots/results.png ADDED

Git LFS Details

  • SHA256: 5eaba2b057aa7791c33b23a93e9d82630c4dfee06eff0ee8e81154e482e6b058
  • Pointer size: 131 Bytes
  • Size of remote file: 273 kB
detector/results.csv ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ epoch,time,train/box_loss,train/cls_loss,train/dfl_loss,metrics/precision(B),metrics/recall(B),metrics/mAP50(B),metrics/mAP50-95(B),val/box_loss,val/cls_loss,val/dfl_loss,lr/pg0,lr/pg1,lr/pg2
2
+ 1,111.468,0.40144,1.19864,0.84562,0.99981,0.99984,0.995,0.9085,0.40196,1.20815,0.82106,0.000656085,0.000656085,0.000656085
3
+ 2,206.25,0.34115,0.4089,0.82012,1,1,0.995,0.95523,0.29123,0.4463,0.79571,0.0011918,0.0011918,0.0011918
4
+ 3,304.901,0.32374,0.34916,0.8127,0.99999,1,0.995,0.93135,0.35031,0.37025,0.79607,0.00159551,0.00159551,0.00159551
5
+ 4,395.233,0.31295,0.32333,0.81186,0.99997,1,0.995,0.96262,0.26761,0.2805,0.79735,0.001406,0.001406,0.001406
6
+ 5,492.083,0.26809,0.29046,0.80266,0.99989,1,0.995,0.97475,0.24272,0.25697,0.7803,0.001208,0.001208,0.001208
7
+ 6,606.75,0.23878,0.24702,0.79441,0.99999,1,0.995,0.98544,0.21188,0.22382,0.77178,0.00101,0.00101,0.00101
8
+ 7,741.542,0.22471,0.22465,0.79115,0.99999,0.99513,0.995,0.97671,0.21286,0.2033,0.77386,0.000812,0.000812,0.000812
9
+ 8,873.213,0.2143,0.21188,0.78911,0.99999,0.98882,0.98531,0.97465,0.17893,0.19046,0.76739,0.000614,0.000614,0.000614
10
+ 9,1041.5,0.19247,0.19322,0.78823,0.99999,0.97096,0.97537,0.9692,0.14963,0.16733,0.76335,0.000416,0.000416,0.000416
11
+ 10,1205.28,0.17649,0.17659,0.78384,0.99999,0.99914,0.995,0.99123,0.13999,0.14978,0.76308,0.000218,0.000218,0.000218
detector/weights/best.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccd020c044339d0cd2e671a837f209a8adcdcef17982cbae90da2b962e160081
3
+ size 10477995
detector/weights/best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d65deae13124271df8739b700d2f893bca1eb7a7bc8ac870702e714b787ceee7
3
+ size 5453594
example-detection.png ADDED

Git LFS Details

  • SHA256: c6bd78d069daceb3c9d0990163d75bdd0942a08c841196aa3bf8cf605652cdcd
  • Pointer size: 132 Bytes
  • Size of remote file: 2.8 MB