Hawk3388 commited on
Commit
03d6964
·
1 Parent(s): 3dfc234

modified: app.py

Browse files

modified: main.py
modified: model/gap_detection_model.pt

Files changed (3) hide show
  1. app.py +58 -10
  2. main.py +326 -56
  3. model/gap_detection_model.pt +2 -2
app.py CHANGED
@@ -2,10 +2,12 @@ import os
2
  import tempfile
3
  import uuid
4
  import warnings
 
5
 
6
  import gradio as gr
7
  import requests
8
  from PIL import Image
 
9
 
10
  from main import WorksheetSolver
11
 
@@ -13,24 +15,70 @@ warnings.filterwarnings("ignore")
13
 
14
  ALLOWED_EXTENSIONS = {"png", "jpg", "jpeg", "webp", "bmp"}
15
  GAP_DETECTION_MODEL_PATH = "./model/gap_detection_model.pt"
16
- GAP_MODEL_URL = "https://github.com/Hawk3388/solver/releases/download/v1.1.0/gap_detection_model.pt"
17
 
18
 
19
  def ensure_gap_model() -> str:
20
- os.makedirs("./model", exist_ok=True)
21
- if os.path.exists(GAP_DETECTION_MODEL_PATH):
22
- return GAP_DETECTION_MODEL_PATH
23
 
24
- with requests.get(GAP_MODEL_URL, stream=True, timeout=60) as response:
25
- response.raise_for_status()
26
- with open(GAP_DETECTION_MODEL_PATH, "wb") as model_file:
27
- for chunk in response.iter_content(chunk_size=8192):
28
- if chunk:
29
- model_file.write(chunk)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  return GAP_DETECTION_MODEL_PATH
32
 
33
 
 
 
 
 
 
 
 
 
34
  def _is_allowed_image(filename: str) -> bool:
35
  return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS
36
 
 
2
  import tempfile
3
  import uuid
4
  import warnings
5
+ import re
6
 
7
  import gradio as gr
8
  import requests
9
  from PIL import Image
10
+ from pathlib import Path
11
 
12
  from main import WorksheetSolver
13
 
 
15
 
16
  ALLOWED_EXTENSIONS = {"png", "jpg", "jpeg", "webp", "bmp"}
17
  GAP_DETECTION_MODEL_PATH = "./model/gap_detection_model.pt"
18
+ RELEASES_URL = "https://github.com/Hawk3388/solver/releases"
19
 
20
 
21
  def ensure_gap_model() -> str:
22
+ download = False
 
 
23
 
24
+ os.makedirs("./model", exist_ok=True)
25
+ folder_path = Path("./model")
26
+ model_folder_names = [p.name for p in folder_path.iterdir() if p.is_dir()]
27
+
28
+ if model_folder_names:
29
+ latest_version = sorted(model_folder_names, key=lambda s: list(map(int, s.lstrip("v").split("."))), reverse=True)[0]
30
+ model_path = folder_path / latest_version / "gap_detection_model.pt"
31
+ if not model_path.exists():
32
+ download = True
33
+ else:
34
+ download = True
35
+
36
+ release_response = requests.get(RELEASES_URL)
37
+ if release_response.status_code == 200:
38
+ pattern = re.compile(r"<h2[^>]*>(v\d+\.\d+\.\d+)</h2>")
39
+ versions = pattern.findall(release_response.text)
40
+ if not versions:
41
+ raise Exception("Could not determine the latest model version from GitHub releases.")
42
+ else:
43
+ raise Exception(f"Failed to fetch releases from GitHub: {release_response.status_code}")
44
+
45
+ for version in versions:
46
+ GAP_MODEL_URL = f"https://github.com/Hawk3388/solver/releases/download/{version}/gap_detection_model.pt"
47
+ if not url_exists(GAP_MODEL_URL):
48
+ continue
49
+ if download:
50
+ with requests.get(GAP_MODEL_URL, stream=True, timeout=60) as response:
51
+ with open(GAP_DETECTION_MODEL_PATH, "wb") as model_file:
52
+ for chunk in response.iter_content(chunk_size=8192):
53
+ if chunk:
54
+ model_file.write(chunk)
55
+ GAP_DETECTION_MODEL_PATH = str(folder_path / version / "gap_detection_model.pt")
56
+ break
57
+ else:
58
+ compare_versions = sorted([latest_version, version], key=lambda s: list(map(int, s.lstrip("v").split("."))), reverse=True)
59
+ newer_version = compare_versions[0]
60
+ if newer_version != latest_version:
61
+ with requests.get(GAP_MODEL_URL, stream=True, timeout=60) as response:
62
+ with open(GAP_DETECTION_MODEL_PATH, "wb") as model_file:
63
+ for chunk in response.iter_content(chunk_size=8192):
64
+ if chunk:
65
+ model_file.write(chunk)
66
+ GAP_DETECTION_MODEL_PATH = str(folder_path / version / "gap_detection_model.pt")
67
+ break
68
+ else:
69
+ GAP_DETECTION_MODEL_PATH = str(model_path)
70
 
