Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| from PIL import Image, ImageFont, ImageDraw | |
| import numpy as np | |
| import os | |
| import string | |
| import cv2 | |
| from torchvision.transforms.functional import to_pil_image | |
| import matplotlib.pyplot as plt | |
| import math | |
| from datetime import datetime | |
| import re | |
| from termcolor import colored | |
| from pyctcdecode import BeamSearchDecoderCTC, Alphabet | |
| from difflib import SequenceMatcher | |
| # --------- Globals --------- # | |
| CHARS = string.ascii_letters + string.digits + string.punctuation | |
| CHAR2IDX = {c: i + 1 for i, c in enumerate(CHARS)} # Start from 1 | |
| CHAR2IDX["<BLANK>"] = 0 # CTC blank | |
| IDX2CHAR = {v: k for k, v in CHAR2IDX.items()} | |
| BLANK_IDX = 0 | |
| IMAGE_HEIGHT = 32 | |
| IMAGE_WIDTH = 128 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| font_path = None | |
| ocr_model = None | |
| # Create vocabulary list (ensure order matches your model’s output indices!) | |
| labels = [IDX2CHAR.get(i, "") for i in range(len(IDX2CHAR))] | |
| # Wrap in Alphabet | |
| alphabet = Alphabet.build_alphabet(labels) | |
| # Now initialize decoder correctly | |
| decoder = BeamSearchDecoderCTC(alphabet) | |
| # Ensure required directories exist at startup | |
| os.makedirs("./fonts", exist_ok=True) | |
| os.makedirs("./models", exist_ok=True) | |
| os.makedirs("./labels", exist_ok=True) | |
| # --------- Dataset --------- # | |
| class OCRDataset(Dataset): | |
| def __init__(self, font_path, size=1000, label_length_range=(4, 7)): | |
| self.font = ImageFont.truetype(font_path, 32) | |
| self.label_length_range = label_length_range | |
| self.samples = [ | |
| "".join(np.random.choice(list(CHARS), np.random.randint(*self.label_length_range))) | |
| for _ in range(size) | |
| ] | |
| self.transform = transforms.Compose([ | |
| transforms.ToTensor(), # must be first | |
| transforms.Normalize((0.5,), (0.5,)), | |
| transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)), | |
| transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.3), | |
| transforms.RandomApply([transforms.RandomAffine(degrees=10, translate=(0.1, 0.1))], p=0.3), | |
| ]) | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, idx): | |
| label = self.samples[idx] | |
| # Create an image with padding | |
| pad = 8 | |
| w = self.font.getlength(label) | |
| h = self.font.size | |
| img_w, img_h = int(w + 2 * pad), int(h + 2 * pad) | |
| img = Image.new("L", (img_w, img_h), 255) | |
| draw = ImageDraw.Draw(img) | |
| draw.text((pad, pad), label, font=self.font, fill=0) | |
| img = self.transform(img) | |
| label_encoded = torch.tensor([CHAR2IDX[c] for c in label], dtype=torch.long) | |
| label_length = torch.tensor(len(label_encoded), dtype=torch.long) | |
| return img, label_encoded, label_length | |
| def render_text(self, text): | |
| img = Image.new("L", (IMAGE_WIDTH, IMAGE_HEIGHT), color=255) | |
| draw = ImageDraw.Draw(img) | |
| bbox = self.font.getbbox(text) | |
| w, h = bbox[2] - bbox[0], bbox[3] - bbox[1] | |
| draw.text(((IMAGE_WIDTH - w) // 2, (IMAGE_HEIGHT - h) // 2), text, font=self.font, fill=0) | |
| return img | |
| # --------- Model --------- # | |
| class OCRModel(nn.Module): | |
| def __init__(self, num_classes): | |
| super().__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d((2, 2), (2, 1)), # height↓2, width↓1 | |
| nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d((2, 2), (2, 1)) # height↓2 again, width↓1 | |
| ) | |
| self.rnn = nn.LSTM(64 * 8, 128, bidirectional=True, num_layers=2, batch_first=True) | |
| self.fc = nn.Linear(256, num_classes) | |
| with torch.no_grad(): | |
| self.fc.bias[0] = -5.0 # discourage blank early on | |
| def forward(self, x): | |
| b, c, h, w = x.size() | |
| x = self.conv(x) | |
| x = x.permute(0, 3, 1, 2) | |
| x = x.reshape(b, x.size(1), -1) | |
| x, _ = self.rnn(x) | |
| x = self.fc(x) | |
| return x | |
| def color_char(c, conf): | |
| color_levels = ['\033[31m', '\033[33m', '\033[32m', '\033[36m', '\033[34m', '\033[35m', '\033[0m'] | |
| idx = min(int(conf * (len(color_levels) - 1)), len(color_levels) - 1) | |
| return f"{color_levels[idx]}{c}\033[0m" | |
| def sanitize_filename(name): | |
| return re.sub(r'[^a-zA-Z0-9_-]', '_', name) | |
| def greedy_decode(log_probs): | |
| # log_probs shape: (T, B, C) | |
| # Usually, B=1 during inference | |
| pred = log_probs.argmax(2).squeeze(1).tolist() # this should give a list of ints | |
| print(f"Decoded indices: {pred}") # debug print | |
| decoded = [] | |
| prev = BLANK_IDX | |
| for p in pred: | |
| if p != prev and p != BLANK_IDX: | |
| decoded.append(IDX2CHAR.get(p, "")) | |
| prev = p | |
| return ''.join(decoded) | |
| # --------- Custom Collate --------- # | |
| def custom_collate_fn(batch): | |
| images, labels, _ = zip(*batch) | |
| images = torch.stack(images, 0) | |
| flat_labels = [] | |
| label_lengths = [] | |
| for label in labels: | |
| flat_labels.append(label) | |
| label_lengths.append(len(label)) | |
| targets = torch.cat(flat_labels) | |
| return images, targets, torch.tensor(label_lengths, dtype=torch.long) | |
| # --------- Model Save/Load --------- # | |
| def list_saved_models(): | |
| model_dir = "./models" | |
| if not os.path.exists(model_dir): | |
| return [] | |
| return [f for f in os.listdir(model_dir) if f.endswith(".pth")] | |
| def save_model(model, path): | |
| torch.save(model.state_dict(), path) | |
| def load_model(filename): | |
| global ocr_model | |
| model_dir = "./models" | |
| path = os.path.join(model_dir, filename) | |
| if not os.path.exists(path): | |
| return f"Model file '{path}' does not exist." | |
| model = OCRModel(num_classes=len(CHAR2IDX)) | |
| model.load_state_dict(torch.load(path, map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| ocr_model = model | |
| return f"Model '{path}' loaded." | |
| # --------- Gradio Functions --------- # | |
| def train_model(font_file, epochs=100, learning_rate=0.001): | |
| import time | |
| global font_path, ocr_model | |
| # Ensure directories exist | |
| os.makedirs("./fonts", exist_ok=True) | |
| os.makedirs("./models", exist_ok=True) | |
| # Save uploaded font to ./fonts | |
| font_name = os.path.splitext(os.path.basename(font_file.name))[0] | |
| font_path = f"./fonts/{font_name}.ttf" | |
| with open(font_file.name, "rb") as uploaded: | |
| with open(font_path, "wb") as f: | |
| f.write(uploaded.read()) | |
| # Curriculum learning: label length grows over time | |
| def get_dataset_for_epoch(epoch): | |
| if epoch < epochs // 3: | |
| label_len = (3, 4) | |
| elif epoch < 2 * epochs // 3: | |
| label_len = (4, 6) | |
| else: | |
| label_len = (5, 7) | |
| return OCRDataset(font_path, label_length_range=label_len) | |
| # Visualize one sample | |
| dataset = get_dataset_for_epoch(0) | |
| img, label, _ = dataset[0] | |
| print("Label:", ''.join([IDX2CHAR[i.item()] for i in label])) | |
| plt.imshow(img.permute(1, 2, 0).squeeze(), cmap='gray') | |
| plt.show() | |
| # Model setup | |
| model = OCRModel(num_classes=len(CHAR2IDX)).to(device) | |
| criterion = nn.CTCLoss(blank=BLANK_IDX) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5) | |
| for epoch in range(epochs): | |
| dataset = get_dataset_for_epoch(epoch) | |
| dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=custom_collate_fn) | |
| model.train() | |
| running_loss = 0.0 | |
| # Warmup learning rate | |
| if epoch < 5: | |
| warmup_lr = learning_rate * 0.2 | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] = warmup_lr | |
| else: | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] = learning_rate | |
| for img, targets, target_lengths in dataloader: | |
| img = img.to(device) | |
| targets = targets.to(device) | |
| target_lengths = target_lengths.to(device) | |
| output = model(img) | |
| seq_len = output.size(1) | |
| batch_size = img.size(0) | |
| input_lengths = torch.full((batch_size,), seq_len, dtype=torch.long).to(device) | |
| log_probs = output.log_softmax(2).transpose(0, 1) | |
| loss = criterion(log_probs, targets, input_lengths, target_lengths) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() | |
| avg_loss = running_loss / len(dataloader) | |
| scheduler.step(avg_loss) | |
| print(f"[{epoch + 1}/{epochs}] Loss: {avg_loss:.4f}") | |
| # Save the model to ./models | |
| timestamp = time.strftime("%Y%m%d%H%M%S") | |
| model_name = f"{font_name}_{epochs}ep_lr{learning_rate:.0e}_{timestamp}.pth" | |
| model_path = os.path.join("./models", model_name) | |
| save_model(model, model_path) | |
| ocr_model = model | |
| return f"✅ Training complete! Model saved as '{model_path}'" | |
| def preprocess_image(image: Image.Image): | |
| img_cv = np.array(image.convert("L")) | |
| img_bin = cv2.adaptiveThreshold(img_cv, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, | |
| cv2.THRESH_BINARY_INV, 25, 15) | |
| # Invert if background is dark | |
| white_px = (img_bin == 255).sum() | |
| black_px = (img_bin == 0).sum() | |
| if black_px > white_px: | |
| img_bin = 255 - img_bin | |
| # Resize and pad/crop to (IMAGE_HEIGHT, IMAGE_WIDTH) | |
| h, w = img_bin.shape | |
| scale = IMAGE_HEIGHT / h | |
| new_w = int(w * scale) | |
| resized = cv2.resize(img_bin, (new_w, IMAGE_HEIGHT), interpolation=cv2.INTER_AREA) | |
| if new_w < IMAGE_WIDTH: | |
| pad_width = IMAGE_WIDTH - new_w | |
| padded = np.pad(resized, ((0, 0), (0, pad_width)), constant_values=255) | |
| else: | |
| padded = resized[:, :IMAGE_WIDTH] | |
| return to_pil_image(padded) | |
| # ROYGBIV color ramp (low → high confidence) | |
| CONFIDENCE_COLORS = [ | |
| "#FF0000", # Red | |
| "#FF7F00", # Orange | |
| "#FFFF00", # Yellow | |
| "#00FF00", # Green | |
| "#00BFFF", # Sky Blue | |
| "#0000FF", # Blue | |
| "#8B00FF", # Violet | |
| ] | |
| def confidence_to_color(conf): | |
| """ | |
| Map confidence (0.0–1.0) to a ROYGBIV-style hex color. | |
| """ | |
| index = min(int(conf * (len(CONFIDENCE_COLORS) - 1)), len(CONFIDENCE_COLORS) - 1) | |
| return CONFIDENCE_COLORS[index] | |
| def color_char(c, conf): | |
| """ | |
| Wrap character `c` in a span tag with color mapped from `conf`. | |
| """ | |
| color = confidence_to_color(conf) | |
| return f'<span style="color:{color}; font-size:12pt; font-weight:bold;">{c}</span>' | |
| def predict_text(image: Image.Image, ground_truth: str = None, debug: bool = False): | |
| if ocr_model is None: | |
| return "Please load or train a model first." | |
| processed = preprocess_image(image) | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5,), (0.5,)) | |
| ]) | |
| img_tensor = transform(processed).unsqueeze(0).to(device) # (1, C, H, W) | |
| ocr_model.eval() | |
| with torch.no_grad(): | |
| output = ocr_model(img_tensor) # (1, T, C) | |
| log_probs = output.log_softmax(2)[0] # (T, C) | |
| # Decode best beam path (string) | |
| pred_text_raw = decoder.decode(log_probs.cpu().numpy()) | |
| pred_chars = pred_text_raw.replace("<BLANK>", "") | |
| # Remove <BLANK> tokens if present (assuming <BLANK> is in vocab) | |
| pred_text = ''.join([c for c in pred_chars if c != "<BLANK>"]) | |
| # Confidence: mean max prob per timestep | |
| probs = log_probs.exp() | |
| max_probs = probs.max(dim=1)[0] | |
| avg_conf = max_probs.mean().item() | |
| # Color each character (uniform confidence for now) | |
| colorized_chars = [color_char(c, avg_conf) for c in pred_text] | |
| pretty_output = ''.join(colorized_chars) | |
| sim_score = "" | |
| if ground_truth: | |
| similarity = SequenceMatcher(None, ground_truth, pred_text).ratio() | |
| sim_score = f"<br><strong>Levenshtein Similarity:</strong> {similarity:.2%}" | |
| if debug: | |
| print("Decoded Text:", pred_text) | |
| print("Average Confidence:", avg_conf) | |
| if ground_truth: | |
| print("Ground Truth:", ground_truth) | |
| return f"<strong>Prediction:</strong> <strong>{pretty_output}</strong><br><strong>Confidence:</strong> {avg_conf:.2%}{sim_score}" | |
| # New helper function: generate label images grid | |
| CHARS = string.ascii_letters + string.digits + string.punctuation | |
| FONT_SIZE = 32 | |
| PADDING = 8 | |
| LABEL_DIR = "./labels" | |
| def generate_labels(font_file=None, num_labels: int = 25): | |
| global font_path | |
| try: | |
| if font_file and font_file != "None": | |
| font_path = os.path.abspath(font_file) | |
| else: | |
| font_path = None | |
| if font_path is None or not os.path.exists(font_path): | |
| font = ImageFont.load_default() | |
| else: | |
| font = ImageFont.truetype(font_path, 32) | |
| os.makedirs("./labels", exist_ok=True) | |
| labels = ["".join(np.random.choice(list(CHARS), np.random.randint(4, 7))) for _ in range(num_labels)] | |
| images = [] | |
| for label in labels: | |
| bbox = font.getbbox(label) | |
| text_w = bbox[2] - bbox[0] | |
| text_h = bbox[3] - bbox[1] | |
| pad = 8 | |
| img_w = text_w + pad * 2 | |
| img_h = text_h + pad * 2 | |
| img = Image.new("L", (img_w, img_h), color=255) | |
| draw = ImageDraw.Draw(img) | |
| draw.text((pad, pad), label, font=font, fill=0) | |
| safe_label = sanitize_filename(label) | |
| timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f") | |
| label_dir = os.path.join("./labels", safe_label) | |
| os.makedirs(label_dir, exist_ok=True) | |
| filepath = os.path.join(label_dir, f"{timestamp}.png") | |
| img.save(filepath) | |
| images.append(img) | |
| return images | |
| except Exception as e: | |
| print("Error in generate_labels:", e) | |
| error_img = Image.new("RGB", (512, 128), color=(255, 255, 255)) | |
| draw = ImageDraw.Draw(error_img) | |
| draw.text((10, 50), f"Error: {str(e)}", fill=(255, 0, 0)) | |
| return [error_img] | |
| def list_fonts(): | |
| font_dir = "./fonts" | |
| if not os.path.exists(font_dir): | |
| return ["None"] | |
| fonts = [ | |
| (f, os.path.join(font_dir, f)) for f in os.listdir(font_dir) | |
| if f.lower().endswith((".ttf", ".otf")) | |
| ] | |
| return [("None", "None")] + fonts | |
| custom_css = """ | |
| #label-gallery .gallery-item img { | |
| height: 43px; /* 32pt ≈ 43px */ | |
| width: auto; | |
| object-fit: contain; | |
| padding: 4px; | |
| } | |
| #label-gallery { | |
| flex-grow: 1; | |
| overflow-y: auto; | |
| height: 100%; | |
| } | |
| #output-text { | |
| font-size: 12pt; | |
| } | |
| """ | |
| # --------- Updated Gradio UI with new tab --------- # | |
| with gr.Blocks(css=custom_css) as demo: | |
| with gr.Tab("【Train OCR Model】"): | |
| font_file = gr.File(label="Upload .ttf or .otf font", file_types=[".ttf", ".otf"]) | |
| epochs_input = gr.Slider(minimum=1, maximum=4096, value=256, step=1, label="Epochs") | |
| lr_input = gr.Slider(minimum=0.001, maximum=0.1, value=0.05, step=0.001, label="Learning Rate") | |
| train_button = gr.Button("Train OCR Model") | |
| train_status = gr.Textbox(label="Status") | |
| train_button.click(fn=train_model, inputs=[font_file, epochs_input, lr_input], outputs=train_status) | |
| with gr.Tab("【Generate Labels】"): | |
| font_file_labels = gr.Dropdown( | |
| choices=list_fonts(), | |
| label="Optional font for label image", | |
| interactive=True, | |
| ) | |
| num_labels = gr.Number(value=20, label="Number of labels to generate", precision=0, interactive=True) | |
| gen_button = gr.Button("Generate Label Grid") | |
| gen_button.click( | |
| fn=generate_labels, | |
| inputs=[font_file_labels, num_labels], | |
| outputs=gr.Gallery( | |
| label="Generated Labels", | |
| columns=16, # 16 tiles per row | |
| object_fit="contain", # Maintain aspect ratio | |
| height="100%", # Allow full app height | |
| elem_id="label-gallery" # For CSS targeting | |
| ) | |
| ) | |
| with gr.Tab("【Recognize Text】"): | |
| model_list = gr.Dropdown(choices=list_saved_models(), label="Select OCR Model") | |
| refresh_btn = gr.Button("🔄 Refresh Models") | |
| load_model_btn = gr.Button("Load Model") # <-- new button | |
| image_input = gr.Image(type="pil", label="Upload word strip") | |
| predict_btn = gr.Button("Predict") | |
| output_text = gr.HTML(label="Recognized Text", elem_id="output-text") | |
| model_status = gr.Textbox(label="Model Load Status") | |
| # Refresh dropdown choices | |
| refresh_btn.click(fn=lambda: gr.update(choices=list_saved_models()), outputs=model_list) | |
| # Load model on button click, NOT dropdown change | |
| load_model_btn.click(fn=load_model, inputs=model_list, outputs=model_status) | |
| predict_btn.click(fn=predict_text, inputs=image_input, outputs=output_text) | |
| if __name__ == "__main__": | |
| demo.launch() | |