geez-char-ocr / fhtf_onnx.py
Yaredoffice's picture
Upload folder using huggingface_hub
fdf6f8c verified
import tkinter as tk
from tkinter import font as tkfont, ttk, messagebox, filedialog
from PIL import Image, ImageDraw, ImageTk
import numpy as np
import onnxruntime as ort
import cv2
import random
import webbrowser
import sqlite3
import io
import datetime
import sys
import os
from win32event import CreateMutex
from win32api import GetLastError
from winerror import ERROR_ALREADY_EXISTS
# ==========================================
# SINGLE INSTANCE ENFORCEMENT
# ==========================================
def ensure_single_instance():
mutex = CreateMutex(None, False, 'Global\\AmharicOCR_SingleInstance_Mutex')
if GetLastError() == ERROR_ALREADY_EXISTS:
messagebox.showwarning("Already Running", "Amharic OCR is already running!")
sys.exit(0)
return mutex
# ==========================================
# RESOURCE PATH HELPER
# ==========================================
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and PyInstaller"""
try:
base_path = sys._MEIPASS
except Exception:
base_path = os.path.abspath(".")
return os.path.join(base_path, relative_path)
def get_db_path():
"""Get database path in user's home directory"""
home_dir = os.path.expanduser("~")
app_dir = os.path.join(home_dir, "AmharicOCR")
if not os.path.exists(app_dir):
os.makedirs(app_dir)
return os.path.join(app_dir, "amharic_ocr_history.db")
def get_model_path():
"""Get model path - embedded in exe or local file"""
# Always use resource_path (works for both embedded and local)
return resource_path('cnn_output.onnx')
# ==========================================
# DATABASE MANAGER
# ==========================================
class HistoryDB:
def __init__(self, db_name=None):
if db_name is None:
db_name = get_db_path()
self.conn = sqlite3.connect(db_name, check_same_thread=False)
self.create_table()
self.update_schema()
def create_table(self):
cursor = self.conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
image_data BLOB,
predicted_char TEXT,
probability REAL,
is_correct INTEGER DEFAULT 1,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
''')
self.conn.commit()
def update_schema(self):
cursor = self.conn.cursor()
cursor.execute("PRAGMA table_info(history)")
columns = [info[1] for info in cursor.fetchall()]
if 'is_correct' not in columns:
try:
cursor.execute("ALTER TABLE history ADD COLUMN is_correct INTEGER DEFAULT 1")
self.conn.commit()
except Exception as e:
print(f"Error adding column: {e}")
try:
cursor.execute("UPDATE history SET is_correct = 1 WHERE is_correct IS NULL")
self.conn.commit()
except Exception as e:
print(f"Error updating records: {e}")
def add_record(self, image_pil, char, prob, limit):
img_byte_arr = io.BytesIO()
image_pil.save(img_byte_arr, format='PNG')
img_blob = img_byte_arr.getvalue()
cursor = self.conn.cursor()
cursor.execute("INSERT INTO history (image_data, predicted_char, probability, is_correct) VALUES (?, ?, ?, 1)",
(img_blob, char, prob))
self.conn.commit()
cursor.execute("DELETE FROM history WHERE id NOT IN (SELECT id FROM history ORDER BY id DESC LIMIT ?)", (limit,))
self.conn.commit()
return cursor.lastrowid
def toggle_correctness(self, record_id):
cursor = self.conn.cursor()
cursor.execute("SELECT is_correct FROM history WHERE id = ?", (record_id,))
result = cursor.fetchone()
if result:
current_status = result[0]
new_status = 0 if current_status == 1 else 1
cursor.execute("UPDATE history SET is_correct = ? WHERE id = ?", (new_status, record_id))
self.conn.commit()
return new_status
return None
def delete_record(self, record_id):
cursor = self.conn.cursor()
cursor.execute("DELETE FROM history WHERE id = ?", (record_id,))
self.conn.commit()
def get_page(self, page_num, per_page=10):
offset = (page_num - 1) * per_page
cursor = self.conn.cursor()
cursor.execute("SELECT * FROM history ORDER BY id DESC LIMIT ? OFFSET ?", (per_page, offset))
rows = cursor.fetchall()
cursor.execute("SELECT count(*) FROM history")
total_count = cursor.fetchone()[0]
return rows, total_count
def export_to_sql(self):
cursor = self.conn.cursor()
cursor.execute("SELECT * FROM history")
rows = cursor.fetchall()
sql_lines = []
sql_lines.append("-- Amharic OCR Dataset Export\n")
sql_lines.append("-- Generated on: " + str(datetime.datetime.now()) + "\n\n")
sql_lines.append("PRAGMA synchronous=OFF;\n")
sql_lines.append("BEGIN TRANSACTION;\n")
for row in rows:
img_hex = row[1].hex()
char_safe = row[2].replace("'", "''")
if len(row) > 5:
is_corr = row[4]
ts = row[5]
else:
is_corr = 1
ts = row[4]
sql = f"INSERT INTO history (id, image_data, predicted_char, probability, is_correct, timestamp) VALUES ({row[0]}, x'{img_hex}', '{char_safe}', {row[3]}, {is_corr}, '{ts}');"
sql_lines.append(sql + "\n")
sql_lines.append("COMMIT;\n")
return "".join(sql_lines)
# ==========================================
# GEEZ MAPPING LOGIC
# ==========================================
def flatten(l):
try:
return flatten(l[0]) + (flatten(l[1:]) if len(l) > 1 else []) if type(l) is list else [l]
except IndexError:
return []
def fidel(bet):
deqala1 = ['ቈ','ኈ','ኰ','ዀ','ጐ']
deqala2 = ['ሏ','ሗ','ሟ','ሧ','ሯ','ሷ','ሿ','ቊ','ቧ','ቯ',
'ቷ','ቿ','ኊ','ኗ','ኟ','ኧ','ኲ','ዂ','ዟ','ዧ',
'ዷ','ጇ','ጒ','ጧ','ጯ','ጷ','ጿ','ፏ','ፗ']
deqala3 =['ቋ','ኋ','ኳ','ዃ','ጓ']
deqala4 = ['ቌ','ኌ','ኴ','ዄ','ጔ']
deqala5 = ['ቍ','ኍ','ኵ','ዅ','ጕ']
i=0;fidel = 4608;fideloch = []
if bet==1: fidel = 4608
elif bet==2: fidel = 4609
elif bet==3: fidel = 4610
elif bet==4: fidel = 4611
elif bet==5: fidel = 4612
elif bet==6: fidel = 4613
elif bet==7: fidel = 4614
elif bet==8:
for word in deqala1: fideloch.append(word)
elif bet==9:
for word in deqala2: fideloch.append(word)
elif bet==10:
for word in deqala3: fideloch.append(word)
elif bet==11:
for word in deqala4: fideloch.append(word)
elif bet==12:
for word in deqala5: fideloch.append(word)
if bet < 8:
while i<34:
fideloch.append(chr(fidel))
if i==8: fidel = fidel+32
elif i==13 or i==17 or i==24: fidel = fidel+16
elif i==18: fidel = fidel+16
elif i==26: fidel = fidel+24
else: fidel = fidel+8
i=i+1
return fideloch
def fidel_form(min_bet, max_bet):
fidels = []
for index in range(min_bet,max_bet):
fidels.append(fidel(index))
return fidels
laters = flatten(fidel_form(1, 13))
def get_char_from_id(class_id):
if 0 <= class_id < len(laters):
return laters[class_id]
return "?"
# ==========================================
# AUGMENTATION FUNCTIONS
# ==========================================
def ensure_shape(img):
return img.reshape(32, 32, 1).astype('float32')
def erode_image(img):
kernel = np.ones((2,2), np.uint8)
img_uint8 = (img * 255).astype(np.uint8)
eroded = cv2.erode(img_uint8, kernel, iterations=1)
return ensure_shape(eroded.astype('float32') / 255.0)
def dilate_image(img):
kernel = np.ones((2,2), np.uint8)
img_uint8 = (img * 255).astype(np.uint8)
dilated = cv2.dilate(img_uint8, kernel, iterations=1)
return ensure_shape(dilated.astype('float32') / 255.0)
def zoom_image(img):
zoom = random.choice([0.9, 1.1])
h, w = img.shape[:2]
if zoom > 1.0:
new_h, new_w = int(h / zoom), int(w / zoom)
y = (h - new_h) // 2
x = (w - new_w) // 2
cropped = img[y:y+new_h, x:x+new_w]
resized = cv2.resize(cropped, (w, h))
else:
new_h, new_w = int(h * zoom), int(w * zoom)
resized = cv2.resize(img, (new_w, new_h))
pad_top = (h - new_h) // 2
pad_bottom = h - new_h - pad_top
pad_left = (w - new_w) // 2
pad_right = w - new_w - pad_left
resized = cv2.copyMakeBorder(resized, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_REFLECT_101)
return ensure_shape(resized)
def shift_brightness(img):
val = random.choice([-0.1, 0.1])
return ensure_shape(np.clip(img.astype('float32') + val, 0, 1))
def blur_image_light(img):
img_uint8 = (img * 255).astype(np.uint8)
blurred = cv2.GaussianBlur(img_uint8, (3, 3), 0)
return ensure_shape(blurred.astype('float32') / 255.0)
def blur_image_heavy(img):
img_uint8 = (img * 255).astype(np.uint8)
blurred = cv2.GaussianBlur(img_uint8, (7, 7), 0)
return ensure_shape(blurred.astype('float32') / 255.0)
def blur_image_moderate(img):
img_uint8 = (img * 255).astype(np.uint8)
blurred = cv2.GaussianBlur(img_uint8, (5, 5), 0)
return ensure_shape(blurred.astype('float32') / 255.0)
# ==========================================
# SPLASH SCREEN
# ==========================================
class SplashScreen:
def __init__(self, parent_root):
self.splash = tk.Toplevel(parent_root)
self.splash.title("Loading...")
self.splash.geometry("400x150")
self.splash.resizable(False, False)
self.splash.overrideredirect(True)
# Center on screen
self.splash.update_idletasks()
x = (self.splash.winfo_screenwidth() // 2) - (400 // 2)
y = (self.splash.winfo_screenheight() // 2) - (150 // 2)
self.splash.geometry(f"400x150+{x}+{y}")
# Background
main_frame = tk.Frame(self.splash, bg="#2196F3")
main_frame.pack(fill=tk.BOTH, expand=True)
# Title
title_label = tk.Label(main_frame, text="Geez OCR", font=("Arial", 20, "bold"),
bg="#2196F3", fg="white")
title_label.pack(pady=(20, 10))
# Message
self.message_label = tk.Label(main_frame, text="Geez OCR library loading in background...",
font=("Arial", 11), bg="#2196F3", fg="white")
self.message_label.pack(pady=5)
# Progress bar frame
progress_frame = tk.Frame(main_frame, bg="#2196F3")
progress_frame.pack(pady=10, padx=40, fill=tk.X)
self.progress_bar = ttk.Progressbar(progress_frame, mode='indeterminate', length=320)
self.progress_bar.pack(fill=tk.X)
self.progress_bar.start(10)
self.splash.update()
def update_message(self, text):
self.message_label.config(text=text)
self.splash.update()
def destroy(self):
self.progress_bar.stop()
self.splash.destroy()
# ==========================================
# APP CLASS
# ==========================================
class DrawingApp:
def __init__(self, root):
self.root = root
self.root.minsize(1400, 800)
self.db = HistoryDB()
self.current_history_page = 1
self.history_limit = 100
# LAYOUT
root.grid_rowconfigure(0, weight=1)
root.grid_columnconfigure(0, weight=4)
root.grid_columnconfigure(1, weight=3)
root.grid_columnconfigure(2, weight=3)
# COLUMN 1: DRAWING
self.left_panel = tk.Frame(root, bg="white")
self.left_panel.grid(row=0, column=0, sticky="nsew", padx=(10, 5), pady=10)
self.left_panel.grid_rowconfigure(0, weight=0) # Canvas - fixed size
self.left_panel.grid_rowconfigure(1, weight=0) # Controls - fixed size
self.left_panel.grid_rowconfigure(2, weight=0) # Results - fixed size
self.left_panel.grid_columnconfigure(0, weight=1)
# Canvas - Square for accurate model input
self.canvas_size = 320
self.canvas_container = tk.Frame(self.left_panel, width=self.canvas_size+10, height=self.canvas_size+10, bg="white")
self.canvas_container.grid(row=0, column=0, sticky="")
self.canvas_container.grid_propagate(False)
self.canvas = tk.Canvas(self.canvas_container, width=self.canvas_size, height=self.canvas_size, bg="white", cursor="cross")
self.canvas.pack(padx=5, pady=5)
self.placeholder_tag = self.canvas.create_text(
self.canvas_size/2, self.canvas_size/2, text="Draw here", fill="#e0e0e0", font=("Arial", 18, "bold"), tag="placeholder"
)
self.strokes = []
self.old_x = None
self.old_y = None
self.canvas.bind('<Button-1>', self.start_draw)
self.canvas.bind('<B1-Motion>', self.draw)
# Controls
self.controls_frame = tk.Frame(self.left_panel, bg="#f9f9f9", bd=1, relief=tk.RAISED)
self.controls_frame.grid(row=1, column=0, sticky="ew", pady=5)
self.pen_width_var = tk.IntVar(value=10)
self.tool_mode = tk.StringVar(value="pen")
c_inner = tk.Frame(self.controls_frame, bg="#f9f9f9")
c_inner.pack(pady=10)
tk.Radiobutton(c_inner, text="Pen", variable=self.tool_mode, value="pen", bg="#f9f9f9").grid(row=0, column=0, padx=5)
tk.Radiobutton(c_inner, text="Eraser", variable=self.tool_mode, value="eraser", bg="#f9f9f9").grid(row=0, column=1, padx=5)
tk.Label(c_inner, text="Size", bg="#f9f9f9").grid(row=0, column=2, padx=5)
tk.Scale(c_inner, from_=7, to=13, orient=tk.HORIZONTAL, variable=self.pen_width_var, length=150).grid(row=0, column=3, padx=10)
btn_frame = tk.Frame(c_inner, bg="#f9f9f9")
btn_frame.grid(row=1, column=0, columnspan=4, pady=10)
self.btn_predict = tk.Button(btn_frame, text="PREDICT", command=self.run_tta_analysis, bg="#2196F3", fg="white", font=("Arial", 10, "bold"), width=15)
self.btn_predict.pack(side=tk.LEFT, padx=10)
tk.Button(btn_frame, text="CLEAR", command=self.clear_canvas, bg="#f44336", fg="white", font=("Arial", 10, "bold"), width=15).pack(side=tk.LEFT, padx=10)
# Results
self.results_container = tk.Frame(self.left_panel, bg="white")
self.results_container.grid(row=2, column=0, sticky="new", pady=10)
self.lbl_final_char = tk.Label(self.results_container, text="?", font=("Amharic Unicode", 60, "bold"), fg="blue", cursor="hand2")
self.lbl_final_char.pack(pady=5)
self.lbl_final_char.bind("<Button-1>", self.copy_to_clipboard)
self.lbl_final_method = tk.Label(self.results_container, text="Loading model...", font=("Arial", 10), fg="#555")
self.lbl_final_method.pack(pady=(0, 5))
self.winners_frame = tk.Frame(self.results_container, bg="#e3f2fd", bd=2, relief=tk.GROOVE)
self.winners_frame.pack(fill=tk.X, padx=10, pady=5)
def create_winner_block(parent, title):
frame = tk.Frame(parent, bg="#e3f2fd")
frame.pack(side=tk.LEFT, expand=True, fill=tk.BOTH, padx=3)
tk.Label(frame, text=title, font=("Arial", 8, "bold"), bg="#e3f2fd", fg="#0d47a1").pack()
lbl_char = tk.Label(frame, text="", font=("Amharic Unicode", 24), bg="#e3f2fd", fg="blue", cursor="hand2")
lbl_char.pack()
lbl_char.bind("<Button-1>", lambda e, l=lbl_char: self.copy_specific_char(l))
lbl_info = tk.Label(frame, text="--", font=("Arial", 7), bg="#e3f2fd", fg="#555")
lbl_info.pack()
return lbl_char, lbl_info
self.w_lbl_acc, self.w_info_acc = create_winner_block(self.winners_frame, "Top Accuracy")
self.w_lbl_unq, self.w_info_unq = create_winner_block(self.winners_frame, "Unique (High Conf)")
self.w_lbl_vote, self.w_info_vote = create_winner_block(self.winners_frame, "Agreed Vote")
# COLUMN 2: MIDDLE
self.right_panel = tk.Frame(root, bg="#f0f0f0")
self.right_panel.grid(row=0, column=1, sticky="nsew", padx=5, pady=10)
self.right_panel.grid_rowconfigure(0, weight=0) # Credits - fixed
self.right_panel.grid_rowconfigure(1, weight=0) # Guide - fixed
self.right_panel.grid_rowconfigure(2, weight=0) # Blue note - fixed
self.right_panel.grid_rowconfigure(3, weight=0) # Label - fixed
self.right_panel.grid_rowconfigure(4, weight=1) # TTA Grid - expandable
self.right_panel.grid_columnconfigure(0, weight=1)
# Credits
self.credits_frame = tk.Frame(self.right_panel, bg="#f0f0f0", bd=2, relief=tk.GROOVE)
self.credits_frame.grid(row=0, column=0, sticky="ew", pady=(0, 10))
tk.Label(self.credits_frame, text="Made with ❤️ by Yared Kassa", font=("Arial", 12, "bold"), bg="#f0f0f0").pack(pady=2)
contact_frame = tk.Frame(self.credits_frame, bg="#f0f0f0")
contact_frame.pack(pady=2)
email_lbl = tk.Label(contact_frame, text="yaredoffice@gmail.com", font=("Arial", 10, "bold"), fg="#0000EE", cursor="hand2", bg="#f0f0f0")
email_lbl.pack(side=tk.LEFT, padx=5)
email_lbl.bind("<Button-1>", lambda e: webbrowser.open("mailto:yaredoffice@gmail.com"))
tg_lbl = tk.Label(contact_frame, text="https://t.me/yaredoffice", font=("Arial", 10, "bold"), fg="#0088cc", cursor="hand2", bg="#f0f0f0")
tg_lbl.pack(side=tk.LEFT, padx=5)
tg_lbl.bind("<Button-1>", lambda e: webbrowser.open("https://t.me/yaredoffice"))
# User Guide
guide_frame = tk.LabelFrame(self.right_panel, text="User Guide", font=("Arial", 9, "bold"), bg="#f0f0f0")
guide_frame.grid(row=1, column=0, sticky="new", pady=(0, 5), padx=5)
guide_text = (
"⚠ IMPORTANT:\n"
"1. Test SINGLE Geez characters only.\n"
"2. Do NOT test full words.\n"
"3. No Geez numbers support.\n"
"4. Export SQL after 100+ tests to contribute."
)
tk.Label(guide_frame, text=guide_text, font=("Arial", 8), bg="#f0f0f0", justify=tk.LEFT, fg="#333", wraplength=380).pack(padx=8, pady=5, fill=tk.X)
# Blue Note
blue_note_frame = tk.Frame(self.right_panel, bg="#e3f2fd")
blue_note_frame.grid(row=2, column=0, sticky="ew", padx=5, pady=(0,5))
wrap_len = 360
lbl_part1 = tk.Label(blue_note_frame, text="Word OCR coming soon. Model available at ", font=("Arial", 8), bg="#e3f2fd", fg="#0d47a1", wraplength=wrap_len, justify=tk.LEFT)
lbl_part1.pack(anchor="w", padx=5, pady=2)
link_frame = tk.Frame(blue_note_frame, bg="#e3f2fd")
link_frame.pack(anchor="w", padx=5)
link_lbl = tk.Label(link_frame, text="https://t.me/yaredone", font=("Arial", 8, "bold"), fg="#0000EE", cursor="hand2", bg="#e3f2fd")
link_lbl.pack(side=tk.LEFT)
link_lbl.bind("<Button-1>", lambda e: webbrowser.open("https://t.me/yaredone"))
# TTA Grid
tk.Label(self.right_panel, text="TTA Analysis Grid", font=("Arial", 11, "bold"), bg="#f0f0f0").grid(row=3, column=0, sticky="w", pady=(5, 0), padx=5)
self.aug_grid_frame = tk.Frame(self.right_panel, bg="#f0f0f0")
self.aug_grid_frame.grid(row=4, column=0, sticky="nsew", padx=5, pady=5)
self.grid_images = []
self.grid_labels = []
for i in range(10):
frame = tk.Frame(self.aug_grid_frame, bg="white", bd=1, relief=tk.RAISED)
r, c = divmod(i, 5)
frame.grid(row=r, column=c, padx=2, pady=2, sticky="nsew")
lbl_img = tk.Label(frame, bg="black", width=8, height=4)
lbl_img.pack()
lbl_txt = tk.Label(frame, text="...", font=("Arial", 8), bg="white", wraplength=70)
lbl_txt.pack()
self.grid_images.append(lbl_img)
self.grid_labels.append(lbl_txt)
for i in range(5): self.aug_grid_frame.grid_columnconfigure(i, weight=1)
# COLUMN 3: HISTORY
self.history_panel = tk.LabelFrame(root, text="History (Training Data)", font=("Arial", 11, "bold"))
self.history_panel.grid(row=0, column=2, sticky="nsew", padx=(5, 10), pady=10)
self.history_panel.grid_rowconfigure(2, weight=1)
self.history_panel.grid_columnconfigure(0, weight=1)
# Top Controls
top_hist_frame = tk.Frame(self.history_panel)
top_hist_frame.grid(row=0, column=0, sticky="ew", padx=5, pady=5)
tk.Button(top_hist_frame, text="📂 Export SQL", command=self.export_sql, bg="#4CAF50", fg="white", font=("Arial", 9, "bold")).pack(side=tk.LEFT, padx=5)
tk.Label(top_hist_frame, text="Limit:", bg="#f0f0f0").pack(side=tk.LEFT, padx=2)
self.limit_var = tk.StringVar(value="100")
limit_combo = ttk.Combobox(top_hist_frame, textvariable=self.limit_var, values=["100", "1000", "2000", "3000", "4000", "5000"], state="readonly", width=8)
limit_combo.pack(side=tk.LEFT, padx=2)
limit_combo.bind("<<ComboboxSelected>>", self.change_history_limit)
# Header
hist_header_frame = tk.Frame(self.history_panel)
hist_header_frame.grid(row=1, column=0, sticky="ew", padx=5, pady=5)
tk.Label(hist_header_frame, text="Img", width=4, font=("Arial", 8, "bold")).pack(side=tk.LEFT, padx=2)
tk.Label(hist_header_frame, text="Char", width=5, font=("Arial", 8, "bold")).pack(side=tk.LEFT, padx=2)
tk.Label(hist_header_frame, text="Prob", width=6, font=("Arial", 8, "bold")).pack(side=tk.LEFT, padx=2)
tk.Label(hist_header_frame, text="Status", width=8, font=("Arial", 8, "bold")).pack(side=tk.LEFT, padx=2)
tk.Label(hist_header_frame, text="Actions", width=8, font=("Arial", 8, "bold")).pack(side=tk.RIGHT, padx=2)
# List Area
self.history_canvas = tk.Canvas(self.history_panel, bg="white")
self.history_scrollbar = tk.Scrollbar(self.history_panel, orient="vertical", command=self.history_canvas.yview)
self.history_scrollable_frame = tk.Frame(self.history_canvas, bg="white")
self.history_scrollable_frame.bind(
"<Configure>",
lambda e: self.history_canvas.configure(scrollregion=self.history_canvas.bbox("all"))
)
self.history_canvas.create_window((0, 0), window=self.history_scrollable_frame, anchor="nw")
self.history_canvas.configure(yscrollcommand=self.history_scrollbar.set)
self.history_canvas.grid(row=2, column=0, sticky="nsew", padx=5, pady=5)
self.history_scrollbar.grid(row=2, column=1, sticky="ns")
# Pagination
self.pagination_frame = tk.Frame(self.history_panel)
self.pagination_frame.grid(row=3, column=0, columnspan=2, sticky="ew", padx=5, pady=5)
self.btn_prev = tk.Button(self.pagination_frame, text="Prev", command=self.prev_page, width=8)
self.btn_prev.pack(side=tk.LEFT, padx=5)
self.lbl_page_info = tk.Label(self.pagination_frame, text="Page 1")
self.lbl_page_info.pack(side=tk.LEFT, expand=True)
self.btn_next = tk.Button(self.pagination_frame, text="Next", command=self.next_page, width=8)
self.btn_next.pack(side=tk.RIGHT, padx=5)
# Model and history will be loaded after splash screen
def load_model(self):
try:
model_path = get_model_path()
self.session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
self.input_name = self.session.get_inputs()[0].name
print(f"ONNX Model Loaded Successfully from: {model_path}")
self.lbl_final_method.config(text="Draw to Start...", fg="#555")
except Exception as e:
self.lbl_final_char.config(text="Err", fg="red")
self.lbl_final_method.config(text="Model not found", fg="red")
messagebox.showerror("Model Error", f"Failed to load model:\n{e}\n\nPlace cnn_output.onnx next to the exe.")
print(f"Error: {e}")
def change_history_limit(self, event):
try:
self.history_limit = int(self.limit_var.get())
messagebox.showinfo("Limit Updated", f"History limit set to {self.history_limit} records.")
except ValueError:
pass
def start_draw(self, event):
self.old_x = event.x
self.old_y = event.y
self.canvas.delete("placeholder")
def draw(self, event):
w = self.pen_width_var.get()
if self.old_x and self.old_y:
color = "black" if self.tool_mode.get() == "pen" else "white"
self.canvas.create_line(self.old_x, self.old_y, event.x, event.y,
width=w, fill=color,
capstyle=tk.ROUND, smooth=False)
stroke_color_val = 0 if self.tool_mode.get() == "pen" else 255
self.strokes.append((self.old_x, self.old_y, event.x, event.y, w, stroke_color_val))
self.old_x = event.x
self.old_y = event.y
def copy_to_clipboard(self, event):
char = self.lbl_final_char.cget("text")
if char not in ["?", "...", "Err", "Processing...", ""]:
toplevel = self.root.winfo_toplevel()
toplevel.clipboard_clear()
toplevel.clipboard_append(char)
original_text = self.lbl_final_method.cget("text")
self.lbl_final_method.config(text="Copied to clipboard!", fg="green")
self.root.after(2000, lambda: self.lbl_final_method.config(text=original_text, fg="#555"))
def copy_specific_char(self, label_widget):
char = label_widget.cget("text")
if char not in ["?", "...", "Err", "Processing...", ""]:
toplevel = self.root.winfo_toplevel()
toplevel.clipboard_clear()
toplevel.clipboard_append(char)
original_text = char
original_fg = label_widget.cget("fg")
label_widget.config(text="Copied!", fg="green")
self.root.after(1500, lambda: label_widget.config(text=original_text, fg=original_fg))
def clear_canvas(self):
self.canvas.delete("all")
self.strokes = []
self.placeholder_tag = self.canvas.create_text(
self.canvas_size/2, self.canvas_size/2, text="Draw here", fill="#e0e0e0", font=("Arial", 18, "bold"), tag="placeholder"
)
self.lbl_final_char.config(text="", fg="blue")
self.lbl_final_method.config(text="Draw to Start...", fg="#555")
for lbl in [self.w_lbl_acc, self.w_lbl_unq, self.w_lbl_vote]:
lbl.config(text="")
for lbl in self.grid_labels: lbl.config(text="...")
for lbl in self.grid_images: lbl.config(image="", bg="black")
def get_base_image_array(self):
pil_image = Image.new("L", (self.canvas_size, self.canvas_size), 255)
draw = ImageDraw.Draw(pil_image)
if not self.strokes:
return np.zeros((1, 32, 32, 1), dtype='float32'), Image.new("L", (32, 32), 255)
for (x1, y1, x2, y2, w_stroke, col_val) in self.strokes:
draw.line([(x1, y1), (x2, y2)], fill=col_val, width=w_stroke)
pil_image = pil_image.resize((32, 32), Image.Resampling.LANCZOS)
img_array = np.array(pil_image).astype('float32') / 255.0
return img_array.reshape(1, 32, 32, 1), pil_image
def run_tta_analysis(self):
if not self.strokes:
return
if not hasattr(self, 'session'):
messagebox.showerror("Error", "Model not loaded!")
return
self.btn_predict.config(state="disabled")
self.lbl_final_char.config(text="...", fg="orange")
self.lbl_final_method.config(text="Processing...", fg="orange")
self.root.update_idletasks()
try:
base_img_batch, base_pil_img = self.get_base_image_array()
base_img_single = base_img_batch[0]
aug_list = [
("Blur Light", blur_image_light(base_img_single)),
("Blur Heavy", blur_image_heavy(base_img_single)),
("Blur Mod 1", blur_image_moderate(base_img_single)),
("Blur Mod 2", blur_image_moderate(base_img_single)),
("Zoom In", zoom_image(base_img_single)),
("Zoom Out", zoom_image(base_img_single)),
("Erode", erode_image(base_img_single)),
("Dilate", dilate_image(base_img_single)),
("Bright", shift_brightness(base_img_single)),
("Dark", shift_brightness(base_img_single))
]
batch_input = np.array([img for name, img in aug_list])
preds = self.session.run(None, {self.input_name: batch_input})[0]
avg_prob_vector = np.mean(preds, axis=0)
idx_acc = np.argmax(avg_prob_vector)
prob_acc = avg_prob_vector[idx_acc]
max_probs_per_aug = np.max(preds, axis=1)
idx_aug_best = np.argmax(max_probs_per_aug)
prob_unq = max_probs_per_aug[idx_aug_best]
idx_unq = np.argmax(preds[idx_aug_best])
top_5_indices = avg_prob_vector.argsort()[-5:][::-1]
votes = {idx: 0 for idx in top_5_indices}
valid_votes_count = 0
for i, (name, img) in enumerate(aug_list):
pred_vec = preds[i]
max_prob = np.max(pred_vec)
class_pred = np.argmax(pred_vec)
img_display = (img.reshape(32, 32) * 255).astype(np.uint8)
pil_img = Image.fromarray(img_display).resize((120, 120), Image.NEAREST)
photo = ImageTk.PhotoImage(pil_img)
self.grid_images[i].config(image=photo, bg="white")
self.grid_images[i].image = photo
status_text = ""
text_color = "black"
if max_prob < 0.35:
status_text = f"{name}\nIgnored"
text_color = "gray"
else:
valid_votes_count += 1
candidate_probs = pred_vec[top_5_indices]
best_local_idx = np.argmax(candidate_probs)
voted_for_idx = top_5_indices[best_local_idx]
votes[voted_for_idx] += 1
voted_char = get_char_from_id(voted_for_idx)
status_text = f"{name}\n{voted_char}\n{max_prob*100:.0f}%"
text_color = "green"
self.grid_labels[i].config(text=status_text, fg=text_color)
if valid_votes_count > 0:
idx_vote = max(votes, key=votes.get)
count_vote = votes[idx_vote]
prob_vote = avg_prob_vector[idx_vote]
else:
idx_vote = idx_acc
count_vote = 0
prob_vote = prob_acc
if valid_votes_count >= 5 and count_vote >= (valid_votes_count * 0.6):
final_idx = idx_vote
final_method = f"Majority ({count_vote}/{valid_votes_count})"
final_conf = prob_vote
else:
final_idx = idx_acc
final_method = "Accuracy Based"
final_conf = prob_acc
final_char = get_char_from_id(final_idx)
self.lbl_final_char.config(text=final_char, fg="blue")
self.lbl_final_method.config(text=f"{final_method} | Conf: {final_conf*100:.1f}%", fg="#555")
self.w_lbl_acc.config(text=get_char_from_id(idx_acc))
self.w_info_acc.config(text=f"{prob_acc*100:.1f}%")
self.w_lbl_unq.config(text=get_char_from_id(idx_unq))
self.w_info_unq.config(text=f"{prob_unq*100:.1f}%")
self.w_lbl_vote.config(text=get_char_from_id(idx_vote))
self.w_info_vote.config(text=f"{count_vote}/{valid_votes_count}")
self.db.add_record(base_pil_img, final_char, float(final_conf), self.history_limit)
self.load_history()
except Exception as e:
self.lbl_final_char.config(text="Err", fg="red")
self.lbl_final_method.config(text=str(e), fg="red")
print(f"Error: {e}")
finally:
self.btn_predict.config(state="normal")
def load_history(self):
for widget in self.history_scrollable_frame.winfo_children():
widget.destroy()
rows, total_count = self.db.get_page(self.current_history_page)
if not rows:
tk.Label(self.history_scrollable_frame, text="No history yet.", bg="white").pack()
self.lbl_page_info.config(text="Page 1/1")
self.btn_prev.config(state="disabled")
self.btn_next.config(state="disabled")
return
total_pages = (total_count + 9) // 10
if self.current_history_page > total_pages:
self.current_history_page = total_pages
rows, total_count = self.db.get_page(self.current_history_page)
total_pages = (total_count + 9) // 10
self.lbl_page_info.config(text=f"{self.current_history_page}/{total_pages}")
self.btn_prev.config(state="normal" if self.current_history_page > 1 else "disabled")
self.btn_next.config(state="normal" if self.current_history_page < total_pages else "disabled")
for row in rows:
r_id = row[0]
img_blob = row[1]
char = row[2]
prob = row[3]
if len(row) > 5:
is_correct = row[4]
else:
is_correct = 1
if is_correct == 0:
bg_color = "#ffcdd2"
status_text = "Incorrect"
status_fg = "red"
else:
bg_color = "#e8f5e9"
status_text = "Correct"
status_fg = "green"
row_frame = tk.Frame(self.history_scrollable_frame, bg=bg_color, bd=1, relief=tk.RIDGE)
row_frame.pack(fill=tk.X, pady=1)
try:
pil_img = Image.open(io.BytesIO(img_blob))
pil_img = pil_img.resize((35, 35), Image.Resampling.LANCZOS)
photo = ImageTk.PhotoImage(pil_img)
lbl_img = tk.Label(row_frame, image=photo, bg=bg_color)
lbl_img.image = photo
lbl_img.pack(side=tk.LEFT, padx=3)
except:
tk.Label(row_frame, text="ImgErr", bg=bg_color, width=4).pack(side=tk.LEFT, padx=3)
tk.Label(row_frame, text=char, width=5, font=("Amharic Unicode", 14), bg=bg_color).pack(side=tk.LEFT, padx=2)
tk.Label(row_frame, text=f"{prob:.0%}", width=6, bg=bg_color).pack(side=tk.LEFT, padx=2)
tk.Label(row_frame, text=status_text, width=8, bg=bg_color, fg=status_fg).pack(side=tk.LEFT, padx=2)
tk.Frame(row_frame, bg=bg_color).pack(side=tk.LEFT, expand=True, fill=tk.X)
actions_frame = tk.Frame(row_frame, bg=bg_color)
actions_frame.pack(side=tk.RIGHT, padx=2)
btn_del = tk.Button(actions_frame, text="🗑", fg="white", bg="#757575", font=("Arial", 7), width=2,
command=lambda rid=r_id: self.delete_history_row(rid))
btn_del.pack(side=tk.LEFT, padx=1)
x_bg = "#d32f2f" if is_correct == 1 else "#4CAF50"
btn_x = tk.Button(actions_frame, text="X", fg="white", bg=x_bg, font=("Arial", 8, "bold"), width=2,
command=lambda rid=r_id: self.flag_as_wrong(rid))
btn_x.pack(side=tk.LEFT, padx=1)
def flag_as_wrong(self, record_id):
if messagebox.askyesno("Toggle Status", "Toggle correctness status for this prediction?"):
try:
new_status = self.db.toggle_correctness(record_id)
if new_status is not None:
self.load_history()
except Exception as e:
messagebox.showerror("Database Error", f"Failed to update status: {e}")
def delete_history_row(self, record_id):
if messagebox.askyesno("Delete Row", "Delete this history entry permanently?"):
self.db.delete_record(record_id)
self.load_history()
def prev_page(self):
if self.current_history_page > 1:
self.current_history_page -= 1
self.load_history()
def next_page(self):
self.current_history_page += 1
self.load_history()
def export_sql(self):
sql_data = self.db.export_to_sql()
timestamp_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
default_name = f"amharic_ocr_dataset_{timestamp_str}.sql"
file_path = filedialog.asksaveasfilename(initialfile=default_name,
defaultextension=".sql",
filetypes=[("SQL Files", "*.sql"), ("All Files", "*.*")],
title="Save Dataset")
if file_path:
try:
with open(file_path, "w", encoding="utf-8") as f:
f.write(sql_data)
messagebox.showinfo("Export Successful", f"Dataset saved to {file_path}")
except Exception as e:
messagebox.showerror("Export Error", str(e))
if __name__ == "__main__":
mutex = ensure_single_instance()
# Create root window but don't show it yet
root = tk.Tk()
root.withdraw() # Hide main window initially
# Show splash screen
splash = SplashScreen(root)
root.update()
# Create app (this will build the UI)
app = DrawingApp(root)
# Load model with progress updates
splash.update_message("Loading ONNX model...")
root.update()
app.load_model()
splash.update_message("Loading history database...")
root.update()
app.load_history()
# Close splash and show main window
splash.destroy()
root.deiconify() # Show main window
root.title("Amharic OCR - Advanced Data Collector")
root.geometry("1500x900")
root.mainloop()