71
  return GAP_DETECTION_MODEL_PATH
72
 
73
 
74
+ def url_exists(url: str, timeout: float = 5.0) -> bool:
75
+ try:
76
+ r = requests.head(url, allow_redirects=True, timeout=timeout)
77
+ return (200 <= r.status_code < 400)
78
+ except requests.RequestException as e:
79
+ return False
80
+
81
+
82
  def _is_allowed_image(filename: str) -> bool:
83
  return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS
84
 
main.py CHANGED
@@ -143,6 +143,11 @@ class WorksheetSolver():
143
 
144
  self.image = None
145
  self.detected_gaps = []
 
 
 
 
 
146
 
147
  def load_image(self, image_path: str):
148
  """Load image and create a copy for processing"""
@@ -231,11 +236,11 @@ class WorksheetSolver():
231
  current_line = [boxes_sorted[0]]
232
  # y-center and height of the current line
233
  line_y_min = boxes_sorted[0][1]
234
- line_y_max = boxes_sorted[0][3] if len(boxes_sorted[0]) == 4 else boxes_sorted[0][1] + boxes_sorted[0][3]
235
 
236
  for box in boxes_sorted[1:]:
237
  box_y_top = box[1]
238
- box_y_bottom = box[3] if len(box) == 4 else box[1] + box[3]
239
  box_height = box_y_bottom - box_y_top
240
  line_height = line_y_max - line_y_min
241
 
@@ -266,8 +271,172 @@ class WorksheetSolver():
266
 
267
  return result
268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  def detect_gaps(self):
270
  self.detected_gaps = []
 
271
 
272
  results = self.model.predict(source=self.path, conf=0.10)
273
 
@@ -286,51 +455,100 @@ class WorksheetSolver():
286
  else:
287
  for idx in keep_indices:
288
  box = r.boxes[idx]
 
 
289
  x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
290
- self.detected_gaps.append((int(x1), int(y1), int(x2), int(y2)))
291
  img = r.orig_img.copy()
292
 
293
  # Sort in reading order (line by line)
