File size: 15,329 Bytes
8a34385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91a905e
 
8a34385
 
 
 
 
 
 
91a905e
8a34385
 
 
 
 
 
 
 
 
 
 
 
 
 
91a905e
 
8a34385
 
 
 
91a905e
8a34385
 
91a905e
 
 
8a34385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f8aa89
8a34385
 
 
 
 
1f8aa89
 
 
 
 
 
 
 
 
 
 
 
8a34385
7e1047a
 
 
 
8a34385
 
 
 
1f8aa89
8a34385
1f8aa89
8a34385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
"""
End-to-end Set solver pipeline.

Photo → Detect cards → Classify each → Find Sets → Visualize
"""

import sys
from pathlib import Path
from typing import List, Tuple, Optional

import torch
from PIL import Image, ImageDraw, ImageFont
from ultralytics import YOLO
import numpy as np

# Add parent to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent.parent))

from src.train.classifier import (
    SetCardClassifier,
    NUMBER_NAMES, COLOR_NAMES, SHAPE_NAMES, FILL_NAMES,
)
from src.solver.set_finder import Card, Shape, Color, Number, Fill, find_all_sets


WEIGHTS_DIR = Path(__file__).parent.parent.parent / "weights"
DATA_WEIGHTS_DIR = Path.home() / "data" / "set-solver" / "weights"

# Chinese shorthand names: {1,2,3}-{实,空,线}-{红,绿,紫}-{菱,圆,弯}
CHINESE_NUMBER = {"one": "1", "two": "2", "three": "3"}
CHINESE_FILL = {"full": "实", "empty": "空", "partial": "线"}
CHINESE_COLOR = {"red": "红", "green": "绿", "blue": "紫"}
CHINESE_SHAPE = {"diamond": "菱", "oval": "圆", "squiggle": "弯"}


def card_to_chinese(attrs: dict) -> str:
    """Convert card attributes to Chinese shorthand like '2实红菱'."""
    num = CHINESE_NUMBER.get(attrs['number'], attrs['number'])
    fill = CHINESE_FILL.get(attrs['fill'], attrs['fill'])
    color = CHINESE_COLOR.get(attrs['color'], attrs['color'])
    shape = CHINESE_SHAPE.get(attrs['shape'], attrs['shape'])
    return f"{num}{fill}{color}{shape}"

# Colors for highlighting Sets (RGB)
SET_COLORS = [
    (255, 0, 0),      # Red
    (0, 255, 0),      # Green
    (0, 0, 255),      # Blue
    (255, 255, 0),    # Yellow
    (255, 0, 255),    # Magenta
    (0, 255, 255),    # Cyan
    (255, 128, 0),    # Orange
    (128, 0, 255),    # Purple
]


