| import tkinter as tk
|
| from tkinter import font as tkfont, ttk, messagebox, filedialog
|
| from PIL import Image, ImageDraw, ImageTk
|
| import numpy as np
|
| import onnxruntime as ort
|
| import cv2
|
| import random
|
| import webbrowser
|
| import sqlite3
|
| import io
|
| import datetime
|
| import sys
|
| import os
|
| from win32event import CreateMutex
|
| from win32api import GetLastError
|
| from winerror import ERROR_ALREADY_EXISTS
|
|
|
|
|
|
|
|
|
| def ensure_single_instance():
|
| mutex = CreateMutex(None, False, 'Global\\AmharicOCR_SingleInstance_Mutex')
|
| if GetLastError() == ERROR_ALREADY_EXISTS:
|
| messagebox.showwarning("Already Running", "Amharic OCR is already running!")
|
| sys.exit(0)
|
| return mutex
|
|
|
|
|
|
|
|
|
| def resource_path(relative_path):
|
| """Get absolute path to resource, works for dev and PyInstaller"""
|
| try:
|
| base_path = sys._MEIPASS
|
| except Exception:
|
| base_path = os.path.abspath(".")
|
| return os.path.join(base_path, relative_path)
|
|
|
| def get_db_path():
|
| """Get database path in user's home directory"""
|
| home_dir = os.path.expanduser("~")
|
| app_dir = os.path.join(home_dir, "AmharicOCR")
|
| if not os.path.exists(app_dir):
|
| os.makedirs(app_dir)
|
| return os.path.join(app_dir, "amharic_ocr_history.db")
|
|
|
| def get_model_path():
|
| """Get model path - embedded in exe or local file"""
|
|
|
| return resource_path('cnn_output.onnx')
|
|
|
|
|
|
|
|
|
| class HistoryDB:
|
| def __init__(self, db_name=None):
|
| if db_name is None:
|
| db_name = get_db_path()
|
| self.conn = sqlite3.connect(db_name, check_same_thread=False)
|
| self.create_table()
|
| self.update_schema()
|
|
|
| def create_table(self):
|
| cursor = self.conn.cursor()
|
| cursor.execute('''
|
| CREATE TABLE IF NOT EXISTS history (
|
| id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| image_data BLOB,
|
| predicted_char TEXT,
|
| probability REAL,
|
| is_correct INTEGER DEFAULT 1,
|
| timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
| )
|
| ''')
|
| self.conn.commit()
|
|
|
| def update_schema(self):
|
| cursor = self.conn.cursor()
|
| cursor.execute("PRAGMA table_info(history)")
|
| columns = [info[1] for info in cursor.fetchall()]
|
|
|
| if 'is_correct' not in columns:
|
| try:
|
| cursor.execute("ALTER TABLE history ADD COLUMN is_correct INTEGER DEFAULT 1")
|
| self.conn.commit()
|
| except Exception as e:
|
| print(f"Error adding column: {e}")
|
|
|
| try:
|
| cursor.execute("UPDATE history SET is_correct = 1 WHERE is_correct IS NULL")
|
| self.conn.commit()
|
| except Exception as e:
|
| print(f"Error updating records: {e}")
|
|
|
| def add_record(self, image_pil, char, prob, limit):
|
| img_byte_arr = io.BytesIO()
|
| image_pil.save(img_byte_arr, format='PNG')
|
| img_blob = img_byte_arr.getvalue()
|
|
|
| cursor = self.conn.cursor()
|
| cursor.execute("INSERT INTO history (image_data, predicted_char, probability, is_correct) VALUES (?, ?, ?, 1)",
|
| (img_blob, char, prob))
|
| self.conn.commit()
|
|
|
| cursor.execute("DELETE FROM history WHERE id NOT IN (SELECT id FROM history ORDER BY id DESC LIMIT ?)", (limit,))
|
| self.conn.commit()
|
|
|
| return cursor.lastrowid
|
|
|
| def toggle_correctness(self, record_id):
|
| cursor = self.conn.cursor()
|
| cursor.execute("SELECT is_correct FROM history WHERE id = ?", (record_id,))
|
| result = cursor.fetchone()
|
| if result:
|
| current_status = result[0]
|
| new_status = 0 if current_status == 1 else 1
|
| cursor.execute("UPDATE history SET is_correct = ? WHERE id = ?", (new_status, record_id))
|
| self.conn.commit()
|
| return new_status
|
| return None
|
|
|
| def delete_record(self, record_id):
|
| cursor = self.conn.cursor()
|
| cursor.execute("DELETE FROM history WHERE id = ?", (record_id,))
|
| self.conn.commit()
|
|
|
| def get_page(self, page_num, per_page=10):
|
| offset = (page_num - 1) * per_page
|
| cursor = self.conn.cursor()
|
| cursor.execute("SELECT * FROM history ORDER BY id DESC LIMIT ? OFFSET ?", (per_page, offset))
|
| rows = cursor.fetchall()
|
| cursor.execute("SELECT count(*) FROM history")
|
| total_count = cursor.fetchone()[0]
|
| return rows, total_count
|
|
|
| def export_to_sql(self):
|
| cursor = self.conn.cursor()
|
| cursor.execute("SELECT * FROM history")
|
| rows = cursor.fetchall()
|
|
|
| sql_lines = []
|
| sql_lines.append("-- Amharic OCR Dataset Export\n")
|
| sql_lines.append("-- Generated on: " + str(datetime.datetime.now()) + "\n\n")
|
| sql_lines.append("PRAGMA synchronous=OFF;\n")
|
| sql_lines.append("BEGIN TRANSACTION;\n")
|
|
|
| for row in rows:
|
| img_hex = row[1].hex()
|
| char_safe = row[2].replace("'", "''")
|
|
|
| if len(row) > 5:
|
| is_corr = row[4]
|
| ts = row[5]
|
| else:
|
| is_corr = 1
|
| ts = row[4]
|
|
|
| sql = f"INSERT INTO history (id, image_data, predicted_char, probability, is_correct, timestamp) VALUES ({row[0]}, x'{img_hex}', '{char_safe}', {row[3]}, {is_corr}, '{ts}');"
|
| sql_lines.append(sql + "\n")
|
|
|
| sql_lines.append("COMMIT;\n")
|
| return "".join(sql_lines)
|
|
|
|
|
|
|
|
|
| def flatten(l):
|
| try:
|
| return flatten(l[0]) + (flatten(l[1:]) if len(l) > 1 else []) if type(l) is list else [l]
|
| except IndexError:
|
| return []
|
|
|
| def fidel(bet):
|
| deqala1 = ['ቈ','ኈ','ኰ','ዀ','ጐ']
|
| deqala2 = ['ሏ','ሗ','ሟ','ሧ','ሯ','ሷ','ሿ','ቊ','ቧ','ቯ',
|
| 'ቷ','ቿ','ኊ','ኗ','ኟ','ኧ','ኲ','ዂ','ዟ','ዧ',
|
| 'ዷ','ጇ','ጒ','ጧ','ጯ','ጷ','ጿ','ፏ','ፗ']
|
| deqala3 =['ቋ','ኋ','ኳ','ዃ','ጓ']
|
| deqala4 = ['ቌ','ኌ','ኴ','ዄ','ጔ']
|
| deqala5 = ['ቍ','ኍ','ኵ','ዅ','ጕ']
|
|
|
| i=0;fidel = 4608;fideloch = []
|
| if bet==1: fidel = 4608
|
| elif bet==2: fidel = 4609
|
| elif bet==3: fidel = 4610
|
| elif bet==4: fidel = 4611
|
| elif bet==5: fidel = 4612
|
| elif bet==6: fidel = 4613
|
| elif bet==7: fidel = 4614
|
| elif bet==8:
|
| for word in deqala1: fideloch.append(word)
|
| elif bet==9:
|
| for word in deqala2: fideloch.append(word)
|
| elif bet==10:
|
| for word in deqala3: fideloch.append(word)
|
| elif bet==11:
|
| for word in deqala4: fideloch.append(word)
|
| elif bet==12:
|
| for word in deqala5: fideloch.append(word)
|
|
|
| if bet < 8:
|
| while i<34:
|
| fideloch.append(chr(fidel))
|
| if i==8: fidel = fidel+32
|
| elif i==13 or i==17 or i==24: fidel = fidel+16
|
| elif i==18: fidel = fidel+16
|
| elif i==26: fidel = fidel+24
|
| else: fidel = fidel+8
|
| i=i+1
|
| return fideloch
|
|
|
| def fidel_form(min_bet, max_bet):
|
| fidels = []
|
| for index in range(min_bet,max_bet):
|
| fidels.append(fidel(index))
|
| return fidels
|
|
|
| laters = flatten(fidel_form(1, 13))
|
| def get_char_from_id(class_id):
|
| if 0 <= class_id < len(laters):
|
| return laters[class_id]
|
| return "?"
|
|
|
|
|
|
|
|
|
| def ensure_shape(img):
|
| return img.reshape(32, 32, 1).astype('float32')
|
|
|
| def erode_image(img):
|
| kernel = np.ones((2,2), np.uint8)
|
| img_uint8 = (img * 255).astype(np.uint8)
|
| eroded = cv2.erode(img_uint8, kernel, iterations=1)
|
| return ensure_shape(eroded.astype('float32') / 255.0)
|
|
|
| def dilate_image(img):
|
| kernel = np.ones((2,2), np.uint8)
|
| img_uint8 = (img * 255).astype(np.uint8)
|
| dilated = cv2.dilate(img_uint8, kernel, iterations=1)
|
| return ensure_shape(dilated.astype('float32') / 255.0)
|
|
|
| def zoom_image(img):
|
| zoom = random.choice([0.9, 1.1])
|
| h, w = img.shape[:2]
|
| if zoom > 1.0:
|
| new_h, new_w = int(h / zoom), int(w / zoom)
|
| y = (h - new_h) // 2
|
| x = (w - new_w) // 2
|
| cropped = img[y:y+new_h, x:x+new_w]
|
| resized = cv2.resize(cropped, (w, h))
|
| else:
|
| new_h, new_w = int(h * zoom), int(w * zoom)
|
| resized = cv2.resize(img, (new_w, new_h))
|
| pad_top = (h - new_h) // 2
|
| pad_bottom = h - new_h - pad_top
|
| pad_left = (w - new_w) // 2
|
| pad_right = w - new_w - pad_left
|
| resized = cv2.copyMakeBorder(resized, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_REFLECT_101)
|
| return ensure_shape(resized)
|
|
|
| def shift_brightness(img):
|
| val = random.choice([-0.1, 0.1])
|
| return ensure_shape(np.clip(img.astype('float32') + val, 0, 1))
|
|
|
| def blur_image_light(img):
|
| img_uint8 = (img * 255).astype(np.uint8)
|
| blurred = cv2.GaussianBlur(img_uint8, (3, 3), 0)
|
| return ensure_shape(blurred.astype('float32') / 255.0)
|
|
|
| def blur_image_heavy(img):
|
| img_uint8 = (img * 255).astype(np.uint8)
|
| blurred = cv2.GaussianBlur(img_uint8, (7, 7), 0)
|
| return ensure_shape(blurred.astype('float32') / 255.0)
|
|
|
| def blur_image_moderate(img):
|
| img_uint8 = (img * 255).astype(np.uint8)
|
| blurred = cv2.GaussianBlur(img_uint8, (5, 5), 0)
|
| return ensure_shape(blurred.astype('float32') / 255.0)
|
|
|
|
|
|
|
|
|
| class SplashScreen:
|
| def __init__(self, parent_root):
|
| self.splash = tk.Toplevel(parent_root)
|
| self.splash.title("Loading...")
|
| self.splash.geometry("400x150")
|
| self.splash.resizable(False, False)
|
| self.splash.overrideredirect(True)
|
|
|
|
|
| self.splash.update_idletasks()
|
| x = (self.splash.winfo_screenwidth() // 2) - (400 // 2)
|
| y = (self.splash.winfo_screenheight() // 2) - (150 // 2)
|
| self.splash.geometry(f"400x150+{x}+{y}")
|
|
|
|
|
| main_frame = tk.Frame(self.splash, bg="#2196F3")
|
| main_frame.pack(fill=tk.BOTH, expand=True)
|
|
|
|
|
| title_label = tk.Label(main_frame, text="Geez OCR", font=("Arial", 20, "bold"),
|
| bg="#2196F3", fg="white")
|
| title_label.pack(pady=(20, 10))
|
|
|
|
|
| self.message_label = tk.Label(main_frame, text="Geez OCR library loading in background...",
|
| font=("Arial", 11), bg="#2196F3", fg="white")
|
| self.message_label.pack(pady=5)
|
|
|
|
|
| progress_frame = tk.Frame(main_frame, bg="#2196F3")
|
| progress_frame.pack(pady=10, padx=40, fill=tk.X)
|
|
|
| self.progress_bar = ttk.Progressbar(progress_frame, mode='indeterminate', length=320)
|
| self.progress_bar.pack(fill=tk.X)
|
| self.progress_bar.start(10)
|
|
|
| self.splash.update()
|
|
|
| def update_message(self, text):
|
| self.message_label.config(text=text)
|
| self.splash.update()
|
|
|
| def destroy(self):
|
| self.progress_bar.stop()
|
| self.splash.destroy()
|
|
|
|
|
|
|
|
|
| class DrawingApp:
|
| def __init__(self, root):
|
| self.root = root
|
| self.root.minsize(1400, 800)
|
|
|
| self.db = HistoryDB()
|
| self.current_history_page = 1
|
| self.history_limit = 100
|
|
|
|
|
| root.grid_rowconfigure(0, weight=1)
|
| root.grid_columnconfigure(0, weight=4)
|
| root.grid_columnconfigure(1, weight=3)
|
| root.grid_columnconfigure(2, weight=3)
|
|
|
|
|
| self.left_panel = tk.Frame(root, bg="white")
|
| self.left_panel.grid(row=0, column=0, sticky="nsew", padx=(10, 5), pady=10)
|
| self.left_panel.grid_rowconfigure(0, weight=0)
|
| self.left_panel.grid_rowconfigure(1, weight=0)
|
| self.left_panel.grid_rowconfigure(2, weight=0)
|
| self.left_panel.grid_columnconfigure(0, weight=1)
|
|
|
|
|
| self.canvas_size = 320
|
| self.canvas_container = tk.Frame(self.left_panel, width=self.canvas_size+10, height=self.canvas_size+10, bg="white")
|
| self.canvas_container.grid(row=0, column=0, sticky="")
|
| self.canvas_container.grid_propagate(False)
|
|
|
| self.canvas = tk.Canvas(self.canvas_container, width=self.canvas_size, height=self.canvas_size, bg="white", cursor="cross")
|
| self.canvas.pack(padx=5, pady=5)
|
|
|
| self.placeholder_tag = self.canvas.create_text(
|
| self.canvas_size/2, self.canvas_size/2, text="Draw here", fill="#e0e0e0", font=("Arial", 18, "bold"), tag="placeholder"
|
| )
|
|
|
| self.strokes = []
|
| self.old_x = None
|
| self.old_y = None
|
| self.canvas.bind('<Button-1>', self.start_draw)
|
| self.canvas.bind('<B1-Motion>', self.draw)
|
|
|
|
|
| self.controls_frame = tk.Frame(self.left_panel, bg="#f9f9f9", bd=1, relief=tk.RAISED)
|
| self.controls_frame.grid(row=1, column=0, sticky="ew", pady=5)
|
|
|
| self.pen_width_var = tk.IntVar(value=10)
|
| self.tool_mode = tk.StringVar(value="pen")
|
|
|
| c_inner = tk.Frame(self.controls_frame, bg="#f9f9f9")
|
| c_inner.pack(pady=10)
|
|
|
| tk.Radiobutton(c_inner, text="Pen", variable=self.tool_mode, value="pen", bg="#f9f9f9").grid(row=0, column=0, padx=5)
|
| tk.Radiobutton(c_inner, text="Eraser", variable=self.tool_mode, value="eraser", bg="#f9f9f9").grid(row=0, column=1, padx=5)
|
|
|
| tk.Label(c_inner, text="Size", bg="#f9f9f9").grid(row=0, column=2, padx=5)
|
| tk.Scale(c_inner, from_=7, to=13, orient=tk.HORIZONTAL, variable=self.pen_width_var, length=150).grid(row=0, column=3, padx=10)
|
|
|
| btn_frame = tk.Frame(c_inner, bg="#f9f9f9")
|
| btn_frame.grid(row=1, column=0, columnspan=4, pady=10)
|
|
|
| self.btn_predict = tk.Button(btn_frame, text="PREDICT", command=self.run_tta_analysis, bg="#2196F3", fg="white", font=("Arial", 10, "bold"), width=15)
|
| self.btn_predict.pack(side=tk.LEFT, padx=10)
|
| tk.Button(btn_frame, text="CLEAR", command=self.clear_canvas, bg="#f44336", fg="white", font=("Arial", 10, "bold"), width=15).pack(side=tk.LEFT, padx=10)
|
|
|
|
|
| self.results_container = tk.Frame(self.left_panel, bg="white")
|
| self.results_container.grid(row=2, column=0, sticky="new", pady=10)
|
|
|
| self.lbl_final_char = tk.Label(self.results_container, text="?", font=("Amharic Unicode", 60, "bold"), fg="blue", cursor="hand2")
|
| self.lbl_final_char.pack(pady=5)
|
| self.lbl_final_char.bind("<Button-1>", self.copy_to_clipboard)
|
|
|
| self.lbl_final_method = tk.Label(self.results_container, text="Loading model...", font=("Arial", 10), fg="#555")
|
| self.lbl_final_method.pack(pady=(0, 5))
|
|
|
| self.winners_frame = tk.Frame(self.results_container, bg="#e3f2fd", bd=2, relief=tk.GROOVE)
|
| self.winners_frame.pack(fill=tk.X, padx=10, pady=5)
|
|
|
| def create_winner_block(parent, title):
|
| frame = tk.Frame(parent, bg="#e3f2fd")
|
| frame.pack(side=tk.LEFT, expand=True, fill=tk.BOTH, padx=3)
|
|
|
| tk.Label(frame, text=title, font=("Arial", 8, "bold"), bg="#e3f2fd", fg="#0d47a1").pack()
|
|
|
| lbl_char = tk.Label(frame, text="", font=("Amharic Unicode", 24), bg="#e3f2fd", fg="blue", cursor="hand2")
|
| lbl_char.pack()
|
| lbl_char.bind("<Button-1>", lambda e, l=lbl_char: self.copy_specific_char(l))
|
|
|
| lbl_info = tk.Label(frame, text="--", font=("Arial", 7), bg="#e3f2fd", fg="#555")
|
| lbl_info.pack()
|
| return lbl_char, lbl_info
|
|
|
| self.w_lbl_acc, self.w_info_acc = create_winner_block(self.winners_frame, "Top Accuracy")
|
| self.w_lbl_unq, self.w_info_unq = create_winner_block(self.winners_frame, "Unique (High Conf)")
|
| self.w_lbl_vote, self.w_info_vote = create_winner_block(self.winners_frame, "Agreed Vote")
|
|
|
|
|
| self.right_panel = tk.Frame(root, bg="#f0f0f0")
|
| self.right_panel.grid(row=0, column=1, sticky="nsew", padx=5, pady=10)
|
| self.right_panel.grid_rowconfigure(0, weight=0)
|
| self.right_panel.grid_rowconfigure(1, weight=0)
|
| self.right_panel.grid_rowconfigure(2, weight=0)
|
| self.right_panel.grid_rowconfigure(3, weight=0)
|
| self.right_panel.grid_rowconfigure(4, weight=1)
|
| self.right_panel.grid_columnconfigure(0, weight=1)
|
|
|
|
|
| self.credits_frame = tk.Frame(self.right_panel, bg="#f0f0f0", bd=2, relief=tk.GROOVE)
|
| self.credits_frame.grid(row=0, column=0, sticky="ew", pady=(0, 10))
|
|
|
| tk.Label(self.credits_frame, text="Made with ❤️ by Yared Kassa", font=("Arial", 12, "bold"), bg="#f0f0f0").pack(pady=2)
|
|
|
| contact_frame = tk.Frame(self.credits_frame, bg="#f0f0f0")
|
| contact_frame.pack(pady=2)
|
|
|
| email_lbl = tk.Label(contact_frame, text="yaredoffice@gmail.com", font=("Arial", 10, "bold"), fg="#0000EE", cursor="hand2", bg="#f0f0f0")
|
| email_lbl.pack(side=tk.LEFT, padx=5)
|
| email_lbl.bind("<Button-1>", lambda e: webbrowser.open("mailto:yaredoffice@gmail.com"))
|
|
|
| tg_lbl = tk.Label(contact_frame, text="https://t.me/yaredoffice", font=("Arial", 10, "bold"), fg="#0088cc", cursor="hand2", bg="#f0f0f0")
|
| tg_lbl.pack(side=tk.LEFT, padx=5)
|
| tg_lbl.bind("<Button-1>", lambda e: webbrowser.open("https://t.me/yaredoffice"))
|
|
|
|
|
| guide_frame = tk.LabelFrame(self.right_panel, text="User Guide", font=("Arial", 9, "bold"), bg="#f0f0f0")
|
| guide_frame.grid(row=1, column=0, sticky="new", pady=(0, 5), padx=5)
|
|
|
| guide_text = (
|
| "⚠ IMPORTANT:\n"
|
| "1. Test SINGLE Geez characters only.\n"
|
| "2. Do NOT test full words.\n"
|
| "3. No Geez numbers support.\n"
|
| "4. Export SQL after 100+ tests to contribute."
|
| )
|
| tk.Label(guide_frame, text=guide_text, font=("Arial", 8), bg="#f0f0f0", justify=tk.LEFT, fg="#333", wraplength=380).pack(padx=8, pady=5, fill=tk.X)
|
|
|
|
|
| blue_note_frame = tk.Frame(self.right_panel, bg="#e3f2fd")
|
| blue_note_frame.grid(row=2, column=0, sticky="ew", padx=5, pady=(0,5))
|
|
|
| wrap_len = 360
|
|
|
| lbl_part1 = tk.Label(blue_note_frame, text="Word OCR coming soon. Model available at ", font=("Arial", 8), bg="#e3f2fd", fg="#0d47a1", wraplength=wrap_len, justify=tk.LEFT)
|
| lbl_part1.pack(anchor="w", padx=5, pady=2)
|
|
|
| link_frame = tk.Frame(blue_note_frame, bg="#e3f2fd")
|
| link_frame.pack(anchor="w", padx=5)
|
| link_lbl = tk.Label(link_frame, text="https://t.me/yaredone", font=("Arial", 8, "bold"), fg="#0000EE", cursor="hand2", bg="#e3f2fd")
|
| link_lbl.pack(side=tk.LEFT)
|
| link_lbl.bind("<Button-1>", lambda e: webbrowser.open("https://t.me/yaredone"))
|
|
|
|
|
| tk.Label(self.right_panel, text="TTA Analysis Grid", font=("Arial", 11, "bold"), bg="#f0f0f0").grid(row=3, column=0, sticky="w", pady=(5, 0), padx=5)
|
|
|
| self.aug_grid_frame = tk.Frame(self.right_panel, bg="#f0f0f0")
|
| self.aug_grid_frame.grid(row=4, column=0, sticky="nsew", padx=5, pady=5)
|
|
|
| self.grid_images = []
|
| self.grid_labels = []
|
|
|
| for i in range(10):
|
| frame = tk.Frame(self.aug_grid_frame, bg="white", bd=1, relief=tk.RAISED)
|
| r, c = divmod(i, 5)
|
| frame.grid(row=r, column=c, padx=2, pady=2, sticky="nsew")
|
|
|
| lbl_img = tk.Label(frame, bg="black", width=8, height=4)
|
| lbl_img.pack()
|
|
|
| lbl_txt = tk.Label(frame, text="...", font=("Arial", 8), bg="white", wraplength=70)
|
| lbl_txt.pack()
|
|
|
| self.grid_images.append(lbl_img)
|
| self.grid_labels.append(lbl_txt)
|
|
|
| for i in range(5): self.aug_grid_frame.grid_columnconfigure(i, weight=1)
|
|
|
|
|
| self.history_panel = tk.LabelFrame(root, text="History (Training Data)", font=("Arial", 11, "bold"))
|
| self.history_panel.grid(row=0, column=2, sticky="nsew", padx=(5, 10), pady=10)
|
| self.history_panel.grid_rowconfigure(2, weight=1)
|
| self.history_panel.grid_columnconfigure(0, weight=1)
|
|
|
|
|
| top_hist_frame = tk.Frame(self.history_panel)
|
| top_hist_frame.grid(row=0, column=0, sticky="ew", padx=5, pady=5)
|
|
|
| tk.Button(top_hist_frame, text="📂 Export SQL", command=self.export_sql, bg="#4CAF50", fg="white", font=("Arial", 9, "bold")).pack(side=tk.LEFT, padx=5)
|
|
|
| tk.Label(top_hist_frame, text="Limit:", bg="#f0f0f0").pack(side=tk.LEFT, padx=2)
|
| self.limit_var = tk.StringVar(value="100")
|
| limit_combo = ttk.Combobox(top_hist_frame, textvariable=self.limit_var, values=["100", "1000", "2000", "3000", "4000", "5000"], state="readonly", width=8)
|
| limit_combo.pack(side=tk.LEFT, padx=2)
|
| limit_combo.bind("<<ComboboxSelected>>", self.change_history_limit)
|
|
|
|
|
| hist_header_frame = tk.Frame(self.history_panel)
|
| hist_header_frame.grid(row=1, column=0, sticky="ew", padx=5, pady=5)
|
|
|
| tk.Label(hist_header_frame, text="Img", width=4, font=("Arial", 8, "bold")).pack(side=tk.LEFT, padx=2)
|
| tk.Label(hist_header_frame, text="Char", width=5, font=("Arial", 8, "bold")).pack(side=tk.LEFT, padx=2)
|
| tk.Label(hist_header_frame, text="Prob", width=6, font=("Arial", 8, "bold")).pack(side=tk.LEFT, padx=2)
|
| tk.Label(hist_header_frame, text="Status", width=8, font=("Arial", 8, "bold")).pack(side=tk.LEFT, padx=2)
|
| tk.Label(hist_header_frame, text="Actions", width=8, font=("Arial", 8, "bold")).pack(side=tk.RIGHT, padx=2)
|
|
|
|
|
| self.history_canvas = tk.Canvas(self.history_panel, bg="white")
|
| self.history_scrollbar = tk.Scrollbar(self.history_panel, orient="vertical", command=self.history_canvas.yview)
|
|
|
| self.history_scrollable_frame = tk.Frame(self.history_canvas, bg="white")
|
| self.history_scrollable_frame.bind(
|
| "<Configure>",
|
| lambda e: self.history_canvas.configure(scrollregion=self.history_canvas.bbox("all"))
|
| )
|
|
|
| self.history_canvas.create_window((0, 0), window=self.history_scrollable_frame, anchor="nw")
|
| self.history_canvas.configure(yscrollcommand=self.history_scrollbar.set)
|
|
|
| self.history_canvas.grid(row=2, column=0, sticky="nsew", padx=5, pady=5)
|
| self.history_scrollbar.grid(row=2, column=1, sticky="ns")
|
|
|
|
|
| self.pagination_frame = tk.Frame(self.history_panel)
|
| self.pagination_frame.grid(row=3, column=0, columnspan=2, sticky="ew", padx=5, pady=5)
|
|
|
| self.btn_prev = tk.Button(self.pagination_frame, text="Prev", command=self.prev_page, width=8)
|
| self.btn_prev.pack(side=tk.LEFT, padx=5)
|
|
|
| self.lbl_page_info = tk.Label(self.pagination_frame, text="Page 1")
|
| self.lbl_page_info.pack(side=tk.LEFT, expand=True)
|
|
|
| self.btn_next = tk.Button(self.pagination_frame, text="Next", command=self.next_page, width=8)
|
| self.btn_next.pack(side=tk.RIGHT, padx=5)
|
|
|
|
|
|
|
| def load_model(self):
|
| try:
|
| model_path = get_model_path()
|
| self.session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
|
| self.input_name = self.session.get_inputs()[0].name
|
| print(f"ONNX Model Loaded Successfully from: {model_path}")
|
| self.lbl_final_method.config(text="Draw to Start...", fg="#555")
|
| except Exception as e:
|
| self.lbl_final_char.config(text="Err", fg="red")
|
| self.lbl_final_method.config(text="Model not found", fg="red")
|
| messagebox.showerror("Model Error", f"Failed to load model:\n{e}\n\nPlace cnn_output.onnx next to the exe.")
|
| print(f"Error: {e}")
|
|
|
| def change_history_limit(self, event):
|
| try:
|
| self.history_limit = int(self.limit_var.get())
|
| messagebox.showinfo("Limit Updated", f"History limit set to {self.history_limit} records.")
|
| except ValueError:
|
| pass
|
|
|
| def start_draw(self, event):
|
| self.old_x = event.x
|
| self.old_y = event.y
|
| self.canvas.delete("placeholder")
|
|
|
| def draw(self, event):
|
| w = self.pen_width_var.get()
|
| if self.old_x and self.old_y:
|
| color = "black" if self.tool_mode.get() == "pen" else "white"
|
| self.canvas.create_line(self.old_x, self.old_y, event.x, event.y,
|
| width=w, fill=color,
|
| capstyle=tk.ROUND, smooth=False)
|
| stroke_color_val = 0 if self.tool_mode.get() == "pen" else 255
|
| self.strokes.append((self.old_x, self.old_y, event.x, event.y, w, stroke_color_val))
|
| self.old_x = event.x
|
| self.old_y = event.y
|
|
|
| def copy_to_clipboard(self, event):
|
| char = self.lbl_final_char.cget("text")
|
| if char not in ["?", "...", "Err", "Processing...", ""]:
|
| toplevel = self.root.winfo_toplevel()
|
| toplevel.clipboard_clear()
|
| toplevel.clipboard_append(char)
|
| original_text = self.lbl_final_method.cget("text")
|
| self.lbl_final_method.config(text="Copied to clipboard!", fg="green")
|
| self.root.after(2000, lambda: self.lbl_final_method.config(text=original_text, fg="#555"))
|
|
|
| def copy_specific_char(self, label_widget):
|
| char = label_widget.cget("text")
|
| if char not in ["?", "...", "Err", "Processing...", ""]:
|
| toplevel = self.root.winfo_toplevel()
|
| toplevel.clipboard_clear()
|
| toplevel.clipboard_append(char)
|
| original_text = char
|
| original_fg = label_widget.cget("fg")
|
| label_widget.config(text="Copied!", fg="green")
|
| self.root.after(1500, lambda: label_widget.config(text=original_text, fg=original_fg))
|
|
|
| def clear_canvas(self):
|
| self.canvas.delete("all")
|
| self.strokes = []
|
| self.placeholder_tag = self.canvas.create_text(
|
| self.canvas_size/2, self.canvas_size/2, text="Draw here", fill="#e0e0e0", font=("Arial", 18, "bold"), tag="placeholder"
|
| )
|
|
|
| self.lbl_final_char.config(text="", fg="blue")
|
| self.lbl_final_method.config(text="Draw to Start...", fg="#555")
|
| for lbl in [self.w_lbl_acc, self.w_lbl_unq, self.w_lbl_vote]:
|
| lbl.config(text="")
|
| for lbl in self.grid_labels: lbl.config(text="...")
|
| for lbl in self.grid_images: lbl.config(image="", bg="black")
|
|
|
| def get_base_image_array(self):
|
| pil_image = Image.new("L", (self.canvas_size, self.canvas_size), 255)
|
| draw = ImageDraw.Draw(pil_image)
|
| if not self.strokes:
|
| return np.zeros((1, 32, 32, 1), dtype='float32'), Image.new("L", (32, 32), 255)
|
| for (x1, y1, x2, y2, w_stroke, col_val) in self.strokes:
|
| draw.line([(x1, y1), (x2, y2)], fill=col_val, width=w_stroke)
|
| pil_image = pil_image.resize((32, 32), Image.Resampling.LANCZOS)
|
| img_array = np.array(pil_image).astype('float32') / 255.0
|
| return img_array.reshape(1, 32, 32, 1), pil_image
|
|
|
| def run_tta_analysis(self):
|
| if not self.strokes:
|
| return
|
| if not hasattr(self, 'session'):
|
| messagebox.showerror("Error", "Model not loaded!")
|
| return
|
|
|
| self.btn_predict.config(state="disabled")
|
| self.lbl_final_char.config(text="...", fg="orange")
|
| self.lbl_final_method.config(text="Processing...", fg="orange")
|
| self.root.update_idletasks()
|
|
|
| try:
|
| base_img_batch, base_pil_img = self.get_base_image_array()
|
| base_img_single = base_img_batch[0]
|
|
|
| aug_list = [
|
| ("Blur Light", blur_image_light(base_img_single)),
|
| ("Blur Heavy", blur_image_heavy(base_img_single)),
|
| ("Blur Mod 1", blur_image_moderate(base_img_single)),
|
| ("Blur Mod 2", blur_image_moderate(base_img_single)),
|
| ("Zoom In", zoom_image(base_img_single)),
|
| ("Zoom Out", zoom_image(base_img_single)),
|
| ("Erode", erode_image(base_img_single)),
|
| ("Dilate", dilate_image(base_img_single)),
|
| ("Bright", shift_brightness(base_img_single)),
|
| ("Dark", shift_brightness(base_img_single))
|
| ]
|
|
|
| batch_input = np.array([img for name, img in aug_list])
|
| preds = self.session.run(None, {self.input_name: batch_input})[0]
|
|
|
| avg_prob_vector = np.mean(preds, axis=0)
|
| idx_acc = np.argmax(avg_prob_vector)
|
| prob_acc = avg_prob_vector[idx_acc]
|
|
|
| max_probs_per_aug = np.max(preds, axis=1)
|
| idx_aug_best = np.argmax(max_probs_per_aug)
|
| prob_unq = max_probs_per_aug[idx_aug_best]
|
| idx_unq = np.argmax(preds[idx_aug_best])
|
|
|
| top_5_indices = avg_prob_vector.argsort()[-5:][::-1]
|
| votes = {idx: 0 for idx in top_5_indices}
|
| valid_votes_count = 0
|
|
|
| for i, (name, img) in enumerate(aug_list):
|
| pred_vec = preds[i]
|
| max_prob = np.max(pred_vec)
|
| class_pred = np.argmax(pred_vec)
|
|
|
| img_display = (img.reshape(32, 32) * 255).astype(np.uint8)
|
| pil_img = Image.fromarray(img_display).resize((120, 120), Image.NEAREST)
|
| photo = ImageTk.PhotoImage(pil_img)
|
| self.grid_images[i].config(image=photo, bg="white")
|
| self.grid_images[i].image = photo
|
|
|
| status_text = ""
|
| text_color = "black"
|
|
|
| if max_prob < 0.35:
|
| status_text = f"{name}\nIgnored"
|
| text_color = "gray"
|
| else:
|
| valid_votes_count += 1
|
| candidate_probs = pred_vec[top_5_indices]
|
| best_local_idx = np.argmax(candidate_probs)
|
| voted_for_idx = top_5_indices[best_local_idx]
|
| votes[voted_for_idx] += 1
|
| voted_char = get_char_from_id(voted_for_idx)
|
| status_text = f"{name}\n{voted_char}\n{max_prob*100:.0f}%"
|
| text_color = "green"
|
|
|
| self.grid_labels[i].config(text=status_text, fg=text_color)
|
|
|
| if valid_votes_count > 0:
|
| idx_vote = max(votes, key=votes.get)
|
| count_vote = votes[idx_vote]
|
| prob_vote = avg_prob_vector[idx_vote]
|
| else:
|
| idx_vote = idx_acc
|
| count_vote = 0
|
| prob_vote = prob_acc
|
|
|
| if valid_votes_count >= 5 and count_vote >= (valid_votes_count * 0.6):
|
| final_idx = idx_vote
|
| final_method = f"Majority ({count_vote}/{valid_votes_count})"
|
| final_conf = prob_vote
|
| else:
|
| final_idx = idx_acc
|
| final_method = "Accuracy Based"
|
| final_conf = prob_acc
|
|
|
| final_char = get_char_from_id(final_idx)
|
| self.lbl_final_char.config(text=final_char, fg="blue")
|
| self.lbl_final_method.config(text=f"{final_method} | Conf: {final_conf*100:.1f}%", fg="#555")
|
|
|
| self.w_lbl_acc.config(text=get_char_from_id(idx_acc))
|
| self.w_info_acc.config(text=f"{prob_acc*100:.1f}%")
|
|
|
| self.w_lbl_unq.config(text=get_char_from_id(idx_unq))
|
| self.w_info_unq.config(text=f"{prob_unq*100:.1f}%")
|
|
|
| self.w_lbl_vote.config(text=get_char_from_id(idx_vote))
|
| self.w_info_vote.config(text=f"{count_vote}/{valid_votes_count}")
|
|
|
| self.db.add_record(base_pil_img, final_char, float(final_conf), self.history_limit)
|
| self.load_history()
|
|
|
| except Exception as e:
|
| self.lbl_final_char.config(text="Err", fg="red")
|
| self.lbl_final_method.config(text=str(e), fg="red")
|
| print(f"Error: {e}")
|
| finally:
|
| self.btn_predict.config(state="normal")
|
|
|
| def load_history(self):
|
| for widget in self.history_scrollable_frame.winfo_children():
|
| widget.destroy()
|
|
|
| rows, total_count = self.db.get_page(self.current_history_page)
|
|
|
| if not rows:
|
| tk.Label(self.history_scrollable_frame, text="No history yet.", bg="white").pack()
|
| self.lbl_page_info.config(text="Page 1/1")
|
| self.btn_prev.config(state="disabled")
|
| self.btn_next.config(state="disabled")
|
| return
|
|
|
| total_pages = (total_count + 9) // 10
|
| if self.current_history_page > total_pages:
|
| self.current_history_page = total_pages
|
| rows, total_count = self.db.get_page(self.current_history_page)
|
| total_pages = (total_count + 9) // 10
|
|
|
| self.lbl_page_info.config(text=f"{self.current_history_page}/{total_pages}")
|
| self.btn_prev.config(state="normal" if self.current_history_page > 1 else "disabled")
|
| self.btn_next.config(state="normal" if self.current_history_page < total_pages else "disabled")
|
|
|
| for row in rows:
|
| r_id = row[0]
|
| img_blob = row[1]
|
| char = row[2]
|
| prob = row[3]
|
|
|
| if len(row) > 5:
|
| is_correct = row[4]
|
| else:
|
| is_correct = 1
|
|
|
| if is_correct == 0:
|
| bg_color = "#ffcdd2"
|
| status_text = "Incorrect"
|
| status_fg = "red"
|
| else:
|
| bg_color = "#e8f5e9"
|
| status_text = "Correct"
|
| status_fg = "green"
|
|
|
| row_frame = tk.Frame(self.history_scrollable_frame, bg=bg_color, bd=1, relief=tk.RIDGE)
|
| row_frame.pack(fill=tk.X, pady=1)
|
|
|
| try:
|
| pil_img = Image.open(io.BytesIO(img_blob))
|
| pil_img = pil_img.resize((35, 35), Image.Resampling.LANCZOS)
|
| photo = ImageTk.PhotoImage(pil_img)
|
| lbl_img = tk.Label(row_frame, image=photo, bg=bg_color)
|
| lbl_img.image = photo
|
| lbl_img.pack(side=tk.LEFT, padx=3)
|
| except:
|
| tk.Label(row_frame, text="ImgErr", bg=bg_color, width=4).pack(side=tk.LEFT, padx=3)
|
|
|
| tk.Label(row_frame, text=char, width=5, font=("Amharic Unicode", 14), bg=bg_color).pack(side=tk.LEFT, padx=2)
|
| tk.Label(row_frame, text=f"{prob:.0%}", width=6, bg=bg_color).pack(side=tk.LEFT, padx=2)
|
| tk.Label(row_frame, text=status_text, width=8, bg=bg_color, fg=status_fg).pack(side=tk.LEFT, padx=2)
|
| tk.Frame(row_frame, bg=bg_color).pack(side=tk.LEFT, expand=True, fill=tk.X)
|
|
|
| actions_frame = tk.Frame(row_frame, bg=bg_color)
|
| actions_frame.pack(side=tk.RIGHT, padx=2)
|
|
|
| btn_del = tk.Button(actions_frame, text="🗑", fg="white", bg="#757575", font=("Arial", 7), width=2,
|
| command=lambda rid=r_id: self.delete_history_row(rid))
|
| btn_del.pack(side=tk.LEFT, padx=1)
|
|
|
| x_bg = "#d32f2f" if is_correct == 1 else "#4CAF50"
|
| btn_x = tk.Button(actions_frame, text="X", fg="white", bg=x_bg, font=("Arial", 8, "bold"), width=2,
|
| command=lambda rid=r_id: self.flag_as_wrong(rid))
|
| btn_x.pack(side=tk.LEFT, padx=1)
|
|
|
| def flag_as_wrong(self, record_id):
|
| if messagebox.askyesno("Toggle Status", "Toggle correctness status for this prediction?"):
|
| try:
|
| new_status = self.db.toggle_correctness(record_id)
|
| if new_status is not None:
|
| self.load_history()
|
| except Exception as e:
|
| messagebox.showerror("Database Error", f"Failed to update status: {e}")
|
|
|
| def delete_history_row(self, record_id):
|
| if messagebox.askyesno("Delete Row", "Delete this history entry permanently?"):
|
| self.db.delete_record(record_id)
|
| self.load_history()
|
|
|
| def prev_page(self):
|
| if self.current_history_page > 1:
|
| self.current_history_page -= 1
|
| self.load_history()
|
|
|
| def next_page(self):
|
| self.current_history_page += 1
|
| self.load_history()
|
|
|
| def export_sql(self):
|
| sql_data = self.db.export_to_sql()
|
| timestamp_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| default_name = f"amharic_ocr_dataset_{timestamp_str}.sql"
|
|
|
| file_path = filedialog.asksaveasfilename(initialfile=default_name,
|
| defaultextension=".sql",
|
| filetypes=[("SQL Files", "*.sql"), ("All Files", "*.*")],
|
| title="Save Dataset")
|
| if file_path:
|
| try:
|
| with open(file_path, "w", encoding="utf-8") as f:
|
| f.write(sql_data)
|
| messagebox.showinfo("Export Successful", f"Dataset saved to {file_path}")
|
| except Exception as e:
|
| messagebox.showerror("Export Error", str(e))
|
|
|
| if __name__ == "__main__":
|
| mutex = ensure_single_instance()
|
|
|
|
|
| root = tk.Tk()
|
| root.withdraw()
|
|
|
|
|
| splash = SplashScreen(root)
|
| root.update()
|
|
|
|
|
| app = DrawingApp(root)
|
|
|
|
|
| splash.update_message("Loading ONNX model...")
|
| root.update()
|
| app.load_model()
|
|
|
| splash.update_message("Loading history database...")
|
| root.update()
|
| app.load_history()
|
|
|
|
|
| splash.destroy()
|
| root.deiconify()
|
| root.title("Amharic OCR - Advanced Data Collector")
|
| root.geometry("1500x900")
|
|
|
| root.mainloop()
|
|
|
|
|