""" End-to-end Set solver pipeline. Photo → Detect cards → Classify each → Find Sets → Visualize """ import sys from pathlib import Path from typing import List, Tuple, Optional import torch from PIL import Image, ImageDraw, ImageFont from ultralytics import YOLO import numpy as np # Add parent to path for imports sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from src.train.classifier import ( SetCardClassifier, NUMBER_NAMES, COLOR_NAMES, SHAPE_NAMES, FILL_NAMES, ) from src.solver.set_finder import Card, Shape, Color, Number, Fill, find_all_sets WEIGHTS_DIR = Path(__file__).parent.parent.parent / "weights" DATA_WEIGHTS_DIR = Path.home() / "data" / "set-solver" / "weights" # Chinese shorthand names: {1,2,3}-{实,空,线}-{红,绿,紫}-{菱,圆,弯} CHINESE_NUMBER = {"one": "1", "two": "2", "three": "3"} CHINESE_FILL = {"full": "实", "empty": "空", "partial": "线"} CHINESE_COLOR = {"red": "红", "green": "绿", "blue": "紫"} CHINESE_SHAPE = {"diamond": "菱", "oval": "圆", "squiggle": "弯"} def card_to_chinese(attrs: dict) -> str: """Convert card attributes to Chinese shorthand like '2实红菱'.""" num = CHINESE_NUMBER.get(attrs['number'], attrs['number']) fill = CHINESE_FILL.get(attrs['fill'], attrs['fill']) color = CHINESE_COLOR.get(attrs['color'], attrs['color']) shape = CHINESE_SHAPE.get(attrs['shape'], attrs['shape']) return f"{num}{fill}{color}{shape}" # Colors for highlighting Sets (RGB) SET_COLORS = [ (255, 0, 0), # Red (0, 255, 0), # Green (0, 0, 255), # Blue (255, 255, 0), # Yellow (255, 0, 255), # Magenta (0, 255, 255), # Cyan (255, 128, 0), # Orange (128, 0, 255), # Purple ] class SetSolver: """End-to-end Set solver.""" def __init__( self, detector_path: Optional[Path] = None, classifier_path: Optional[Path] = None, device: Optional[str] = None, ): if device is None: device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" self.device = device # Load detector if detector_path is None: # Check ~/data first, then repo weights data_path = DATA_WEIGHTS_DIR / "detector" / "weights" / "best.pt" repo_path = WEIGHTS_DIR / "detector" / "weights" / "best.pt" detector_path = data_path if data_path.exists() else repo_path print(f"Loading detector from {detector_path}") self.detector = YOLO(str(detector_path)) # Load classifier if classifier_path is None: classifier_path = WEIGHTS_DIR / "classifier_best.pt" print(f"Loading classifier from {classifier_path}") self.classifier = SetCardClassifier(pretrained=False) checkpoint = torch.load(classifier_path, map_location=device) self.classifier.load_state_dict(checkpoint["model_state_dict"]) self.classifier.to(device) self.classifier.eval() # Classifier transform from torchvision import transforms self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def detect_cards(self, image: Image.Image, conf: float = 0.5) -> List[dict]: """ Detect cards in image. Returns list of detections with bounding boxes. Filters out oversized detections that likely merged two cards. """ results = self.detector(image, conf=conf, verbose=False) detections = [] for result in results: boxes = result.boxes for box in boxes: x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() c = box.conf[0].cpu().item() w, h = x2 - x1, y2 - y1 detections.append({ "bbox": (int(x1), int(y1), int(x2), int(y2)), "confidence": c, "area": w * h, }) # Filter out merged detections: if a box is >2x the median area, # it's likely covering two cards if len(detections) >= 3: areas = sorted(d["area"] for d in detections) median_area = areas[len(areas) // 2] detections = [d for d in detections if d["area"] <= median_area * 2.2] return detections def classify_card(self, card_image: Image.Image) -> dict: """Classify a cropped card image.""" img_tensor = self.transform(card_image).unsqueeze(0).to(self.device) with torch.no_grad(): outputs = self.classifier(img_tensor) result = {} for key, names in [ ("number", NUMBER_NAMES), ("color", COLOR_NAMES), ("shape", SHAPE_NAMES), ("fill", FILL_NAMES), ]: probs = torch.softmax(outputs[key], dim=1)[0] pred_idx = probs.argmax().item() result[key] = names[pred_idx] result[f"{key}_conf"] = probs[pred_idx].item() return result def detection_to_card(self, attrs: dict, bbox: Tuple[int, int, int, int]) -> Card: """Convert classification result to Card object.""" # Map classifier output to solver enums # Training data uses "blue" but standard Set calls it "purple" color_map = {"red": "RED", "green": "GREEN", "blue": "PURPLE"} # Training data uses "partial" for striped, "full" for solid fill_map = {"empty": "EMPTY", "full": "SOLID", "partial": "STRIPED"} return Card( shape=Shape[attrs["shape"].upper()], color=Color[color_map[attrs["color"]]], number=Number[attrs["number"].upper()], fill=Fill[fill_map[attrs["fill"]]], bbox=bbox, ) def solve_from_image( self, image: Image.Image, conf: float = 0.7, cls_conf: float = 0.8, ) -> dict: """ Solve a Set game from a PIL Image directly. Args: image: PIL Image (RGB) conf: Detection confidence threshold cls_conf: Classification confidence threshold (min across all attributes) Returns: Dict with detected cards, found Sets, and annotated result image """ image = image.convert("RGB") detections = self.detect_cards(image, conf=conf) cards = [] for det in detections: x1, y1, x2, y2 = det["bbox"] card_crop = image.crop((x1, y1, x2, y2)) attrs = self.classify_card(card_crop) card = self.detection_to_card(attrs, det["bbox"]) min_cls_conf = min(attrs.get("number_conf", 0), attrs.get("color_conf", 0), attrs.get("shape_conf", 0), attrs.get("fill_conf", 0)) cards.append({ "card": card, "attrs": attrs, "detection": det, "cls_confident": min_cls_conf >= cls_conf, }) # Only use cards that pass classification threshold for Set finding confident_cards = [c["card"] for c in cards if c["cls_confident"]] sets = find_all_sets(confident_cards) # Generate one annotated image per set (each highlighting only that set) result_images = [] if sets: for i in range(len(sets)): result_images.append(self._draw_results(image, cards, sets, highlight_idx=i)) else: result_images.append(self._draw_results(image, cards, sets)) return { "num_cards": len(cards), "cards": [ { "attrs": c["attrs"], "chinese": card_to_chinese(c["attrs"]), "bbox": c["detection"]["bbox"], "confidence": c["detection"]["confidence"], } for c in cards ], "num_sets": len(sets), "sets": [ [str(card) for card in s] for s in sets ], "sets_chinese": [ [card_to_chinese(next(c["attrs"] for c in cards if c["card"] is card)) for card in s] for s in sets ], "sets_bboxes": [ [card.bbox for card in s] for s in sets ], "result_images": result_images, } def solve( self, image_path: str, conf: float = 0.5, output_path: Optional[str] = None, show: bool = False, ) -> dict: """ Solve a Set game from image. Args: image_path: Path to input image conf: Detection confidence threshold output_path: Path to save annotated output image show: Whether to display the result Returns: Dict with detected cards and found Sets """ # Load image image = Image.open(image_path).convert("RGB") print(f"Loaded image: {image.size}") # Detect cards print("Detecting cards...") detections = self.detect_cards(image, conf=conf) print(f"Found {len(detections)} cards") # Classify each card print("Classifying cards...") cards = [] for det in detections: x1, y1, x2, y2 = det["bbox"] card_crop = image.crop((x1, y1, x2, y2)) attrs = self.classify_card(card_crop) card = self.detection_to_card(attrs, det["bbox"]) cards.append({ "card": card, "attrs": attrs, "detection": det, }) # Find Sets print("Finding Sets...") card_objects = [c["card"] for c in cards] sets = find_all_sets(card_objects) print(f"Found {len(sets)} valid Set(s)") # Draw results result_image = self._draw_results(image, cards, sets) if output_path: result_image.save(output_path) print(f"Saved result to {output_path}") if show: result_image.show() return { "num_cards": len(cards), "cards": [ { "attrs": c["attrs"], "chinese": card_to_chinese(c["attrs"]), "bbox": c["detection"]["bbox"], "confidence": c["detection"]["confidence"], } for c in cards ], "num_sets": len(sets), "sets": [ [str(card) for card in s] for s in sets ], "sets_chinese": [ [card_to_chinese(next(c["attrs"] for c in cards if c["card"] is card)) for card in s] for s in sets ], "result_image": result_image, } def _draw_results( self, image: Image.Image, cards: List[dict], sets: List[Tuple[Card, Card, Card]], highlight_idx: Optional[int] = None, ) -> Image.Image: """Draw bounding boxes and Set highlights on image. Args: highlight_idx: If set, only highlight this one set (0-based). If None, highlight all sets. """ result = image.copy() draw = ImageDraw.Draw(result) # Try to load a Chinese-compatible font font = None font_paths = [ "/System/Library/Fonts/PingFang.ttc", # macOS "/System/Library/Fonts/STHeiti Light.ttc", # macOS "/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf", # Linux "C:\\Windows\\Fonts\\msyh.ttc", # Windows ] for font_path in font_paths: try: font = ImageFont.truetype(font_path, 18) break except: continue if font is None: font = ImageFont.load_default() # Determine which set(s) to highlight if highlight_idx is not None and 0 <= highlight_idx < len(sets): highlighted_sets = [(highlight_idx, sets[highlight_idx])] else: highlighted_sets = list(enumerate(sets)) # Build set of highlighted card ids highlighted_card_ids = set() for _, card_set in highlighted_sets: for card in card_set: highlighted_card_ids.add(id(card)) # Draw all detected cards for c in cards: card = c["card"] attrs = c["attrs"] x1, y1, x2, y2 = card.bbox if id(card) in highlighted_card_ids: color_idx = highlighted_sets[0][0] if len(highlighted_sets) == 1 else 0 for si, card_set in highlighted_sets: if card in card_set: color_idx = si break color = SET_COLORS[color_idx % len(SET_COLORS)] width = 4 else: color = (128, 128, 128) width = 2 draw.rectangle([x1, y1, x2, y2], outline=color, width=width) det_conf = c["detection"]["confidence"] cls_conf = min(attrs.get("number_conf", 1), attrs.get("color_conf", 1), attrs.get("shape_conf", 1), attrs.get("fill_conf", 1)) label = f"{card_to_chinese(attrs)} d{det_conf:.0%} c{cls_conf:.0%}" draw.text((x1, y1 - 20), label, fill=color, font=font) # Draw Set info if highlight_idx is not None: draw.text((10, 10), f"{len(cards)} cards — Set {highlight_idx + 1} / {len(sets)}", fill=(255, 255, 255), font=font) else: draw.text((10, 10), f"{len(cards)} cards — {len(sets)} Set(s)", fill=(255, 255, 255), font=font) return result def main(): import argparse parser = argparse.ArgumentParser(description="Solve Set game from image") parser.add_argument("image", type=str, help="Path to input image") parser.add_argument("--output", "-o", type=str, help="Path to save output image") parser.add_argument("--conf", type=float, default=0.25, help="Detection confidence") parser.add_argument("--show", action="store_true", help="Display result") args = parser.parse_args() solver = SetSolver() result = solver.solve( args.image, conf=args.conf, output_path=args.output, show=args.show, ) print("\n" + "="*50) print("结果 RESULTS") print("="*50) print(f"检测到卡牌: {result['num_cards']}") print(f"找到Set: {result['num_sets']}") if result['cards']: print("\n卡牌:") for c in result['cards']: print(f" {c['chinese']}") if result['sets_chinese']: print("\nSets:") for i, s in enumerate(result['sets_chinese'], 1): print(f" Set {i}: {' + '.join(s)}") if __name__ == "__main__": main()