modified: app.py
Browse files
app.py
CHANGED
|
@@ -2,89 +2,18 @@ import os
|
|
| 2 |
import tempfile
|
| 3 |
import uuid
|
| 4 |
import warnings
|
| 5 |
-
import re
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
-
import requests
|
| 9 |
from PIL import Image
|
| 10 |
-
from pathlib import Path
|
| 11 |
|
| 12 |
from main import WorksheetSolver
|
| 13 |
|
| 14 |
warnings.filterwarnings("ignore")
|
| 15 |
|
| 16 |
-
def get_gap_model() -> str:
|
| 17 |
-
download = False
|
| 18 |
-
|
| 19 |
-
os.makedirs("./model", exist_ok=True)
|
| 20 |
-
folder_path = Path("./model")
|
| 21 |
-
model_folder_names = [p.name for p in folder_path.iterdir() if p.is_dir()]
|
| 22 |
-
|
| 23 |
-
if model_folder_names:
|
| 24 |
-
latest_version = sorted(model_folder_names, key=lambda s: list(map(int, s.lstrip("v").split("."))), reverse=True)[0]
|
| 25 |
-
model_path = folder_path / latest_version / "gap_detection_model.pt"
|
| 26 |
-
if not model_path.exists():
|
| 27 |
-
download = True
|
| 28 |
-
else:
|
| 29 |
-
download = True
|
| 30 |
-
|
| 31 |
-
release_response = requests.get(RELEASES_URL)
|
| 32 |
-
if release_response.status_code == 200:
|
| 33 |
-
pattern = re.compile(r"<h2[^>]*>(v\d+\.\d+\.\d+)</h2>")
|
| 34 |
-
versions = pattern.findall(release_response.text)
|
| 35 |
-
if not versions:
|
| 36 |
-
raise Exception("Could not determine the latest model version from GitHub releases.")
|
| 37 |
-
else:
|
| 38 |
-
raise Exception(f"Failed to fetch releases from GitHub: {release_response.status_code}")
|
| 39 |
-
|
| 40 |
-
for version in versions:
|
| 41 |
-
GAP_MODEL_URL = f"https://github.com/Hawk3388/solver/releases/download/{version}/gap_detection_model.pt"
|
| 42 |
-
if not url_exists(GAP_MODEL_URL):
|
| 43 |
-
continue
|
| 44 |
-
if download:
|
| 45 |
-
gd_model_path = str(folder_path / version / "gap_detection_model.pt")
|
| 46 |
-
with requests.get(GAP_MODEL_URL, stream=True, timeout=60) as response:
|
| 47 |
-
with open(gd_model_path, "wb") as model_file:
|
| 48 |
-
for chunk in response.iter_content(chunk_size=8192):
|
| 49 |
-
if chunk:
|
| 50 |
-
model_file.write(chunk)
|
| 51 |
-
break
|
| 52 |
-
else:
|
| 53 |
-
compare_versions = sorted([latest_version, version], key=lambda s: list(map(int, s.lstrip("v").split("."))), reverse=True)
|
| 54 |
-
newer_version = compare_versions[0]
|
| 55 |
-
if newer_version != latest_version:
|
| 56 |
-
gd_model_path = str(folder_path / newer_version / "gap_detection_model.pt")
|
| 57 |
-
with requests.get(GAP_MODEL_URL, stream=True, timeout=60) as response:
|
| 58 |
-
with open(gd_model_path, "wb") as model_file:
|
| 59 |
-
for chunk in response.iter_content(chunk_size=8192):
|
| 60 |
-
if chunk:
|
| 61 |
-
model_file.write(chunk)
|
| 62 |
-
break
|
| 63 |
-
else:
|
| 64 |
-
gd_model_path = str(model_path)
|
| 65 |
-
|
| 66 |
-
return gd_model_path
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def url_exists(url: str, timeout: float = 5.0) -> bool:
|
| 70 |
-
try:
|
| 71 |
-
r = requests.head(url, allow_redirects=True, timeout=timeout)
|
| 72 |
-
return (200 <= r.status_code < 400)
|
| 73 |
-
except requests.RequestException as e:
|
| 74 |
-
return False
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
def _is_allowed_image(filename: str) -> bool:
|
| 78 |
-
return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS
|
| 79 |
-
|
| 80 |
-
|
| 81 |
def solve_worksheet(image_path: str):
|
| 82 |
if not image_path:
|
| 83 |
raise gr.Error("Please upload an image first.")
|
| 84 |
|
| 85 |
-
if not _is_allowed_image(image_path):
|
| 86 |
-
raise gr.Error("Please upload a valid image file (PNG, JPG, JPEG, WEBP, BMP).")
|
| 87 |
-
|
| 88 |
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 89 |
unique_id = uuid.uuid4().hex
|
| 90 |
input_path = os.path.join(tmp_dir, f"{unique_id}.png")
|
|
@@ -95,7 +24,6 @@ def solve_worksheet(image_path: str):
|
|
| 95 |
|
| 96 |
solver = WorksheetSolver(
|
| 97 |
input_path,
|
| 98 |
-
gap_detection_model_path=MODEL_PATH,
|
| 99 |
llm_model_name="gemini-3-flash-preview",
|
| 100 |
think=True,
|
| 101 |
local=False,
|
|
@@ -122,7 +50,6 @@ def solve_worksheet(image_path: str):
|
|
| 122 |
except Exception as error:
|
| 123 |
raise gr.Error(f"Processing error: {error}") from error
|
| 124 |
|
| 125 |
-
|
| 126 |
def build_app() -> gr.Blocks:
|
| 127 |
with gr.Blocks(title="Worksheet Solver", css="""
|
| 128 |
.app-shell {max-width: 1200px; margin: 0 auto;}
|
|
@@ -160,10 +87,6 @@ def build_app() -> gr.Blocks:
|
|
| 160 |
|
| 161 |
return demo
|
| 162 |
|
| 163 |
-
ALLOWED_EXTENSIONS = {"png", "jpg", "jpeg", "webp", "bmp"}
|
| 164 |
-
RELEASES_URL = "https://github.com/Hawk3388/solver/releases"
|
| 165 |
-
MODEL_PATH = get_gap_model()
|
| 166 |
-
|
| 167 |
demo = build_app()
|
| 168 |
|
| 169 |
if __name__ == "__main__":
|
|
|
|
| 2 |
import tempfile
|
| 3 |
import uuid
|
| 4 |
import warnings
|
|
|
|
| 5 |
|
| 6 |
import gradio as gr
|
|
|
|
| 7 |
from PIL import Image
|
|
|
|
| 8 |
|
| 9 |
from main import WorksheetSolver
|
| 10 |
|
| 11 |
warnings.filterwarnings("ignore")
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
def solve_worksheet(image_path: str):
|
| 14 |
if not image_path:
|
| 15 |
raise gr.Error("Please upload an image first.")
|
| 16 |
|
|
|
|
|
|
|
|
|
|
| 17 |
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 18 |
unique_id = uuid.uuid4().hex
|
| 19 |
input_path = os.path.join(tmp_dir, f"{unique_id}.png")
|
|
|
|
| 24 |
|
| 25 |
solver = WorksheetSolver(
|
| 26 |
input_path,
|
|
|
|
| 27 |
llm_model_name="gemini-3-flash-preview",
|
| 28 |
think=True,
|
| 29 |
local=False,
|
|
|
|
| 50 |
except Exception as error:
|
| 51 |
raise gr.Error(f"Processing error: {error}") from error
|
| 52 |
|
|
|
|
| 53 |
def build_app() -> gr.Blocks:
|
| 54 |
with gr.Blocks(title="Worksheet Solver", css="""
|
| 55 |
.app-shell {max-width: 1200px; margin: 0 auto;}
|
|
|
|
| 87 |
|
| 88 |
return demo
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
demo = build_app()
|
| 91 |
|
| 92 |
if __name__ == "__main__":
|
main.py
CHANGED
|
@@ -10,6 +10,8 @@ from PIL import Image, ImageDraw, ImageFont
|
|
| 10 |
import numpy as np
|
| 11 |
from ultralytics import YOLO
|
| 12 |
from pathlib import Path
|
|
|
|
|
|
|
| 13 |
|
| 14 |
# Define Pydantic models outside the class
|
| 15 |
class Pair(BaseModel):
|
|
@@ -20,8 +22,11 @@ class get_solution(BaseModel):
|
|
| 20 |
solutions: List[Pair]
|
| 21 |
|
| 22 |
class WorksheetSolver():
|
| 23 |
-
def __init__(self, path:str, gap_detection_model_path: str = "
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
| 25 |
self.model_name = llm_model_name
|
| 26 |
self.local = local
|
| 27 |
self.path = path
|
|
@@ -30,6 +35,15 @@ class WorksheetSolver():
|
|
| 30 |
self.thinking_budget = thinking_budget
|
| 31 |
self.think = think
|
| 32 |
self.experimental = experimental
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
if self.debug:
|
| 35 |
import time
|
|
@@ -39,11 +53,16 @@ class WorksheetSolver():
|
|
| 39 |
print(f"💡 Please check the path to the image and try again.")
|
| 40 |
exit()
|
| 41 |
else:
|
| 42 |
-
if
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
if not Path(self.model_path).exists():
|
| 48 |
print(f"❌ Trained model not found: {self.model_path}")
|
| 49 |
print(f"💡 Run train_yolo.py first!")
|
|
@@ -57,11 +76,9 @@ class WorksheetSolver():
|
|
| 57 |
elif os.getenv("GOOGLE_API_KEY"):
|
| 58 |
self.client = genai.Client(api_key=os.getenv("GOOGLE_API_KEY"))
|
| 59 |
else:
|
| 60 |
-
|
| 61 |
-
print(f"💡 Please create a .env file with your Google API key as GOOGLE_API_KEY=your_key and try again.")
|
| 62 |
except Exception:
|
| 63 |
-
|
| 64 |
-
print(f"💡 Please create a .env file with your Google API key as GOOGLE_API_KEY=your_key and try again.")
|
| 65 |
if self.experimental and self.local:
|
| 66 |
|
| 67 |
from transformers.generation import LogitsProcessor
|
|
@@ -141,14 +158,6 @@ class WorksheetSolver():
|
|
| 141 |
|
| 142 |
self.model = YOLO(self.model_path)
|
| 143 |
|
| 144 |
-
self.image = None
|
| 145 |
-
self.detected_gaps = []
|
| 146 |
-
self.gap_groups = [] # Groups of gap indices
|
| 147 |
-
self.gap_to_group = {} # Maps gap index to group index
|
| 148 |
-
self.ungrouped_gap_indices = []
|
| 149 |
-
self.answer_units = [] # Line groups + single ungrouped boxes
|
| 150 |
-
self.gap_to_answer_unit = {} # Maps any gap index to answer unit index
|
| 151 |
-
|
| 152 |
def load_image(self, image_path: str):
|
| 153 |
"""Load image and create a copy for processing"""
|
| 154 |
self.image = cv2.imread(image_path)
|
|
@@ -156,6 +165,70 @@ class WorksheetSolver():
|
|
| 156 |
raise FileNotFoundError(f"Image {image_path} not found!")
|
| 157 |
return self.image.copy()
|
| 158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
def calculate_iou(self, box1: list, box2: list):
|
| 160 |
"""
|
| 161 |
Calculates Intersection over Union (IoU) between two boxes
|
|
@@ -815,7 +888,7 @@ def main():
|
|
| 815 |
# For Ollama models you have to set local=True
|
| 816 |
|
| 817 |
path = input("📂 Please enter the path to the worksheet image: ").strip()
|
| 818 |
-
llm_model_name = "
|
| 819 |
think = True
|
| 820 |
local = True
|
| 821 |
debug = True
|
|
|
|
| 10 |
import numpy as np
|
| 11 |
from ultralytics import YOLO
|
| 12 |
from pathlib import Path
|
| 13 |
+
import re
|
| 14 |
+
import requests
|
| 15 |
|
| 16 |
# Define Pydantic models outside the class
|
| 17 |
class Pair(BaseModel):
|
|
|
|
| 22 |
solutions: List[Pair]
|
| 23 |
|
| 24 |
class WorksheetSolver():
|
| 25 |
+
def __init__(self, path:str, gap_detection_model_path: str = "", llm_model_name: str = "gemini-2.5-flash", think: bool = True, local: bool = False, thinking_budget: int = 2048, debug: bool = False, experimental: bool = False):
|
| 26 |
+
if gap_detection_model_path:
|
| 27 |
+
self.model_path = gap_detection_model_path
|
| 28 |
+
else:
|
| 29 |
+
self.model_path = self.get_gap_model()
|
| 30 |
self.model_name = llm_model_name
|
| 31 |
self.local = local
|
| 32 |
self.path = path
|
|
|
|
| 35 |
self.thinking_budget = thinking_budget
|
| 36 |
self.think = think
|
| 37 |
self.experimental = experimental
|
| 38 |
+
|
| 39 |
+
self.image = None
|
| 40 |
+
self.allowed_extensions = {'png', 'jpg', 'jpeg', 'webp', 'bmp'}
|
| 41 |
+
self.detected_gaps = []
|
| 42 |
+
self.gap_groups = [] # Groups of gap indices
|
| 43 |
+
self.gap_to_group = {} # Maps gap index to group index
|
| 44 |
+
self.ungrouped_gap_indices = []
|
| 45 |
+
self.answer_units = [] # Line groups + single ungrouped boxes
|
| 46 |
+
self.gap_to_answer_unit = {} # Maps any gap index to answer unit index
|
| 47 |
|
| 48 |
if self.debug:
|
| 49 |
import time
|
|
|
|
| 53 |
print(f"💡 Please check the path to the image and try again.")
|
| 54 |
exit()
|
| 55 |
else:
|
| 56 |
+
if self.is_allowed_image(self.path):
|
| 57 |
+
if not self.path.lower().endswith(".png"):
|
| 58 |
+
print(f"✅ Worksheet image found: {self.path}")
|
| 59 |
+
img = Image.open(self.path)
|
| 60 |
+
img.save(f"{Path(self.path).stem}_temp.png")
|
| 61 |
+
self.path = f"{Path(self.path).stem}_temp.png"
|
| 62 |
+
else:
|
| 63 |
+
print(f"❌ Invalid file type: {self.path}")
|
| 64 |
+
print(f"💡 Please upload an image file with one of the following extensions: {', '.join(self.allowed_extensions)}")
|
| 65 |
+
exit()
|
| 66 |
if not Path(self.model_path).exists():
|
| 67 |
print(f"❌ Trained model not found: {self.model_path}")
|
| 68 |
print(f"💡 Run train_yolo.py first!")
|
|
|
|
| 76 |
elif os.getenv("GOOGLE_API_KEY"):
|
| 77 |
self.client = genai.Client(api_key=os.getenv("GOOGLE_API_KEY"))
|
| 78 |
else:
|
| 79 |
+
raise ValueError("❌ .env file with Google API key not found!\n💡 Please create a .env file with your Google API key as GOOGLE_API_KEY=your_key and try again.")
|
|
|
|
| 80 |
except Exception:
|
| 81 |
+
raise ValueError("❌ .env file with Google API key not found!\n💡 Please create a .env file with your Google API key as GOOGLE_API_KEY=your_key and try again.")
|
|
|
|
| 82 |
if self.experimental and self.local:
|
| 83 |
|
| 84 |
from transformers.generation import LogitsProcessor
|
|
|
|
| 158 |
|
| 159 |
self.model = YOLO(self.model_path)
|
| 160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
def load_image(self, image_path: str):
|
| 162 |
"""Load image and create a copy for processing"""
|
| 163 |
self.image = cv2.imread(image_path)
|
|
|
|
| 165 |
raise FileNotFoundError(f"Image {image_path} not found!")
|
| 166 |
return self.image.copy()
|
| 167 |
|
| 168 |
+
def get_gap_model(self) -> str:
|
| 169 |
+
releases_url = "https://github.com/Hawk3388/solver/releases"
|
| 170 |
+
download = False
|
| 171 |
+
|
| 172 |
+
os.makedirs("./model", exist_ok=True)
|
| 173 |
+
folder_path = Path("./model")
|
| 174 |
+
model_folder_names = [p.name for p in folder_path.iterdir() if p.is_dir()]
|
| 175 |
+
|
| 176 |
+
if model_folder_names:
|
| 177 |
+
latest_version = sorted(model_folder_names, key=lambda s: list(map(int, s.lstrip("v").split("."))), reverse=True)[0]
|
| 178 |
+
model_path = folder_path / latest_version / "gap_detection_model.pt"
|
| 179 |
+
if not model_path.exists():
|
| 180 |
+
download = True
|
| 181 |
+
else:
|
| 182 |
+
download = True
|
| 183 |
+
|
| 184 |
+
release_response = requests.get(releases_url)
|
| 185 |
+
if release_response.status_code == 200:
|
| 186 |
+
pattern = re.compile(r"<h2[^>]*>(v\d+\.\d+\.\d+)</h2>")
|
| 187 |
+
versions = pattern.findall(release_response.text)
|
| 188 |
+
if not versions:
|
| 189 |
+
raise Exception("Could not determine the latest model version from GitHub releases.")
|
| 190 |
+
else:
|
| 191 |
+
raise Exception(f"Failed to fetch releases from GitHub: {release_response.status_code}")
|
| 192 |
+
|
| 193 |
+
for version in versions:
|
| 194 |
+
GAP_MODEL_URL = f"https://github.com/Hawk3388/solver/releases/download/{version}/gap_detection_model.pt"
|
| 195 |
+
if not self.url_exists(GAP_MODEL_URL):
|
| 196 |
+
continue
|
| 197 |
+
if download:
|
| 198 |
+
gd_model_path = str(folder_path / version / "gap_detection_model.pt")
|
| 199 |
+
with requests.get(GAP_MODEL_URL, stream=True, timeout=60) as response:
|
| 200 |
+
with open(gd_model_path, "wb") as model_file:
|
| 201 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 202 |
+
if chunk:
|
| 203 |
+
model_file.write(chunk)
|
| 204 |
+
break
|
| 205 |
+
else:
|
| 206 |
+
compare_versions = sorted([latest_version, version], key=lambda s: list(map(int, s.lstrip("v").split("."))), reverse=True)
|
| 207 |
+
newer_version = compare_versions[0]
|
| 208 |
+
if newer_version != latest_version:
|
| 209 |
+
gd_model_path = str(folder_path / newer_version / "gap_detection_model.pt")
|
| 210 |
+
with requests.get(GAP_MODEL_URL, stream=True, timeout=60) as response:
|
| 211 |
+
with open(gd_model_path, "wb") as model_file:
|
| 212 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 213 |
+
if chunk:
|
| 214 |
+
model_file.write(chunk)
|
| 215 |
+
break
|
| 216 |
+
else:
|
| 217 |
+
gd_model_path = str(model_path)
|
| 218 |
+
|
| 219 |
+
return gd_model_path
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def url_exists(self, url: str, timeout: float = 5.0) -> bool:
|
| 223 |
+
try:
|
| 224 |
+
r = requests.head(url, allow_redirects=True, timeout=timeout)
|
| 225 |
+
return (200 <= r.status_code < 400)
|
| 226 |
+
except requests.RequestException as e:
|
| 227 |
+
return False
|
| 228 |
+
|
| 229 |
+
def is_allowed_image(self, filename: str) -> bool:
|
| 230 |
+
return "." in filename and filename.rsplit(".", 1)[1].lower() in self.allowed_extensions
|
| 231 |
+
|
| 232 |
def calculate_iou(self, box1: list, box2: list):
|
| 233 |
"""
|
| 234 |
Calculates Intersection over Union (IoU) between two boxes
|
|
|
|
| 888 |
# For Ollama models you have to set local=True
|
| 889 |
|
| 890 |
path = input("📂 Please enter the path to the worksheet image: ").strip()
|
| 891 |
+
llm_model_name = "gemma4:26b"
|
| 892 |
think = True
|
| 893 |
local = True
|
| 894 |
debug = True
|