Hawk3388 commited on
Commit
8b07cf9
·
1 Parent(s): bb554b6

modified: app.py

Browse files
Files changed (2) hide show
  1. app.py +0 -77
  2. main.py +93 -20
app.py CHANGED
@@ -2,89 +2,18 @@ import os
2
  import tempfile
3
  import uuid
4
  import warnings
5
- import re
6
 
7
  import gradio as gr
8
- import requests
9
  from PIL import Image
10
- from pathlib import Path
11
 
12
  from main import WorksheetSolver
13
 
14
  warnings.filterwarnings("ignore")
15
 
16
- def get_gap_model() -> str:
17
- download = False
18
-
19
- os.makedirs("./model", exist_ok=True)
20
- folder_path = Path("./model")
21
- model_folder_names = [p.name for p in folder_path.iterdir() if p.is_dir()]
22
-
23
- if model_folder_names:
24
- latest_version = sorted(model_folder_names, key=lambda s: list(map(int, s.lstrip("v").split("."))), reverse=True)[0]
25
- model_path = folder_path / latest_version / "gap_detection_model.pt"
26
- if not model_path.exists():
27
- download = True
28
- else:
29
- download = True
30
-
31
- release_response = requests.get(RELEASES_URL)
32
- if release_response.status_code == 200:
33
- pattern = re.compile(r"<h2[^>]*>(v\d+\.\d+\.\d+)</h2>")
34
- versions = pattern.findall(release_response.text)
35
- if not versions:
36
- raise Exception("Could not determine the latest model version from GitHub releases.")
37
- else:
38
- raise Exception(f"Failed to fetch releases from GitHub: {release_response.status_code}")
39
-
40
- for version in versions:
41
- GAP_MODEL_URL = f"https://github.com/Hawk3388/solver/releases/download/{version}/gap_detection_model.pt"
42
- if not url_exists(GAP_MODEL_URL):
43
- continue
44
- if download:
45
- gd_model_path = str(folder_path / version / "gap_detection_model.pt")
46
- with requests.get(GAP_MODEL_URL, stream=True, timeout=60) as response:
47
- with open(gd_model_path, "wb") as model_file:
48
- for chunk in response.iter_content(chunk_size=8192):
49
- if chunk:
50
- model_file.write(chunk)
51
- break
52
- else:
53
- compare_versions = sorted([latest_version, version], key=lambda s: list(map(int, s.lstrip("v").split("."))), reverse=True)
54
- newer_version = compare_versions[0]
55
- if newer_version != latest_version:
56
- gd_model_path = str(folder_path / newer_version / "gap_detection_model.pt")
57
- with requests.get(GAP_MODEL_URL, stream=True, timeout=60) as response:
58
- with open(gd_model_path, "wb") as model_file:
59
- for chunk in response.iter_content(chunk_size=8192):
60
- if chunk:
61
- model_file.write(chunk)
62
- break
63
- else:
64
- gd_model_path = str(model_path)
65
-
66
- return gd_model_path
67
-
68
-
69
- def url_exists(url: str, timeout: float = 5.0) -> bool:
70
- try:
71
- r = requests.head(url, allow_redirects=True, timeout=timeout)
72
- return (200 <= r.status_code < 400)
73
- except requests.RequestException as e:
74
- return False
75
-
76
-
77
- def _is_allowed_image(filename: str) -> bool:
78
- return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS
79
-
80
-
81
  def solve_worksheet(image_path: str):
82
  if not image_path:
83
  raise gr.Error("Please upload an image first.")
84
 
85
- if not _is_allowed_image(image_path):
86
- raise gr.Error("Please upload a valid image file (PNG, JPG, JPEG, WEBP, BMP).")
87
-
88
  with tempfile.TemporaryDirectory() as tmp_dir:
89
  unique_id = uuid.uuid4().hex
90
  input_path = os.path.join(tmp_dir, f"{unique_id}.png")
@@ -95,7 +24,6 @@ def solve_worksheet(image_path: str):
95
 
