Hawk3388 commited on
Commit
4287017
·
1 Parent(s): c170470

new file: main.py

Browse files
Files changed (1) hide show
  1. main.py +591 -0
main.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import ollama
4
+ from pydantic import BaseModel
5
+ from google import genai
6
+ from google.genai import types
7
+ from dotenv import load_dotenv
8
+ from typing import List
9
+ from PIL import Image, ImageDraw, ImageFont
10
+ import numpy as np
11
+ from ultralytics import YOLO
12
+ from pathlib import Path
13
+
14
+ # Define Pydantic models outside the class
15
+ class Pair(BaseModel):
16
+ key: int
17
+ value: str
18
+
19
+ class get_solution(BaseModel):
20
+ solutions: List[Pair]
21
+
22
+ class WorksheetSolver():
23
+ def __init__(self, path:str, gap_detection_model_path: str = "./model/gap_detection_model.pt", llm_model_name: str = "gemini-2.5-flash", think: bool = True, local: bool = False, thinking_budget: int = 2048, debug: bool = False, experimental: bool = False):
24
+ self.model_path = gap_detection_model_path
25
+ self.model_name = llm_model_name
26
+ self.local = local
27
+ self.path = path
28
+ self.debug = debug
29
+ if think:
30
+ self.thinking_budget = thinking_budget
31
+ self.think = think
32
+ self.experimental = experimental
33
+
34
+ if self.debug:
35
+ import time
36
+ self.time = time
37
+ if not Path(self.path).exists():
38
+ print(f"❌ Worksheet image not found: {self.path}")
39
+ print(f"💡 Please check the path to the image and try again.")
40
+ exit()
41
+ else:
42
+ if not self.path.lower().endswith(".png"):
43
+ print(f"✅ Worksheet image found: {self.path}")
44
+ img = Image.open(self.path)
45
+ img.save(f"{Path(self.path).stem}_temp.png")
46
+ self.path = f"{Path(self.path).stem}_temp.png"
47
+ if not Path(self.model_path).exists():
48
+ print(f"❌ Trained model not found: {self.model_path}")
49
+ print(f"💡 Run train_yolo.py first!")
50
+ print(f"\nIf available, change MODEL_PATH to the correct location")
51
+ exit()
52
+ if not self.local and not self.experimental:
53
+ if os.path.exists(".env"):
54
+ load_dotenv()
55
+ self.client = genai.Client(api_key=os.getenv("GOOGLE_API_KEY"))
56
+ else:
57
+ print(f"❌ .env file with Google API key not found!")
58
+ print(f"💡 Please create a .env file with your Google API key as GOOGLE_API_KEY=your_key and try again.")
59
+ if self.experimental and self.local:
60
+
61
+ from transformers.generation import LogitsProcessor
62
+ from transformers import AutoTokenizer, pipeline, BitsAndBytesConfig
63
+ from lmformatenforcer import JsonSchemaParser
64
+ from lmformatenforcer.integrations.transformers import build_transformers_prefix_allowed_tokens_fn
65
+ import torch
66
+
67
+ class ThinkingTokenBudgetProcessor(LogitsProcessor):
68
+ """
69
+ A processor where after a maximum number of tokens are generated,
70
+ a </think> token is added at the end to stop the thinking generation,
71
+ and then it will continue to generate the response.
72
+ """
73
+ def __init__(self, tokenizer, max_thinking_tokens=None):
74
+ self.tokenizer = tokenizer
75
+ self.max_thinking_tokens = max_thinking_tokens
76
+ self.think_end_token = self.tokenizer.encode("</think>", add_special_tokens=False)[0]
77
+ self.nl_token = self.tokenizer.encode("\n", add_special_tokens=False)[0]
78
+ self.tokens_generated = 0
79
+ self.stopped_thinking = False
80
+ self.neg_inf = float('-inf')
81
+
82
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
83
+ self.tokens_generated += 1
84
+ if self.max_thinking_tokens == 0 and not self.stopped_thinking and self.tokens_generated > 0:
85
+ scores[:] = self.neg_inf
86
+ scores[0][self.nl_token] = 0
87
+ scores[0][self.think_end_token] = 0
88
+ self.stopped_thinking = True
89
+ return scores
90
+
91
+ if self.max_thinking_tokens is not None and not self.stopped_thinking:
92
+ if (self.tokens_generated / self.max_thinking_tokens) > .95:
93
+ scores[0][self.nl_token] = scores[0][self.think_end_token] * (1 + (self.tokens_generated / self.max_thinking_tokens))
94
+ scores[0][self.think_end_token] = (
95
+ scores[0][self.think_end_token] * (1 + (self.tokens_generated / self.max_thinking_tokens))
96
+ )
97
+
98
+ if self.tokens_generated >= (self.max_thinking_tokens - 1):
99
+ if self.tokens_generated == self.max_thinking_tokens-1:
100
+ scores[:] = self.neg_inf
101
+ scores[0][self.nl_token] = 0
102
+ else:
103
+ scores[:] = self.neg_inf
104
+ scores[0][self.think_end_token] = 0
105
+ self.stopped_thinking = True
106
+
107
+ return scores
108
+
109
+ quantization_config = BitsAndBytesConfig(
110
+ load_in_4bit=True,
111
+ bnb_4bit_compute_dtype=torch.float16,
112
+ bnb_4bit_use_double_quant=True,
113
+ bnb_4bit_quant_type="nf4"
114
+ )
115
+
116
+ tokenizer = AutoTokenizer.from_pretrained(self.model)
117
+
118
+ if self.think:
119
+ processor = ThinkingTokenBudgetProcessor(tokenizer, max_thinking_tokens=self.thinking_budget)
120
+ else:
121
+ # print("For the experimental mode thinking will be enabled")
122
+ processor = ThinkingTokenBudgetProcessor(tokenizer, max_thinking_tokens=self.thinking_budget)
123
+
124
+ schema_parser = JsonSchemaParser(get_solution.model_json_schema())
125
+ self.prefix_function = build_transformers_prefix_allowed_tokens_fn(tokenizer, schema_parser)
126
+
127
+ self.pipe = pipeline(
128
+ "image-text-to-text",
129
+ model=self.model,
130
+ max_new_tokens=4096,
131
+ logits_processor=[processor],
132
+ device=0,
133
+ model_kwargs={"quantization_config": quantization_config}
134
+ )
135
+
136
+ self.model = YOLO(self.model_path)
137
+
138
+ self.image = None
139
+ self.detected_gaps = []
140
+
141
+ def load_image(self, image_path: str):
142
+ """Load image and create a copy for processing"""
143
+ self.image = cv2.imread(image_path)
144
+ if self.image is None:
145
+ raise FileNotFoundError(f"Image {image_path} not found!")
146
+ return self.image.copy()
147
+
148
+ def calculate_iou(self, box1: list, box2: list):
149
+ """
150
+ Calculates Intersection over Union (IoU) between two boxes
151
+ box: [x1, y1, x2, y2]
152
+ """
153
+ x1_inter = max(box1[0], box2[0])
154
+ y1_inter = max(box1[1], box2[1])
155
+ x2_inter = min(box1[2], box2[2])
156
+ y2_inter = min(box1[3], box2[3])
157
+
158
+ if x2_inter < x1_inter or y2_inter < y1_inter:
159
+ return 0.0
160
+
161
+ inter_area = (x2_inter - x1_inter) * (y2_inter - y1_inter)
162
+
163
+ box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
164
+ box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
165
+
166
+ union_area = box1_area + box2_area - inter_area
167
+
168
+ return inter_area / union_area if union_area > 0 else 0.0
169
+
170
+
171
+ def filter_overlapping_boxes(self, boxes, iou_threshold=0.5):
172
+ """
173
+ Filters overlapping boxes - keeps only the one with highest confidence
174
+
175
+ Args:
176
+ boxes: YOLO boxes object
177
+ iou_threshold: Minimum IoU for overlap (0.5 = 50%)
178
+
179
+ Returns:
180
+ List of indices of boxes to keep
181
+ """
182
+ if len(boxes) == 0:
183
+ return []
184
+
185
+ # Extract coordinates and confidences
186
+ coords = boxes.xyxy.cpu().numpy() # [x1, y1, x2, y2]
187
+ confidences = boxes.conf.cpu().numpy()
188
+
189
+ # Sort by confidence (highest first)
190
+ sorted_indices = np.argsort(-confidences)
191
+
192
+ keep = []
193
+
194
+ for i in sorted_indices:
195
+ # Check if this box overlaps with already kept boxes
196
+ should_keep = True
197
+
198
+ for kept_idx in keep:
199
+ iou = self.calculate_iou(coords[i], coords[kept_idx])
200
+
201
+ if iou > iou_threshold:
202
+ # Overlap found - discard this box (lower confidence)
203
+ should_keep = False
204
+ break
205
+
206
+ if should_keep:
207
+ keep.append(i)
208
+
209
+ return sorted(keep) # Back in original order
210
+
211
+ def sort_reading_order(self, boxes):
212
+ """Sort boxes in reading order: line by line from top to bottom, left to right within a line.
213
+
214
+ Boxes on the same text line often have slightly different y values.
215
+ This method groups boxes with similar y position (overlap) into lines.
216
+ """
217
+ if not boxes:
218
+ return boxes
219
+
220
+ # Sort roughly by y first
221
+ boxes_sorted = sorted(boxes, key=lambda b: b[1])
222
+
223
+ # Group into lines based on vertical overlap
224
+ lines = []
225
+ current_line = [boxes_sorted[0]]
226
+ # y-center and height of the current line
227
+ line_y_min = boxes_sorted[0][1]
228
+ line_y_max = boxes_sorted[0][3] if len(boxes_sorted[0]) == 4 else boxes_sorted[0][1] + boxes_sorted[0][3]
229
+
230
+ for box in boxes_sorted[1:]:
231
+ box_y_top = box[1]
232
+ box_y_bottom = box[3] if len(box) == 4 else box[1] + box[3]
233
+ box_height = box_y_bottom - box_y_top
234
+ line_height = line_y_max - line_y_min
235
+
236
+ # Check if the box overlaps vertically with the current line
237
+ # Tolerance: at least 50% of the smaller height must overlap
238
+ overlap = min(line_y_max, box_y_bottom) - max(line_y_min, box_y_top)
239
+ min_height = max(min(box_height, line_height), 1)
240
+
241
+ if overlap > 0 and overlap / min_height > 0.3:
242
+ # Same line
243
+ current_line.append(box)
244
+ line_y_min = min(line_y_min, box_y_top)
245
+ line_y_max = max(line_y_max, box_y_bottom)
246
+ else:
247
+ # New line
248
+ lines.append(current_line)
249
+ current_line = [box]
250
+ line_y_min = box_y_top
251
+ line_y_max = box_y_bottom
252
+
253
+ lines.append(current_line)
254
+
255
+ # Sort within each line by x, lines from top to bottom
256
+ result = []
257
+ for line in lines:
258
+ line.sort(key=lambda b: b[0]) # By x coordinate
259
+ result.extend(line)
260
+
261
+ return result
262
+
263
+ def detect_gaps(self):
264
+ self.detected_gaps = []
265
+
266
+ results = self.model.predict(source=self.path, conf=0.10)
267
+
268
+ for r in results:
269
+ if len(r.boxes) > 0:
270
+ keep_indices = self.filter_overlapping_boxes(r.boxes, iou_threshold=0.5)
271
+ print(f"🔍 After overlap filtering: {len(keep_indices)} boxes")
272
+ else:
273
+ keep_indices = []
274
+ if len(keep_indices) == 0:
275
+ print("\n❌ No gaps detected!")
276
+ print("💡 Check:")
277
+ print(" - Is the image a worksheet?")
278
+ print(" - Was the model trained correctly?")
279
+ print(" - Try lower conf (e.g. 0.1)")
280
+ else:
281
+ for idx in keep_indices:
282
+ box = r.boxes[idx]
283
+ x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
284
+ self.detected_gaps.append((int(x1), int(y1), int(x2), int(y2)))
285
+ img = r.orig_img.copy()
286
+
287
+ # Sort in reading order (line by line)
288
+ self.detected_gaps = self.sort_reading_order(self.detected_gaps)
289
+
290
+ return self.detected_gaps, img
291
+
292
+ def mark_gaps(self, image, gaps):
293
+ """Mark detected gaps in the image with numbers"""
294
+
295
+ for i, gap in enumerate(gaps):
296
+ x1, y1, x2, y2 = gap
297
+ # Draw red box
298
+ cv2.rectangle(image, (x1, y1), (x2, y2), (0, 0, 255), 2)
299
+ # Number at top left of the box
300
+ label = str(i + 1)
301
+ label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)
302
+ # Background for better readability
303
+ cv2.rectangle(image, (x1, y1 - label_size[1] - 4), (x1 + label_size[0] + 2, y1), (0, 0, 255), -1)
304
+ cv2.putText(image, label, (x1 + 1, y1 - 3), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
305
+ return image
306
+
307
+ def ask_ai_about_all_gaps(self, marked_image):
308
+ """Ask Gemini about the content of ALL gaps at once - just like test3"""
309
+ if self.debug:
310
+ start_time = self.time.time()
311
+ # Save the marked image (with boxes) just as test3 expects
312
+ thinking = None
313
+ marked_image_path = f"{Path(self.path).stem}_marked.png"
314
+ cv2.imwrite(marked_image_path, marked_image)
315
+
316
+ prompt = f"""Look at the two images: one with red numbered boxes marking {len(self.detected_gaps)} gaps, one without markings.
317
+
318
+ For each red box, read its number label and fill in the missing word(s) from the worksheet.
319
+
320
+ Rules:
321
+ - Answer in the worksheet's language.
322
+ - Only the missing word(s), nothing else.
323
+ - Match each answer to the correct box number.
324
+ - If a box doesn't need filling, because it is already filled or is not a gap, answer with "none".
325
+ - Do NOT overthink. These are simple language exercises. Answer quickly and directly. Only reason for about 10 sentences.
326
+ - Look at the sheets carefully and use them as context for your answers.
327
+ - Only answer in this exact JSON format: {{"solutions": [{{"key": box_number, "value": answer}}]}}"""
328
+
329
+ if not self.experimental:
330
+ if not self.local:
331
+ image = Image.open(marked_image_path)
332
+ original_image = Image.open(self.path)
333
+ response = self.client.models.generate_content(
334
+ model=self.model_name,
335
+ contents=[image, original_image, prompt],
336
+ config=types.GenerateContentConfig(
337
+ response_mime_type="application/json",
338
+ response_schema=get_solution,
339
+ thinking_config=types.ThinkingConfig(thinking_budget=self.thinking_budget if self.think else 0),
340
+ ),
341
+ )
342
+ output = response.parsed
343
+ else:
344
+ if self.model_name == "qwen3-vl:8b-thinking" and self.think:
345
+ print("you are using an experimantal thinking model - we will stream the response and switch to an instruct model if it seems to get stuck in thinking mode")
346
+ response = ollama.chat(
347
+ model=self.model_name,
348
+ messages=[{"role": "user", "content": prompt, "images": [marked_image_path, self.path]}],
349
+ format=get_solution.model_json_schema(),
350
+ options={"num_ctx": 8192},
351
+ stream=True
352
+ )
353
+ full_response = ""
354
+ thinking = ""
355
+ finished = True
356
+ for chunk in response:
357
+ if chunk.message.content:
358
+ full_response += chunk.message.content
359
+ print(chunk.message.content, end="", flush=True)
360
+ elif chunk.message.thinking:
361
+ print(chunk.message.thinking, end="", flush=True)
362
+ thinking += chunk.message.thinking
363
+ if len(thinking) > 12000:
364
+ if "\n\n" in thinking.strip()[-10:]:
365
+ thinking = thinking.split("\n\n")[0]
366
+ del response
367
+ print(len(thinking))
368
+ finished = False
369
+ break
370
+
371
+ if not finished:
372
+ final_response = ollama.chat(
373
+ model=self.model_name.replace("thinking", "instruct"),
374
+ messages=[{"role": "user", "content": prompt, "images": [marked_image_path, self.path]},
375
+ {"role": "assistant", "content": thinking}],
376
+ format=get_solution.model_json_schema(),
377
+ options={"num_ctx": 8192}
378
+ )
379
+
380
+ output = get_solution.model_validate_json(final_response.message.content)
381
+ else:
382
+ output = get_solution.model_validate_json(full_response)
383
+ else:
384
+ response = ollama.chat(
385
+ model=self.model_name,
386
+ messages=[{"role": "user", "content": prompt, "images": [marked_image_path, self.path]}],
387
+ format=get_solution.model_json_schema(),
388
+ think=None if not 'thinking' in ollama.show(self.model_name).capabilities else True if self.think else False,
389
+ options={"num_ctx": 8192}
390
+ )
391
+ if response.message.thinking:
392
+ thinking = response.message.thinking
393
+ try:
394
+ output = get_solution.model_validate_json(response.message.content)
395
+ except Exception as e:
396
+ print(f"Error validating JSON response: {e}")
397
+ if self.debug:
398
+ if thinking:
399
+ print(f"Thinking content:\n{thinking}")
400
+ print(f"Full response content:\n{response.message.content}")
401
+ print(f"⏱️ Debug mode ON - timing enabled")
402
+ end_time = self.time.time()
403
+ print(f"⏱️ Time taken: {end_time - start_time:.2f} seconds")
404
+ else:
405
+ if self.local:
406
+ messages = [{"role": "user", "content": [
407
+ {"type": "image", "image_path": marked_image_path},
408
+ {"type": "image", "image_path": self.path},
409
+ {"type": "text", "text": prompt},
410
+ ]}]
411
+ response = self.pipe(messages, enable_thinking=self.think, prefix_allowed_tokens_fn=self.prefix_function)[0]["generated_text"][-1]["content"]
412
+ response = response.split("</think>")
413
+ output = get_solution.model_validate_json(response[-1])
414
+
415
+ if not self.debug:
416
+ if os.path.exists(self.path) and self.path.endswith("_temp.png"):
417
+ os.remove(self.path)
418
+ if os.path.exists(marked_image_path):
419
+ os.remove(marked_image_path)
420
+ else:
421
+ print(f"⏱️ Debug mode ON - timing enabled")
422
+ end_time = self.time.time()
423
+ print(f"⏱️ Time taken: {end_time - start_time:.2f} seconds")
424
+ if thinking:
425
+ print(f"Thinking: {thinking}")
426
+ print(f"AI output:\n{output}")
427
+
428
+ return output
429
+
430
+ def solve_all_gaps(self, marked_image):
431
+ """Solve all detected gaps with Ollama - structured!"""
432
+ if not self.detected_gaps:
433
+ print("No gaps found!")
434
+ return {}
435
+
436
+ print(f"🤖 Analyzing all {len(self.detected_gaps)} gaps with Ollama...")
437
+
438
+ # Ask Ollama about all gaps at once
439
+ print("📤 Sending image to Ollama...")
440
+ solutions_data = self.ask_ai_about_all_gaps(marked_image)
441
+
442
+ if solutions_data:
443
+ print("📥 Structured Ollama response received!")
444
+
445
+ # Convert structured response to our format
446
+ solutions = {}
447
+
448
+ # solutions_data.solutions is now a list of Pair objects
449
+ for pair in solutions_data.solutions:
450
+ try:
451
+ gap_id = pair.key
452
+ answer = pair.value
453
+ gap_index = gap_id - 1 # 0-based
454
+
455
+ if 0 <= gap_index < len(self.detected_gaps):
456
+ solutions[gap_index] = {
457
+ 'position': self.detected_gaps[gap_index],
458
+ 'solution': answer
459
+ }
460
+ except (ValueError, KeyError) as e:
461
+ print(f"Error processing gap {gap_id}: {e}")
462
+ continue
463
+
464
+ return solutions
465
+ else:
466
+ print("❌ No response received from Ollama.")
467
+ return {}
468
+
469
+ def fill_gaps_in_image(self, image_path: str, solutions: dict, output_path: str = "worksheet_solved.png"):
470
+ """Fill the solutions into the image"""
471
+ # Load OpenCV image and convert to PIL (for Unicode/umlauts)
472
+ cv_image = self.load_image(image_path)
473
+ pil_image = Image.fromarray(cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB))
474
+
475
+ draw = ImageDraw.Draw(pil_image)
476
+
477
+ for gap_index, solution_data in solutions.items():
478
+ # Position is (x1, y1, x2, y2)
479
+ x1, y1, x2, y2 = solution_data['position']
480
+ w = x2 - x1
481
+ h = y2 - y1
482
+ solution = solution_data['solution']
483
+
484
+ if not solution or solution.lower() == 'none':
485
+ continue
486
+
487
+ # Find dynamic font size
488
+ font_size = 40 # Start large
489
+ min_font_size = 8
490
+ font = None
491
+
492
+ while font_size >= min_font_size:
493
+ try:
494
+ font = ImageFont.truetype("arial.ttf", font_size)
495
+ except OSError:
496
+ try:
497
+ font = ImageFont.truetype("C:/Windows/Fonts/arial.ttf", font_size)
498
+ except OSError:
499
+ font = ImageFont.load_default()
500
+ break
501
+
502
+ bbox = draw.textbbox((0, 0), solution, font=font)
503
+ text_width = bbox[2] - bbox[0]
504
+ text_height = bbox[3] - bbox[1]
505
+
506
+ padding = 4
507
+ if text_width <= w - padding and text_height <= h - padding:
508
+ break
509
+
510
+ font_size -= 1
511
+
512
+ # Measure text size with final font
513
+ bbox = draw.textbbox((0, 0), solution, font=font)
514
+ text_width = bbox[2] - bbox[0]
515
+ text_height = bbox[3] - bbox[1]
516
+
517
+ # Position text centered in the box
518
+ text_x = x1 + (w - text_width) // 2
519
+ text_y = y1 + (h - text_height) // 2
520
+
521
+ # Draw text in black
522
+ draw.text((text_x, text_y), solution, fill=(0, 0, 0), font=font)
523
+
524
+ # Convert back to OpenCV and save
525
+ result_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
526
+ cv2.imwrite(output_path, result_image)
527
+ print(f"Solved worksheet saved as: {output_path}")
528
+ return result_image
529
+
530
+ # Main program
531
+ def main():
532
+ # Best results with gemini-3-flash-preview (local: qwen3.5:35b for 16 GB VRAM + 32 GB RAM)
533
+ # For Gemini you have to use a Google API-key in a .env file
534
+ # For Ollama models you have to set local=True
535
+
536
+ path = input("📂 Please enter the path to the worksheet image: ").strip()
537
+ llm_model_name = "qwen3.5:35b"
538
+ think = True
539
+ local = True
540
+ debug = True
541
+ solver = WorksheetSolver(path, llm_model_name=llm_model_name, think=think, local=local, debug=debug)
542
+
543
+ ask = False
544
+ print("🔍 Loading image and detecting gaps...")
545
+ try:
546
+ gaps, img = solver.detect_gaps()
547
+
548
+ print(f"✅ {len(gaps)} gaps found!")
549
+
550
+ marked_image = solver.mark_gaps(img, gaps)
551
+
552
+ print("\n📍 Detected gaps (x, y, width, height):")
553
+ for i, gap in enumerate(gaps):
554
+ print(f" Gap {i+1}: {gap}")
555
+
556
+ if solver.debug:
557
+ # Ask user if AI analysis is desired
558
+ user_input = input("\n🤖 Should an AI analyze and fill the gaps? (y/n): ").lower().strip()
559
+ if user_input in ['y', 'yes']:
560
+ ask = True
561
+ else:
562
+ ask = True
563
+
564
+ if ask:
565
+ solutions = solver.solve_all_gaps(marked_image)
566
+
567
+ if solutions:
568
+ print("\n✨ Solutions found:")
569
+ for i, sol in solutions.items():
570
+ print(f" Gap {i+1}: '{sol['solution']}'")
571
+
572
+ solver.fill_gaps_in_image(path, solutions)
573
+
574
+ print("\n📁 Result saved. Press any key to exit...")
575
+ else:
576
+ print("❌ No solutions received.")
577
+ else:
578
+ print("📁 Gap detection only")
579
+
580
+ except FileNotFoundError as e:
581
+ print(f"❌ Error: {e}")
582
+ except Exception as e:
583
+ print(f"❌ Unexpected error: {e}")
584
+
585
+ if __name__ == "__main__":
586
+ main()
587
+
588
+ # TODO:
589
+ # - better image detection with support for more kinds of worksheets
590
+ # - Add support for multiple files (batch processing)
591
+ # - Create an executable (.exe) for easy use without Python setup (Command: pyinstaller solver.spec)