Tian Wang
commited on
Commit
·
9843ba6
1
Parent(s):
dadb116
Add detector and classifier models with training artifacts
Browse files- .gitattributes +2 -0
- README.md +78 -0
- classifier/classifier_best.pt +3 -0
- classifier/training_results.json +13 -0
- detector/args.yaml +109 -0
- detector/plots/BoxF1_curve.png +3 -0
- detector/plots/BoxPR_curve.png +3 -0
- detector/plots/BoxP_curve.png +3 -0
- detector/plots/BoxR_curve.png +3 -0
- detector/plots/confusion_matrix.png +3 -0
- detector/plots/confusion_matrix_normalized.png +3 -0
- detector/plots/results.png +3 -0
- detector/results.csv +11 -0
- detector/weights/best.onnx +3 -0
- detector/weights/best.pt +3 -0
- example-detection.png +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- object-detection
|
| 5 |
+
- image-classification
|
| 6 |
+
- yolo
|
| 7 |
+
- set-game
|
| 8 |
+
- card-game
|
| 9 |
+
- computer-vision
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# Set Solver Models
|
| 13 |
+
|
| 14 |
+
Trained models for the [Set card game](https://www.setgame.com/) solver.
|
| 15 |
+
|
| 16 |
+

|
| 17 |
+
|
| 18 |
+
**Live demo**: [huggingface.co/spaces/wangtianthu/set-solver](https://huggingface.co/spaces/wangtianthu/set-solver)
|
| 19 |
+
|
| 20 |
+
## Models
|
| 21 |
+
|
| 22 |
+
### Detector — YOLOv11n
|
| 23 |
+
|
| 24 |
+
Detects individual Set cards on a board image.
|
| 25 |
+
|
| 26 |
+
| Metric | Value |
|
| 27 |
+
|--------|-------|
|
| 28 |
+
| mAP50 | 99.5% |
|
| 29 |
+
| mAP50-95 | 97.4% |
|
| 30 |
+
| Architecture | YOLOv11n |
|
| 31 |
+
| Input size | 640x640 |
|
| 32 |
+
| Epochs | 10 |
|
| 33 |
+
| Training data | 4000 synthetic board images |
|
| 34 |
+
|
| 35 |
+
**Files**: `detector/weights/best.pt` (PyTorch), `detector/weights/best.onnx` (ONNX)
|
| 36 |
+
|
| 37 |
+
### Classifier — MobileNetV3
|
| 38 |
+
|
| 39 |
+
Classifies each card's 4 attributes: shape, color, number, and fill.
|
| 40 |
+
|
| 41 |
+
| Metric | Value |
|
| 42 |
+
|--------|-------|
|
| 43 |
+
| Overall accuracy | 99.9% |
|
| 44 |
+
| Number accuracy | 100% |
|
| 45 |
+
| Color accuracy | 100% |
|
| 46 |
+
| Shape accuracy | 99.9% |
|
| 47 |
+
| Fill accuracy | 99.8% |
|
| 48 |
+
| Architecture | MobileNetV3-Small |
|
| 49 |
+
| Input size | 224x224 |
|
| 50 |
+
| Training data | ~9500 cropped card images (81 classes) |
|
| 51 |
+
|
| 52 |
+
**File**: `classifier/classifier_best.pt`
|
| 53 |
+
|
| 54 |
+
## Usage
|
| 55 |
+
|
| 56 |
+
```python
|
| 57 |
+
from ultralytics import YOLO
|
| 58 |
+
from PIL import Image
|
| 59 |
+
|
| 60 |
+
# Load detector
|
| 61 |
+
detector = YOLO("detector/weights/best.pt")
|
| 62 |
+
results = detector("board_photo.jpg", conf=0.25)
|
| 63 |
+
|
| 64 |
+
# Load classifier
|
| 65 |
+
import torch
|
| 66 |
+
from src.train.classifier import SetCardClassifier
|
| 67 |
+
|
| 68 |
+
classifier = SetCardClassifier(pretrained=False)
|
| 69 |
+
checkpoint = torch.load("classifier/classifier_best.pt", map_location="cpu")
|
| 70 |
+
classifier.load_state_dict(checkpoint["model_state_dict"])
|
| 71 |
+
classifier.eval()
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
## Training
|
| 75 |
+
|
| 76 |
+
Both models were trained on synthetic data generated by a custom board generator that produces realistic Set game layouts with varied backgrounds, perspective transforms, and noise objects.
|
| 77 |
+
|
| 78 |
+
Source code: [github.com/wangtian24/set-solver](https://github.com/wangtian24/set-solver)
|
classifier/classifier_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a0c464367eccfcfd6599377c9af35f72cd23c524b01eda7e9a11ccb1e3ba3f6d
|
| 3 |
+
size 11465795
|
classifier/training_results.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"test_loss": 0.009692762911799945,
|
| 3 |
+
"test_accuracy": {
|
| 4 |
+
"number": 1.0,
|
| 5 |
+
"color": 1.0,
|
| 6 |
+
"shape": 0.9992082343626286,
|
| 7 |
+
"fill": 0.997624703087886
|
| 8 |
+
},
|
| 9 |
+
"avg_test_accuracy": 0.9992082343626286,
|
| 10 |
+
"train_size": 9479,
|
| 11 |
+
"val_size": 1895,
|
| 12 |
+
"test_size": 1263
|
| 13 |
+
}
|
detector/args.yaml
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
task: detect
|
| 2 |
+
mode: train
|
| 3 |
+
model: yolo11n.pt
|
| 4 |
+
data: /Users/wangtian/workspace/set-solver/data/synthetic/dataset.yaml
|
| 5 |
+
epochs: 10
|
| 6 |
+
time: null
|
| 7 |
+
patience: 20
|
| 8 |
+
batch: 16
|
| 9 |
+
imgsz: 640
|
| 10 |
+
save: true
|
| 11 |
+
save_period: -1
|
| 12 |
+
cache: false
|
| 13 |
+
device: mps
|
| 14 |
+
workers: 8
|
| 15 |
+
project: /Users/wangtian/workspace/set-solver/weights
|
| 16 |
+
name: detector
|
| 17 |
+
exist_ok: true
|
| 18 |
+
pretrained: true
|
| 19 |
+
optimizer: auto
|
| 20 |
+
verbose: true
|
| 21 |
+
seed: 0
|
| 22 |
+
deterministic: true
|
| 23 |
+
single_cls: false
|
| 24 |
+
rect: false
|
| 25 |
+
cos_lr: false
|
| 26 |
+
close_mosaic: 10
|
| 27 |
+
resume: false
|
| 28 |
+
amp: true
|
| 29 |
+
fraction: 1.0
|
| 30 |
+
profile: false
|
| 31 |
+
freeze: null
|
| 32 |
+
multi_scale: 0.0
|
| 33 |
+
compile: false
|
| 34 |
+
overlap_mask: true
|
| 35 |
+
mask_ratio: 4
|
| 36 |
+
dropout: 0.0
|
| 37 |
+
val: true
|
| 38 |
+
split: val
|
| 39 |
+
save_json: false
|
| 40 |
+
conf: null
|
| 41 |
+
iou: 0.7
|
| 42 |
+
max_det: 300
|
| 43 |
+
half: false
|
| 44 |
+
dnn: false
|
| 45 |
+
plots: true
|
| 46 |
+
end2end: null
|
| 47 |
+
source: null
|
| 48 |
+
vid_stride: 1
|
| 49 |
+
stream_buffer: false
|
| 50 |
+
visualize: false
|
| 51 |
+
augment: false
|
| 52 |
+
agnostic_nms: false
|
| 53 |
+
classes: null
|
| 54 |
+
retina_masks: false
|
| 55 |
+
embed: null
|
| 56 |
+
show: false
|
| 57 |
+
save_frames: false
|
| 58 |
+
save_txt: false
|
| 59 |
+
save_conf: false
|
| 60 |
+
save_crop: false
|
| 61 |
+
show_labels: true
|
| 62 |
+
show_conf: true
|
| 63 |
+
show_boxes: true
|
| 64 |
+
line_width: null
|
| 65 |
+
format: torchscript
|
| 66 |
+
keras: false
|
| 67 |
+
optimize: false
|
| 68 |
+
int8: false
|
| 69 |
+
dynamic: false
|
| 70 |
+
simplify: true
|
| 71 |
+
opset: null
|
| 72 |
+
workspace: null
|
| 73 |
+
nms: false
|
| 74 |
+
lr0: 0.01
|
| 75 |
+
lrf: 0.01
|
| 76 |
+
momentum: 0.937
|
| 77 |
+
weight_decay: 0.0005
|
| 78 |
+
warmup_epochs: 3.0
|
| 79 |
+
warmup_momentum: 0.8
|
| 80 |
+
warmup_bias_lr: 0.1
|
| 81 |
+
box: 7.5
|
| 82 |
+
cls: 0.5
|
| 83 |
+
dfl: 1.5
|
| 84 |
+
pose: 12.0
|
| 85 |
+
kobj: 1.0
|
| 86 |
+
rle: 1.0
|
| 87 |
+
angle: 1.0
|
| 88 |
+
nbs: 64
|
| 89 |
+
hsv_h: 0.015
|
| 90 |
+
hsv_s: 0.7
|
| 91 |
+
hsv_v: 0.4
|
| 92 |
+
degrees: 0.0
|
| 93 |
+
translate: 0.1
|
| 94 |
+
scale: 0.5
|
| 95 |
+
shear: 0.0
|
| 96 |
+
perspective: 0.0
|
| 97 |
+
flipud: 0.0
|
| 98 |
+
fliplr: 0.5
|
| 99 |
+
bgr: 0.0
|
| 100 |
+
mosaic: 1.0
|
| 101 |
+
mixup: 0.0
|
| 102 |
+
cutmix: 0.0
|
| 103 |
+
copy_paste: 0.0
|
| 104 |
+
copy_paste_mode: flip
|
| 105 |
+
auto_augment: randaugment
|
| 106 |
+
erasing: 0.4
|
| 107 |
+
cfg: null
|
| 108 |
+
tracker: botsort.yaml
|
| 109 |
+
save_dir: /Users/wangtian/workspace/set-solver/weights/detector
|
detector/plots/BoxF1_curve.png
ADDED
|
Git LFS Details
|
detector/plots/BoxPR_curve.png
ADDED
|
Git LFS Details
|
detector/plots/BoxP_curve.png
ADDED
|
Git LFS Details
|
detector/plots/BoxR_curve.png
ADDED
|
Git LFS Details
|
detector/plots/confusion_matrix.png
ADDED
|
Git LFS Details
|
detector/plots/confusion_matrix_normalized.png
ADDED
|
Git LFS Details
|
detector/plots/results.png
ADDED
|
Git LFS Details
|
detector/results.csv
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
epoch,time,train/box_loss,train/cls_loss,train/dfl_loss,metrics/precision(B),metrics/recall(B),metrics/mAP50(B),metrics/mAP50-95(B),val/box_loss,val/cls_loss,val/dfl_loss,lr/pg0,lr/pg1,lr/pg2
|
| 2 |
+
1,111.468,0.40144,1.19864,0.84562,0.99981,0.99984,0.995,0.9085,0.40196,1.20815,0.82106,0.000656085,0.000656085,0.000656085
|
| 3 |
+
2,206.25,0.34115,0.4089,0.82012,1,1,0.995,0.95523,0.29123,0.4463,0.79571,0.0011918,0.0011918,0.0011918
|
| 4 |
+
3,304.901,0.32374,0.34916,0.8127,0.99999,1,0.995,0.93135,0.35031,0.37025,0.79607,0.00159551,0.00159551,0.00159551
|
| 5 |
+
4,395.233,0.31295,0.32333,0.81186,0.99997,1,0.995,0.96262,0.26761,0.2805,0.79735,0.001406,0.001406,0.001406
|
| 6 |
+
5,492.083,0.26809,0.29046,0.80266,0.99989,1,0.995,0.97475,0.24272,0.25697,0.7803,0.001208,0.001208,0.001208
|
| 7 |
+
6,606.75,0.23878,0.24702,0.79441,0.99999,1,0.995,0.98544,0.21188,0.22382,0.77178,0.00101,0.00101,0.00101
|
| 8 |
+
7,741.542,0.22471,0.22465,0.79115,0.99999,0.99513,0.995,0.97671,0.21286,0.2033,0.77386,0.000812,0.000812,0.000812
|
| 9 |
+
8,873.213,0.2143,0.21188,0.78911,0.99999,0.98882,0.98531,0.97465,0.17893,0.19046,0.76739,0.000614,0.000614,0.000614
|
| 10 |
+
9,1041.5,0.19247,0.19322,0.78823,0.99999,0.97096,0.97537,0.9692,0.14963,0.16733,0.76335,0.000416,0.000416,0.000416
|
| 11 |
+
10,1205.28,0.17649,0.17659,0.78384,0.99999,0.99914,0.995,0.99123,0.13999,0.14978,0.76308,0.000218,0.000218,0.000218
|
detector/weights/best.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ccd020c044339d0cd2e671a837f209a8adcdcef17982cbae90da2b962e160081
|
| 3 |
+
size 10477995
|
detector/weights/best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d65deae13124271df8739b700d2f893bca1eb7a7bc8ac870702e714b787ceee7
|
| 3 |
+
size 5453594
|
example-detection.png
ADDED
|
Git LFS Details
|