96
  solver = WorksheetSolver(
97
  input_path,
98
- gap_detection_model_path=MODEL_PATH,
99
  llm_model_name="gemini-3-flash-preview",
100
  think=True,
101
  local=False,
@@ -122,7 +50,6 @@ def solve_worksheet(image_path: str):
122
  except Exception as error:
123
  raise gr.Error(f"Processing error: {error}") from error
124
 
125
-
126
  def build_app() -> gr.Blocks:
127
  with gr.Blocks(title="Worksheet Solver", css="""
128
  .app-shell {max-width: 1200px; margin: 0 auto;}
@@ -160,10 +87,6 @@ def build_app() -> gr.Blocks:
160
 
161
  return demo
162
 
163
- ALLOWED_EXTENSIONS = {"png", "jpg", "jpeg", "webp", "bmp"}
164
- RELEASES_URL = "https://github.com/Hawk3388/solver/releases"
165
- MODEL_PATH = get_gap_model()
166
-
167
  demo = build_app()
168
 
169
  if __name__ == "__main__":
 
2
  import tempfile
3
  import uuid
4
  import warnings
 
5
 
6
  import gradio as gr
 
7
  from PIL import Image
 
8
 
9
  from main import WorksheetSolver
10
 
11
  warnings.filterwarnings("ignore")
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def solve_worksheet(image_path: str):
14
  if not image_path:
15
  raise gr.Error("Please upload an image first.")
16
 
 
 
 
17
  with tempfile.TemporaryDirectory() as tmp_dir:
18
  unique_id = uuid.uuid4().hex
19
  input_path = os.path.join(tmp_dir, f"{unique_id}.png")
 
24
 
25
  solver = WorksheetSolver(
26
  input_path,
 
27
  llm_model_name="gemini-3-flash-preview",
28
  think=True,
29
  local=False,
 
50
  except Exception as error:
51
  raise gr.Error(f"Processing error: {error}") from error
52
 
 
53
  def build_app() -> gr.Blocks:
54
  with gr.Blocks(title="Worksheet Solver", css="""
55
  .app-shell {max-width: 1200px; margin: 0 auto;}
 
87
 
88
  return demo
89
 
 
 
 
 
90
  demo = build_app()
91
 
92
  if __name__ == "__main__":
main.py CHANGED
@@ -10,6 +10,8 @@ from PIL import Image, ImageDraw, ImageFont
10
  import numpy as np
11
  from ultralytics import YOLO
12
  from pathlib import Path
 
 
13
 
14
  # Define Pydantic models outside the class
15
  class Pair(BaseModel):
@@ -20,8 +22,11 @@ class get_solution(BaseModel):
20
  solutions: List[Pair]
21
 
22
  class WorksheetSolver():
23
- def __init__(self, path:str, gap_detection_model_path: str = "./model/gap_detection_model.pt", llm_model_name: str = "gemini-2.5-flash", think: bool = True, local: bool = False, thinking_budget: int = 2048, debug: bool = False, experimental: bool = False):
24
- self.model_path = gap_detection_model_path
 
 
 
25
  self.model_name = llm_model_name
26
  self.local = local
27
  self.path = path
@@ -30,6 +35,15 @@ class WorksheetSolver():
30
  self.thinking_budget = thinking_budget
31
  self.think = think
32
  self.experimental = experimental
 
 
 
 
 
 
 
 
 
33
 
34
  if self.debug:
35
  import time
@@ -39,11 +53,16 @@ class WorksheetSolver():
39
  print(f"💡 Please check the path to the image and try again.")
40
  exit()
41
  else:
42
- if not self.path.lower().endswith(".png"):
43
- print(f"✅ Worksheet image found: {self.path}")
44
- img = Image.open(self.path)
45
- img.save(f"{Path(self.path).stem}_temp.png")
46
- self.path = f"{Path(self.path).stem}_temp.png"
 
 
 
 
 
47
  if not Path(self.model_path).exists():
48
  print(f"❌ Trained model not found: {self.model_path}")
49
  print(f"💡 Run train_yolo.py first!")
@@ -57,11 +76,9 @@ class WorksheetSolver():
57
  elif os.getenv("GOOGLE_API_KEY"):
58
  self.client = genai.Client(api_key=os.getenv("GOOGLE_API_KEY"))
59
  else:
60
- print(f"❌ .env file with Google API key not found!")
61
- print(f"💡 Please create a .env file with your Google API key as GOOGLE_API_KEY=your_key and try again.")
62
  except Exception:
63
- print(f"❌ .env file with Google API key not found!")
64
- print(f"💡 Please create a .env file with your Google API key as GOOGLE_API_KEY=your_key and try again.")
65
  if self.experimental and self.local:
66
 
67
  from transformers.generation import LogitsProcessor
@@ -141,14 +158,6 @@ class WorksheetSolver():
141
 
142
  self.model = YOLO(self.model_path)
143
 
144
- self.image = None
145
- self.detected_gaps = []
146
- self.gap_groups = [] # Groups of gap indices
147
- self.gap_to_group = {} # Maps gap index to group index
148
- self.ungrouped_gap_indices = []
149
- self.answer_units = [] # Line groups + single ungrouped boxes
150
- self.gap_to_answer_unit = {} # Maps any gap index to answer unit index
151
-
152
  def load_image(self, image_path: str):
153
  """Load image and create a copy for processing"""
154
  self.image = cv2.imread(image_path)
@@ -156,6 +165,70 @@ class WorksheetSolver():
156
  raise FileNotFoundError(f"Image {image_path} not found!")
157
  return self.image.copy()
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  def calculate_iou(self, box1: list, box2: list):
160
  """
161
  Calculates Intersection over Union (IoU) between two boxes
@@ -815,7 +888,7 @@ def main():
815
  # For Ollama models you have to set local=True
816
 
817
  path = input("📂 Please enter the path to the worksheet image: ").strip()
818
- llm_model_name = "qwen3.5:35b"
819
  think = True
820
  local = True
821
  debug = True
 
10
  import numpy as np
11
  from ultralytics import YOLO
12
  from pathlib import Path
13
+ import re
14
+ import requests
15
 
16
  # Define Pydantic models outside the class
17
  class Pair(BaseModel):
 
22
  solutions: List[Pair]
23
 
24
  class WorksheetSolver():
25
+ def __init__(self, path:str, gap_detection_model_path: str = "", llm_model_name: str = "gemini-2.5-flash", think: bool = True, local: bool = False, thinking_budget: int = 2048, debug: bool = False, experimental: bool = False):
26
+ if gap_detection_model_path:
27
+ self.model_path = gap_detection_model_path
28
+ else:
29
+ self.model_path = self.get_gap_model()
30
  self.model_name = llm_model_name
31
  self.local = local
32
  self.path = path
 
35
  self.thinking_budget = thinking_budget
36
  self.think = think
37
  self.experimental = experimental
38
+
39
+ self.image = None
40
+ self.allowed_extensions = {'png', 'jpg', 'jpeg', 'webp', 'bmp'}
41
+ self.detected_gaps = []
42
+ self.gap_groups = [] # Groups of gap indices
43
+ self.gap_to_group = {} # Maps gap index to group index
44
+ self.ungrouped_gap_indices = []
45
+ self.answer_units = [] # Line groups + single ungrouped boxes
46
+ self.gap_to_answer_unit = {} # Maps any gap index to answer unit index
47
 
48
  if self.debug:
49
  import time
 
53
  print(f"💡 Please check the path to the image and try again.")
54
  exit()
55
  else:
56
+ if self.is_allowed_image(self.path):
57
+ if not self.path.lower().endswith(".png"):
58
+ print(f"✅ Worksheet image found: {self.path}")
59
+ img = Image.open(self.path)
60
+ img.save(f"{Path(self.path).stem}_temp.png")
61
+ self.path = f"{Path(self.path).stem}_temp.png"
62
+ else:
63
+ print(f"❌ Invalid file type: {self.path}")
64
+ print(f"💡 Please upload an image file with one of the following extensions: {', '.join(self.allowed_extensions)}")
65
+ exit()
66
  if not Path(self.model_path).exists():
67
  print(f"❌ Trained model not found: {self.model_path}")
68
  print(f"💡 Run train_yolo.py first!")
 
76
  elif os.getenv("GOOGLE_API_KEY"):
77
  self.client = genai.Client(api_key=os.getenv("GOOGLE_API_KEY"))
78
  else:
79
+ raise ValueError("❌ .env file with Google API key not found!\n💡 Please create a .env file with your Google API key as GOOGLE_API_KEY=your_key and try again.")
 
80
  except Exception:
81
+ raise ValueError("❌ .env file with Google API key not found!\n💡 Please create a .env file with your Google API key as GOOGLE_API_KEY=your_key and try again.")
 
82
  if self.experimental and self.local:
83
 
84
  from transformers.generation import LogitsProcessor
 
158
 
159
  self.model = YOLO(self.model_path)
160
 
 
 
 
 
 
 
 
 
161
  def load_image(self, image_path: str):
162
  """Load image and create a copy for processing"""
163
  self.image = cv2.imread(image_path)
 
165
  raise FileNotFoundError(f"Image {image_path} not found!")
166
  return self.image.copy()
167
 
168
+ def get_gap_model(self) -> str:
169
+ releases_url = "https://github.com/Hawk3388/solver/releases"
170
+ download = False
171
+
172
+ os.makedirs("./model", exist_ok=True)
173
+ folder_path = Path("./model")
174
+ model_folder_names = [p.name for p in folder_path.iterdir() if p.is_dir()]
175
+
176
+ if model_folder_names:
177
+ latest_version = sorted(model_folder_names, key=lambda s: list(map(int, s.lstrip("v").split("."))), reverse=True)[0]
178
+ model_path = folder_path / latest_version / "gap_detection_model.pt"
179
+ if not model_path.exists():
180
+ download = True
181
+ else:
182
+ download = True
183
+
184
+ release_response = requests.get(releases_url)
185
+ if release_response.status_code == 200:
186
+ pattern = re.compile(r"<h2[^>]*>(v\d+\.\d+\.\d+)</h2>")
187
+ versions = pattern.findall(release_response.text)
188
+ if not versions:
189
+ raise Exception("Could not determine the latest model version from GitHub releases.")
190
+ else:
191
+ raise Exception(f"Failed to fetch releases from GitHub: {release_response.status_code}")
192
+
193
+ for version in versions:
194
+ GAP_MODEL_URL = f"https://github.com/Hawk3388/solver/releases/download/{version}/gap_detection_model.pt"
195
+ if not self.url_exists(GAP_MODEL_URL):
196
+ continue
197
+ if download:
198
+ gd_model_path = str(folder_path / version / "gap_detection_model.pt")
199
+ with requests.get(GAP_MODEL_URL, stream=True, timeout=60) as response:
200
+ with open(gd_model_path, "wb") as model_file:
201
+ for chunk in response.iter_content(chunk_size=8192):
202
+ if chunk:
203
+ model_file.write(chunk)
204
+ break
205
+ else:
206
+ compare_versions = sorted([latest_version, version], key=lambda s: list(map(int, s.lstrip("v").split("."))), reverse=True)
207
+ newer_version = compare_versions[0]
208
+ if newer_version != latest_version:
209
+ gd_model_path = str(folder_path / newer_version / "gap_detection_model.pt")
210
+ with requests.get(GAP_MODEL_URL, stream=True, timeout=60) as response:
211
+ with open(gd_model_path, "wb") as model_file:
212
+ for chunk in response.iter_content(chunk_size=8192):
213
+ if chunk:
214
+ model_file.write(chunk)
215
+ break
216
+ else:
217
+ gd_model_path = str(model_path)
218
+
219
+ return gd_model_path
220
+
221
+
222
+ def url_exists(self, url: str, timeout: float = 5.0) -> bool:
223
+ try:
224
+ r = requests.head(url, allow_redirects=True, timeout=timeout)
225
+ return (200 <= r.status_code < 400)
226
+ except requests.RequestException as e:
227
+ return False
228
+
229
+ def is_allowed_image(self, filename: str) -> bool:
230
+ return "." in filename and filename.rsplit(".", 1)[1].lower() in self.allowed_extensions
231
+
232
  def calculate_iou(self, box1: list, box2: list):
233
  """
234
  Calculates Intersection over Union (IoU) between two boxes
 
888
  # For Ollama models you have to set local=True
889
 
890
  path = input("📂 Please enter the path to the worksheet image: ").strip()
891
+ llm_model_name = "gemma4:26b"
892
  think = True
893
  local = True
894
  debug = True