Tian Wang commited on
Commit
91a905e
·
1 Parent(s): 1f8aa89

Add 70% detection and 80% classifier thresholds

Browse files
Files changed (2) hide show
  1. src/inference/solve.py +9 -3
  2. src/web/app.py +1 -1
src/inference/solve.py CHANGED
@@ -165,7 +165,8 @@ class SetSolver:
165
  def solve_from_image(
166
  self,
167
  image: Image.Image,
168
- conf: float = 0.5,
 
169
  ) -> dict:
170
  """
171
  Solve a Set game from a PIL Image directly.
@@ -173,6 +174,7 @@ class SetSolver:
173
  Args:
174
  image: PIL Image (RGB)
175
  conf: Detection confidence threshold
 
176
 
177
  Returns:
178
  Dict with detected cards, found Sets, and annotated result image
@@ -187,14 +189,18 @@ class SetSolver:
187
  card_crop = image.crop((x1, y1, x2, y2))
188
  attrs = self.classify_card(card_crop)
189
  card = self.detection_to_card(attrs, det["bbox"])
 
 
190
  cards.append({
191
  "card": card,
192
  "attrs": attrs,
193
  "detection": det,
 
194
  })
195
 
196
- card_objects = [c["card"] for c in cards]
197
- sets = find_all_sets(card_objects)
 
198
 
199
  # Generate one annotated image per set (each highlighting only that set)
200
  result_images = []
 
165
  def solve_from_image(
166
  self,
167
  image: Image.Image,
168
+ conf: float = 0.7,
169
+ cls_conf: float = 0.8,
170
  ) -> dict:
171
  """
172
  Solve a Set game from a PIL Image directly.
 
174
  Args:
175
  image: PIL Image (RGB)
176
  conf: Detection confidence threshold
177
+ cls_conf: Classification confidence threshold (min across all attributes)
178
 
179
  Returns:
180
  Dict with detected cards, found Sets, and annotated result image
 
189
  card_crop = image.crop((x1, y1, x2, y2))
190
  attrs = self.classify_card(card_crop)
191
  card = self.detection_to_card(attrs, det["bbox"])
192
+ min_cls_conf = min(attrs.get("number_conf", 0), attrs.get("color_conf", 0),
193
+ attrs.get("shape_conf", 0), attrs.get("fill_conf", 0))
194
  cards.append({
195
  "card": card,
196
  "attrs": attrs,
197
  "detection": det,
198
+ "cls_confident": min_cls_conf >= cls_conf,
199
  })
200
 
201
+ # Only use cards that pass classification threshold for Set finding
202
+ confident_cards = [c["card"] for c in cards if c["cls_confident"]]
203
+ sets = find_all_sets(confident_cards)
204
 
205
  # Generate one annotated image per set (each highlighting only that set)
206
  result_images = []
src/web/app.py CHANGED
@@ -45,7 +45,7 @@ async def solve_frame(file: UploadFile = File(...)):
45
  contents = await file.read()
46
  image = Image.open(io.BytesIO(contents)).convert("RGB")
47
 
48
- result = solver.solve_from_image(image, conf=0.5)
49
 
50
  # Encode per-set annotated images as base64 JPEG
51
  result_images_b64 = []
 
45
  contents = await file.read()
46
  image = Image.open(io.BytesIO(contents)).convert("RGB")
47
 
48
+ result = solver.solve_from_image(image, conf=0.7, cls_conf=0.8)
49
 
50
  # Encode per-set annotated images as base64 JPEG
51
  result_images_b64 = []