294
  self.detected_gaps = self.sort_reading_order(self.detected_gaps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
 
296
  return self.detected_gaps, img
297
 
298
  def mark_gaps(self, image, gaps):
299
- """Mark detected gaps in the image with numbers"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
- for i, gap in enumerate(gaps):
302
- x1, y1, x2, y2 = gap
303
- # Draw red box
304
  cv2.rectangle(image, (x1, y1), (x2, y2), (0, 0, 255), 2)
305
- # Number at top left of the box
306
- label = str(i + 1)
307
  label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)
308
- # Background for better readability
309
  cv2.rectangle(image, (x1, y1 - label_size[1] - 4), (x1 + label_size[0] + 2, y1), (0, 0, 255), -1)
310
- cv2.putText(image, label, (x1 + 1, y1 - 3), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
311
  return image
312
 
313
  def ask_ai_about_all_gaps(self, marked_image):
314
- """Ask Gemini about the content of ALL gaps at once - just like test3"""
315
  if self.debug:
316
  start_time = self.time.time()
317
- # Save the marked image (with boxes) just as test3 expects
318
  thinking = None
319
  marked_image_path = f"{Path(self.path).stem}_marked.png"
320
  cv2.imwrite(marked_image_path, marked_image)
321
 
322
- prompt = f"""Look at the two images: one with red numbered boxes marking {len(self.detected_gaps)} gaps, one without markings.
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
- For each red box, read its number label and fill in the missing word(s) from the worksheet.
 
 
 
 
325
 
326
  Rules:
327
  - Answer in the worksheet's language.
328
- - Only the missing word(s), nothing else.
329
- - Match each answer to the correct box number.
330
- - If a box doesn't need filling, because it is already filled or is not a gap, answer with "none".
331
  - Do NOT overthink. These are simple language exercises. Answer quickly and directly. Only reason for about 10 sentences.
332
  - Look at the sheets carefully and use them as context for your answers.
333
- - Only answer in this exact JSON format: {{"solutions": [{{"key": box_number, "value": answer}}]}}"""
334
 
335
  if not self.experimental:
336
  if not self.local:
@@ -434,64 +652,72 @@ Rules:
434
  return output
435
 
436
  def solve_all_gaps(self, marked_image):
437
- """Solve all detected gaps with Ollama - structured!"""
438
  if not self.detected_gaps:
439
  print("No gaps found!")
440
  return {}
 
 
 
441
 
442
- print(f"🤖 Analyzing all {len(self.detected_gaps)} gaps with Ollama...")
443
 
444
- # Ask Ollama about all gaps at once
445
- print("📤 Sending image to Ollama...")
446
  solutions_data = self.ask_ai_about_all_gaps(marked_image)
447
 
448
  if solutions_data:
449
- print("📥 Structured Ollama response received!")
450
 
451
  # Convert structured response to our format
452
  solutions = {}
453
 
454
- # solutions_data.solutions is now a list of Pair objects
455
  for pair in solutions_data.solutions:
456
  try:
457
- gap_id = pair.key
458
  answer = pair.value
459
- gap_index = gap_id - 1 # 0-based
460
 
461
- if 0 <= gap_index < len(self.detected_gaps):
462
- solutions[gap_index] = {
463
- 'position': self.detected_gaps[gap_index],
 
464
  'solution': answer
465
  }
466
  except (ValueError, KeyError) as e:
467
- print(f"Error processing gap {gap_id}: {e}")
468
  continue
469
 
470
  return solutions
471
  else:
472
- print("❌ No response received from Ollama.")
473
  return {}
474
 
475
  def fill_gaps_in_image(self, image_path: str, solutions: dict, output_path: str = "worksheet_solved.png"):
476
- """Fill the solutions into the image"""
477
  # Load OpenCV image and convert to PIL (for Unicode/umlauts)
478
  cv_image = self.load_image(image_path)
479
  pil_image = Image.fromarray(cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB))
480
 
481
  draw = ImageDraw.Draw(pil_image)
482
 
483
- for gap_index, solution_data in solutions.items():
484
- # Position is (x1, y1, x2, y2)
485
- x1, y1, x2, y2 = solution_data['position']
486
- w = x2 - x1
487
- h = y2 - y1
488
  solution = solution_data['solution']
489
 
490
  if not solution or solution.lower() == 'none':
491
  continue
492
 
493
- # Find dynamic font size
494
- font_size = 40 # Start large
 
 
 
 
 
 
 
495
  min_font_size = 8
496
  font = None
497
 
@@ -505,27 +731,61 @@ Rules:
505
  font = ImageFont.load_default()
506
  break
507
 
 
508
  bbox = draw.textbbox((0, 0), solution, font=font)
509
  text_width = bbox[2] - bbox[0]
510
  text_height = bbox[3] - bbox[1]
511
 
 
512
  padding = 4
513
- if text_width <= w - padding and text_height <= h - padding:
514
- break
 
 
515
 
516
  font_size -= 1
517
 
518
- # Measure text size with final font
519
- bbox = draw.textbbox((0, 0), solution, font=font)
520
- text_width = bbox[2] - bbox[0]
521
- text_height = bbox[3] - bbox[1]
522
-
523
- # Position text centered in the box
524
- text_x = x1 + (w - text_width) // 2
525
- text_y = y1 + (h - text_height) // 2
526
 
527
- # Draw text in black
528
- draw.text((text_x, text_y), solution, fill=(0, 0, 0), font=font)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
 
530
  # Convert back to OpenCV and save
531
  result_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
@@ -551,13 +811,21 @@ def main():
551
  try:
552
  gaps, img = solver.detect_gaps()
553
 
554
- print(f"✅ {len(gaps)} gaps found!")
555
 
556
  marked_image = solver.mark_gaps(img, gaps)
557
 
558
  print("\n📍 Detected gaps (x, y, width, height):")
559
  for i, gap in enumerate(gaps):
560
- print(f" Gap {i+1}: {gap}")
 
 
 
 
 
 
 
 
561
 
562
  if solver.debug:
563
  # Ask user if AI analysis is desired
@@ -572,8 +840,10 @@ def main():
572
 
573
  if solutions:
574
  print("\n✨ Solutions found:")
575
- for i, sol in solutions.items():
576
- print(f" Gap {i+1}: '{sol['solution']}'")
 
 
577
 
578
  solver.fill_gaps_in_image(path, solutions)
579
 
 
143
 
144
  self.image = None
145
  self.detected_gaps = []
146
+ self.gap_groups = [] # Groups of gap indices
147
+ self.gap_to_group = {} # Maps gap index to group index
148
+ self.ungrouped_gap_indices = []
149
+ self.answer_units = [] # Line groups + single ungrouped boxes
150
+ self.gap_to_answer_unit = {} # Maps any gap index to answer unit index
151
 
152
  def load_image(self, image_path: str):
153
  """Load image and create a copy for processing"""
 
236
  current_line = [boxes_sorted[0]]
237
  # y-center and height of the current line
238
  line_y_min = boxes_sorted[0][1]
239
+ line_y_max = boxes_sorted[0][3]
240
 
241
  for box in boxes_sorted[1:]:
242
  box_y_top = box[1]
243
+ box_y_bottom = box[3]
244
  box_height = box_y_bottom - box_y_top
245
  line_height = line_y_max - line_y_min
246
 
 
271
 
272
  return result
273
 
274
+ def is_line_class(self, class_name):
275
+ """True only for the exact YOLO class name 'line'."""
276
+ return str(class_name).strip().lower() == "line"
277
+
278
+ def _unit_bbox(self, unit, gaps):
279
+ """Return merged bbox (x1, y1, x2, y2) for an answer unit."""
280
+ boxes = [gaps[i][:4] for i in unit if 0 <= i < len(gaps)]
281
+ if not boxes:
282
+ return (0, 0, 0, 0)
283
+ return (
284
+ min(b[0] for b in boxes),
285
+ min(b[1] for b in boxes),
286
+ max(b[2] for b in boxes),
287
+ max(b[3] for b in boxes),
288
+ )
289
+
290
+ def sort_answer_units_reading_order(self, units, gaps):
291
+ """Sort answer units globally by reading order: top->bottom, left->right."""
292
+ if not units:
293
+ return []
294
+
295
+ unit_data = []
296
+ for idx, unit in enumerate(units):
297
+ x1, y1, x2, y2 = self._unit_bbox(unit, gaps)
298
+ unit_data.append({
299
+ "idx": idx,
300
+ "unit": unit,
301
+ "x1": x1,
302
+ "y1": y1,
303
+ "x2": x2,
304
+ "y2": y2,
305
+ "h": max(1, y2 - y1),
306
+ })
307
+
308
+ unit_data.sort(key=lambda u: u["y1"])
309
+
310
+ rows = []
311
+ current_row = [unit_data[0]]
312
+ row_y_min = unit_data[0]["y1"]
313
+ row_y_max = unit_data[0]["y2"]
314
+
315
+ for u in unit_data[1:]:
316
+ overlap = min(row_y_max, u["y2"]) - max(row_y_min, u["y1"])
317
+ row_h = max(1, row_y_max - row_y_min)
318
+ min_h = max(1, min(row_h, u["h"]))
319
+
320
+ if overlap > 0 and (overlap / min_h) > 0.3:
321
+ current_row.append(u)
322
+ row_y_min = min(row_y_min, u["y1"])
323
+ row_y_max = max(row_y_max, u["y2"])
324
+ else:
325
+ rows.append(current_row)
326
+ current_row = [u]
327
+ row_y_min = u["y1"]
328
+ row_y_max = u["y2"]
329
+
330
+ rows.append(current_row)
331
+
332
+ sorted_units = []
333
+ for row in rows:
334
+ row.sort(key=lambda u: u["x1"])
335
+ sorted_units.extend([u["unit"] for u in row])
336
+
337
+ return sorted_units
338
+
339
+ def group_gaps_by_proximity(self, gaps):
340
+ """Group gaps that are directly below each other into groups.
341
+
342
+ Returns:
343
+ List of groups, where each group is a list of gap indices (0-based) sorted by Y position
344
+ Also returns a mapping from gap index to group index
345
+ """
346
+ if not gaps:
347
+ return [], {}
348
+
349
+ # Create index mapping: sorted_idx -> original_idx
350
+ indices = list(range(len(gaps)))
351
+ sorted_indices = sorted(indices, key=lambda i: gaps[i][1]) # Sort by Y (top to bottom)
352
+
353
+ # Calculate average gap height as threshold
354
+ heights = [(gap[3] - gap[1]) for gap in gaps]
355
+ avg_height = sum(heights) / len(heights) if heights else 0
356
+
357
+ # Distance threshold: gaps are "below each other" if distance < avg_height * 1.5
358
+ distance_threshold = avg_height * 1.5
359
+
360
+ groups = []
361
+ gap_to_group = {}
362
+ grouped = set()
363
+
364
+ # Process gaps from top to bottom
365
+ for sort_i, i in enumerate(sorted_indices):
366
+ if i in grouped:
367
+ continue
368
+
369
+ gap_i = gaps[i]
370
+ x1_i, y1_i, x2_i, y2_i = gap_i[:4]
371
+ class_name_i = gap_i[4] if len(gap_i) > 4 else "line"
372
+
373
+ # Only exact 'line' class is groupable. Other classes are ignored here.
374
+ if not self.is_line_class(class_name_i):
375
+ continue
376
+
377
+ # Start new group with current line gap
378
+ current_group = [i]
379
+ grouped.add(i)
380
+
381
+ # Look for gaps below this one
382
+ for sort_j in range(sort_i + 1, len(sorted_indices)):
383
+ j = sorted_indices[sort_j]
384
+
385
+ if j in grouped:
386
+ continue
387
+
388
+ gap_j = gaps[j]
389
+ x1_j, y1_j, x2_j, y2_j = gap_j[:4]
390
+ class_name_j = gap_j[4] if len(gap_j) > 4 else "line"
391
+
392
+ # Only group if both are exact line class detections
393
+ if not self.is_line_class(class_name_j):
394
+ continue
395
+
396
+ # Check vertical distance (gap j should be below gap i)
397
+ vertical_distance = y1_j - y2_i
398
+
399
+ # Check horizontal alignment
400
+ i_left, i_top, i_right, i_bottom = x1_i, y1_i, x2_i, y2_i
401
+ j_left, j_top, j_right, j_bottom = x1_j, y1_j, x2_j, y2_j
402
+
403
+ # Calculate horizontal overlap
404
+ h_overlap_start = max(i_left, j_left)
405
+ h_overlap_end = min(i_right, j_right)
406
+ h_overlap = max(0, h_overlap_end - h_overlap_start)
407
+
408
+ # Box widths
409
+ i_width = i_right - i_left
410
+ j_width = j_right - j_left
411
+ min_width = min(i_width, j_width)
412
+
413
+ # Check if box j is below box i and horizontally aligned
414
+ if 0 < vertical_distance < distance_threshold:
415
+ # At least 30% overlap or 15px minimum
416
+ if h_overlap > min_width * 0.3 or h_overlap > 15:
417
+ current_group.append(j)
418
+ grouped.add(j)
419
+ gap_i = gap_j # Update for next iteration
420
+ x1_i, y1_i, x2_i, y2_i = gap_i[:4]
421
+ else:
422
+ # Not enough overlap, end this group
423
+ break
424
+ else:
425
+ # Distance too large, end this group
426
+ break
427
+
428
+ # Store group (sort indices in return order)
429
+ current_group.sort()
430
+ for idx in current_group:
431
+ gap_to_group[idx] = len(groups)
432
+
433
+ groups.append(current_group)
434
+
435
+ return groups, gap_to_group
436
+
437
  def detect_gaps(self):
438
  self.detected_gaps = []
439
+ img = self.load_image(self.path)
440
 
441
  results = self.model.predict(source=self.path, conf=0.10)
442
 
 
455
  else:
456
  for idx in keep_indices:
457
  box = r.boxes[idx]
458
+ class_id = int(box.cls[0])
459
+ class_name = r.names[class_id]
460
  x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
461
+ self.detected_gaps.append((int(x1), int(y1), int(x2), int(y2), class_name))
462
  img = r.orig_img.copy()
463
 
464
  # Sort in reading order (line by line)
465
  self.detected_gaps = self.sort_reading_order(self.detected_gaps)
466
+
467
+ # Group gaps by proximity (vertically aligned and close together)
468
+ self.gap_groups, self.gap_to_group = self.group_gaps_by_proximity(self.detected_gaps)
469
+ self.ungrouped_gap_indices = [i for i in range(len(self.detected_gaps)) if i not in self.gap_to_group]
470
+
471
+ # Build answer units for the AI:
472
+ # - grouped line boxes stay grouped
473
+ # - each ungrouped box (e.g. class gap) becomes its own single unit
474
+ unsorted_units = list(self.gap_groups) + [[idx] for idx in self.ungrouped_gap_indices]
475
+ self.answer_units = self.sort_answer_units_reading_order(unsorted_units, self.detected_gaps)
476
+ self.gap_to_answer_unit = {}
477
+ for unit_idx, unit in enumerate(self.answer_units):
478
+ for gap_idx in unit:
479
+ self.gap_to_answer_unit[gap_idx] = unit_idx
480
+
481
+ print(f"📊 Line-boxes grouped into {len(self.gap_groups)} groups")
482
+ for i, group in enumerate(self.gap_groups):
483
+ print(f" Group {i+1}: {len(group)} gaps (indices: {group})")
484
+ print(f"📌 Ungrouped boxes (e.g. gap): {len(self.ungrouped_gap_indices)}")
485
+ print(f"🧠 Total AI answer units: {len(self.answer_units)}")
486
 
487
  return self.detected_gaps, img
488
 
489
  def mark_gaps(self, image, gaps):
490
+ """Draw one red box per answer unit (group) instead of per single line."""
491
+
492
+ if not self.answer_units:
493
+ return image
494
+
495
+ for unit_idx, unit in enumerate(self.answer_units):
496
+ unit_boxes = [gaps[i][:4] for i in unit if 0 <= i < len(gaps)]
497
+ if not unit_boxes:
498
+ continue
499
+
500
+ # Surround the whole group with one box.
501
+ x1 = min(b[0] for b in unit_boxes)
502
+ y1 = min(b[1] for b in unit_boxes)
503
+ x2 = max(b[2] for b in unit_boxes)
504
+ y2 = max(b[3] for b in unit_boxes)
505
 
 
 
 
506
  cv2.rectangle(image, (x1, y1), (x2, y2), (0, 0, 255), 2)
507
+
508
+ label = str(unit_idx + 1)
509
  label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)
 
510
  cv2.rectangle(image, (x1, y1 - label_size[1] - 4), (x1 + label_size[0] + 2, y1), (0, 0, 255), -1)
511
+ cv2.putText(image, (label), (x1 + 1, y1 - 3), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
512
  return image
513
 
514
  def ask_ai_about_all_gaps(self, marked_image):
515
+ """Ask Gemini about the content of ALL gap groups at once"""
516
  if self.debug:
517
  start_time = self.time.time()
518
+
519
  thinking = None
520
  marked_image_path = f"{Path(self.path).stem}_marked.png"
521
  cv2.imwrite(marked_image_path, marked_image)
522
 
523
+ # Build description of answer units
524
+ group_descriptions = []
525
+ for i, group in enumerate(self.answer_units):
526
+ group_num = i + 1
527
+ first_idx = group[0]
528
+ class_name = str(self.detected_gaps[first_idx][4]) if len(self.detected_gaps[first_idx]) > 4 else "gap"
529
+ if len(group) > 1:
530
+ group_descriptions.append(f"Group {group_num}: {len(group)} stacked line boxes (marked as {group_num})")
531
+ else:
532
+ group_descriptions.append(f"Group {group_num}: 1 single {class_name} box (marked as {group_num})")
533
+
534
+ group_text = "\n".join(group_descriptions)
535
+
536
+ prompt = f"""Look at the two images: one with red numbered boxes marking {len(self.answer_units)} answer groups, one without markings.
537
 
538
+ Answer groups to fill:
539
+ {group_text}
540
+
541
+ For each group marked with its number label, provide ONE answer that should fill that group.
542
+ The answer will be distributed across the stacked lines (first line(s) filled first, then overflow to next line).
543
 
544
  Rules:
545
  - Answer in the worksheet's language.
546
+ - Provide text that makes sense when distributed line by line.
547
+ - Match each answer to the correct group number.
548
+ - If a group doesn't need filling, answer with "none".
549
  - Do NOT overthink. These are simple language exercises. Answer quickly and directly. Only reason for about 10 sentences.
550
  - Look at the sheets carefully and use them as context for your answers.
551
+ - Only answer in this exact JSON format: {{"solutions": [{{"key": group_number, "value": answer}}]}}"""
552
 
553
  if not self.experimental:
554
  if not self.local:
 
652
  return output
653
 
654
  def solve_all_gaps(self, marked_image):
655
+ """Solve all gap groups with Ollama - structured!"""
656
  if not self.detected_gaps:
657
  print("No gaps found!")
658
  return {}
659
+ if not self.answer_units:
660
+ print("No answer units found to solve.")
661
+ return {}
662
 
663
+ print(f"🤖 Analyzing all {len(self.answer_units)} answer units with AI...")
664
 
665
+ # Ask AI about all gap groups at once
666
+ print("📤 Sending image to AI...")
667
  solutions_data = self.ask_ai_about_all_gaps(marked_image)
668
 
669
  if solutions_data:
670
+ print("📥 Structured AI response received!")
671
 
672
  # Convert structured response to our format
673
  solutions = {}
674
 
675
+ # solutions_data.solutions is now a list of GroupPair objects
676
  for pair in solutions_data.solutions:
677
  try:
678
+ group_id = pair.key
679
  answer = pair.value
680
+ group_index = group_id - 1 # 0-based
681
 
682
+ if 0 <= group_index < len(self.answer_units):
683
+ gap_indices = self.answer_units[group_index]
684
+ solutions[group_index] = {
685
+ 'gap_indices': gap_indices,
686
  'solution': answer
687
  }
688
  except (ValueError, KeyError) as e:
689
+ print(f"Error processing group {group_id}: {e}")
690
  continue
691
 
692
  return solutions
693
  else:
694
+ print("❌ No response received from AI.")
695
  return {}
696
 
697
  def fill_gaps_in_image(self, image_path: str, solutions: dict, output_path: str = "worksheet_solved.png"):
698
+ """Fill the solutions into grouped gaps with text flowing across multiple boxes"""
699
  # Load OpenCV image and convert to PIL (for Unicode/umlauts)
700
  cv_image = self.load_image(image_path)
701
  pil_image = Image.fromarray(cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB))
702
 
703
  draw = ImageDraw.Draw(pil_image)
704
 
705
+ for group_index, solution_data in solutions.items():
706
+ gap_indices = solution_data['gap_indices']
 
 
 
707
  solution = solution_data['solution']
708
 
709
  if not solution or solution.lower() == 'none':
710
  continue
711
 
712
+ # Get all boxes for this group
713
+ boxes = [self.detected_gaps[idx] for idx in gap_indices]
714
+
715
+ # Calculate total available space
716
+ total_width = sum(box[2] - box[0] for box in boxes)
717
+ avg_height = boxes[0][3] - boxes[0][1]
718
+
719
+ # Find optimal font size for this solution
720
+ font_size = 40
721
  min_font_size = 8
722
  font = None
723
 
 
731
  font = ImageFont.load_default()
732
  break
733
 
734
+ # Test if text fits
735
  bbox = draw.textbbox((0, 0), solution, font=font)
736
  text_width = bbox[2] - bbox[0]
737
  text_height = bbox[3] - bbox[1]
738
 
739
+ # Check if it fits in available space (with padding)
740
  padding = 4
741
+ if text_height <= avg_height - padding:
742
+ # For width, use total available width or at least one box width
743
+ if text_width <= total_width - padding or text_width <= (boxes[0][2] - boxes[0][0]) - padding:
744
+ break
745
 
746
  font_size -= 1
747
 
748
+ # Distribute text across boxes in the group
749
+ words = solution.split()
750
+ current_box_idx = 0
751
+ x_offset = boxes[current_box_idx][0] # Start position in current box
 
 
 
 
752
 
753
+ for word in words:
754
+ if current_box_idx >= len(boxes):
755
+ break
756
+
757
+ # Get current box dimensions
758
+ x1, y1, x2, y2 = boxes[current_box_idx][:4]
759
+ box_width = x2 - x1
760
+ box_height = y2 - y1
761
+
762
+ # Measure word with space
763
+ word_with_space = word + " "
764
+ bbox = draw.textbbox((0, 0), word_with_space, font=font)
765
+ word_width = bbox[2] - bbox[0]
766
+ text_height = bbox[3] - bbox[1]
767
+
768
+ # Check if word fits in current box
769
+ available_width = (x2 - x_offset) - 4 # Subtract padding
770
+
771
+ if word_width <= available_width:
772
+ # Word fits in current box
773
+ text_y = y1 + (box_height - text_height) // 2
774
+ draw.text((x_offset, text_y), word_with_space, fill=(0, 0, 0), font=font)
775
+ x_offset += word_width
776
+ else:
777
+ # Word doesn't fit - move to next box
778
+ current_box_idx += 1
779
+
780
+ if current_box_idx < len(boxes):
781
+ x1, y1, x2, y2 = boxes[current_box_idx][:4]
782
+ x_offset = x1 + 2 # Small padding
783
+
784
+ # Now place the word in the new box
785
+ if word_width <= (x2 - x_offset) - 4:
786
+ text_y = y1 + (box_height - text_height) // 2
787
+ draw.text((x_offset, text_y), word_with_space, fill=(0, 0, 0), font=font)
788
+ x_offset += word_width
789
 
790
  # Convert back to OpenCV and save
791
  result_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
 
811
  try:
812
  gaps, img = solver.detect_gaps()
813
 
814
+ print(f"✅ {len(gaps)} boxes found, {len(solver.gap_groups)} line groups, {len(solver.ungrouped_gap_indices)} ungrouped!")
815
 
816
  marked_image = solver.mark_gaps(img, gaps)
817
 
818
  print("\n📍 Detected gaps (x, y, width, height):")
819
  for i, gap in enumerate(gaps):
820
+ unit_num = solver.gap_to_answer_unit.get(i)
821
+ if unit_num is not None:
822
+ print(f" Box {i+1} (Group {unit_num + 1}): {gap}")
823
+ else:
824
+ print(f" Box {i+1} (ungrouped): {gap}")
825
+
826
+ print("\n📊 Gap groups:")
827
+ for g_idx, group in enumerate(solver.gap_groups):
828
+ print(f" Group {g_idx+1}: gaps {[idx+1 for idx in group]}")
829
 
830
  if solver.debug:
831
  # Ask user if AI analysis is desired
 
840
 
841
  if solutions:
842
  print("\n✨ Solutions found:")
843
+ for group_idx, sol in solutions.items():
844
+ group_num = group_idx + 1
845
+ gap_indices = [idx+1 for idx in sol['gap_indices']]
846
+ print(f" Group {group_num} (gaps {gap_indices}): '{sol['solution']}'")
847
 
848
  solver.fill_gaps_in_image(path, solutions)
849
 
model/gap_detection_model.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a09d72ab83480428164c040356af5dce6b59fd42d305621901d9d234f0657c09
3
- size 53210085
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2593fee314b21afead4fc047f7c545b7e117ef37fba80bac452880e89ab1fb18
3
+ size 53167589