Tian Wang commited on
Commit
8a34385
·
1 Parent(s): c3f6e96

Deploy Set Solver web app

Browse files
.dockerignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .venv/
2
+ .git/
3
+ __pycache__/
4
+ *.pyc
5
+ data/
6
+ training_images/
7
+ docs/
8
+ scripts/
9
+ *.ipynb
10
+ .DS_Store
11
+ showcase.html
12
+
13
+ # Training artifacts in weights/detector (keep only weights/*.pt)
14
+ weights/detector/*.png
15
+ weights/detector/*.jpg
16
+ weights/detector/*.csv
17
+ weights/detector/weights/last.pt
18
+ weights/detector/weights/best.onnx
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system deps for opencv
6
+ RUN apt-get update && apt-get install -y --no-install-recommends \
7
+ libgl1 libglib2.0-0 \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # Install Python deps (CPU-only torch)
11
+ COPY requirements-web.txt .
12
+ RUN pip install --no-cache-dir -r requirements-web.txt
13
+
14
+ # Copy application code and weights
15
+ COPY src/ src/
16
+ COPY weights/ weights/
17
+
18
+ # Hugging Face Spaces uses port 7860
19
+ EXPOSE 7860
20
+
21
+ CMD ["uvicorn", "src.web.app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,15 @@
1
  ---
2
  title: Set Solver
3
- emoji: 📉
4
- colorFrom: purple
5
- colorTo: purple
6
  sdk: docker
 
7
  pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
1
  ---
2
  title: Set Solver
3
+ emoji: 🃏
4
+ colorFrom: green
5
+ colorTo: blue
6
  sdk: docker
7
+ app_port: 7860
8
  pinned: false
9
  ---
10
 
11
+ # Set Solver
12
+
13
+ Vision-based solver for the [Set card game](https://www.setgame.com/).
14
+
15
+ Point your camera at Set cards → Get all valid Sets highlighted in real time.
requirements-web.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Web deployment only (CPU inference)
2
+ --extra-index-url https://download.pytorch.org/whl/cpu
3
+ torch
4
+ torchvision
5
+ ultralytics>=8.0
6
+ pillow>=10.0
7
+ numpy>=1.24
8
+ opencv-python-headless>=4.8
9
+ fastapi
10
+ uvicorn[standard]
11
+ python-multipart
src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Set Solver
src/inference/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Inference scripts
src/inference/classify.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference script for classifying a single card image.
3
+ """
4
+
5
+ import torch
6
+ from torchvision import transforms
7
+ from PIL import Image
8
+ from pathlib import Path
9
+
10
+ # Import from training module
11
+ import sys
12
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
13
+ from src.train.classifier import (
14
+ SetCardClassifier,
15
+ NUMBER_NAMES, COLOR_NAMES, SHAPE_NAMES, FILL_NAMES
16
+ )
17
+
18
+ WEIGHTS_DIR = Path(__file__).parent.parent.parent / "weights"
19
+
20
+
21
+ def load_model(weights_path: Path = None, device: str = None):
22
+ """Load trained classifier."""
23
+ if weights_path is None:
24
+ weights_path = WEIGHTS_DIR / "classifier_best.pt"
25
+
26
+ if device is None:
27
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
28
+
29
+ model = SetCardClassifier(pretrained=False)
30
+ checkpoint = torch.load(weights_path, map_location=device)
31
+ model.load_state_dict(checkpoint["model_state_dict"])
32
+ model.to(device)
33
+ model.eval()
34
+
35
+ return model, device
36
+
37
+
38
+ def classify_card(image: Image.Image, model, device) -> dict:
39
+ """
40
+ Classify a card image.
41
+
42
+ Returns dict with predicted attributes and confidences.
43
+ """
44
+ transform = transforms.Compose([
45
+ transforms.Resize((224, 224)),
46
+ transforms.ToTensor(),
47
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
48
+ ])
49
+
50
+ img_tensor = transform(image).unsqueeze(0).to(device)
51
+
52
+ with torch.no_grad():
53
+ outputs = model(img_tensor)
54
+
55
+ # Get predictions and confidences
56
+ result = {}
57
+ for key, names in [
58
+ ("number", NUMBER_NAMES),
59
+ ("color", COLOR_NAMES),
60
+ ("shape", SHAPE_NAMES),
61
+ ("fill", FILL_NAMES),
62
+ ]:
63
+ probs = torch.softmax(outputs[key], dim=1)[0]
64
+ pred_idx = probs.argmax().item()
65
+ result[key] = {
66
+ "value": names[pred_idx],
67
+ "confidence": probs[pred_idx].item(),
68
+ "all_probs": {name: probs[i].item() for i, name in enumerate(names)},
69
+ }
70
+
71
+ return result
72
+
73
+
74
+ def main():
75
+ import argparse
76
+
77
+ parser = argparse.ArgumentParser(description="Classify a Set card image")
78
+ parser.add_argument("image", type=str, help="Path to card image")
79
+ args = parser.parse_args()
80
+
81
+ print("Loading model...")
82
+ model, device = load_model()
83
+
84
+ print(f"Classifying {args.image}...")
85
+ image = Image.open(args.image).convert("RGB")
86
+ result = classify_card(image, model, device)
87
+
88
+ print("\nPrediction:")
89
+ print(f" Number: {result['number']['value']} ({result['number']['confidence']:.1%})")
90
+ print(f" Color: {result['color']['value']} ({result['color']['confidence']:.1%})")
91
+ print(f" Shape: {result['shape']['value']} ({result['shape']['confidence']:.1%})")
92
+ print(f" Fill: {result['fill']['value']} ({result['fill']['confidence']:.1%})")
93
+
94
+ # Human-readable card name
95
+ n = result['number']['value']
96
+ c = result['color']['value']
97
+ s = result['shape']['value']
98
+ f = result['fill']['value']
99
+ print(f"\nCard: {n} {f} {c} {s}(s)")
100
+
101
+
102
+ if __name__ == "__main__":
103
+ main()
src/inference/solve.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ End-to-end Set solver pipeline.
3
+
4
+ Photo → Detect cards → Classify each → Find Sets → Visualize
5
+ """
6
+
7
+ import sys
8
+ from pathlib import Path
9
+ from typing import List, Tuple, Optional
10
+
11
+ import torch
12
+ from PIL import Image, ImageDraw, ImageFont
13
+ from ultralytics import YOLO
14
+ import numpy as np
15
+
16
+ # Add parent to path for imports
17
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
18
+
19
+ from src.train.classifier import (
20
+ SetCardClassifier,
21
+ NUMBER_NAMES, COLOR_NAMES, SHAPE_NAMES, FILL_NAMES,
22
+ )
23
+ from src.solver.set_finder import Card, Shape, Color, Number, Fill, find_all_sets
24
+
25
+
26
+ WEIGHTS_DIR = Path(__file__).parent.parent.parent / "weights"
27
+ DATA_WEIGHTS_DIR = Path.home() / "data" / "set-solver" / "weights"
28
+
29
+ # Chinese shorthand names: {1,2,3}-{实,空,线}-{红,绿,紫}-{菱,圆,弯}
30
+ CHINESE_NUMBER = {"one": "1", "two": "2", "three": "3"}
31
+ CHINESE_FILL = {"full": "实", "empty": "空", "partial": "线"}
32
+ CHINESE_COLOR = {"red": "红", "green": "绿", "blue": "紫"}
33
+ CHINESE_SHAPE = {"diamond": "菱", "oval": "圆", "squiggle": "弯"}
34
+
35
+
36
+ def card_to_chinese(attrs: dict) -> str:
37
+ """Convert card attributes to Chinese shorthand like '2实红菱'."""
38
+ num = CHINESE_NUMBER.get(attrs['number'], attrs['number'])
39
+ fill = CHINESE_FILL.get(attrs['fill'], attrs['fill'])
40
+ color = CHINESE_COLOR.get(attrs['color'], attrs['color'])
41
+ shape = CHINESE_SHAPE.get(attrs['shape'], attrs['shape'])
42
+ return f"{num}{fill}{color}{shape}"
43
+
44
+ # Colors for highlighting Sets (RGB)
45
+ SET_COLORS = [
46
+ (255, 0, 0), # Red
47
+ (0, 255, 0), # Green
48
+ (0, 0, 255), # Blue
49
+ (255, 255, 0), # Yellow
50
+ (255, 0, 255), # Magenta
51
+ (0, 255, 255), # Cyan
52
+ (255, 128, 0), # Orange
53
+ (128, 0, 255), # Purple
54
+ ]
55
+
56
+
57
+ class SetSolver:
58
+ """End-to-end Set solver."""
59
+
60
+ def __init__(
61
+ self,
62
+ detector_path: Optional[Path] = None,
63
+ classifier_path: Optional[Path] = None,
64
+ device: Optional[str] = None,
65
+ ):
66
+ if device is None:
67
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
68
+ self.device = device
69
+
70
+ # Load detector
71
+ if detector_path is None:
72
+ # Check ~/data first, then repo weights
73
+ data_path = DATA_WEIGHTS_DIR / "detector" / "weights" / "best.pt"
74
+ repo_path = WEIGHTS_DIR / "detector" / "weights" / "best.pt"
75
+ detector_path = data_path if data_path.exists() else repo_path
76
+ print(f"Loading detector from {detector_path}")
77
+ self.detector = YOLO(str(detector_path))
78
+
79
+ # Load classifier
80
+ if classifier_path is None:
81
+ classifier_path = WEIGHTS_DIR / "classifier_best.pt"
82
+ print(f"Loading classifier from {classifier_path}")
83
+ self.classifier = SetCardClassifier(pretrained=False)
84
+ checkpoint = torch.load(classifier_path, map_location=device)
85
+ self.classifier.load_state_dict(checkpoint["model_state_dict"])
86
+ self.classifier.to(device)
87
+ self.classifier.eval()
88
+
89
+ # Classifier transform
90
+ from torchvision import transforms
91
+ self.transform = transforms.Compose([
92
+ transforms.Resize((224, 224)),
93
+ transforms.ToTensor(),
94
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
95
+ ])
96
+
97
+ def detect_cards(self, image: Image.Image, conf: float = 0.5) -> List[dict]:
98
+ """
99
+ Detect cards in image.
100
+
101
+ Returns list of detections with bounding boxes.
102
+ Filters out oversized detections that likely merged two cards.
103
+ """
104
+ results = self.detector(image, conf=conf, verbose=False)
105
+
106
+ detections = []
107
+ for result in results:
108
+ boxes = result.boxes
109
+ for box in boxes:
110
+ x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
111
+ c = box.conf[0].cpu().item()
112
+ w, h = x2 - x1, y2 - y1
113
+ detections.append({
114
+ "bbox": (int(x1), int(y1), int(x2), int(y2)),
115
+ "confidence": c,
116
+ "area": w * h,
117
+ })
118
+
119
+ # Filter out merged detections: if a box is >2x the median area,
120
+ # it's likely covering two cards
121
+ if len(detections) >= 3:
122
+ areas = sorted(d["area"] for d in detections)
123
+ median_area = areas[len(areas) // 2]
124
+ detections = [d for d in detections if d["area"] <= median_area * 2.2]
125
+
126
+ return detections
127
+
128
+ def classify_card(self, card_image: Image.Image) -> dict:
129
+ """Classify a cropped card image."""
130
+ img_tensor = self.transform(card_image).unsqueeze(0).to(self.device)
131
+
132
+ with torch.no_grad():
133
+ outputs = self.classifier(img_tensor)
134
+
135
+ result = {}
136
+ for key, names in [
137
+ ("number", NUMBER_NAMES),
138
+ ("color", COLOR_NAMES),
139
+ ("shape", SHAPE_NAMES),
140
+ ("fill", FILL_NAMES),
141
+ ]:
142
+ probs = torch.softmax(outputs[key], dim=1)[0]
143
+ pred_idx = probs.argmax().item()
144
+ result[key] = names[pred_idx]
145
+ result[f"{key}_conf"] = probs[pred_idx].item()
146
+
147
+ return result
148
+
149
+ def detection_to_card(self, attrs: dict, bbox: Tuple[int, int, int, int]) -> Card:
150
+ """Convert classification result to Card object."""
151
+ # Map classifier output to solver enums
152
+ # Training data uses "blue" but standard Set calls it "purple"
153
+ color_map = {"red": "RED", "green": "GREEN", "blue": "PURPLE"}
154
+ # Training data uses "partial" for striped, "full" for solid
155
+ fill_map = {"empty": "EMPTY", "full": "SOLID", "partial": "STRIPED"}
156
+
157
+ return Card(
158
+ shape=Shape[attrs["shape"].upper()],
159
+ color=Color[color_map[attrs["color"]]],
160
+ number=Number[attrs["number"].upper()],
161
+ fill=Fill[fill_map[attrs["fill"]]],
162
+ bbox=bbox,
163
+ )
164
+
165
+ def solve_from_image(
166
+ self,
167
+ image: Image.Image,
168
+ conf: float = 0.5,
169
+ ) -> dict:
170
+ """
171
+ Solve a Set game from a PIL Image directly.
172
+
173
+ Args:
174
+ image: PIL Image (RGB)
175
+ conf: Detection confidence threshold
176
+
177
+ Returns:
178
+ Dict with detected cards, found Sets, and annotated result image
179
+ """
180
+ image = image.convert("RGB")
181
+
182
+ detections = self.detect_cards(image, conf=conf)
183
+
184
+ cards = []
185
+ for det in detections:
186
+ x1, y1, x2, y2 = det["bbox"]
187
+ card_crop = image.crop((x1, y1, x2, y2))
188
+ attrs = self.classify_card(card_crop)
189
+ card = self.detection_to_card(attrs, det["bbox"])
190
+ cards.append({
191
+ "card": card,
192
+ "attrs": attrs,
193
+ "detection": det,
194
+ })
195
+
196
+ card_objects = [c["card"] for c in cards]
197
+ sets = find_all_sets(card_objects)
198
+
199
+ # Generate one annotated image per set (each highlighting only that set)
200
+ result_images = []
201
+ if sets:
202
+ for i in range(len(sets)):
203
+ result_images.append(self._draw_results(image, cards, sets, highlight_idx=i))
204
+ else:
205
+ result_images.append(self._draw_results(image, cards, sets))
206
+
207
+ return {
208
+ "num_cards": len(cards),
209
+ "cards": [
210
+ {
211
+ "attrs": c["attrs"],
212
+ "chinese": card_to_chinese(c["attrs"]),
213
+ "bbox": c["detection"]["bbox"],
214
+ "confidence": c["detection"]["confidence"],
215
+ }
216
+ for c in cards
217
+ ],
218
+ "num_sets": len(sets),
219
+ "sets": [
220
+ [str(card) for card in s]
221
+ for s in sets
222
+ ],
223
+ "sets_chinese": [
224
+ [card_to_chinese(next(c["attrs"] for c in cards if c["card"] is card)) for card in s]
225
+ for s in sets
226
+ ],
227
+ "sets_bboxes": [
228
+ [card.bbox for card in s]
229
+ for s in sets
230
+ ],
231
+ "result_images": result_images,
232
+ }
233
+
234
+ def solve(
235
+ self,
236
+ image_path: str,
237
+ conf: float = 0.5,
238
+ output_path: Optional[str] = None,
239
+ show: bool = False,
240
+ ) -> dict:
241
+ """
242
+ Solve a Set game from image.
243
+
244
+ Args:
245
+ image_path: Path to input image
246
+ conf: Detection confidence threshold
247
+ output_path: Path to save annotated output image
248
+ show: Whether to display the result
249
+
250
+ Returns:
251
+ Dict with detected cards and found Sets
252
+ """
253
+ # Load image
254
+ image = Image.open(image_path).convert("RGB")
255
+ print(f"Loaded image: {image.size}")
256
+
257
+ # Detect cards
258
+ print("Detecting cards...")
259
+ detections = self.detect_cards(image, conf=conf)
260
+ print(f"Found {len(detections)} cards")
261
+
262
+ # Classify each card
263
+ print("Classifying cards...")
264
+ cards = []
265
+ for det in detections:
266
+ x1, y1, x2, y2 = det["bbox"]
267
+ card_crop = image.crop((x1, y1, x2, y2))
268
+ attrs = self.classify_card(card_crop)
269
+ card = self.detection_to_card(attrs, det["bbox"])
270
+ cards.append({
271
+ "card": card,
272
+ "attrs": attrs,
273
+ "detection": det,
274
+ })
275
+
276
+ # Find Sets
277
+ print("Finding Sets...")
278
+ card_objects = [c["card"] for c in cards]
279
+ sets = find_all_sets(card_objects)
280
+ print(f"Found {len(sets)} valid Set(s)")
281
+
282
+ # Draw results
283
+ result_image = self._draw_results(image, cards, sets)
284
+
285
+ if output_path:
286
+ result_image.save(output_path)
287
+ print(f"Saved result to {output_path}")
288
+
289
+ if show:
290
+ result_image.show()
291
+
292
+ return {
293
+ "num_cards": len(cards),
294
+ "cards": [
295
+ {
296
+ "attrs": c["attrs"],
297
+ "chinese": card_to_chinese(c["attrs"]),
298
+ "bbox": c["detection"]["bbox"],
299
+ "confidence": c["detection"]["confidence"],
300
+ }
301
+ for c in cards
302
+ ],
303
+ "num_sets": len(sets),
304
+ "sets": [
305
+ [str(card) for card in s]
306
+ for s in sets
307
+ ],
308
+ "sets_chinese": [
309
+ [card_to_chinese(next(c["attrs"] for c in cards if c["card"] is card)) for card in s]
310
+ for s in sets
311
+ ],
312
+ "result_image": result_image,
313
+ }
314
+
315
+ def _draw_results(
316
+ self,
317
+ image: Image.Image,
318
+ cards: List[dict],
319
+ sets: List[Tuple[Card, Card, Card]],
320
+ highlight_idx: Optional[int] = None,
321
+ ) -> Image.Image:
322
+ """Draw bounding boxes and Set highlights on image.
323
+
324
+ Args:
325
+ highlight_idx: If set, only highlight this one set (0-based).
326
+ If None, highlight all sets.
327
+ """
328
+ result = image.copy()
329
+ draw = ImageDraw.Draw(result)
330
+
331
+ # Try to load a Chinese-compatible font
332
+ font = None
333
+ font_paths = [
334
+ "/System/Library/Fonts/PingFang.ttc", # macOS
335
+ "/System/Library/Fonts/STHeiti Light.ttc", # macOS
336
+ "/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf", # Linux
337
+ "C:\\Windows\\Fonts\\msyh.ttc", # Windows
338
+ ]
339
+ for font_path in font_paths:
340
+ try:
341
+ font = ImageFont.truetype(font_path, 18)
342
+ break
343
+ except:
344
+ continue
345
+ if font is None:
346
+ font = ImageFont.load_default()
347
+
348
+ # Determine which set(s) to highlight
349
+ if highlight_idx is not None and 0 <= highlight_idx < len(sets):
350
+ highlighted_sets = [(highlight_idx, sets[highlight_idx])]
351
+ else:
352
+ highlighted_sets = list(enumerate(sets))
353
+
354
+ # Build set of highlighted card ids
355
+ highlighted_card_ids = set()
356
+ for _, card_set in highlighted_sets:
357
+ for card in card_set:
358
+ highlighted_card_ids.add(id(card))
359
+
360
+ # Draw only highlighted cards
361
+ for c in cards:
362
+ card = c["card"]
363
+ if id(card) not in highlighted_card_ids:
364
+ continue
365
+ attrs = c["attrs"]
366
+ x1, y1, x2, y2 = card.bbox
367
+
368
+ color_idx = highlighted_sets[0][0] if len(highlighted_sets) == 1 else 0
369
+ for si, card_set in highlighted_sets:
370
+ if card in card_set:
371
+ color_idx = si
372
+ break
373
+ color = SET_COLORS[color_idx % len(SET_COLORS)]
374
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
375
+
376
+ label = card_to_chinese(attrs)
377
+ draw.text((x1, y1 - 20), label, fill=color, font=font)
378
+
379
+ # Draw Set info
380
+ if highlight_idx is not None:
381
+ draw.text((10, 10), f"Set {highlight_idx + 1} / {len(sets)}", fill=(255, 255, 255), font=font)
382
+ else:
383
+ draw.text((10, 10), f"Found {len(sets)} Set(s)", fill=(255, 255, 255), font=font)
384
+
385
+ return result
386
+
387
+
388
+ def main():
389
+ import argparse
390
+
391
+ parser = argparse.ArgumentParser(description="Solve Set game from image")
392
+ parser.add_argument("image", type=str, help="Path to input image")
393
+ parser.add_argument("--output", "-o", type=str, help="Path to save output image")
394
+ parser.add_argument("--conf", type=float, default=0.25, help="Detection confidence")
395
+ parser.add_argument("--show", action="store_true", help="Display result")
396
+ args = parser.parse_args()
397
+
398
+ solver = SetSolver()
399
+ result = solver.solve(
400
+ args.image,
401
+ conf=args.conf,
402
+ output_path=args.output,
403
+ show=args.show,
404
+ )
405
+
406
+ print("\n" + "="*50)
407
+ print("结果 RESULTS")
408
+ print("="*50)
409
+ print(f"检测到卡牌: {result['num_cards']}")
410
+ print(f"找到Set: {result['num_sets']}")
411
+
412
+ if result['cards']:
413
+ print("\n卡牌:")
414
+ for c in result['cards']:
415
+ print(f" {c['chinese']}")
416
+
417
+ if result['sets_chinese']:
418
+ print("\nSets:")
419
+ for i, s in enumerate(result['sets_chinese'], 1):
420
+ print(f" Set {i}: {' + '.join(s)}")
421
+
422
+
423
+ if __name__ == "__main__":
424
+ main()
src/solver/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .set_finder import (
2
+ Card, Shape, Color, Number, Fill,
3
+ is_valid_set, find_all_sets, find_first_set,
4
+ generate_all_cards, card_to_index, index_to_card
5
+ )
6
+
7
+ __all__ = [
8
+ 'Card', 'Shape', 'Color', 'Number', 'Fill',
9
+ 'is_valid_set', 'find_all_sets', 'find_first_set',
10
+ 'generate_all_cards', 'card_to_index', 'index_to_card'
11
+ ]
src/solver/set_finder.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Set-finding algorithm.
3
+
4
+ A valid Set consists of 3 cards where, for each attribute,
5
+ the values are either ALL THE SAME or ALL DIFFERENT.
6
+ """
7
+
8
+ from dataclasses import dataclass
9
+ from enum import IntEnum
10
+ from itertools import combinations
11
+ from typing import List, Tuple
12
+
13
+
14
+ class Shape(IntEnum):
15
+ DIAMOND = 0
16
+ OVAL = 1
17
+ SQUIGGLE = 2
18
+
19
+
20
+ class Color(IntEnum):
21
+ RED = 0
22
+ GREEN = 1
23
+ PURPLE = 2
24
+
25
+
26
+ class Number(IntEnum):
27
+ ONE = 0
28
+ TWO = 1
29
+ THREE = 2
30
+
31
+
32
+ class Fill(IntEnum):
33
+ SOLID = 0
34
+ STRIPED = 1
35
+ EMPTY = 2
36
+
37
+
38
+ @dataclass
39
+ class Card:
40
+ """A Set card with 4 attributes."""
41
+ shape: Shape
42
+ color: Color
43
+ number: Number
44
+ fill: Fill
45
+
46
+ # Optional: position in image (for visualization)
47
+ bbox: Tuple[float, float, float, float] = None # x, y, w, h
48
+
49
+ def __hash__(self):
50
+ return hash((self.shape, self.color, self.number, self.fill))
51
+
52
+ def __eq__(self, other):
53
+ if not isinstance(other, Card):
54
+ return False
55
+ return (self.shape == other.shape and
56
+ self.color == other.color and
57
+ self.number == other.number and
58
+ self.fill == other.fill)
59
+
60
+ def to_tuple(self) -> Tuple[int, int, int, int]:
61
+ """Return attributes as tuple of ints."""
62
+ return (self.shape, self.color, self.number, self.fill)
63
+
64
+ @classmethod
65
+ def from_tuple(cls, attrs: Tuple[int, int, int, int], bbox=None) -> "Card":
66
+ """Create card from tuple of attribute indices."""
67
+ return cls(
68
+ shape=Shape(attrs[0]),
69
+ color=Color(attrs[1]),
70
+ number=Number(attrs[2]),
71
+ fill=Fill(attrs[3]),
72
+ bbox=bbox
73
+ )
74
+
75
+ def __repr__(self):
76
+ n = ["one", "two", "three"][self.number]
77
+ return f"{n} {self.fill.name.lower()} {self.color.name.lower()} {self.shape.name.lower()}(s)"
78
+
79
+
80
+ def is_valid_set(card1: Card, card2: Card, card3: Card) -> bool:
81
+ """
82
+ Check if three cards form a valid Set.
83
+
84
+ For each attribute, the three values must be either:
85
+ - All the same (e.g., all red)
86
+ - All different (e.g., red, green, purple)
87
+ """
88
+ for attr in ['shape', 'color', 'number', 'fill']:
89
+ values = [getattr(card1, attr), getattr(card2, attr), getattr(card3, attr)]
90
+ unique = len(set(values))
91
+ # Valid: all same (1 unique) or all different (3 unique)
92
+ # Invalid: exactly 2 unique
93
+ if unique == 2:
94
+ return False
95
+ return True
96
+
97
+
98
+ def find_all_sets(cards: List[Card]) -> List[Tuple[Card, Card, Card]]:
99
+ """
100
+ Find all valid Sets among the given cards.
101
+
102
+ Uses brute force: check all C(n,3) combinations.
103
+ For 12 cards: C(12,3) = 220 combinations - very fast.
104
+ For 21 cards (max in real game): C(21,3) = 1330 combinations - still fast.
105
+ """
106
+ valid_sets = []
107
+ for combo in combinations(cards, 3):
108
+ if is_valid_set(*combo):
109
+ valid_sets.append(combo)
110
+ return valid_sets
111
+
112
+
113
+ def find_first_set(cards: List[Card]) -> Tuple[Card, Card, Card] | None:
114
+ """Find the first valid Set, or None if no Set exists."""
115
+ for combo in combinations(cards, 3):
116
+ if is_valid_set(*combo):
117
+ return combo
118
+ return None
119
+
120
+
121
+ # --- Utilities ---
122
+
123
+ def generate_all_cards() -> List[Card]:
124
+ """Generate all 81 unique Set cards."""
125
+ cards = []
126
+ for s in Shape:
127
+ for c in Color:
128
+ for n in Number:
129
+ for f in Fill:
130
+ cards.append(Card(shape=s, color=c, number=n, fill=f))
131
+ return cards
132
+
133
+
134
+ def card_to_index(card: Card) -> int:
135
+ """Convert card to unique index (0-80)."""
136
+ return (card.shape * 27 + card.color * 9 + card.number * 3 + card.fill)
137
+
138
+
139
+ def index_to_card(idx: int) -> Card:
140
+ """Convert index (0-80) to card."""
141
+ fill = idx % 3
142
+ idx //= 3
143
+ number = idx % 3
144
+ idx //= 3
145
+ color = idx % 3
146
+ idx //= 3
147
+ shape = idx
148
+ return Card(Shape(shape), Color(color), Number(number), Fill(fill))
149
+
150
+
151
+ # --- Demo ---
152
+
153
+ if __name__ == "__main__":
154
+ # Example: find sets in a random deal
155
+ import random
156
+
157
+ all_cards = generate_all_cards()
158
+ print(f"Total cards in deck: {len(all_cards)}")
159
+
160
+ # Deal 12 cards
161
+ deal = random.sample(all_cards, 12)
162
+ print(f"\nDealt {len(deal)} cards:")
163
+ for i, card in enumerate(deal):
164
+ print(f" {i+1}. {card}")
165
+
166
+ # Find all sets
167
+ sets = find_all_sets(deal)
168
+ print(f"\nFound {len(sets)} valid Set(s):")
169
+ for i, (c1, c2, c3) in enumerate(sets):
170
+ print(f"\n Set {i+1}:")
171
+ print(f" - {c1}")
172
+ print(f" - {c2}")
173
+ print(f" - {c3}")
src/train/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Training scripts
src/train/classifier.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train a card attribute classifier on the existing labeled images.
3
+
4
+ Uses MobileNetV3-Small for iPhone compatibility.
5
+ Multi-head output: predicts all 4 attributes simultaneously.
6
+ """
7
+
8
+ import os
9
+ import json
10
+ from pathlib import Path
11
+ from typing import Tuple, Dict, List
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ from torch.utils.data import Dataset, DataLoader, random_split
16
+ from torchvision import transforms, models
17
+ from torchvision.io import read_image, ImageReadMode
18
+ from PIL import Image
19
+ import numpy as np
20
+ from tqdm import tqdm
21
+
22
+ # === Config ===
23
+
24
+ DATA_DIR = Path(__file__).parent.parent.parent / "training_images"
25
+ SYNTHETIC_DATA_DIR = Path(__file__).parent.parent.parent / "training_images_synthetic"
26
+ WEIGHTS_DIR = Path(__file__).parent.parent.parent / "weights"
27
+ WEIGHTS_DIR.mkdir(exist_ok=True)
28
+
29
+ # Attribute mappings (folder names → indices)
30
+ NUMBER_MAP = {"one": 0, "two": 1, "three": 2}
31
+ COLOR_MAP = {"red": 0, "green": 1, "blue": 2} # blue = purple in standard Set
32
+ SHAPE_MAP = {"diamond": 0, "oval": 1, "squiggle": 2}
33
+ FILL_MAP = {"empty": 0, "full": 1, "partial": 2} # partial = striped
34
+
35
+ # Reverse mappings for inference
36
+ NUMBER_NAMES = ["one", "two", "three"]
37
+ COLOR_NAMES = ["red", "green", "blue"]
38
+ SHAPE_NAMES = ["diamond", "oval", "squiggle"]
39
+ FILL_NAMES = ["empty", "full", "partial"]
40
+
41
+
42
+ # === Dataset ===
43
+
44
+ class SetCardDataset(Dataset):
45
+ """Dataset of labeled Set card images."""
46
+
47
+ def __init__(self, data_dirs, transform=None):
48
+ if isinstance(data_dirs, Path):
49
+ data_dirs = [data_dirs]
50
+ self.transform = transform
51
+ self.samples: List[Tuple[Path, Dict[str, int]]] = []
52
+
53
+ # Walk the directory structure to find all images
54
+ for data_dir in data_dirs:
55
+ if not data_dir.exists():
56
+ continue
57
+ count_before = len(self.samples)
58
+ for number in NUMBER_MAP:
59
+ for color in COLOR_MAP:
60
+ for shape in SHAPE_MAP:
61
+ for fill in FILL_MAP:
62
+ folder = data_dir / number / color / shape / fill
63
+ if folder.exists():
64
+ for img_path in folder.glob("*.png"):
65
+ labels = {
66
+ "number": NUMBER_MAP[number],
67
+ "color": COLOR_MAP[color],
68
+ "shape": SHAPE_MAP[shape],
69
+ "fill": FILL_MAP[fill],
70
+ }
71
+ self.samples.append((img_path, labels))
72
+ print(f"Loaded {len(self.samples) - count_before} samples from {data_dir}")
73
+
74
+ print(f"Total: {len(self.samples)} samples")
75
+
76
+ def __len__(self):
77
+ return len(self.samples)
78
+
79
+ def __getitem__(self, idx):
80
+ img_path, labels = self.samples[idx]
81
+
82
+ # Load image
83
+ image = Image.open(img_path).convert("RGB")
84
+
85
+ if self.transform:
86
+ image = self.transform(image)
87
+
88
+ # Stack labels into tensor
89
+ label_tensor = torch.tensor([
90
+ labels["number"],
91
+ labels["color"],
92
+ labels["shape"],
93
+ labels["fill"],
94
+ ], dtype=torch.long)
95
+
96
+ return image, label_tensor
97
+
98
+ def get_raw(self, idx):
99
+ """Get raw PIL image and labels (no transform)."""
100
+ img_path, labels = self.samples[idx]
101
+ image = Image.open(img_path).convert("RGB")
102
+ label_tensor = torch.tensor([
103
+ labels["number"],
104
+ labels["color"],
105
+ labels["shape"],
106
+ labels["fill"],
107
+ ], dtype=torch.long)
108
+ return image, label_tensor
109
+
110
+
111
+ # === Model ===
112
+
113
+ class SetCardClassifier(nn.Module):
114
+ """
115
+ Multi-head classifier for Set card attributes.
116
+
117
+ Uses MobileNetV3-Small backbone (good for mobile deployment).
118
+ Four output heads, one per attribute.
119
+ """
120
+
121
+ def __init__(self, pretrained: bool = True):
122
+ super().__init__()
123
+
124
+ # Load pretrained MobileNetV3-Small
125
+ weights = models.MobileNet_V3_Small_Weights.DEFAULT if pretrained else None
126
+ self.backbone = models.mobilenet_v3_small(weights=weights)
127
+
128
+ # Get the feature dimension from the classifier
129
+ in_features = self.backbone.classifier[0].in_features
130
+
131
+ # Remove the original classifier
132
+ self.backbone.classifier = nn.Identity()
133
+
134
+ # Add our multi-head classifier
135
+ self.heads = nn.ModuleDict({
136
+ "number": nn.Linear(in_features, 3),
137
+ "color": nn.Linear(in_features, 3),
138
+ "shape": nn.Linear(in_features, 3),
139
+ "fill": nn.Linear(in_features, 3),
140
+ })
141
+
142
+ def forward(self, x):
143
+ features = self.backbone(x)
144
+ return {
145
+ "number": self.heads["number"](features),
146
+ "color": self.heads["color"](features),
147
+ "shape": self.heads["shape"](features),
148
+ "fill": self.heads["fill"](features),
149
+ }
150
+
151
+
152
+ # === Training ===
153
+
154
+ def train_epoch(model, loader, optimizer, criterion, device):
155
+ model.train()
156
+ total_loss = 0
157
+ correct = {k: 0 for k in ["number", "color", "shape", "fill"]}
158
+ total = 0
159
+
160
+ for images, labels in tqdm(loader, desc="Training", leave=False):
161
+ images = images.to(device)
162
+ labels = labels.to(device)
163
+
164
+ optimizer.zero_grad()
165
+ outputs = model(images)
166
+
167
+ # Compute loss for each head (2x weight on fill to penalize fill mistakes)
168
+ loss = 0
169
+ fill_weight = 2.0
170
+ for i, key in enumerate(["number", "color", "shape", "fill"]):
171
+ head_loss = criterion(outputs[key], labels[:, i])
172
+ loss += fill_weight * head_loss if key == "fill" else head_loss
173
+ preds = outputs[key].argmax(dim=1)
174
+ correct[key] += (preds == labels[:, i]).sum().item()
175
+
176
+ loss.backward()
177
+ optimizer.step()
178
+
179
+ total_loss += loss.item()
180
+ total += labels.size(0)
181
+
182
+ avg_loss = total_loss / len(loader)
183
+ accuracies = {k: v / total for k, v in correct.items()}
184
+ return avg_loss, accuracies
185
+
186
+
187
+ def evaluate(model, loader, criterion, device):
188
+ model.eval()
189
+ total_loss = 0
190
+ correct = {k: 0 for k in ["number", "color", "shape", "fill"]}
191
+ total = 0
192
+
193
+ with torch.no_grad():
194
+ for images, labels in tqdm(loader, desc="Evaluating", leave=False):
195
+ images = images.to(device)
196
+ labels = labels.to(device)
197
+
198
+ outputs = model(images)
199
+
200
+ loss = 0
201
+ for i, key in enumerate(["number", "color", "shape", "fill"]):
202
+ loss += criterion(outputs[key], labels[:, i])
203
+ preds = outputs[key].argmax(dim=1)
204
+ correct[key] += (preds == labels[:, i]).sum().item()
205
+
206
+ total_loss += loss.item()
207
+ total += labels.size(0)
208
+
209
+ avg_loss = total_loss / len(loader)
210
+ accuracies = {k: v / total for k, v in correct.items()}
211
+ return avg_loss, accuracies
212
+
213
+
214
+ def main():
215
+ # === Hyperparameters ===
216
+ BATCH_SIZE = 32
217
+ EPOCHS = 50
218
+ LR = 1e-3
219
+ VAL_SPLIT = 0.15
220
+ TEST_SPLIT = 0.10
221
+ IMG_SIZE = 224
222
+
223
+ device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
224
+ print(f"Using device: {device}")
225
+
226
+ # === Data transforms ===
227
+ train_transform = transforms.Compose([
228
+ transforms.RandomResizedCrop(IMG_SIZE, scale=(0.7, 1.0)), # Simulate imperfect detector crops
229
+ transforms.RandomHorizontalFlip(),
230
+ transforms.RandomVerticalFlip(),
231
+ transforms.RandomRotation(180), # Cards can be any orientation
232
+ transforms.RandomPerspective(distortion_scale=0.15, p=0.5), # Perspective warp from detection
233
+ transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.05),
234
+ transforms.RandomGrayscale(p=0.05), # Force model to not rely solely on color for fill
235
+ transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)), # ~30% effective via random sigma
236
+ transforms.ToTensor(),
237
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
238
+ ])
239
+
240
+ val_transform = transforms.Compose([
241
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
242
+ transforms.ToTensor(),
243
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
244
+ ])
245
+
246
+ # === Load dataset (clean + synthetic crops) ===
247
+ data_dirs = [DATA_DIR]
248
+ if SYNTHETIC_DATA_DIR.exists():
249
+ data_dirs.append(SYNTHETIC_DATA_DIR)
250
+ full_dataset = SetCardDataset(data_dirs, transform=None) # No transform yet
251
+
252
+ # Split into train/val/test
253
+ total = len(full_dataset)
254
+ test_size = int(total * TEST_SPLIT)
255
+ val_size = int(total * VAL_SPLIT)
256
+ train_size = total - val_size - test_size
257
+
258
+ train_dataset, val_dataset, test_dataset = random_split(
259
+ full_dataset, [train_size, val_size, test_size],
260
+ generator=torch.Generator().manual_seed(42)
261
+ )
262
+
263
+ print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
264
+
265
+ # Wrap with transform (can't change transform on Subset, so we wrap)
266
+ class TransformDataset(torch.utils.data.Dataset):
267
+ def __init__(self, subset, transform):
268
+ self.subset = subset
269
+ self.transform = transform
270
+ def __len__(self):
271
+ return len(self.subset)
272
+ def __getitem__(self, idx):
273
+ image, label = self.subset[idx]
274
+ if self.transform:
275
+ image = self.transform(image)
276
+ return image, label
277
+
278
+ train_dataset = TransformDataset(train_dataset, train_transform)
279
+ val_dataset = TransformDataset(val_dataset, val_transform)
280
+ test_dataset = TransformDataset(test_dataset, val_transform)
281
+
282
+ # Use num_workers=0 on macOS to avoid shared memory issues
283
+ import platform
284
+ num_workers = 0 if platform.system() == "Darwin" else 4
285
+
286
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers)
287
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers)
288
+ test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers)
289
+
290
+ # === Model ===
291
+ model = SetCardClassifier(pretrained=True).to(device)
292
+ criterion = nn.CrossEntropyLoss()
293
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
294
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
295
+
296
+ # === Training loop ===
297
+ best_val_acc = 0
298
+
299
+ for epoch in range(EPOCHS):
300
+ train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
301
+ val_loss, val_acc = evaluate(model, val_loader, criterion, device)
302
+ scheduler.step()
303
+
304
+ # Average accuracy across all heads
305
+ avg_train_acc = sum(train_acc.values()) / 4
306
+ avg_val_acc = sum(val_acc.values()) / 4
307
+
308
+ print(f"Epoch {epoch+1}/{EPOCHS}")
309
+ print(f" Train Loss: {train_loss:.4f}, Acc: {avg_train_acc:.4f}")
310
+ print(f" Val Loss: {val_loss:.4f}, Acc: {avg_val_acc:.4f}")
311
+ print(f" Val per-head: num={val_acc['number']:.3f} col={val_acc['color']:.3f} "
312
+ f"shp={val_acc['shape']:.3f} fil={val_acc['fill']:.3f}")
313
+
314
+ # Save best model
315
+ if avg_val_acc > best_val_acc:
316
+ best_val_acc = avg_val_acc
317
+ torch.save({
318
+ "epoch": epoch,
319
+ "model_state_dict": model.state_dict(),
320
+ "optimizer_state_dict": optimizer.state_dict(),
321
+ "val_acc": val_acc,
322
+ }, WEIGHTS_DIR / "classifier_best.pt")
323
+ print(f" Saved new best model (val_acc={avg_val_acc:.4f})")
324
+
325
+ # === Final evaluation on test set ===
326
+ print("\n" + "="*50)
327
+ print("Final Test Evaluation")
328
+ print("="*50)
329
+
330
+ # Load best model
331
+ checkpoint = torch.load(WEIGHTS_DIR / "classifier_best.pt")
332
+ model.load_state_dict(checkpoint["model_state_dict"])
333
+
334
+ test_loss, test_acc = evaluate(model, test_loader, criterion, device)
335
+ avg_test_acc = sum(test_acc.values()) / 4
336
+
337
+ print(f"Test Loss: {test_loss:.4f}")
338
+ print(f"Test Accuracy (avg): {avg_test_acc:.4f}")
339
+ print(f" Number: {test_acc['number']:.4f}")
340
+ print(f" Color: {test_acc['color']:.4f}")
341
+ print(f" Shape: {test_acc['shape']:.4f}")
342
+ print(f" Fill: {test_acc['fill']:.4f}")
343
+
344
+ # Save final results
345
+ results = {
346
+ "test_loss": test_loss,
347
+ "test_accuracy": test_acc,
348
+ "avg_test_accuracy": avg_test_acc,
349
+ "train_size": train_size,
350
+ "val_size": val_size,
351
+ "test_size": test_size,
352
+ }
353
+ with open(WEIGHTS_DIR / "training_results.json", "w") as f:
354
+ json.dump(results, f, indent=2)
355
+
356
+ print(f"\nModel saved to {WEIGHTS_DIR / 'classifier_best.pt'}")
357
+ print(f"Results saved to {WEIGHTS_DIR / 'training_results.json'}")
358
+
359
+
360
+ if __name__ == "__main__":
361
+ main()
src/web/__init__.py ADDED
File without changes
src/web/app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Web-based real-time Set solver.
3
+
4
+ FastAPI backend serving a single HTML page with live camera feed.
5
+ Processes frames via the SetSolver pipeline and returns annotated results.
6
+ """
7
+
8
+ import base64
9
+ import io
10
+ import sys
11
+ from pathlib import Path
12
+
13
+ from fastapi import FastAPI, UploadFile, File
14
+ from fastapi.responses import HTMLResponse
15
+ from PIL import Image
16
+
17
+ # Add project root to path
18
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
19
+
20
+ from src.inference.solve import SetSolver
21
+
22
+ app = FastAPI(title="Set Solver")
23
+
24
+ # Global solver instance (loaded once at startup)
25
+ solver: SetSolver = None
26
+
27
+
28
+ @app.on_event("startup")
29
+ def load_solver():
30
+ global solver
31
+ print("Loading Set Solver pipeline...")
32
+ solver = SetSolver()
33
+ print("Solver ready!")
34
+
35
+
36
+ @app.get("/", response_class=HTMLResponse)
37
+ def index():
38
+ html_path = Path(__file__).parent / "templates" / "index.html"
39
+ return html_path.read_text()
40
+
41
+
42
+ @app.post("/api/solve")
43
+ async def solve_frame(file: UploadFile = File(...)):
44
+ """Accept a JPEG frame, run solver, return results."""
45
+ contents = await file.read()
46
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
47
+
48
+ result = solver.solve_from_image(image, conf=0.25)
49
+
50
+ # Encode per-set annotated images as base64 JPEG
51
+ result_images_b64 = []
52
+ for img in result.pop("result_images"):
53
+ buf = io.BytesIO()
54
+ img.save(buf, format="JPEG", quality=85)
55
+ result_images_b64.append(base64.b64encode(buf.getvalue()).decode("utf-8"))
56
+ result["result_images_b64"] = result_images_b64
57
+
58
+ # Crop cards per set for trophy display
59
+ per_set_cards_b64 = []
60
+ for bboxes in result.get("sets_bboxes", []):
61
+ crops = []
62
+ for bbox in bboxes:
63
+ x1, y1, x2, y2 = bbox
64
+ crop = image.crop((x1, y1, x2, y2))
65
+ cbuf = io.BytesIO()
66
+ crop.save(cbuf, format="JPEG", quality=90)
67
+ crops.append(base64.b64encode(cbuf.getvalue()).decode("utf-8"))
68
+ per_set_cards_b64.append(crops)
69
+ result["per_set_cards_b64"] = per_set_cards_b64
70
+
71
+ return result
72
+
73
+
74
+ if __name__ == "__main__":
75
+ import argparse
76
+ import subprocess
77
+ import tempfile
78
+ import uvicorn
79
+
80
+ parser = argparse.ArgumentParser(description="Set Solver web server")
81
+ parser.add_argument("--port", type=int, default=8000)
82
+ parser.add_argument("--no-ssl", action="store_true", help="Disable auto-generated SSL (camera requires HTTPS on non-localhost)")
83
+ args = parser.parse_args()
84
+
85
+ ssl_kwargs = {}
86
+ if not args.no_ssl:
87
+ # Generate a self-signed cert so mobile browsers allow camera access
88
+ cert_dir = Path(tempfile.mkdtemp())
89
+ cert_file = cert_dir / "cert.pem"
90
+ key_file = cert_dir / "key.pem"
91
+ subprocess.run([
92
+ "openssl", "req", "-x509", "-newkey", "rsa:2048",
93
+ "-keyout", str(key_file), "-out", str(cert_file),
94
+ "-days", "1", "-nodes",
95
+ "-subj", "/CN=set-solver",
96
+ ], check=True, capture_output=True)
97
+ ssl_kwargs = {"ssl_certfile": str(cert_file), "ssl_keyfile": str(key_file)}
98
+ proto = "https"
99
+ else:
100
+ proto = "http"
101
+
102
+ # Show access URLs
103
+ import socket
104
+ hostname = socket.gethostname()
105
+ try:
106
+ local_ip = socket.gethostbyname(hostname)
107
+ except socket.gaierror:
108
+ local_ip = "127.0.0.1"
109
+ print(f"\n Set Solver running at:")
110
+ print(f" Local: {proto}://localhost:{args.port}")
111
+ print(f" Network: {proto}://{local_ip}:{args.port}\n")
112
+
113
+ uvicorn.run("src.web.app:app", host="0.0.0.0", port=args.port, reload=False, **ssl_kwargs)
src/web/templates/index.html ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no">
6
+ <title>Set Solver</title>
7
+ <style>
8
+ * { margin: 0; padding: 0; box-sizing: border-box; }
9
+ body {
10
+ background: #000;
11
+ color: #fff;
12
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
13
+ overflow: hidden;
14
+ height: 100dvh;
15
+ width: 100vw;
16
+ display: flex;
17
+ flex-direction: column;
18
+ }
19
+ #trophy {
20
+ display: none;
21
+ flex-direction: row;
22
+ justify-content: center;
23
+ align-items: center;
24
+ gap: 6px;
25
+ padding: 6px;
26
+ background: #111;
27
+ flex-shrink: 0;
28
+ }
29
+ #trophy.active { display: flex; }
30
+ #trophy img {
31
+ height: 60px;
32
+ max-width: 30vw;
33
+ border-radius: 4px;
34
+ border: 2px solid #4f4;
35
+ object-fit: contain;
36
+ }
37
+ #camera-container {
38
+ position: relative;
39
+ flex: 1;
40
+ display: flex;
41
+ align-items: center;
42
+ justify-content: center;
43
+ overflow: hidden;
44
+ }
45
+ video, #result-img {
46
+ max-width: 100%;
47
+ max-height: 100%;
48
+ object-fit: contain;
49
+ }
50
+ #result-img { display: none; }
51
+ #bottom-bar {
52
+ position: absolute;
53
+ bottom: 0; left: 0; right: 0;
54
+ display: flex;
55
+ flex-direction: column;
56
+ align-items: center;
57
+ padding-bottom: 16px;
58
+ z-index: 15;
59
+ pointer-events: none;
60
+ }
61
+ #set-nav {
62
+ display: none;
63
+ align-items: center;
64
+ gap: 12px;
65
+ margin-bottom: 10px;
66
+ pointer-events: auto;
67
+ }
68
+ #set-nav.active { display: flex; }
69
+ #set-nav .nav-arrow {
70
+ background: rgba(255,255,255,0.2);
71
+ border: none;
72
+ color: #fff;
73
+ font-size: 22px;
74
+ width: 40px; height: 40px;
75
+ border-radius: 50%;
76
+ cursor: pointer;
77
+ display: flex;
78
+ align-items: center;
79
+ justify-content: center;
80
+ }
81
+ #set-nav .nav-arrow:active { background: rgba(255,255,255,0.4); }
82
+ #set-label {
83
+ font-size: 14px;
84
+ color: #ccc;
85
+ min-width: 100px;
86
+ text-align: center;
87
+ }
88
+ #scan-btn {
89
+ border: none;
90
+ border-radius: 28px;
91
+ padding: 14px 48px;
92
+ font-size: 18px;
93
+ font-weight: 600;
94
+ cursor: pointer;
95
+ transition: background 0.2s;
96
+ pointer-events: auto;
97
+ }
98
+ #scan-btn.start {
99
+ background: #4f4;
100
+ color: #000;
101
+ }
102
+ #scan-btn.stop {
103
+ background: #f44;
104
+ color: #fff;
105
+ }
106
+ #scan-btn.restart {
107
+ background: #ff0;
108
+ color: #000;
109
+ }
110
+ #scan-btn:active { opacity: 0.7; }
111
+ #status-bar {
112
+ position: absolute;
113
+ top: 8px; left: 8px;
114
+ background: rgba(0,0,0,0.6);
115
+ border-radius: 8px;
116
+ padding: 4px 10px;
117
+ font-size: 13px;
118
+ z-index: 5;
119
+ }
120
+ #status-bar .dot {
121
+ display: inline-block;
122
+ width: 8px; height: 8px;
123
+ border-radius: 50%;
124
+ margin-right: 6px;
125
+ vertical-align: middle;
126
+ }
127
+ .dot.active { background: #4f4; }
128
+ .dot.inactive { background: #f44; }
129
+ .dot.processing { background: #ff4; }
130
+ .dot.idle { background: #888; }
131
+ </style>
132
+ </head>
133
+ <body>
134
+ <div id="trophy"></div>
135
+ <div id="camera-container">
136
+ <video id="video" autoplay playsinline muted></video>
137
+ <img id="result-img" alt="Result">
138
+ <div id="status-bar">
139
+ <span class="dot inactive" id="status-dot"></span>
140
+ <span id="status-text">Starting camera...</span>
141
+ </div>
142
+ <div id="bottom-bar">
143
+ <div id="set-nav">
144
+ <button class="nav-arrow" id="prev-btn">&larr;</button>
145
+ <span id="set-label"></span>
146
+ <button class="nav-arrow" id="next-btn">&rarr;</button>
147
+ </div>
148
+ <button id="scan-btn" class="start">Start</button>
149
+ </div>
150
+ </div>
151
+
152
+ <canvas id="capture-canvas" style="display:none;"></canvas>
153
+
154
+ <script>
155
+ const video = document.getElementById('video');
156
+ const resultImg = document.getElementById('result-img');
157
+ const trophy = document.getElementById('trophy');
158
+ const setNav = document.getElementById('set-nav');
159
+ const setLabel = document.getElementById('set-label');
160
+ const prevBtn = document.getElementById('prev-btn');
161
+ const nextBtn = document.getElementById('next-btn');
162
+ const scanBtn = document.getElementById('scan-btn');
163
+ const statusDot = document.getElementById('status-dot');
164
+ const statusText = document.getElementById('status-text');
165
+ const canvas = document.getElementById('capture-canvas');
166
+
167
+ let stream = null;
168
+ let scanning = false;
169
+ let processing = false;
170
+ let frozen = false; // true when showing results
171
+ let loopTimer = null;
172
+
173
+ // Result state for cycling through sets
174
+ let resultData = null;
175
+ let currentSetIdx = 0;
176
+
177
+ async function startCamera() {
178
+ if (!navigator.mediaDevices || !navigator.mediaDevices.getUserMedia) {
179
+ statusDot.className = 'dot inactive';
180
+ statusText.textContent = 'Camera API unavailable — use https://';
181
+ console.error('mediaDevices not available. Page must be served over HTTPS (or localhost).');
182
+ return;
183
+ }
184
+ try {
185
+ stream = await navigator.mediaDevices.getUserMedia({
186
+ video: { facingMode: 'environment', width: { ideal: 1280 }, height: { ideal: 720 } },
187
+ audio: false,
188
+ });
189
+ for (const track of stream.getVideoTracks()) {
190
+ const caps = track.getCapabilities?.() || {};
191
+ const settings = {};
192
+ if ('backgroundBlur' in caps) settings.backgroundBlur = false;
193
+ if ('faceFraming' in caps) settings.faceFraming = false;
194
+ if ('pan' in caps) settings.pan = track.getSettings().pan;
195
+ if ('tilt' in caps) settings.tilt = track.getSettings().tilt;
196
+ if ('zoom' in caps) settings.zoom = track.getSettings().zoom;
197
+ if (Object.keys(settings).length > 0) {
198
+ try { await track.applyConstraints({ advanced: [settings] }); } catch (e) { /* ignore */ }
199
+ }
200
+ }
201
+ video.srcObject = stream;
202
+ await video.play();
203
+ statusDot.className = 'dot idle';
204
+ statusText.textContent = 'Ready — press Start';
205
+ } catch (err) {
206
+ statusDot.className = 'dot inactive';
207
+ statusText.textContent = 'Camera access denied — check browser permissions';
208
+ console.error('Camera error:', err);
209
+ }
210
+ }
211
+
212
+ function restart() {
213
+ // Go from frozen results back to live camera (not scanning yet)
214
+ frozen = false;
215
+ scanning = false;
216
+ resultData = null;
217
+ currentSetIdx = 0;
218
+ trophy.classList.remove('active');
219
+ trophy.innerHTML = '';
220
+ setNav.classList.remove('active');
221
+ resultImg.style.display = 'none';
222
+ video.style.display = 'block';
223
+
224
+ scanBtn.textContent = 'Start';
225
+ scanBtn.className = 'start';
226
+ statusDot.className = 'dot idle';
227
+ statusText.textContent = 'Ready — press Start';
228
+ }
229
+
230
+ function startScanning() {
231
+ scanning = true;
232
+ scanBtn.textContent = 'Stop';
233
+ scanBtn.className = 'stop';
234
+ statusDot.className = 'dot active';
235
+ statusText.textContent = 'Scanning...';
236
+
237
+ if (loopTimer) clearInterval(loopTimer);
238
+ loopTimer = setInterval(() => {
239
+ if (scanning && !processing) captureAndSolve();
240
+ }, 333);
241
+ }
242
+
243
+ function stopScanning() {
244
+ scanning = false;
245
+ if (loopTimer) { clearInterval(loopTimer); loopTimer = null; }
246
+ scanBtn.textContent = 'Start';
247
+ scanBtn.className = 'start';
248
+ statusDot.className = 'dot idle';
249
+ statusText.textContent = 'Stopped';
250
+ }
251
+
252
+ async function captureAndSolve() {
253
+ if (!scanning || processing) return;
254
+ processing = true;
255
+ statusDot.className = 'dot processing';
256
+
257
+ try {
258
+ canvas.width = video.videoWidth;
259
+ canvas.height = video.videoHeight;
260
+ const ctx = canvas.getContext('2d');
261
+ ctx.drawImage(video, 0, 0);
262
+
263
+ const blob = await new Promise(resolve => canvas.toBlob(resolve, 'image/jpeg', 0.8));
264
+ const formData = new FormData();
265
+ formData.append('file', blob, 'frame.jpg');
266
+
267
+ const resp = await fetch('/api/solve', { method: 'POST', body: formData });
268
+ if (!resp.ok) throw new Error(`HTTP ${resp.status}`);
269
+ const data = await resp.json();
270
+
271
+ if (!scanning) return;
272
+
273
+ statusText.textContent = `${data.num_cards} cards`;
274
+ statusDot.className = 'dot active';
275
+
276
+ if (data.num_sets > 0) {
277
+ showResult(data);
278
+ }
279
+ } catch (err) {
280
+ console.error('Solve error:', err);
281
+ if (scanning) statusDot.className = 'dot active';
282
+ } finally {
283
+ processing = false;
284
+ }
285
+ }
286
+
287
+ function showResult(data) {
288
+ scanning = false;
289
+ frozen = true;
290
+ if (loopTimer) { clearInterval(loopTimer); loopTimer = null; }
291
+
292
+ resultData = data;
293
+ currentSetIdx = 0;
294
+
295
+ video.style.display = 'none';
296
+ resultImg.style.display = 'block';
297
+
298
+ // Show nav if multiple sets
299
+ if (data.num_sets > 1) {
300
+ setNav.classList.add('active');
301
+ }
302
+
303
+ showCurrentSet();
304
+
305
+ scanBtn.textContent = 'Restart';
306
+ scanBtn.className = 'restart';
307
+ statusDot.className = 'dot active';
308
+ statusText.textContent = `Found ${data.num_sets} Set${data.num_sets > 1 ? 's' : ''}!`;
309
+
310
+ speak('Set!');
311
+ }
312
+
313
+ function showCurrentSet() {
314
+ if (!resultData) return;
315
+ const data = resultData;
316
+ const i = currentSetIdx;
317
+
318
+ // Show annotated image for this set
319
+ resultImg.src = 'data:image/jpeg;base64,' + data.result_images_b64[i];
320
+
321
+ // Show trophy cards for this set
322
+ const cards = data.per_set_cards_b64[i];
323
+ if (cards && cards.length === 3) {
324
+ trophy.innerHTML = cards
325
+ .map(b64 => `<img src="data:image/jpeg;base64,${b64}">`)
326
+ .join('');
327
+ trophy.classList.add('active');
328
+ }
329
+
330
+ // Update nav label
331
+ setLabel.textContent = `Set ${i + 1} / ${data.num_sets}`;
332
+ }
333
+
334
+ function prevSet() {
335
+ if (!resultData || resultData.num_sets <= 1) return;
336
+ currentSetIdx = (currentSetIdx - 1 + resultData.num_sets) % resultData.num_sets;
337
+ showCurrentSet();
338
+ }
339
+
340
+ function nextSet() {
341
+ if (!resultData || resultData.num_sets <= 1) return;
342
+ currentSetIdx = (currentSetIdx + 1) % resultData.num_sets;
343
+ showCurrentSet();
344
+ }
345
+
346
+ function speak(text) {
347
+ if ('speechSynthesis' in window) {
348
+ const utter = new SpeechSynthesisUtterance(text);
349
+ utter.rate = 1.2;
350
+ utter.pitch = 1.1;
351
+ speechSynthesis.speak(utter);
352
+ }
353
+ }
354
+
355
+ scanBtn.addEventListener('click', () => {
356
+ if (frozen) {
357
+ restart();
358
+ } else if (scanning) {
359
+ stopScanning();
360
+ } else {
361
+ startScanning();
362
+ }
363
+ });
364
+ prevBtn.addEventListener('click', prevSet);
365
+ nextBtn.addEventListener('click', nextSet);
366
+
367
+ document.addEventListener('keydown', e => {
368
+ if (e.key === ' ') {
369
+ e.preventDefault();
370
+ if (frozen) restart();
371
+ else if (scanning) stopScanning();
372
+ else startScanning();
373
+ } else if (e.key === 'ArrowLeft') {
374
+ prevSet();
375
+ } else if (e.key === 'ArrowRight') {
376
+ nextSet();
377
+ }
378
+ });
379
+
380
+ startCamera();
381
+ </script>
382
+ </body>
383
+ </html>
weights/classifier_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0c464367eccfcfd6599377c9af35f72cd23c524b01eda7e9a11ccb1e3ba3f6d
3
+ size 11465795
weights/detector/weights/best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d65deae13124271df8739b700d2f893bca1eb7a7bc8ac870702e714b787ceee7
3
+ size 5453594