Spaces:
Running
Running
Tian Wang
commited on
Commit
·
8a34385
1
Parent(s):
c3f6e96
Deploy Set Solver web app
Browse files- .dockerignore +18 -0
- Dockerfile +21 -0
- README.md +9 -4
- requirements-web.txt +11 -0
- src/__init__.py +1 -0
- src/inference/__init__.py +1 -0
- src/inference/classify.py +103 -0
- src/inference/solve.py +424 -0
- src/solver/__init__.py +11 -0
- src/solver/set_finder.py +173 -0
- src/train/__init__.py +1 -0
- src/train/classifier.py +361 -0
- src/web/__init__.py +0 -0
- src/web/app.py +113 -0
- src/web/templates/index.html +383 -0
- weights/classifier_best.pt +3 -0
- weights/detector/weights/best.pt +3 -0
.dockerignore
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/
|
| 2 |
+
.git/
|
| 3 |
+
__pycache__/
|
| 4 |
+
*.pyc
|
| 5 |
+
data/
|
| 6 |
+
training_images/
|
| 7 |
+
docs/
|
| 8 |
+
scripts/
|
| 9 |
+
*.ipynb
|
| 10 |
+
.DS_Store
|
| 11 |
+
showcase.html
|
| 12 |
+
|
| 13 |
+
# Training artifacts in weights/detector (keep only weights/*.pt)
|
| 14 |
+
weights/detector/*.png
|
| 15 |
+
weights/detector/*.jpg
|
| 16 |
+
weights/detector/*.csv
|
| 17 |
+
weights/detector/weights/last.pt
|
| 18 |
+
weights/detector/weights/best.onnx
|
Dockerfile
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system deps for opencv
|
| 6 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 7 |
+
libgl1 libglib2.0-0 \
|
| 8 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 9 |
+
|
| 10 |
+
# Install Python deps (CPU-only torch)
|
| 11 |
+
COPY requirements-web.txt .
|
| 12 |
+
RUN pip install --no-cache-dir -r requirements-web.txt
|
| 13 |
+
|
| 14 |
+
# Copy application code and weights
|
| 15 |
+
COPY src/ src/
|
| 16 |
+
COPY weights/ weights/
|
| 17 |
+
|
| 18 |
+
# Hugging Face Spaces uses port 7860
|
| 19 |
+
EXPOSE 7860
|
| 20 |
+
|
| 21 |
+
CMD ["uvicorn", "src.web.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,10 +1,15 @@
|
|
| 1 |
---
|
| 2 |
title: Set Solver
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
|
|
|
| 7 |
pinned: false
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: Set Solver
|
| 3 |
+
emoji: 🃏
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
pinned: false
|
| 9 |
---
|
| 10 |
|
| 11 |
+
# Set Solver
|
| 12 |
+
|
| 13 |
+
Vision-based solver for the [Set card game](https://www.setgame.com/).
|
| 14 |
+
|
| 15 |
+
Point your camera at Set cards → Get all valid Sets highlighted in real time.
|
requirements-web.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Web deployment only (CPU inference)
|
| 2 |
+
--extra-index-url https://download.pytorch.org/whl/cpu
|
| 3 |
+
torch
|
| 4 |
+
torchvision
|
| 5 |
+
ultralytics>=8.0
|
| 6 |
+
pillow>=10.0
|
| 7 |
+
numpy>=1.24
|
| 8 |
+
opencv-python-headless>=4.8
|
| 9 |
+
fastapi
|
| 10 |
+
uvicorn[standard]
|
| 11 |
+
python-multipart
|
src/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Set Solver
|
src/inference/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Inference scripts
|
src/inference/classify.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference script for classifying a single card image.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
# Import from training module
|
| 11 |
+
import sys
|
| 12 |
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
| 13 |
+
from src.train.classifier import (
|
| 14 |
+
SetCardClassifier,
|
| 15 |
+
NUMBER_NAMES, COLOR_NAMES, SHAPE_NAMES, FILL_NAMES
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
WEIGHTS_DIR = Path(__file__).parent.parent.parent / "weights"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def load_model(weights_path: Path = None, device: str = None):
|
| 22 |
+
"""Load trained classifier."""
|
| 23 |
+
if weights_path is None:
|
| 24 |
+
weights_path = WEIGHTS_DIR / "classifier_best.pt"
|
| 25 |
+
|
| 26 |
+
if device is None:
|
| 27 |
+
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
| 28 |
+
|
| 29 |
+
model = SetCardClassifier(pretrained=False)
|
| 30 |
+
checkpoint = torch.load(weights_path, map_location=device)
|
| 31 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 32 |
+
model.to(device)
|
| 33 |
+
model.eval()
|
| 34 |
+
|
| 35 |
+
return model, device
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def classify_card(image: Image.Image, model, device) -> dict:
|
| 39 |
+
"""
|
| 40 |
+
Classify a card image.
|
| 41 |
+
|
| 42 |
+
Returns dict with predicted attributes and confidences.
|
| 43 |
+
"""
|
| 44 |
+
transform = transforms.Compose([
|
| 45 |
+
transforms.Resize((224, 224)),
|
| 46 |
+
transforms.ToTensor(),
|
| 47 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 48 |
+
])
|
| 49 |
+
|
| 50 |
+
img_tensor = transform(image).unsqueeze(0).to(device)
|
| 51 |
+
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
outputs = model(img_tensor)
|
| 54 |
+
|
| 55 |
+
# Get predictions and confidences
|
| 56 |
+
result = {}
|
| 57 |
+
for key, names in [
|
| 58 |
+
("number", NUMBER_NAMES),
|
| 59 |
+
("color", COLOR_NAMES),
|
| 60 |
+
("shape", SHAPE_NAMES),
|
| 61 |
+
("fill", FILL_NAMES),
|
| 62 |
+
]:
|
| 63 |
+
probs = torch.softmax(outputs[key], dim=1)[0]
|
| 64 |
+
pred_idx = probs.argmax().item()
|
| 65 |
+
result[key] = {
|
| 66 |
+
"value": names[pred_idx],
|
| 67 |
+
"confidence": probs[pred_idx].item(),
|
| 68 |
+
"all_probs": {name: probs[i].item() for i, name in enumerate(names)},
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
return result
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def main():
|
| 75 |
+
import argparse
|
| 76 |
+
|
| 77 |
+
parser = argparse.ArgumentParser(description="Classify a Set card image")
|
| 78 |
+
parser.add_argument("image", type=str, help="Path to card image")
|
| 79 |
+
args = parser.parse_args()
|
| 80 |
+
|
| 81 |
+
print("Loading model...")
|
| 82 |
+
model, device = load_model()
|
| 83 |
+
|
| 84 |
+
print(f"Classifying {args.image}...")
|
| 85 |
+
image = Image.open(args.image).convert("RGB")
|
| 86 |
+
result = classify_card(image, model, device)
|
| 87 |
+
|
| 88 |
+
print("\nPrediction:")
|
| 89 |
+
print(f" Number: {result['number']['value']} ({result['number']['confidence']:.1%})")
|
| 90 |
+
print(f" Color: {result['color']['value']} ({result['color']['confidence']:.1%})")
|
| 91 |
+
print(f" Shape: {result['shape']['value']} ({result['shape']['confidence']:.1%})")
|
| 92 |
+
print(f" Fill: {result['fill']['value']} ({result['fill']['confidence']:.1%})")
|
| 93 |
+
|
| 94 |
+
# Human-readable card name
|
| 95 |
+
n = result['number']['value']
|
| 96 |
+
c = result['color']['value']
|
| 97 |
+
s = result['shape']['value']
|
| 98 |
+
f = result['fill']['value']
|
| 99 |
+
print(f"\nCard: {n} {f} {c} {s}(s)")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if __name__ == "__main__":
|
| 103 |
+
main()
|
src/inference/solve.py
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
End-to-end Set solver pipeline.
|
| 3 |
+
|
| 4 |
+
Photo → Detect cards → Classify each → Find Sets → Visualize
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import List, Tuple, Optional
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 13 |
+
from ultralytics import YOLO
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
# Add parent to path for imports
|
| 17 |
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
| 18 |
+
|
| 19 |
+
from src.train.classifier import (
|
| 20 |
+
SetCardClassifier,
|
| 21 |
+
NUMBER_NAMES, COLOR_NAMES, SHAPE_NAMES, FILL_NAMES,
|
| 22 |
+
)
|
| 23 |
+
from src.solver.set_finder import Card, Shape, Color, Number, Fill, find_all_sets
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
WEIGHTS_DIR = Path(__file__).parent.parent.parent / "weights"
|
| 27 |
+
DATA_WEIGHTS_DIR = Path.home() / "data" / "set-solver" / "weights"
|
| 28 |
+
|
| 29 |
+
# Chinese shorthand names: {1,2,3}-{实,空,线}-{红,绿,紫}-{菱,圆,弯}
|
| 30 |
+
CHINESE_NUMBER = {"one": "1", "two": "2", "three": "3"}
|
| 31 |
+
CHINESE_FILL = {"full": "实", "empty": "空", "partial": "线"}
|
| 32 |
+
CHINESE_COLOR = {"red": "红", "green": "绿", "blue": "紫"}
|
| 33 |
+
CHINESE_SHAPE = {"diamond": "菱", "oval": "圆", "squiggle": "弯"}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def card_to_chinese(attrs: dict) -> str:
|
| 37 |
+
"""Convert card attributes to Chinese shorthand like '2实红菱'."""
|
| 38 |
+
num = CHINESE_NUMBER.get(attrs['number'], attrs['number'])
|
| 39 |
+
fill = CHINESE_FILL.get(attrs['fill'], attrs['fill'])
|
| 40 |
+
color = CHINESE_COLOR.get(attrs['color'], attrs['color'])
|
| 41 |
+
shape = CHINESE_SHAPE.get(attrs['shape'], attrs['shape'])
|
| 42 |
+
return f"{num}{fill}{color}{shape}"
|
| 43 |
+
|
| 44 |
+
# Colors for highlighting Sets (RGB)
|
| 45 |
+
SET_COLORS = [
|
| 46 |
+
(255, 0, 0), # Red
|
| 47 |
+
(0, 255, 0), # Green
|
| 48 |
+
(0, 0, 255), # Blue
|
| 49 |
+
(255, 255, 0), # Yellow
|
| 50 |
+
(255, 0, 255), # Magenta
|
| 51 |
+
(0, 255, 255), # Cyan
|
| 52 |
+
(255, 128, 0), # Orange
|
| 53 |
+
(128, 0, 255), # Purple
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class SetSolver:
|
| 58 |
+
"""End-to-end Set solver."""
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
detector_path: Optional[Path] = None,
|
| 63 |
+
classifier_path: Optional[Path] = None,
|
| 64 |
+
device: Optional[str] = None,
|
| 65 |
+
):
|
| 66 |
+
if device is None:
|
| 67 |
+
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
| 68 |
+
self.device = device
|
| 69 |
+
|
| 70 |
+
# Load detector
|
| 71 |
+
if detector_path is None:
|
| 72 |
+
# Check ~/data first, then repo weights
|
| 73 |
+
data_path = DATA_WEIGHTS_DIR / "detector" / "weights" / "best.pt"
|
| 74 |
+
repo_path = WEIGHTS_DIR / "detector" / "weights" / "best.pt"
|
| 75 |
+
detector_path = data_path if data_path.exists() else repo_path
|
| 76 |
+
print(f"Loading detector from {detector_path}")
|
| 77 |
+
self.detector = YOLO(str(detector_path))
|
| 78 |
+
|
| 79 |
+
# Load classifier
|
| 80 |
+
if classifier_path is None:
|
| 81 |
+
classifier_path = WEIGHTS_DIR / "classifier_best.pt"
|
| 82 |
+
print(f"Loading classifier from {classifier_path}")
|
| 83 |
+
self.classifier = SetCardClassifier(pretrained=False)
|
| 84 |
+
checkpoint = torch.load(classifier_path, map_location=device)
|
| 85 |
+
self.classifier.load_state_dict(checkpoint["model_state_dict"])
|
| 86 |
+
self.classifier.to(device)
|
| 87 |
+
self.classifier.eval()
|
| 88 |
+
|
| 89 |
+
# Classifier transform
|
| 90 |
+
from torchvision import transforms
|
| 91 |
+
self.transform = transforms.Compose([
|
| 92 |
+
transforms.Resize((224, 224)),
|
| 93 |
+
transforms.ToTensor(),
|
| 94 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 95 |
+
])
|
| 96 |
+
|
| 97 |
+
def detect_cards(self, image: Image.Image, conf: float = 0.5) -> List[dict]:
|
| 98 |
+
"""
|
| 99 |
+
Detect cards in image.
|
| 100 |
+
|
| 101 |
+
Returns list of detections with bounding boxes.
|
| 102 |
+
Filters out oversized detections that likely merged two cards.
|
| 103 |
+
"""
|
| 104 |
+
results = self.detector(image, conf=conf, verbose=False)
|
| 105 |
+
|
| 106 |
+
detections = []
|
| 107 |
+
for result in results:
|
| 108 |
+
boxes = result.boxes
|
| 109 |
+
for box in boxes:
|
| 110 |
+
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
| 111 |
+
c = box.conf[0].cpu().item()
|
| 112 |
+
w, h = x2 - x1, y2 - y1
|
| 113 |
+
detections.append({
|
| 114 |
+
"bbox": (int(x1), int(y1), int(x2), int(y2)),
|
| 115 |
+
"confidence": c,
|
| 116 |
+
"area": w * h,
|
| 117 |
+
})
|
| 118 |
+
|
| 119 |
+
# Filter out merged detections: if a box is >2x the median area,
|
| 120 |
+
# it's likely covering two cards
|
| 121 |
+
if len(detections) >= 3:
|
| 122 |
+
areas = sorted(d["area"] for d in detections)
|
| 123 |
+
median_area = areas[len(areas) // 2]
|
| 124 |
+
detections = [d for d in detections if d["area"] <= median_area * 2.2]
|
| 125 |
+
|
| 126 |
+
return detections
|
| 127 |
+
|
| 128 |
+
def classify_card(self, card_image: Image.Image) -> dict:
|
| 129 |
+
"""Classify a cropped card image."""
|
| 130 |
+
img_tensor = self.transform(card_image).unsqueeze(0).to(self.device)
|
| 131 |
+
|
| 132 |
+
with torch.no_grad():
|
| 133 |
+
outputs = self.classifier(img_tensor)
|
| 134 |
+
|
| 135 |
+
result = {}
|
| 136 |
+
for key, names in [
|
| 137 |
+
("number", NUMBER_NAMES),
|
| 138 |
+
("color", COLOR_NAMES),
|
| 139 |
+
("shape", SHAPE_NAMES),
|
| 140 |
+
("fill", FILL_NAMES),
|
| 141 |
+
]:
|
| 142 |
+
probs = torch.softmax(outputs[key], dim=1)[0]
|
| 143 |
+
pred_idx = probs.argmax().item()
|
| 144 |
+
result[key] = names[pred_idx]
|
| 145 |
+
result[f"{key}_conf"] = probs[pred_idx].item()
|
| 146 |
+
|
| 147 |
+
return result
|
| 148 |
+
|
| 149 |
+
def detection_to_card(self, attrs: dict, bbox: Tuple[int, int, int, int]) -> Card:
|
| 150 |
+
"""Convert classification result to Card object."""
|
| 151 |
+
# Map classifier output to solver enums
|
| 152 |
+
# Training data uses "blue" but standard Set calls it "purple"
|
| 153 |
+
color_map = {"red": "RED", "green": "GREEN", "blue": "PURPLE"}
|
| 154 |
+
# Training data uses "partial" for striped, "full" for solid
|
| 155 |
+
fill_map = {"empty": "EMPTY", "full": "SOLID", "partial": "STRIPED"}
|
| 156 |
+
|
| 157 |
+
return Card(
|
| 158 |
+
shape=Shape[attrs["shape"].upper()],
|
| 159 |
+
color=Color[color_map[attrs["color"]]],
|
| 160 |
+
number=Number[attrs["number"].upper()],
|
| 161 |
+
fill=Fill[fill_map[attrs["fill"]]],
|
| 162 |
+
bbox=bbox,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
def solve_from_image(
|
| 166 |
+
self,
|
| 167 |
+
image: Image.Image,
|
| 168 |
+
conf: float = 0.5,
|
| 169 |
+
) -> dict:
|
| 170 |
+
"""
|
| 171 |
+
Solve a Set game from a PIL Image directly.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
image: PIL Image (RGB)
|
| 175 |
+
conf: Detection confidence threshold
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
Dict with detected cards, found Sets, and annotated result image
|
| 179 |
+
"""
|
| 180 |
+
image = image.convert("RGB")
|
| 181 |
+
|
| 182 |
+
detections = self.detect_cards(image, conf=conf)
|
| 183 |
+
|
| 184 |
+
cards = []
|
| 185 |
+
for det in detections:
|
| 186 |
+
x1, y1, x2, y2 = det["bbox"]
|
| 187 |
+
card_crop = image.crop((x1, y1, x2, y2))
|
| 188 |
+
attrs = self.classify_card(card_crop)
|
| 189 |
+
card = self.detection_to_card(attrs, det["bbox"])
|
| 190 |
+
cards.append({
|
| 191 |
+
"card": card,
|
| 192 |
+
"attrs": attrs,
|
| 193 |
+
"detection": det,
|
| 194 |
+
})
|
| 195 |
+
|
| 196 |
+
card_objects = [c["card"] for c in cards]
|
| 197 |
+
sets = find_all_sets(card_objects)
|
| 198 |
+
|
| 199 |
+
# Generate one annotated image per set (each highlighting only that set)
|
| 200 |
+
result_images = []
|
| 201 |
+
if sets:
|
| 202 |
+
for i in range(len(sets)):
|
| 203 |
+
result_images.append(self._draw_results(image, cards, sets, highlight_idx=i))
|
| 204 |
+
else:
|
| 205 |
+
result_images.append(self._draw_results(image, cards, sets))
|
| 206 |
+
|
| 207 |
+
return {
|
| 208 |
+
"num_cards": len(cards),
|
| 209 |
+
"cards": [
|
| 210 |
+
{
|
| 211 |
+
"attrs": c["attrs"],
|
| 212 |
+
"chinese": card_to_chinese(c["attrs"]),
|
| 213 |
+
"bbox": c["detection"]["bbox"],
|
| 214 |
+
"confidence": c["detection"]["confidence"],
|
| 215 |
+
}
|
| 216 |
+
for c in cards
|
| 217 |
+
],
|
| 218 |
+
"num_sets": len(sets),
|
| 219 |
+
"sets": [
|
| 220 |
+
[str(card) for card in s]
|
| 221 |
+
for s in sets
|
| 222 |
+
],
|
| 223 |
+
"sets_chinese": [
|
| 224 |
+
[card_to_chinese(next(c["attrs"] for c in cards if c["card"] is card)) for card in s]
|
| 225 |
+
for s in sets
|
| 226 |
+
],
|
| 227 |
+
"sets_bboxes": [
|
| 228 |
+
[card.bbox for card in s]
|
| 229 |
+
for s in sets
|
| 230 |
+
],
|
| 231 |
+
"result_images": result_images,
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
def solve(
|
| 235 |
+
self,
|
| 236 |
+
image_path: str,
|
| 237 |
+
conf: float = 0.5,
|
| 238 |
+
output_path: Optional[str] = None,
|
| 239 |
+
show: bool = False,
|
| 240 |
+
) -> dict:
|
| 241 |
+
"""
|
| 242 |
+
Solve a Set game from image.
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
image_path: Path to input image
|
| 246 |
+
conf: Detection confidence threshold
|
| 247 |
+
output_path: Path to save annotated output image
|
| 248 |
+
show: Whether to display the result
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
Dict with detected cards and found Sets
|
| 252 |
+
"""
|
| 253 |
+
# Load image
|
| 254 |
+
image = Image.open(image_path).convert("RGB")
|
| 255 |
+
print(f"Loaded image: {image.size}")
|
| 256 |
+
|
| 257 |
+
# Detect cards
|
| 258 |
+
print("Detecting cards...")
|
| 259 |
+
detections = self.detect_cards(image, conf=conf)
|
| 260 |
+
print(f"Found {len(detections)} cards")
|
| 261 |
+
|
| 262 |
+
# Classify each card
|
| 263 |
+
print("Classifying cards...")
|
| 264 |
+
cards = []
|
| 265 |
+
for det in detections:
|
| 266 |
+
x1, y1, x2, y2 = det["bbox"]
|
| 267 |
+
card_crop = image.crop((x1, y1, x2, y2))
|
| 268 |
+
attrs = self.classify_card(card_crop)
|
| 269 |
+
card = self.detection_to_card(attrs, det["bbox"])
|
| 270 |
+
cards.append({
|
| 271 |
+
"card": card,
|
| 272 |
+
"attrs": attrs,
|
| 273 |
+
"detection": det,
|
| 274 |
+
})
|
| 275 |
+
|
| 276 |
+
# Find Sets
|
| 277 |
+
print("Finding Sets...")
|
| 278 |
+
card_objects = [c["card"] for c in cards]
|
| 279 |
+
sets = find_all_sets(card_objects)
|
| 280 |
+
print(f"Found {len(sets)} valid Set(s)")
|
| 281 |
+
|
| 282 |
+
# Draw results
|
| 283 |
+
result_image = self._draw_results(image, cards, sets)
|
| 284 |
+
|
| 285 |
+
if output_path:
|
| 286 |
+
result_image.save(output_path)
|
| 287 |
+
print(f"Saved result to {output_path}")
|
| 288 |
+
|
| 289 |
+
if show:
|
| 290 |
+
result_image.show()
|
| 291 |
+
|
| 292 |
+
return {
|
| 293 |
+
"num_cards": len(cards),
|
| 294 |
+
"cards": [
|
| 295 |
+
{
|
| 296 |
+
"attrs": c["attrs"],
|
| 297 |
+
"chinese": card_to_chinese(c["attrs"]),
|
| 298 |
+
"bbox": c["detection"]["bbox"],
|
| 299 |
+
"confidence": c["detection"]["confidence"],
|
| 300 |
+
}
|
| 301 |
+
for c in cards
|
| 302 |
+
],
|
| 303 |
+
"num_sets": len(sets),
|
| 304 |
+
"sets": [
|
| 305 |
+
[str(card) for card in s]
|
| 306 |
+
for s in sets
|
| 307 |
+
],
|
| 308 |
+
"sets_chinese": [
|
| 309 |
+
[card_to_chinese(next(c["attrs"] for c in cards if c["card"] is card)) for card in s]
|
| 310 |
+
for s in sets
|
| 311 |
+
],
|
| 312 |
+
"result_image": result_image,
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
def _draw_results(
|
| 316 |
+
self,
|
| 317 |
+
image: Image.Image,
|
| 318 |
+
cards: List[dict],
|
| 319 |
+
sets: List[Tuple[Card, Card, Card]],
|
| 320 |
+
highlight_idx: Optional[int] = None,
|
| 321 |
+
) -> Image.Image:
|
| 322 |
+
"""Draw bounding boxes and Set highlights on image.
|
| 323 |
+
|
| 324 |
+
Args:
|
| 325 |
+
highlight_idx: If set, only highlight this one set (0-based).
|
| 326 |
+
If None, highlight all sets.
|
| 327 |
+
"""
|
| 328 |
+
result = image.copy()
|
| 329 |
+
draw = ImageDraw.Draw(result)
|
| 330 |
+
|
| 331 |
+
# Try to load a Chinese-compatible font
|
| 332 |
+
font = None
|
| 333 |
+
font_paths = [
|
| 334 |
+
"/System/Library/Fonts/PingFang.ttc", # macOS
|
| 335 |
+
"/System/Library/Fonts/STHeiti Light.ttc", # macOS
|
| 336 |
+
"/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf", # Linux
|
| 337 |
+
"C:\\Windows\\Fonts\\msyh.ttc", # Windows
|
| 338 |
+
]
|
| 339 |
+
for font_path in font_paths:
|
| 340 |
+
try:
|
| 341 |
+
font = ImageFont.truetype(font_path, 18)
|
| 342 |
+
break
|
| 343 |
+
except:
|
| 344 |
+
continue
|
| 345 |
+
if font is None:
|
| 346 |
+
font = ImageFont.load_default()
|
| 347 |
+
|
| 348 |
+
# Determine which set(s) to highlight
|
| 349 |
+
if highlight_idx is not None and 0 <= highlight_idx < len(sets):
|
| 350 |
+
highlighted_sets = [(highlight_idx, sets[highlight_idx])]
|
| 351 |
+
else:
|
| 352 |
+
highlighted_sets = list(enumerate(sets))
|
| 353 |
+
|
| 354 |
+
# Build set of highlighted card ids
|
| 355 |
+
highlighted_card_ids = set()
|
| 356 |
+
for _, card_set in highlighted_sets:
|
| 357 |
+
for card in card_set:
|
| 358 |
+
highlighted_card_ids.add(id(card))
|
| 359 |
+
|
| 360 |
+
# Draw only highlighted cards
|
| 361 |
+
for c in cards:
|
| 362 |
+
card = c["card"]
|
| 363 |
+
if id(card) not in highlighted_card_ids:
|
| 364 |
+
continue
|
| 365 |
+
attrs = c["attrs"]
|
| 366 |
+
x1, y1, x2, y2 = card.bbox
|
| 367 |
+
|
| 368 |
+
color_idx = highlighted_sets[0][0] if len(highlighted_sets) == 1 else 0
|
| 369 |
+
for si, card_set in highlighted_sets:
|
| 370 |
+
if card in card_set:
|
| 371 |
+
color_idx = si
|
| 372 |
+
break
|
| 373 |
+
color = SET_COLORS[color_idx % len(SET_COLORS)]
|
| 374 |
+
draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
|
| 375 |
+
|
| 376 |
+
label = card_to_chinese(attrs)
|
| 377 |
+
draw.text((x1, y1 - 20), label, fill=color, font=font)
|
| 378 |
+
|
| 379 |
+
# Draw Set info
|
| 380 |
+
if highlight_idx is not None:
|
| 381 |
+
draw.text((10, 10), f"Set {highlight_idx + 1} / {len(sets)}", fill=(255, 255, 255), font=font)
|
| 382 |
+
else:
|
| 383 |
+
draw.text((10, 10), f"Found {len(sets)} Set(s)", fill=(255, 255, 255), font=font)
|
| 384 |
+
|
| 385 |
+
return result
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def main():
|
| 389 |
+
import argparse
|
| 390 |
+
|
| 391 |
+
parser = argparse.ArgumentParser(description="Solve Set game from image")
|
| 392 |
+
parser.add_argument("image", type=str, help="Path to input image")
|
| 393 |
+
parser.add_argument("--output", "-o", type=str, help="Path to save output image")
|
| 394 |
+
parser.add_argument("--conf", type=float, default=0.25, help="Detection confidence")
|
| 395 |
+
parser.add_argument("--show", action="store_true", help="Display result")
|
| 396 |
+
args = parser.parse_args()
|
| 397 |
+
|
| 398 |
+
solver = SetSolver()
|
| 399 |
+
result = solver.solve(
|
| 400 |
+
args.image,
|
| 401 |
+
conf=args.conf,
|
| 402 |
+
output_path=args.output,
|
| 403 |
+
show=args.show,
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
print("\n" + "="*50)
|
| 407 |
+
print("结果 RESULTS")
|
| 408 |
+
print("="*50)
|
| 409 |
+
print(f"检测到卡牌: {result['num_cards']}")
|
| 410 |
+
print(f"找到Set: {result['num_sets']}")
|
| 411 |
+
|
| 412 |
+
if result['cards']:
|
| 413 |
+
print("\n卡牌:")
|
| 414 |
+
for c in result['cards']:
|
| 415 |
+
print(f" {c['chinese']}")
|
| 416 |
+
|
| 417 |
+
if result['sets_chinese']:
|
| 418 |
+
print("\nSets:")
|
| 419 |
+
for i, s in enumerate(result['sets_chinese'], 1):
|
| 420 |
+
print(f" Set {i}: {' + '.join(s)}")
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
if __name__ == "__main__":
|
| 424 |
+
main()
|
src/solver/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .set_finder import (
|
| 2 |
+
Card, Shape, Color, Number, Fill,
|
| 3 |
+
is_valid_set, find_all_sets, find_first_set,
|
| 4 |
+
generate_all_cards, card_to_index, index_to_card
|
| 5 |
+
)
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
'Card', 'Shape', 'Color', 'Number', 'Fill',
|
| 9 |
+
'is_valid_set', 'find_all_sets', 'find_first_set',
|
| 10 |
+
'generate_all_cards', 'card_to_index', 'index_to_card'
|
| 11 |
+
]
|
src/solver/set_finder.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Set-finding algorithm.
|
| 3 |
+
|
| 4 |
+
A valid Set consists of 3 cards where, for each attribute,
|
| 5 |
+
the values are either ALL THE SAME or ALL DIFFERENT.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from enum import IntEnum
|
| 10 |
+
from itertools import combinations
|
| 11 |
+
from typing import List, Tuple
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Shape(IntEnum):
|
| 15 |
+
DIAMOND = 0
|
| 16 |
+
OVAL = 1
|
| 17 |
+
SQUIGGLE = 2
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Color(IntEnum):
|
| 21 |
+
RED = 0
|
| 22 |
+
GREEN = 1
|
| 23 |
+
PURPLE = 2
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Number(IntEnum):
|
| 27 |
+
ONE = 0
|
| 28 |
+
TWO = 1
|
| 29 |
+
THREE = 2
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Fill(IntEnum):
|
| 33 |
+
SOLID = 0
|
| 34 |
+
STRIPED = 1
|
| 35 |
+
EMPTY = 2
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class Card:
|
| 40 |
+
"""A Set card with 4 attributes."""
|
| 41 |
+
shape: Shape
|
| 42 |
+
color: Color
|
| 43 |
+
number: Number
|
| 44 |
+
fill: Fill
|
| 45 |
+
|
| 46 |
+
# Optional: position in image (for visualization)
|
| 47 |
+
bbox: Tuple[float, float, float, float] = None # x, y, w, h
|
| 48 |
+
|
| 49 |
+
def __hash__(self):
|
| 50 |
+
return hash((self.shape, self.color, self.number, self.fill))
|
| 51 |
+
|
| 52 |
+
def __eq__(self, other):
|
| 53 |
+
if not isinstance(other, Card):
|
| 54 |
+
return False
|
| 55 |
+
return (self.shape == other.shape and
|
| 56 |
+
self.color == other.color and
|
| 57 |
+
self.number == other.number and
|
| 58 |
+
self.fill == other.fill)
|
| 59 |
+
|
| 60 |
+
def to_tuple(self) -> Tuple[int, int, int, int]:
|
| 61 |
+
"""Return attributes as tuple of ints."""
|
| 62 |
+
return (self.shape, self.color, self.number, self.fill)
|
| 63 |
+
|
| 64 |
+
@classmethod
|
| 65 |
+
def from_tuple(cls, attrs: Tuple[int, int, int, int], bbox=None) -> "Card":
|
| 66 |
+
"""Create card from tuple of attribute indices."""
|
| 67 |
+
return cls(
|
| 68 |
+
shape=Shape(attrs[0]),
|
| 69 |
+
color=Color(attrs[1]),
|
| 70 |
+
number=Number(attrs[2]),
|
| 71 |
+
fill=Fill(attrs[3]),
|
| 72 |
+
bbox=bbox
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def __repr__(self):
|
| 76 |
+
n = ["one", "two", "three"][self.number]
|
| 77 |
+
return f"{n} {self.fill.name.lower()} {self.color.name.lower()} {self.shape.name.lower()}(s)"
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def is_valid_set(card1: Card, card2: Card, card3: Card) -> bool:
|
| 81 |
+
"""
|
| 82 |
+
Check if three cards form a valid Set.
|
| 83 |
+
|
| 84 |
+
For each attribute, the three values must be either:
|
| 85 |
+
- All the same (e.g., all red)
|
| 86 |
+
- All different (e.g., red, green, purple)
|
| 87 |
+
"""
|
| 88 |
+
for attr in ['shape', 'color', 'number', 'fill']:
|
| 89 |
+
values = [getattr(card1, attr), getattr(card2, attr), getattr(card3, attr)]
|
| 90 |
+
unique = len(set(values))
|
| 91 |
+
# Valid: all same (1 unique) or all different (3 unique)
|
| 92 |
+
# Invalid: exactly 2 unique
|
| 93 |
+
if unique == 2:
|
| 94 |
+
return False
|
| 95 |
+
return True
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def find_all_sets(cards: List[Card]) -> List[Tuple[Card, Card, Card]]:
|
| 99 |
+
"""
|
| 100 |
+
Find all valid Sets among the given cards.
|
| 101 |
+
|
| 102 |
+
Uses brute force: check all C(n,3) combinations.
|
| 103 |
+
For 12 cards: C(12,3) = 220 combinations - very fast.
|
| 104 |
+
For 21 cards (max in real game): C(21,3) = 1330 combinations - still fast.
|
| 105 |
+
"""
|
| 106 |
+
valid_sets = []
|
| 107 |
+
for combo in combinations(cards, 3):
|
| 108 |
+
if is_valid_set(*combo):
|
| 109 |
+
valid_sets.append(combo)
|
| 110 |
+
return valid_sets
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def find_first_set(cards: List[Card]) -> Tuple[Card, Card, Card] | None:
|
| 114 |
+
"""Find the first valid Set, or None if no Set exists."""
|
| 115 |
+
for combo in combinations(cards, 3):
|
| 116 |
+
if is_valid_set(*combo):
|
| 117 |
+
return combo
|
| 118 |
+
return None
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# --- Utilities ---
|
| 122 |
+
|
| 123 |
+
def generate_all_cards() -> List[Card]:
|
| 124 |
+
"""Generate all 81 unique Set cards."""
|
| 125 |
+
cards = []
|
| 126 |
+
for s in Shape:
|
| 127 |
+
for c in Color:
|
| 128 |
+
for n in Number:
|
| 129 |
+
for f in Fill:
|
| 130 |
+
cards.append(Card(shape=s, color=c, number=n, fill=f))
|
| 131 |
+
return cards
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def card_to_index(card: Card) -> int:
|
| 135 |
+
"""Convert card to unique index (0-80)."""
|
| 136 |
+
return (card.shape * 27 + card.color * 9 + card.number * 3 + card.fill)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def index_to_card(idx: int) -> Card:
|
| 140 |
+
"""Convert index (0-80) to card."""
|
| 141 |
+
fill = idx % 3
|
| 142 |
+
idx //= 3
|
| 143 |
+
number = idx % 3
|
| 144 |
+
idx //= 3
|
| 145 |
+
color = idx % 3
|
| 146 |
+
idx //= 3
|
| 147 |
+
shape = idx
|
| 148 |
+
return Card(Shape(shape), Color(color), Number(number), Fill(fill))
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# --- Demo ---
|
| 152 |
+
|
| 153 |
+
if __name__ == "__main__":
|
| 154 |
+
# Example: find sets in a random deal
|
| 155 |
+
import random
|
| 156 |
+
|
| 157 |
+
all_cards = generate_all_cards()
|
| 158 |
+
print(f"Total cards in deck: {len(all_cards)}")
|
| 159 |
+
|
| 160 |
+
# Deal 12 cards
|
| 161 |
+
deal = random.sample(all_cards, 12)
|
| 162 |
+
print(f"\nDealt {len(deal)} cards:")
|
| 163 |
+
for i, card in enumerate(deal):
|
| 164 |
+
print(f" {i+1}. {card}")
|
| 165 |
+
|
| 166 |
+
# Find all sets
|
| 167 |
+
sets = find_all_sets(deal)
|
| 168 |
+
print(f"\nFound {len(sets)} valid Set(s):")
|
| 169 |
+
for i, (c1, c2, c3) in enumerate(sets):
|
| 170 |
+
print(f"\n Set {i+1}:")
|
| 171 |
+
print(f" - {c1}")
|
| 172 |
+
print(f" - {c2}")
|
| 173 |
+
print(f" - {c3}")
|
src/train/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Training scripts
|
src/train/classifier.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train a card attribute classifier on the existing labeled images.
|
| 3 |
+
|
| 4 |
+
Uses MobileNetV3-Small for iPhone compatibility.
|
| 5 |
+
Multi-head output: predicts all 4 attributes simultaneously.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import json
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Tuple, Dict, List
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
from torch.utils.data import Dataset, DataLoader, random_split
|
| 16 |
+
from torchvision import transforms, models
|
| 17 |
+
from torchvision.io import read_image, ImageReadMode
|
| 18 |
+
from PIL import Image
|
| 19 |
+
import numpy as np
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
|
| 22 |
+
# === Config ===
|
| 23 |
+
|
| 24 |
+
DATA_DIR = Path(__file__).parent.parent.parent / "training_images"
|
| 25 |
+
SYNTHETIC_DATA_DIR = Path(__file__).parent.parent.parent / "training_images_synthetic"
|
| 26 |
+
WEIGHTS_DIR = Path(__file__).parent.parent.parent / "weights"
|
| 27 |
+
WEIGHTS_DIR.mkdir(exist_ok=True)
|
| 28 |
+
|
| 29 |
+
# Attribute mappings (folder names → indices)
|
| 30 |
+
NUMBER_MAP = {"one": 0, "two": 1, "three": 2}
|
| 31 |
+
COLOR_MAP = {"red": 0, "green": 1, "blue": 2} # blue = purple in standard Set
|
| 32 |
+
SHAPE_MAP = {"diamond": 0, "oval": 1, "squiggle": 2}
|
| 33 |
+
FILL_MAP = {"empty": 0, "full": 1, "partial": 2} # partial = striped
|
| 34 |
+
|
| 35 |
+
# Reverse mappings for inference
|
| 36 |
+
NUMBER_NAMES = ["one", "two", "three"]
|
| 37 |
+
COLOR_NAMES = ["red", "green", "blue"]
|
| 38 |
+
SHAPE_NAMES = ["diamond", "oval", "squiggle"]
|
| 39 |
+
FILL_NAMES = ["empty", "full", "partial"]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# === Dataset ===
|
| 43 |
+
|
| 44 |
+
class SetCardDataset(Dataset):
|
| 45 |
+
"""Dataset of labeled Set card images."""
|
| 46 |
+
|
| 47 |
+
def __init__(self, data_dirs, transform=None):
|
| 48 |
+
if isinstance(data_dirs, Path):
|
| 49 |
+
data_dirs = [data_dirs]
|
| 50 |
+
self.transform = transform
|
| 51 |
+
self.samples: List[Tuple[Path, Dict[str, int]]] = []
|
| 52 |
+
|
| 53 |
+
# Walk the directory structure to find all images
|
| 54 |
+
for data_dir in data_dirs:
|
| 55 |
+
if not data_dir.exists():
|
| 56 |
+
continue
|
| 57 |
+
count_before = len(self.samples)
|
| 58 |
+
for number in NUMBER_MAP:
|
| 59 |
+
for color in COLOR_MAP:
|
| 60 |
+
for shape in SHAPE_MAP:
|
| 61 |
+
for fill in FILL_MAP:
|
| 62 |
+
folder = data_dir / number / color / shape / fill
|
| 63 |
+
if folder.exists():
|
| 64 |
+
for img_path in folder.glob("*.png"):
|
| 65 |
+
labels = {
|
| 66 |
+
"number": NUMBER_MAP[number],
|
| 67 |
+
"color": COLOR_MAP[color],
|
| 68 |
+
"shape": SHAPE_MAP[shape],
|
| 69 |
+
"fill": FILL_MAP[fill],
|
| 70 |
+
}
|
| 71 |
+
self.samples.append((img_path, labels))
|
| 72 |
+
print(f"Loaded {len(self.samples) - count_before} samples from {data_dir}")
|
| 73 |
+
|
| 74 |
+
print(f"Total: {len(self.samples)} samples")
|
| 75 |
+
|
| 76 |
+
def __len__(self):
|
| 77 |
+
return len(self.samples)
|
| 78 |
+
|
| 79 |
+
def __getitem__(self, idx):
|
| 80 |
+
img_path, labels = self.samples[idx]
|
| 81 |
+
|
| 82 |
+
# Load image
|
| 83 |
+
image = Image.open(img_path).convert("RGB")
|
| 84 |
+
|
| 85 |
+
if self.transform:
|
| 86 |
+
image = self.transform(image)
|
| 87 |
+
|
| 88 |
+
# Stack labels into tensor
|
| 89 |
+
label_tensor = torch.tensor([
|
| 90 |
+
labels["number"],
|
| 91 |
+
labels["color"],
|
| 92 |
+
labels["shape"],
|
| 93 |
+
labels["fill"],
|
| 94 |
+
], dtype=torch.long)
|
| 95 |
+
|
| 96 |
+
return image, label_tensor
|
| 97 |
+
|
| 98 |
+
def get_raw(self, idx):
|
| 99 |
+
"""Get raw PIL image and labels (no transform)."""
|
| 100 |
+
img_path, labels = self.samples[idx]
|
| 101 |
+
image = Image.open(img_path).convert("RGB")
|
| 102 |
+
label_tensor = torch.tensor([
|
| 103 |
+
labels["number"],
|
| 104 |
+
labels["color"],
|
| 105 |
+
labels["shape"],
|
| 106 |
+
labels["fill"],
|
| 107 |
+
], dtype=torch.long)
|
| 108 |
+
return image, label_tensor
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# === Model ===
|
| 112 |
+
|
| 113 |
+
class SetCardClassifier(nn.Module):
|
| 114 |
+
"""
|
| 115 |
+
Multi-head classifier for Set card attributes.
|
| 116 |
+
|
| 117 |
+
Uses MobileNetV3-Small backbone (good for mobile deployment).
|
| 118 |
+
Four output heads, one per attribute.
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
def __init__(self, pretrained: bool = True):
|
| 122 |
+
super().__init__()
|
| 123 |
+
|
| 124 |
+
# Load pretrained MobileNetV3-Small
|
| 125 |
+
weights = models.MobileNet_V3_Small_Weights.DEFAULT if pretrained else None
|
| 126 |
+
self.backbone = models.mobilenet_v3_small(weights=weights)
|
| 127 |
+
|
| 128 |
+
# Get the feature dimension from the classifier
|
| 129 |
+
in_features = self.backbone.classifier[0].in_features
|
| 130 |
+
|
| 131 |
+
# Remove the original classifier
|
| 132 |
+
self.backbone.classifier = nn.Identity()
|
| 133 |
+
|
| 134 |
+
# Add our multi-head classifier
|
| 135 |
+
self.heads = nn.ModuleDict({
|
| 136 |
+
"number": nn.Linear(in_features, 3),
|
| 137 |
+
"color": nn.Linear(in_features, 3),
|
| 138 |
+
"shape": nn.Linear(in_features, 3),
|
| 139 |
+
"fill": nn.Linear(in_features, 3),
|
| 140 |
+
})
|
| 141 |
+
|
| 142 |
+
def forward(self, x):
|
| 143 |
+
features = self.backbone(x)
|
| 144 |
+
return {
|
| 145 |
+
"number": self.heads["number"](features),
|
| 146 |
+
"color": self.heads["color"](features),
|
| 147 |
+
"shape": self.heads["shape"](features),
|
| 148 |
+
"fill": self.heads["fill"](features),
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# === Training ===
|
| 153 |
+
|
| 154 |
+
def train_epoch(model, loader, optimizer, criterion, device):
|
| 155 |
+
model.train()
|
| 156 |
+
total_loss = 0
|
| 157 |
+
correct = {k: 0 for k in ["number", "color", "shape", "fill"]}
|
| 158 |
+
total = 0
|
| 159 |
+
|
| 160 |
+
for images, labels in tqdm(loader, desc="Training", leave=False):
|
| 161 |
+
images = images.to(device)
|
| 162 |
+
labels = labels.to(device)
|
| 163 |
+
|
| 164 |
+
optimizer.zero_grad()
|
| 165 |
+
outputs = model(images)
|
| 166 |
+
|
| 167 |
+
# Compute loss for each head (2x weight on fill to penalize fill mistakes)
|
| 168 |
+
loss = 0
|
| 169 |
+
fill_weight = 2.0
|
| 170 |
+
for i, key in enumerate(["number", "color", "shape", "fill"]):
|
| 171 |
+
head_loss = criterion(outputs[key], labels[:, i])
|
| 172 |
+
loss += fill_weight * head_loss if key == "fill" else head_loss
|
| 173 |
+
preds = outputs[key].argmax(dim=1)
|
| 174 |
+
correct[key] += (preds == labels[:, i]).sum().item()
|
| 175 |
+
|
| 176 |
+
loss.backward()
|
| 177 |
+
optimizer.step()
|
| 178 |
+
|
| 179 |
+
total_loss += loss.item()
|
| 180 |
+
total += labels.size(0)
|
| 181 |
+
|
| 182 |
+
avg_loss = total_loss / len(loader)
|
| 183 |
+
accuracies = {k: v / total for k, v in correct.items()}
|
| 184 |
+
return avg_loss, accuracies
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def evaluate(model, loader, criterion, device):
|
| 188 |
+
model.eval()
|
| 189 |
+
total_loss = 0
|
| 190 |
+
correct = {k: 0 for k in ["number", "color", "shape", "fill"]}
|
| 191 |
+
total = 0
|
| 192 |
+
|
| 193 |
+
with torch.no_grad():
|
| 194 |
+
for images, labels in tqdm(loader, desc="Evaluating", leave=False):
|
| 195 |
+
images = images.to(device)
|
| 196 |
+
labels = labels.to(device)
|
| 197 |
+
|
| 198 |
+
outputs = model(images)
|
| 199 |
+
|
| 200 |
+
loss = 0
|
| 201 |
+
for i, key in enumerate(["number", "color", "shape", "fill"]):
|
| 202 |
+
loss += criterion(outputs[key], labels[:, i])
|
| 203 |
+
preds = outputs[key].argmax(dim=1)
|
| 204 |
+
correct[key] += (preds == labels[:, i]).sum().item()
|
| 205 |
+
|
| 206 |
+
total_loss += loss.item()
|
| 207 |
+
total += labels.size(0)
|
| 208 |
+
|
| 209 |
+
avg_loss = total_loss / len(loader)
|
| 210 |
+
accuracies = {k: v / total for k, v in correct.items()}
|
| 211 |
+
return avg_loss, accuracies
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def main():
|
| 215 |
+
# === Hyperparameters ===
|
| 216 |
+
BATCH_SIZE = 32
|
| 217 |
+
EPOCHS = 50
|
| 218 |
+
LR = 1e-3
|
| 219 |
+
VAL_SPLIT = 0.15
|
| 220 |
+
TEST_SPLIT = 0.10
|
| 221 |
+
IMG_SIZE = 224
|
| 222 |
+
|
| 223 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
|
| 224 |
+
print(f"Using device: {device}")
|
| 225 |
+
|
| 226 |
+
# === Data transforms ===
|
| 227 |
+
train_transform = transforms.Compose([
|
| 228 |
+
transforms.RandomResizedCrop(IMG_SIZE, scale=(0.7, 1.0)), # Simulate imperfect detector crops
|
| 229 |
+
transforms.RandomHorizontalFlip(),
|
| 230 |
+
transforms.RandomVerticalFlip(),
|
| 231 |
+
transforms.RandomRotation(180), # Cards can be any orientation
|
| 232 |
+
transforms.RandomPerspective(distortion_scale=0.15, p=0.5), # Perspective warp from detection
|
| 233 |
+
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.05),
|
| 234 |
+
transforms.RandomGrayscale(p=0.05), # Force model to not rely solely on color for fill
|
| 235 |
+
transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)), # ~30% effective via random sigma
|
| 236 |
+
transforms.ToTensor(),
|
| 237 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 238 |
+
])
|
| 239 |
+
|
| 240 |
+
val_transform = transforms.Compose([
|
| 241 |
+
transforms.Resize((IMG_SIZE, IMG_SIZE)),
|
| 242 |
+
transforms.ToTensor(),
|
| 243 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 244 |
+
])
|
| 245 |
+
|
| 246 |
+
# === Load dataset (clean + synthetic crops) ===
|
| 247 |
+
data_dirs = [DATA_DIR]
|
| 248 |
+
if SYNTHETIC_DATA_DIR.exists():
|
| 249 |
+
data_dirs.append(SYNTHETIC_DATA_DIR)
|
| 250 |
+
full_dataset = SetCardDataset(data_dirs, transform=None) # No transform yet
|
| 251 |
+
|
| 252 |
+
# Split into train/val/test
|
| 253 |
+
total = len(full_dataset)
|
| 254 |
+
test_size = int(total * TEST_SPLIT)
|
| 255 |
+
val_size = int(total * VAL_SPLIT)
|
| 256 |
+
train_size = total - val_size - test_size
|
| 257 |
+
|
| 258 |
+
train_dataset, val_dataset, test_dataset = random_split(
|
| 259 |
+
full_dataset, [train_size, val_size, test_size],
|
| 260 |
+
generator=torch.Generator().manual_seed(42)
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
|
| 264 |
+
|
| 265 |
+
# Wrap with transform (can't change transform on Subset, so we wrap)
|
| 266 |
+
class TransformDataset(torch.utils.data.Dataset):
|
| 267 |
+
def __init__(self, subset, transform):
|
| 268 |
+
self.subset = subset
|
| 269 |
+
self.transform = transform
|
| 270 |
+
def __len__(self):
|
| 271 |
+
return len(self.subset)
|
| 272 |
+
def __getitem__(self, idx):
|
| 273 |
+
image, label = self.subset[idx]
|
| 274 |
+
if self.transform:
|
| 275 |
+
image = self.transform(image)
|
| 276 |
+
return image, label
|
| 277 |
+
|
| 278 |
+
train_dataset = TransformDataset(train_dataset, train_transform)
|
| 279 |
+
val_dataset = TransformDataset(val_dataset, val_transform)
|
| 280 |
+
test_dataset = TransformDataset(test_dataset, val_transform)
|
| 281 |
+
|
| 282 |
+
# Use num_workers=0 on macOS to avoid shared memory issues
|
| 283 |
+
import platform
|
| 284 |
+
num_workers = 0 if platform.system() == "Darwin" else 4
|
| 285 |
+
|
| 286 |
+
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers)
|
| 287 |
+
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers)
|
| 288 |
+
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers)
|
| 289 |
+
|
| 290 |
+
# === Model ===
|
| 291 |
+
model = SetCardClassifier(pretrained=True).to(device)
|
| 292 |
+
criterion = nn.CrossEntropyLoss()
|
| 293 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
|
| 294 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
|
| 295 |
+
|
| 296 |
+
# === Training loop ===
|
| 297 |
+
best_val_acc = 0
|
| 298 |
+
|
| 299 |
+
for epoch in range(EPOCHS):
|
| 300 |
+
train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
|
| 301 |
+
val_loss, val_acc = evaluate(model, val_loader, criterion, device)
|
| 302 |
+
scheduler.step()
|
| 303 |
+
|
| 304 |
+
# Average accuracy across all heads
|
| 305 |
+
avg_train_acc = sum(train_acc.values()) / 4
|
| 306 |
+
avg_val_acc = sum(val_acc.values()) / 4
|
| 307 |
+
|
| 308 |
+
print(f"Epoch {epoch+1}/{EPOCHS}")
|
| 309 |
+
print(f" Train Loss: {train_loss:.4f}, Acc: {avg_train_acc:.4f}")
|
| 310 |
+
print(f" Val Loss: {val_loss:.4f}, Acc: {avg_val_acc:.4f}")
|
| 311 |
+
print(f" Val per-head: num={val_acc['number']:.3f} col={val_acc['color']:.3f} "
|
| 312 |
+
f"shp={val_acc['shape']:.3f} fil={val_acc['fill']:.3f}")
|
| 313 |
+
|
| 314 |
+
# Save best model
|
| 315 |
+
if avg_val_acc > best_val_acc:
|
| 316 |
+
best_val_acc = avg_val_acc
|
| 317 |
+
torch.save({
|
| 318 |
+
"epoch": epoch,
|
| 319 |
+
"model_state_dict": model.state_dict(),
|
| 320 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 321 |
+
"val_acc": val_acc,
|
| 322 |
+
}, WEIGHTS_DIR / "classifier_best.pt")
|
| 323 |
+
print(f" Saved new best model (val_acc={avg_val_acc:.4f})")
|
| 324 |
+
|
| 325 |
+
# === Final evaluation on test set ===
|
| 326 |
+
print("\n" + "="*50)
|
| 327 |
+
print("Final Test Evaluation")
|
| 328 |
+
print("="*50)
|
| 329 |
+
|
| 330 |
+
# Load best model
|
| 331 |
+
checkpoint = torch.load(WEIGHTS_DIR / "classifier_best.pt")
|
| 332 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 333 |
+
|
| 334 |
+
test_loss, test_acc = evaluate(model, test_loader, criterion, device)
|
| 335 |
+
avg_test_acc = sum(test_acc.values()) / 4
|
| 336 |
+
|
| 337 |
+
print(f"Test Loss: {test_loss:.4f}")
|
| 338 |
+
print(f"Test Accuracy (avg): {avg_test_acc:.4f}")
|
| 339 |
+
print(f" Number: {test_acc['number']:.4f}")
|
| 340 |
+
print(f" Color: {test_acc['color']:.4f}")
|
| 341 |
+
print(f" Shape: {test_acc['shape']:.4f}")
|
| 342 |
+
print(f" Fill: {test_acc['fill']:.4f}")
|
| 343 |
+
|
| 344 |
+
# Save final results
|
| 345 |
+
results = {
|
| 346 |
+
"test_loss": test_loss,
|
| 347 |
+
"test_accuracy": test_acc,
|
| 348 |
+
"avg_test_accuracy": avg_test_acc,
|
| 349 |
+
"train_size": train_size,
|
| 350 |
+
"val_size": val_size,
|
| 351 |
+
"test_size": test_size,
|
| 352 |
+
}
|
| 353 |
+
with open(WEIGHTS_DIR / "training_results.json", "w") as f:
|
| 354 |
+
json.dump(results, f, indent=2)
|
| 355 |
+
|
| 356 |
+
print(f"\nModel saved to {WEIGHTS_DIR / 'classifier_best.pt'}")
|
| 357 |
+
print(f"Results saved to {WEIGHTS_DIR / 'training_results.json'}")
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
if __name__ == "__main__":
|
| 361 |
+
main()
|
src/web/__init__.py
ADDED
|
File without changes
|
src/web/app.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Web-based real-time Set solver.
|
| 3 |
+
|
| 4 |
+
FastAPI backend serving a single HTML page with live camera feed.
|
| 5 |
+
Processes frames via the SetSolver pipeline and returns annotated results.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import base64
|
| 9 |
+
import io
|
| 10 |
+
import sys
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
from fastapi import FastAPI, UploadFile, File
|
| 14 |
+
from fastapi.responses import HTMLResponse
|
| 15 |
+
from PIL import Image
|
| 16 |
+
|
| 17 |
+
# Add project root to path
|
| 18 |
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
| 19 |
+
|
| 20 |
+
from src.inference.solve import SetSolver
|
| 21 |
+
|
| 22 |
+
app = FastAPI(title="Set Solver")
|
| 23 |
+
|
| 24 |
+
# Global solver instance (loaded once at startup)
|
| 25 |
+
solver: SetSolver = None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@app.on_event("startup")
|
| 29 |
+
def load_solver():
|
| 30 |
+
global solver
|
| 31 |
+
print("Loading Set Solver pipeline...")
|
| 32 |
+
solver = SetSolver()
|
| 33 |
+
print("Solver ready!")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@app.get("/", response_class=HTMLResponse)
|
| 37 |
+
def index():
|
| 38 |
+
html_path = Path(__file__).parent / "templates" / "index.html"
|
| 39 |
+
return html_path.read_text()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@app.post("/api/solve")
|
| 43 |
+
async def solve_frame(file: UploadFile = File(...)):
|
| 44 |
+
"""Accept a JPEG frame, run solver, return results."""
|
| 45 |
+
contents = await file.read()
|
| 46 |
+
image = Image.open(io.BytesIO(contents)).convert("RGB")
|
| 47 |
+
|
| 48 |
+
result = solver.solve_from_image(image, conf=0.25)
|
| 49 |
+
|
| 50 |
+
# Encode per-set annotated images as base64 JPEG
|
| 51 |
+
result_images_b64 = []
|
| 52 |
+
for img in result.pop("result_images"):
|
| 53 |
+
buf = io.BytesIO()
|
| 54 |
+
img.save(buf, format="JPEG", quality=85)
|
| 55 |
+
result_images_b64.append(base64.b64encode(buf.getvalue()).decode("utf-8"))
|
| 56 |
+
result["result_images_b64"] = result_images_b64
|
| 57 |
+
|
| 58 |
+
# Crop cards per set for trophy display
|
| 59 |
+
per_set_cards_b64 = []
|
| 60 |
+
for bboxes in result.get("sets_bboxes", []):
|
| 61 |
+
crops = []
|
| 62 |
+
for bbox in bboxes:
|
| 63 |
+
x1, y1, x2, y2 = bbox
|
| 64 |
+
crop = image.crop((x1, y1, x2, y2))
|
| 65 |
+
cbuf = io.BytesIO()
|
| 66 |
+
crop.save(cbuf, format="JPEG", quality=90)
|
| 67 |
+
crops.append(base64.b64encode(cbuf.getvalue()).decode("utf-8"))
|
| 68 |
+
per_set_cards_b64.append(crops)
|
| 69 |
+
result["per_set_cards_b64"] = per_set_cards_b64
|
| 70 |
+
|
| 71 |
+
return result
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
import argparse
|
| 76 |
+
import subprocess
|
| 77 |
+
import tempfile
|
| 78 |
+
import uvicorn
|
| 79 |
+
|
| 80 |
+
parser = argparse.ArgumentParser(description="Set Solver web server")
|
| 81 |
+
parser.add_argument("--port", type=int, default=8000)
|
| 82 |
+
parser.add_argument("--no-ssl", action="store_true", help="Disable auto-generated SSL (camera requires HTTPS on non-localhost)")
|
| 83 |
+
args = parser.parse_args()
|
| 84 |
+
|
| 85 |
+
ssl_kwargs = {}
|
| 86 |
+
if not args.no_ssl:
|
| 87 |
+
# Generate a self-signed cert so mobile browsers allow camera access
|
| 88 |
+
cert_dir = Path(tempfile.mkdtemp())
|
| 89 |
+
cert_file = cert_dir / "cert.pem"
|
| 90 |
+
key_file = cert_dir / "key.pem"
|
| 91 |
+
subprocess.run([
|
| 92 |
+
"openssl", "req", "-x509", "-newkey", "rsa:2048",
|
| 93 |
+
"-keyout", str(key_file), "-out", str(cert_file),
|
| 94 |
+
"-days", "1", "-nodes",
|
| 95 |
+
"-subj", "/CN=set-solver",
|
| 96 |
+
], check=True, capture_output=True)
|
| 97 |
+
ssl_kwargs = {"ssl_certfile": str(cert_file), "ssl_keyfile": str(key_file)}
|
| 98 |
+
proto = "https"
|
| 99 |
+
else:
|
| 100 |
+
proto = "http"
|
| 101 |
+
|
| 102 |
+
# Show access URLs
|
| 103 |
+
import socket
|
| 104 |
+
hostname = socket.gethostname()
|
| 105 |
+
try:
|
| 106 |
+
local_ip = socket.gethostbyname(hostname)
|
| 107 |
+
except socket.gaierror:
|
| 108 |
+
local_ip = "127.0.0.1"
|
| 109 |
+
print(f"\n Set Solver running at:")
|
| 110 |
+
print(f" Local: {proto}://localhost:{args.port}")
|
| 111 |
+
print(f" Network: {proto}://{local_ip}:{args.port}\n")
|
| 112 |
+
|
| 113 |
+
uvicorn.run("src.web.app:app", host="0.0.0.0", port=args.port, reload=False, **ssl_kwargs)
|
src/web/templates/index.html
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no">
|
| 6 |
+
<title>Set Solver</title>
|
| 7 |
+
<style>
|
| 8 |
+
* { margin: 0; padding: 0; box-sizing: border-box; }
|
| 9 |
+
body {
|
| 10 |
+
background: #000;
|
| 11 |
+
color: #fff;
|
| 12 |
+
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
|
| 13 |
+
overflow: hidden;
|
| 14 |
+
height: 100dvh;
|
| 15 |
+
width: 100vw;
|
| 16 |
+
display: flex;
|
| 17 |
+
flex-direction: column;
|
| 18 |
+
}
|
| 19 |
+
#trophy {
|
| 20 |
+
display: none;
|
| 21 |
+
flex-direction: row;
|
| 22 |
+
justify-content: center;
|
| 23 |
+
align-items: center;
|
| 24 |
+
gap: 6px;
|
| 25 |
+
padding: 6px;
|
| 26 |
+
background: #111;
|
| 27 |
+
flex-shrink: 0;
|
| 28 |
+
}
|
| 29 |
+
#trophy.active { display: flex; }
|
| 30 |
+
#trophy img {
|
| 31 |
+
height: 60px;
|
| 32 |
+
max-width: 30vw;
|
| 33 |
+
border-radius: 4px;
|
| 34 |
+
border: 2px solid #4f4;
|
| 35 |
+
object-fit: contain;
|
| 36 |
+
}
|
| 37 |
+
#camera-container {
|
| 38 |
+
position: relative;
|
| 39 |
+
flex: 1;
|
| 40 |
+
display: flex;
|
| 41 |
+
align-items: center;
|
| 42 |
+
justify-content: center;
|
| 43 |
+
overflow: hidden;
|
| 44 |
+
}
|
| 45 |
+
video, #result-img {
|
| 46 |
+
max-width: 100%;
|
| 47 |
+
max-height: 100%;
|
| 48 |
+
object-fit: contain;
|
| 49 |
+
}
|
| 50 |
+
#result-img { display: none; }
|
| 51 |
+
#bottom-bar {
|
| 52 |
+
position: absolute;
|
| 53 |
+
bottom: 0; left: 0; right: 0;
|
| 54 |
+
display: flex;
|
| 55 |
+
flex-direction: column;
|
| 56 |
+
align-items: center;
|
| 57 |
+
padding-bottom: 16px;
|
| 58 |
+
z-index: 15;
|
| 59 |
+
pointer-events: none;
|
| 60 |
+
}
|
| 61 |
+
#set-nav {
|
| 62 |
+
display: none;
|
| 63 |
+
align-items: center;
|
| 64 |
+
gap: 12px;
|
| 65 |
+
margin-bottom: 10px;
|
| 66 |
+
pointer-events: auto;
|
| 67 |
+
}
|
| 68 |
+
#set-nav.active { display: flex; }
|
| 69 |
+
#set-nav .nav-arrow {
|
| 70 |
+
background: rgba(255,255,255,0.2);
|
| 71 |
+
border: none;
|
| 72 |
+
color: #fff;
|
| 73 |
+
font-size: 22px;
|
| 74 |
+
width: 40px; height: 40px;
|
| 75 |
+
border-radius: 50%;
|
| 76 |
+
cursor: pointer;
|
| 77 |
+
display: flex;
|
| 78 |
+
align-items: center;
|
| 79 |
+
justify-content: center;
|
| 80 |
+
}
|
| 81 |
+
#set-nav .nav-arrow:active { background: rgba(255,255,255,0.4); }
|
| 82 |
+
#set-label {
|
| 83 |
+
font-size: 14px;
|
| 84 |
+
color: #ccc;
|
| 85 |
+
min-width: 100px;
|
| 86 |
+
text-align: center;
|
| 87 |
+
}
|
| 88 |
+
#scan-btn {
|
| 89 |
+
border: none;
|
| 90 |
+
border-radius: 28px;
|
| 91 |
+
padding: 14px 48px;
|
| 92 |
+
font-size: 18px;
|
| 93 |
+
font-weight: 600;
|
| 94 |
+
cursor: pointer;
|
| 95 |
+
transition: background 0.2s;
|
| 96 |
+
pointer-events: auto;
|
| 97 |
+
}
|
| 98 |
+
#scan-btn.start {
|
| 99 |
+
background: #4f4;
|
| 100 |
+
color: #000;
|
| 101 |
+
}
|
| 102 |
+
#scan-btn.stop {
|
| 103 |
+
background: #f44;
|
| 104 |
+
color: #fff;
|
| 105 |
+
}
|
| 106 |
+
#scan-btn.restart {
|
| 107 |
+
background: #ff0;
|
| 108 |
+
color: #000;
|
| 109 |
+
}
|
| 110 |
+
#scan-btn:active { opacity: 0.7; }
|
| 111 |
+
#status-bar {
|
| 112 |
+
position: absolute;
|
| 113 |
+
top: 8px; left: 8px;
|
| 114 |
+
background: rgba(0,0,0,0.6);
|
| 115 |
+
border-radius: 8px;
|
| 116 |
+
padding: 4px 10px;
|
| 117 |
+
font-size: 13px;
|
| 118 |
+
z-index: 5;
|
| 119 |
+
}
|
| 120 |
+
#status-bar .dot {
|
| 121 |
+
display: inline-block;
|
| 122 |
+
width: 8px; height: 8px;
|
| 123 |
+
border-radius: 50%;
|
| 124 |
+
margin-right: 6px;
|
| 125 |
+
vertical-align: middle;
|
| 126 |
+
}
|
| 127 |
+
.dot.active { background: #4f4; }
|
| 128 |
+
.dot.inactive { background: #f44; }
|
| 129 |
+
.dot.processing { background: #ff4; }
|
| 130 |
+
.dot.idle { background: #888; }
|
| 131 |
+
</style>
|
| 132 |
+
</head>
|
| 133 |
+
<body>
|
| 134 |
+
<div id="trophy"></div>
|
| 135 |
+
<div id="camera-container">
|
| 136 |
+
<video id="video" autoplay playsinline muted></video>
|
| 137 |
+
<img id="result-img" alt="Result">
|
| 138 |
+
<div id="status-bar">
|
| 139 |
+
<span class="dot inactive" id="status-dot"></span>
|
| 140 |
+
<span id="status-text">Starting camera...</span>
|
| 141 |
+
</div>
|
| 142 |
+
<div id="bottom-bar">
|
| 143 |
+
<div id="set-nav">
|
| 144 |
+
<button class="nav-arrow" id="prev-btn">←</button>
|
| 145 |
+
<span id="set-label"></span>
|
| 146 |
+
<button class="nav-arrow" id="next-btn">→</button>
|
| 147 |
+
</div>
|
| 148 |
+
<button id="scan-btn" class="start">Start</button>
|
| 149 |
+
</div>
|
| 150 |
+
</div>
|
| 151 |
+
|
| 152 |
+
<canvas id="capture-canvas" style="display:none;"></canvas>
|
| 153 |
+
|
| 154 |
+
<script>
|
| 155 |
+
const video = document.getElementById('video');
|
| 156 |
+
const resultImg = document.getElementById('result-img');
|
| 157 |
+
const trophy = document.getElementById('trophy');
|
| 158 |
+
const setNav = document.getElementById('set-nav');
|
| 159 |
+
const setLabel = document.getElementById('set-label');
|
| 160 |
+
const prevBtn = document.getElementById('prev-btn');
|
| 161 |
+
const nextBtn = document.getElementById('next-btn');
|
| 162 |
+
const scanBtn = document.getElementById('scan-btn');
|
| 163 |
+
const statusDot = document.getElementById('status-dot');
|
| 164 |
+
const statusText = document.getElementById('status-text');
|
| 165 |
+
const canvas = document.getElementById('capture-canvas');
|
| 166 |
+
|
| 167 |
+
let stream = null;
|
| 168 |
+
let scanning = false;
|
| 169 |
+
let processing = false;
|
| 170 |
+
let frozen = false; // true when showing results
|
| 171 |
+
let loopTimer = null;
|
| 172 |
+
|
| 173 |
+
// Result state for cycling through sets
|
| 174 |
+
let resultData = null;
|
| 175 |
+
let currentSetIdx = 0;
|
| 176 |
+
|
| 177 |
+
async function startCamera() {
|
| 178 |
+
if (!navigator.mediaDevices || !navigator.mediaDevices.getUserMedia) {
|
| 179 |
+
statusDot.className = 'dot inactive';
|
| 180 |
+
statusText.textContent = 'Camera API unavailable — use https://';
|
| 181 |
+
console.error('mediaDevices not available. Page must be served over HTTPS (or localhost).');
|
| 182 |
+
return;
|
| 183 |
+
}
|
| 184 |
+
try {
|
| 185 |
+
stream = await navigator.mediaDevices.getUserMedia({
|
| 186 |
+
video: { facingMode: 'environment', width: { ideal: 1280 }, height: { ideal: 720 } },
|
| 187 |
+
audio: false,
|
| 188 |
+
});
|
| 189 |
+
for (const track of stream.getVideoTracks()) {
|
| 190 |
+
const caps = track.getCapabilities?.() || {};
|
| 191 |
+
const settings = {};
|
| 192 |
+
if ('backgroundBlur' in caps) settings.backgroundBlur = false;
|
| 193 |
+
if ('faceFraming' in caps) settings.faceFraming = false;
|
| 194 |
+
if ('pan' in caps) settings.pan = track.getSettings().pan;
|
| 195 |
+
if ('tilt' in caps) settings.tilt = track.getSettings().tilt;
|
| 196 |
+
if ('zoom' in caps) settings.zoom = track.getSettings().zoom;
|
| 197 |
+
if (Object.keys(settings).length > 0) {
|
| 198 |
+
try { await track.applyConstraints({ advanced: [settings] }); } catch (e) { /* ignore */ }
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
video.srcObject = stream;
|
| 202 |
+
await video.play();
|
| 203 |
+
statusDot.className = 'dot idle';
|
| 204 |
+
statusText.textContent = 'Ready — press Start';
|
| 205 |
+
} catch (err) {
|
| 206 |
+
statusDot.className = 'dot inactive';
|
| 207 |
+
statusText.textContent = 'Camera access denied — check browser permissions';
|
| 208 |
+
console.error('Camera error:', err);
|
| 209 |
+
}
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
function restart() {
|
| 213 |
+
// Go from frozen results back to live camera (not scanning yet)
|
| 214 |
+
frozen = false;
|
| 215 |
+
scanning = false;
|
| 216 |
+
resultData = null;
|
| 217 |
+
currentSetIdx = 0;
|
| 218 |
+
trophy.classList.remove('active');
|
| 219 |
+
trophy.innerHTML = '';
|
| 220 |
+
setNav.classList.remove('active');
|
| 221 |
+
resultImg.style.display = 'none';
|
| 222 |
+
video.style.display = 'block';
|
| 223 |
+
|
| 224 |
+
scanBtn.textContent = 'Start';
|
| 225 |
+
scanBtn.className = 'start';
|
| 226 |
+
statusDot.className = 'dot idle';
|
| 227 |
+
statusText.textContent = 'Ready — press Start';
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
function startScanning() {
|
| 231 |
+
scanning = true;
|
| 232 |
+
scanBtn.textContent = 'Stop';
|
| 233 |
+
scanBtn.className = 'stop';
|
| 234 |
+
statusDot.className = 'dot active';
|
| 235 |
+
statusText.textContent = 'Scanning...';
|
| 236 |
+
|
| 237 |
+
if (loopTimer) clearInterval(loopTimer);
|
| 238 |
+
loopTimer = setInterval(() => {
|
| 239 |
+
if (scanning && !processing) captureAndSolve();
|
| 240 |
+
}, 333);
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
function stopScanning() {
|
| 244 |
+
scanning = false;
|
| 245 |
+
if (loopTimer) { clearInterval(loopTimer); loopTimer = null; }
|
| 246 |
+
scanBtn.textContent = 'Start';
|
| 247 |
+
scanBtn.className = 'start';
|
| 248 |
+
statusDot.className = 'dot idle';
|
| 249 |
+
statusText.textContent = 'Stopped';
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
async function captureAndSolve() {
|
| 253 |
+
if (!scanning || processing) return;
|
| 254 |
+
processing = true;
|
| 255 |
+
statusDot.className = 'dot processing';
|
| 256 |
+
|
| 257 |
+
try {
|
| 258 |
+
canvas.width = video.videoWidth;
|
| 259 |
+
canvas.height = video.videoHeight;
|
| 260 |
+
const ctx = canvas.getContext('2d');
|
| 261 |
+
ctx.drawImage(video, 0, 0);
|
| 262 |
+
|
| 263 |
+
const blob = await new Promise(resolve => canvas.toBlob(resolve, 'image/jpeg', 0.8));
|
| 264 |
+
const formData = new FormData();
|
| 265 |
+
formData.append('file', blob, 'frame.jpg');
|
| 266 |
+
|
| 267 |
+
const resp = await fetch('/api/solve', { method: 'POST', body: formData });
|
| 268 |
+
if (!resp.ok) throw new Error(`HTTP ${resp.status}`);
|
| 269 |
+
const data = await resp.json();
|
| 270 |
+
|
| 271 |
+
if (!scanning) return;
|
| 272 |
+
|
| 273 |
+
statusText.textContent = `${data.num_cards} cards`;
|
| 274 |
+
statusDot.className = 'dot active';
|
| 275 |
+
|
| 276 |
+
if (data.num_sets > 0) {
|
| 277 |
+
showResult(data);
|
| 278 |
+
}
|
| 279 |
+
} catch (err) {
|
| 280 |
+
console.error('Solve error:', err);
|
| 281 |
+
if (scanning) statusDot.className = 'dot active';
|
| 282 |
+
} finally {
|
| 283 |
+
processing = false;
|
| 284 |
+
}
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
function showResult(data) {
|
| 288 |
+
scanning = false;
|
| 289 |
+
frozen = true;
|
| 290 |
+
if (loopTimer) { clearInterval(loopTimer); loopTimer = null; }
|
| 291 |
+
|
| 292 |
+
resultData = data;
|
| 293 |
+
currentSetIdx = 0;
|
| 294 |
+
|
| 295 |
+
video.style.display = 'none';
|
| 296 |
+
resultImg.style.display = 'block';
|
| 297 |
+
|
| 298 |
+
// Show nav if multiple sets
|
| 299 |
+
if (data.num_sets > 1) {
|
| 300 |
+
setNav.classList.add('active');
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
showCurrentSet();
|
| 304 |
+
|
| 305 |
+
scanBtn.textContent = 'Restart';
|
| 306 |
+
scanBtn.className = 'restart';
|
| 307 |
+
statusDot.className = 'dot active';
|
| 308 |
+
statusText.textContent = `Found ${data.num_sets} Set${data.num_sets > 1 ? 's' : ''}!`;
|
| 309 |
+
|
| 310 |
+
speak('Set!');
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
function showCurrentSet() {
|
| 314 |
+
if (!resultData) return;
|
| 315 |
+
const data = resultData;
|
| 316 |
+
const i = currentSetIdx;
|
| 317 |
+
|
| 318 |
+
// Show annotated image for this set
|
| 319 |
+
resultImg.src = 'data:image/jpeg;base64,' + data.result_images_b64[i];
|
| 320 |
+
|
| 321 |
+
// Show trophy cards for this set
|
| 322 |
+
const cards = data.per_set_cards_b64[i];
|
| 323 |
+
if (cards && cards.length === 3) {
|
| 324 |
+
trophy.innerHTML = cards
|
| 325 |
+
.map(b64 => `<img src="data:image/jpeg;base64,${b64}">`)
|
| 326 |
+
.join('');
|
| 327 |
+
trophy.classList.add('active');
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
// Update nav label
|
| 331 |
+
setLabel.textContent = `Set ${i + 1} / ${data.num_sets}`;
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
function prevSet() {
|
| 335 |
+
if (!resultData || resultData.num_sets <= 1) return;
|
| 336 |
+
currentSetIdx = (currentSetIdx - 1 + resultData.num_sets) % resultData.num_sets;
|
| 337 |
+
showCurrentSet();
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
function nextSet() {
|
| 341 |
+
if (!resultData || resultData.num_sets <= 1) return;
|
| 342 |
+
currentSetIdx = (currentSetIdx + 1) % resultData.num_sets;
|
| 343 |
+
showCurrentSet();
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
function speak(text) {
|
| 347 |
+
if ('speechSynthesis' in window) {
|
| 348 |
+
const utter = new SpeechSynthesisUtterance(text);
|
| 349 |
+
utter.rate = 1.2;
|
| 350 |
+
utter.pitch = 1.1;
|
| 351 |
+
speechSynthesis.speak(utter);
|
| 352 |
+
}
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
scanBtn.addEventListener('click', () => {
|
| 356 |
+
if (frozen) {
|
| 357 |
+
restart();
|
| 358 |
+
} else if (scanning) {
|
| 359 |
+
stopScanning();
|
| 360 |
+
} else {
|
| 361 |
+
startScanning();
|
| 362 |
+
}
|
| 363 |
+
});
|
| 364 |
+
prevBtn.addEventListener('click', prevSet);
|
| 365 |
+
nextBtn.addEventListener('click', nextSet);
|
| 366 |
+
|
| 367 |
+
document.addEventListener('keydown', e => {
|
| 368 |
+
if (e.key === ' ') {
|
| 369 |
+
e.preventDefault();
|
| 370 |
+
if (frozen) restart();
|
| 371 |
+
else if (scanning) stopScanning();
|
| 372 |
+
else startScanning();
|
| 373 |
+
} else if (e.key === 'ArrowLeft') {
|
| 374 |
+
prevSet();
|
| 375 |
+
} else if (e.key === 'ArrowRight') {
|
| 376 |
+
nextSet();
|
| 377 |
+
}
|
| 378 |
+
});
|
| 379 |
+
|
| 380 |
+
startCamera();
|
| 381 |
+
</script>
|
| 382 |
+
</body>
|
| 383 |
+
</html>
|
weights/classifier_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a0c464367eccfcfd6599377c9af35f72cd23c524b01eda7e9a11ccb1e3ba3f6d
|
| 3 |
+
size 11465795
|
weights/detector/weights/best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d65deae13124271df8739b700d2f893bca1eb7a7bc8ac870702e714b787ceee7
|
| 3 |
+
size 5453594
|