class SetSolver:
    """End-to-end Set solver."""
    
    def __init__(
        self,
        detector_path: Optional[Path] = None,
        classifier_path: Optional[Path] = None,
        device: Optional[str] = None,
    ):
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
        self.device = device
        
        # Load detector
        if detector_path is None:
            # Check ~/data first, then repo weights
            data_path = DATA_WEIGHTS_DIR / "detector" / "weights" / "best.pt"
            repo_path = WEIGHTS_DIR / "detector" / "weights" / "best.pt"
            detector_path = data_path if data_path.exists() else repo_path
        print(f"Loading detector from {detector_path}")
        self.detector = YOLO(str(detector_path))
        
        # Load classifier
        if classifier_path is None:
            classifier_path = WEIGHTS_DIR / "classifier_best.pt"
        print(f"Loading classifier from {classifier_path}")
        self.classifier = SetCardClassifier(pretrained=False)
        checkpoint = torch.load(classifier_path, map_location=device)
        self.classifier.load_state_dict(checkpoint["model_state_dict"])
        self.classifier.to(device)
        self.classifier.eval()
        
        # Classifier transform
        from torchvision import transforms
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    
    def detect_cards(self, image: Image.Image, conf: float = 0.5) -> List[dict]:
        """
        Detect cards in image.

        Returns list of detections with bounding boxes.
        Filters out oversized detections that likely merged two cards.
        """
        results = self.detector(image, conf=conf, verbose=False)

        detections = []
        for result in results:
            boxes = result.boxes
            for box in boxes:
                x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
                c = box.conf[0].cpu().item()
                w, h = x2 - x1, y2 - y1
                detections.append({
                    "bbox": (int(x1), int(y1), int(x2), int(y2)),
                    "confidence": c,
                    "area": w * h,
                })

        # Filter out merged detections: if a box is >2x the median area,
        # it's likely covering two cards
        if len(detections) >= 3:
            areas = sorted(d["area"] for d in detections)
            median_area = areas[len(areas) // 2]
            detections = [d for d in detections if d["area"] <= median_area * 2.2]

        return detections
    
    def classify_card(self, card_image: Image.Image) -> dict:
        """Classify a cropped card image."""
        img_tensor = self.transform(card_image).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            outputs = self.classifier(img_tensor)
        
        result = {}
        for key, names in [
            ("number", NUMBER_NAMES),
            ("color", COLOR_NAMES),
            ("shape", SHAPE_NAMES),
            ("fill", FILL_NAMES),
        ]:
            probs = torch.softmax(outputs[key], dim=1)[0]
            pred_idx = probs.argmax().item()
            result[key] = names[pred_idx]
            result[f"{key}_conf"] = probs[pred_idx].item()
        
        return result
    
    def detection_to_card(self, attrs: dict, bbox: Tuple[int, int, int, int]) -> Card:
        """Convert classification result to Card object."""
        # Map classifier output to solver enums
        # Training data uses "blue" but standard Set calls it "purple"
        color_map = {"red": "RED", "green": "GREEN", "blue": "PURPLE"}
        # Training data uses "partial" for striped, "full" for solid
        fill_map = {"empty": "EMPTY", "full": "SOLID", "partial": "STRIPED"}
        
        return Card(
            shape=Shape[attrs["shape"].upper()],
            color=Color[color_map[attrs["color"]]],
            number=Number[attrs["number"].upper()],
            fill=Fill[fill_map[attrs["fill"]]],
            bbox=bbox,
        )
    
    def solve_from_image(
        self,
        image: Image.Image,
        conf: float = 0.7,
        cls_conf: float = 0.8,
    ) -> dict:
        """
        Solve a Set game from a PIL Image directly.

        Args:
            image: PIL Image (RGB)
            conf: Detection confidence threshold
            cls_conf: Classification confidence threshold (min across all attributes)

        Returns:
            Dict with detected cards, found Sets, and annotated result image
        """
        image = image.convert("RGB")

        detections = self.detect_cards(image, conf=conf)

        cards = []
        for det in detections:
            x1, y1, x2, y2 = det["bbox"]
            card_crop = image.crop((x1, y1, x2, y2))
            attrs = self.classify_card(card_crop)
            card = self.detection_to_card(attrs, det["bbox"])
            min_cls_conf = min(attrs.get("number_conf", 0), attrs.get("color_conf", 0),
                               attrs.get("shape_conf", 0), attrs.get("fill_conf", 0))
            cards.append({
                "card": card,
                "attrs": attrs,
                "detection": det,
                "cls_confident": min_cls_conf >= cls_conf,
            })

        # Only use cards that pass classification threshold for Set finding
        confident_cards = [c["card"] for c in cards if c["cls_confident"]]
        sets = find_all_sets(confident_cards)

        # Generate one annotated image per set (each highlighting only that set)
        result_images = []
        if sets:
            for i in range(len(sets)):
                result_images.append(self._draw_results(image, cards, sets, highlight_idx=i))
        else:
            result_images.append(self._draw_results(image, cards, sets))

        return {
            "num_cards": len(cards),
            "cards": [
                {
                    "attrs": c["attrs"],
                    "chinese": card_to_chinese(c["attrs"]),
                    "bbox": c["detection"]["bbox"],
                    "confidence": c["detection"]["confidence"],
                }
                for c in cards
            ],
            "num_sets": len(sets),
            "sets": [
                [str(card) for card in s]
                for s in sets
            ],
            "sets_chinese": [
                [card_to_chinese(next(c["attrs"] for c in cards if c["card"] is card)) for card in s]
                for s in sets
            ],
            "sets_bboxes": [
                [card.bbox for card in s]
                for s in sets
            ],
            "result_images": result_images,
        }

    def solve(
        self,
        image_path: str,
        conf: float = 0.5,
        output_path: Optional[str] = None,
        show: bool = False,
    ) -> dict:
        """
        Solve a Set game from image.
        
        Args:
            image_path: Path to input image
            conf: Detection confidence threshold
            output_path: Path to save annotated output image
            show: Whether to display the result
        
        Returns:
            Dict with detected cards and found Sets
        """
        # Load image
        image = Image.open(image_path).convert("RGB")
        print(f"Loaded image: {image.size}")
        
        # Detect cards
        print("Detecting cards...")
        detections = self.detect_cards(image, conf=conf)
        print(f"Found {len(detections)} cards")
        
        # Classify each card
        print("Classifying cards...")
        cards = []
        for det in detections:
            x1, y1, x2, y2 = det["bbox"]
            card_crop = image.crop((x1, y1, x2, y2))
            attrs = self.classify_card(card_crop)
            card = self.detection_to_card(attrs, det["bbox"])
            cards.append({
                "card": card,
                "attrs": attrs,
                "detection": det,
            })
        
        # Find Sets
        print("Finding Sets...")
        card_objects = [c["card"] for c in cards]
        sets = find_all_sets(card_objects)
        print(f"Found {len(sets)} valid Set(s)")
        
        # Draw results
        result_image = self._draw_results(image, cards, sets)
        
        if output_path:
            result_image.save(output_path)
            print(f"Saved result to {output_path}")
        
        if show:
            result_image.show()
        
        return {
            "num_cards": len(cards),
            "cards": [
                {
                    "attrs": c["attrs"],
                    "chinese": card_to_chinese(c["attrs"]),
                    "bbox": c["detection"]["bbox"],
                    "confidence": c["detection"]["confidence"],
                }
                for c in cards
            ],
            "num_sets": len(sets),
            "sets": [
                [str(card) for card in s]
                for s in sets
            ],
            "sets_chinese": [
                [card_to_chinese(next(c["attrs"] for c in cards if c["card"] is card)) for card in s]
                for s in sets
            ],
            "result_image": result_image,
        }
    
    def _draw_results(
        self,
        image: Image.Image,
        cards: List[dict],
        sets: List[Tuple[Card, Card, Card]],
        highlight_idx: Optional[int] = None,
    ) -> Image.Image:
        """Draw bounding boxes and Set highlights on image.

        Args:
            highlight_idx: If set, only highlight this one set (0-based).
                           If None, highlight all sets.
        """
        result = image.copy()
        draw = ImageDraw.Draw(result)

        # Try to load a Chinese-compatible font
        font = None
        font_paths = [
            "/System/Library/Fonts/PingFang.ttc",  # macOS
            "/System/Library/Fonts/STHeiti Light.ttc",  # macOS
            "/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf",  # Linux
            "C:\\Windows\\Fonts\\msyh.ttc",  # Windows
        ]
        for font_path in font_paths:
            try:
                font = ImageFont.truetype(font_path, 18)
                break
            except:
                continue
        if font is None:
            font = ImageFont.load_default()

        # Determine which set(s) to highlight
        if highlight_idx is not None and 0 <= highlight_idx < len(sets):
            highlighted_sets = [(highlight_idx, sets[highlight_idx])]
        else:
            highlighted_sets = list(enumerate(sets))

        # Build set of highlighted card ids
        highlighted_card_ids = set()
        for _, card_set in highlighted_sets:
            for card in card_set:
                highlighted_card_ids.add(id(card))

        # Draw all detected cards
        for c in cards:
            card = c["card"]
            attrs = c["attrs"]
            x1, y1, x2, y2 = card.bbox

            if id(card) in highlighted_card_ids:
                color_idx = highlighted_sets[0][0] if len(highlighted_sets) == 1 else 0
                for si, card_set in highlighted_sets:
                    if card in card_set:
                        color_idx = si
                        break
                color = SET_COLORS[color_idx % len(SET_COLORS)]
                width = 4
            else:
                color = (128, 128, 128)
                width = 2
            draw.rectangle([x1, y1, x2, y2], outline=color, width=width)

            det_conf = c["detection"]["confidence"]
            cls_conf = min(attrs.get("number_conf", 1), attrs.get("color_conf", 1),
                          attrs.get("shape_conf", 1), attrs.get("fill_conf", 1))
            label = f"{card_to_chinese(attrs)} d{det_conf:.0%} c{cls_conf:.0%}"
            draw.text((x1, y1 - 20), label, fill=color, font=font)

        # Draw Set info
        if highlight_idx is not None:
            draw.text((10, 10), f"{len(cards)} cards — Set {highlight_idx + 1} / {len(sets)}", fill=(255, 255, 255), font=font)
        else:
            draw.text((10, 10), f"{len(cards)} cards — {len(sets)} Set(s)", fill=(255, 255, 255), font=font)
        
        return result


def main():
    import argparse
    
    parser = argparse.ArgumentParser(description="Solve Set game from image")
    parser.add_argument("image", type=str, help="Path to input image")
    parser.add_argument("--output", "-o", type=str, help="Path to save output image")
    parser.add_argument("--conf", type=float, default=0.25, help="Detection confidence")
    parser.add_argument("--show", action="store_true", help="Display result")
    args = parser.parse_args()
    
    solver = SetSolver()
    result = solver.solve(
        args.image,
        conf=args.conf,
        output_path=args.output,
        show=args.show,
    )
    
    print("\n" + "="*50)
    print("结果 RESULTS")
    print("="*50)
    print(f"检测到卡牌: {result['num_cards']}")
    print(f"找到Set: {result['num_sets']}")
    
    if result['cards']:
        print("\n卡牌:")
        for c in result['cards']:
            print(f"  {c['chinese']}")
    
    if result['sets_chinese']:
        print("\nSets:")
        for i, s in enumerate(result['sets_chinese'], 1):
            print(f"  Set {i}: {' + '.join(s)}")


if __name__ == "__main__":
    main()