Tian Wang
Add 70% detection and 80% classifier thresholds
91a905e
"""
End-to-end Set solver pipeline.
Photo → Detect cards → Classify each → Find Sets → Visualize
"""
import sys
from pathlib import Path
from typing import List, Tuple, Optional
import torch
from PIL import Image, ImageDraw, ImageFont
from ultralytics import YOLO
import numpy as np
# Add parent to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from src.train.classifier import (
SetCardClassifier,
NUMBER_NAMES, COLOR_NAMES, SHAPE_NAMES, FILL_NAMES,
)
from src.solver.set_finder import Card, Shape, Color, Number, Fill, find_all_sets
WEIGHTS_DIR = Path(__file__).parent.parent.parent / "weights"
DATA_WEIGHTS_DIR = Path.home() / "data" / "set-solver" / "weights"
# Chinese shorthand names: {1,2,3}-{实,空,线}-{红,绿,紫}-{菱,圆,弯}
CHINESE_NUMBER = {"one": "1", "two": "2", "three": "3"}
CHINESE_FILL = {"full": "实", "empty": "空", "partial": "线"}
CHINESE_COLOR = {"red": "红", "green": "绿", "blue": "紫"}
CHINESE_SHAPE = {"diamond": "菱", "oval": "圆", "squiggle": "弯"}
def card_to_chinese(attrs: dict) -> str:
"""Convert card attributes to Chinese shorthand like '2实红菱'."""
num = CHINESE_NUMBER.get(attrs['number'], attrs['number'])
fill = CHINESE_FILL.get(attrs['fill'], attrs['fill'])
color = CHINESE_COLOR.get(attrs['color'], attrs['color'])
shape = CHINESE_SHAPE.get(attrs['shape'], attrs['shape'])
return f"{num}{fill}{color}{shape}"
# Colors for highlighting Sets (RGB)
SET_COLORS = [
(255, 0, 0), # Red
(0, 255, 0), # Green
(0, 0, 255), # Blue
(255, 255, 0), # Yellow
(255, 0, 255), # Magenta
(0, 255, 255), # Cyan
(255, 128, 0), # Orange
(128, 0, 255), # Purple
]
class SetSolver:
"""End-to-end Set solver."""
def __init__(
self,
detector_path: Optional[Path] = None,
classifier_path: Optional[Path] = None,
device: Optional[str] = None,
):
if device is None:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
self.device = device
# Load detector
if detector_path is None:
# Check ~/data first, then repo weights
data_path = DATA_WEIGHTS_DIR / "detector" / "weights" / "best.pt"
repo_path = WEIGHTS_DIR / "detector" / "weights" / "best.pt"
detector_path = data_path if data_path.exists() else repo_path
print(f"Loading detector from {detector_path}")
self.detector = YOLO(str(detector_path))
# Load classifier
if classifier_path is None:
classifier_path = WEIGHTS_DIR / "classifier_best.pt"
print(f"Loading classifier from {classifier_path}")
self.classifier = SetCardClassifier(pretrained=False)
checkpoint = torch.load(classifier_path, map_location=device)
self.classifier.load_state_dict(checkpoint["model_state_dict"])
self.classifier.to(device)
self.classifier.eval()
# Classifier transform
from torchvision import transforms
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def detect_cards(self, image: Image.Image, conf: float = 0.5) -> List[dict]:
"""
Detect cards in image.
Returns list of detections with bounding boxes.
Filters out oversized detections that likely merged two cards.
"""
results = self.detector(image, conf=conf, verbose=False)
detections = []
for result in results:
boxes = result.boxes
for box in boxes:
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
c = box.conf[0].cpu().item()
w, h = x2 - x1, y2 - y1
detections.append({
"bbox": (int(x1), int(y1), int(x2), int(y2)),
"confidence": c,
"area": w * h,
})
# Filter out merged detections: if a box is >2x the median area,
# it's likely covering two cards
if len(detections) >= 3:
areas = sorted(d["area"] for d in detections)
median_area = areas[len(areas) // 2]
detections = [d for d in detections if d["area"] <= median_area * 2.2]
return detections
def classify_card(self, card_image: Image.Image) -> dict:
"""Classify a cropped card image."""
img_tensor = self.transform(card_image).unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.classifier(img_tensor)
result = {}
for key, names in [
("number", NUMBER_NAMES),
("color", COLOR_NAMES),
("shape", SHAPE_NAMES),
("fill", FILL_NAMES),
]:
probs = torch.softmax(outputs[key], dim=1)[0]
pred_idx = probs.argmax().item()
result[key] = names[pred_idx]
result[f"{key}_conf"] = probs[pred_idx].item()
return result
def detection_to_card(self, attrs: dict, bbox: Tuple[int, int, int, int]) -> Card:
"""Convert classification result to Card object."""
# Map classifier output to solver enums
# Training data uses "blue" but standard Set calls it "purple"
color_map = {"red": "RED", "green": "GREEN", "blue": "PURPLE"}
# Training data uses "partial" for striped, "full" for solid
fill_map = {"empty": "EMPTY", "full": "SOLID", "partial": "STRIPED"}
return Card(
shape=Shape[attrs["shape"].upper()],
color=Color[color_map[attrs["color"]]],
number=Number[attrs["number"].upper()],
fill=Fill[fill_map[attrs["fill"]]],
bbox=bbox,
)
def solve_from_image(
self,
image: Image.Image,
conf: float = 0.7,
cls_conf: float = 0.8,
) -> dict:
"""
Solve a Set game from a PIL Image directly.
Args:
image: PIL Image (RGB)
conf: Detection confidence threshold
cls_conf: Classification confidence threshold (min across all attributes)
Returns:
Dict with detected cards, found Sets, and annotated result image
"""
image = image.convert("RGB")
detections = self.detect_cards(image, conf=conf)
cards = []
for det in detections:
x1, y1, x2, y2 = det["bbox"]
card_crop = image.crop((x1, y1, x2, y2))
attrs = self.classify_card(card_crop)
card = self.detection_to_card(attrs, det["bbox"])
min_cls_conf = min(attrs.get("number_conf", 0), attrs.get("color_conf", 0),
attrs.get("shape_conf", 0), attrs.get("fill_conf", 0))
cards.append({
"card": card,
"attrs": attrs,
"detection": det,
"cls_confident": min_cls_conf >= cls_conf,
})
# Only use cards that pass classification threshold for Set finding
confident_cards = [c["card"] for c in cards if c["cls_confident"]]
sets = find_all_sets(confident_cards)
# Generate one annotated image per set (each highlighting only that set)
result_images = []
if sets:
for i in range(len(sets)):
result_images.append(self._draw_results(image, cards, sets, highlight_idx=i))
else:
result_images.append(self._draw_results(image, cards, sets))
return {
"num_cards": len(cards),
"cards": [
{
"attrs": c["attrs"],
"chinese": card_to_chinese(c["attrs"]),
"bbox": c["detection"]["bbox"],
"confidence": c["detection"]["confidence"],
}
for c in cards
],
"num_sets": len(sets),
"sets": [
[str(card) for card in s]
for s in sets
],
"sets_chinese": [
[card_to_chinese(next(c["attrs"] for c in cards if c["card"] is card)) for card in s]
for s in sets
],
"sets_bboxes": [
[card.bbox for card in s]
for s in sets
],
"result_images": result_images,
}
def solve(
self,
image_path: str,
conf: float = 0.5,
output_path: Optional[str] = None,
show: bool = False,
) -> dict:
"""
Solve a Set game from image.
Args:
image_path: Path to input image
conf: Detection confidence threshold
output_path: Path to save annotated output image
show: Whether to display the result
Returns:
Dict with detected cards and found Sets
"""
# Load image
image = Image.open(image_path).convert("RGB")
print(f"Loaded image: {image.size}")
# Detect cards
print("Detecting cards...")
detections = self.detect_cards(image, conf=conf)
print(f"Found {len(detections)} cards")
# Classify each card
print("Classifying cards...")
cards = []
for det in detections:
x1, y1, x2, y2 = det["bbox"]
card_crop = image.crop((x1, y1, x2, y2))
attrs = self.classify_card(card_crop)
card = self.detection_to_card(attrs, det["bbox"])
cards.append({
"card": card,
"attrs": attrs,
"detection": det,
})
# Find Sets
print("Finding Sets...")
card_objects = [c["card"] for c in cards]
sets = find_all_sets(card_objects)
print(f"Found {len(sets)} valid Set(s)")
# Draw results
result_image = self._draw_results(image, cards, sets)
if output_path:
result_image.save(output_path)
print(f"Saved result to {output_path}")
if show:
result_image.show()
return {
"num_cards": len(cards),
"cards": [
{
"attrs": c["attrs"],
"chinese": card_to_chinese(c["attrs"]),
"bbox": c["detection"]["bbox"],
"confidence": c["detection"]["confidence"],
}
for c in cards
],
"num_sets": len(sets),
"sets": [
[str(card) for card in s]
for s in sets
],
"sets_chinese": [
[card_to_chinese(next(c["attrs"] for c in cards if c["card"] is card)) for card in s]
for s in sets
],
"result_image": result_image,
}
def _draw_results(
self,
image: Image.Image,
cards: List[dict],
sets: List[Tuple[Card, Card, Card]],
highlight_idx: Optional[int] = None,
) -> Image.Image:
"""Draw bounding boxes and Set highlights on image.
Args:
highlight_idx: If set, only highlight this one set (0-based).
If None, highlight all sets.
"""
result = image.copy()
draw = ImageDraw.Draw(result)
# Try to load a Chinese-compatible font
font = None
font_paths = [
"/System/Library/Fonts/PingFang.ttc", # macOS
"/System/Library/Fonts/STHeiti Light.ttc", # macOS
"/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf", # Linux
"C:\\Windows\\Fonts\\msyh.ttc", # Windows
]
for font_path in font_paths:
try:
font = ImageFont.truetype(font_path, 18)
break
except:
continue
if font is None:
font = ImageFont.load_default()
# Determine which set(s) to highlight
if highlight_idx is not None and 0 <= highlight_idx < len(sets):
highlighted_sets = [(highlight_idx, sets[highlight_idx])]
else:
highlighted_sets = list(enumerate(sets))
# Build set of highlighted card ids
highlighted_card_ids = set()
for _, card_set in highlighted_sets:
for card in card_set:
highlighted_card_ids.add(id(card))
# Draw all detected cards
for c in cards:
card = c["card"]
attrs = c["attrs"]
x1, y1, x2, y2 = card.bbox
if id(card) in highlighted_card_ids:
color_idx = highlighted_sets[0][0] if len(highlighted_sets) == 1 else 0
for si, card_set in highlighted_sets:
if card in card_set:
color_idx = si
break
color = SET_COLORS[color_idx % len(SET_COLORS)]
width = 4
else:
color = (128, 128, 128)
width = 2
draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
det_conf = c["detection"]["confidence"]
cls_conf = min(attrs.get("number_conf", 1), attrs.get("color_conf", 1),
attrs.get("shape_conf", 1), attrs.get("fill_conf", 1))
label = f"{card_to_chinese(attrs)} d{det_conf:.0%} c{cls_conf:.0%}"
draw.text((x1, y1 - 20), label, fill=color, font=font)
# Draw Set info
if highlight_idx is not None:
draw.text((10, 10), f"{len(cards)} cards — Set {highlight_idx + 1} / {len(sets)}", fill=(255, 255, 255), font=font)
else:
draw.text((10, 10), f"{len(cards)} cards — {len(sets)} Set(s)", fill=(255, 255, 255), font=font)
return result
def main():
import argparse
parser = argparse.ArgumentParser(description="Solve Set game from image")
parser.add_argument("image", type=str, help="Path to input image")
parser.add_argument("--output", "-o", type=str, help="Path to save output image")
parser.add_argument("--conf", type=float, default=0.25, help="Detection confidence")
parser.add_argument("--show", action="store_true", help="Display result")
args = parser.parse_args()
solver = SetSolver()
result = solver.solve(
args.image,
conf=args.conf,
output_path=args.output,
show=args.show,
)
print("\n" + "="*50)
print("结果 RESULTS")
print("="*50)
print(f"检测到卡牌: {result['num_cards']}")
print(f"找到Set: {result['num_sets']}")
if result['cards']:
print("\n卡牌:")
for c in result['cards']:
print(f" {c['chinese']}")
if result['sets_chinese']:
print("\nSets:")
for i, s in enumerate(result['sets_chinese'], 1):
print(f" Set {i}: {' + '.join(s)}")
if __name__ == "__main__":
main()