{ "cells": [ { "cell_type": "code", "execution_count": 12, "id": "8e64687d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device: cpu\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import numpy as np\n", "from PIL import Image\n", "import matplotlib.pyplot as plt\n", "import os\n", "import sys\n", "from torchvision import transforms\n", "import torch.nn.functional as F\n", "import pandas as pd\n", "from collections import OrderedDict\n", "import cv2\n", "\n", "sys.path.append('./')\n", "from modules.model import Model\n", "from modules.utils import CTCLabelConverter, AttnLabelConverter\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "print(f'Using device: {device}')" ] }, { "cell_type": "code", "execution_count": 13, "id": "7399f564", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total characters: 132\n", "Numbers: ०१२३४५६७८९0123456789\n", "Language characters sample: अआइईउऊऋएऐओऔअंअःकखगघङचछजझञटठडढणतथदधनपफबभमयरलवशषसहक्षत्रज्ञािीुूृेैोौंःँॅॉ...\n", "Character Set Analysis:\n", "Total unique characters: 132\n", "Numbers: 20 characters\n", "Symbols: 40 characters\n", "Language characters: 72 characters\n", "०१२३४५६७८९0123456789!\"#$%&'()*+,-./:;<=>?@[\\]^_`{}~।॥—‘’“”… अआइईउऊऋएऐओऔअंअःकखगघङचछजझञटठडढणतथदधनपफबभमयरलवशषसहक्षत्रज्ञािीुूृेैोौंःँॅॉ\n" ] } ], "source": [ "class NepaliTextRecognitionConfig:\n", " def __init__(self):\n", " self.number = '०१२३४५६७८९0123456789'\n", " self.symbol = \"!\\\"#$%&'()*+,-./:;<=>?@[\\\\]^_`{}~।॥—‘’“”… \"\n", " self.lang_char = 'अआइईउऊऋएऐओऔअंअःकखगघङचछजझञटठडढणतथदधनपफबभमयरलवशषसहक्षत्रज्ञािीुूृेैोौंःँॅॉ'\n", " \n", " self.character = self.number + self.symbol + self.lang_char\n", " print(f\"Total characters: {len(self.character)}\")\n", " print(f\"Numbers: {self.number}\")\n", " print(f\"Language characters sample: {self.lang_char}...\")\n", " \n", " self.Transformation = 'None'\n", " self.FeatureExtraction = 'ResNet'\n", " self.SequenceModeling = 'BiLSTM'\n", " self.Prediction = 'Attn'\n", " \n", " self.imgH = 80 \n", " self.imgW = 1220\n", " self.input_channel = 3\n", " self.output_channel = 256\n", " self.hidden_size = 256\n", " self.num_fiducial = 20\n", " \n", " self.batch_max_length = 200\n", " self.sensitive = True\n", " self.PAD = False\n", " self.rgb = True\n", " self.contrast_adjust = False\n", " \n", " self.decode = 'greedy' #beam search\n", " self.num_class = len(self.character)\n", "\n", "def print_character_set(config):\n", " print(\"Character Set Analysis:\")\n", " print(f\"Total unique characters: {len(config.character)}\")\n", " print(f\"Numbers: {len(config.number)} characters\")\n", " print(f\"Symbols: {len(config.symbol)} characters\") \n", " print(f\"Language characters: {len(config.lang_char)} characters\")\n", "\n", "config = NepaliTextRecognitionConfig()\n", "print_character_set(config)\n", "print(config.character)" ] }, { "cell_type": "code", "execution_count": 14, "id": "fe642ef1", "metadata": {}, "outputs": [], "source": [ "def load_crnn_model(model_path, config):\n", " if config.Prediction == \"CTC\":\n", " converter = CTCLabelConverter(config.character)\n", " elif config.Prediction == \"Attn\":\n", " converter = AttnLabelConverter(config.character)\n", " config.num_class = len(converter.character)\n", "\n", " # print(f\"Model Configuration:\")\n", " # print(f\" Input size: {config.imgH}x{config.imgW}\")\n", " # print(f\" Input channels: {config.input_channel}\")\n", " # print(f\" Number of classes: {config.num_class}\")\n", " # print(f\" Transformation: {config.Transformation}\")\n", " # print(f\" Feature Extraction: {config.FeatureExtraction}\")\n", " # print(f\" Sequence Modeling: {config.SequenceModeling}\")\n", " # print(f\" Prediction: {config.Prediction}\")\n", " \n", " model = Model(config)\n", " \n", " if os.path.exists(model_path):\n", " #print(f\"Loading model from: {model_path}\")\n", " state_dict = torch.load(model_path, map_location=device)\n", " \n", " if isinstance(state_dict, dict) and 'state_dict' in state_dict:\n", " state_dict = state_dict['state_dict']\n", " \n", " if all(key.startswith('module.') for key in state_dict.keys()):\n", " new_state_dict = OrderedDict()\n", " for k, v in state_dict.items():\n", " name = k[7:]\n", " new_state_dict[name] = v\n", " state_dict = new_state_dict\n", " try:\n", " model.load_state_dict(state_dict)\n", " #print(\"✓ Model loaded successfully!\")\n", " except Exception as e:\n", " print(f\"Error loading state dict: {e}\")\n", " print(\"Trying to load with strict=False...\")\n", " model.load_state_dict(state_dict, strict=False)\n", " print(\"✓ Model loaded with strict=False\")\n", " \n", " else:\n", " raise FileNotFoundError(f\"Model file not found: {model_path}\")\n", " \n", " model = model.to(device)\n", " model.eval()\n", " \n", " return model, converter\n", "\n", "def count_parameters(model):\n", " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "\n", "def center_and_resize_image(img, target_size=(1220, 80)):\n", " \"\"\"\n", " Resize the image to fit inside target_size while maintaining aspect ratio.\n", " If the image is smaller, center it on a black background of target_size.\n", " \"\"\"\n", " if isinstance(img, str):\n", " img = Image.open(img)\n", " target_w, target_h = target_size\n", " img_w, img_h = img.size\n", "\n", " if img_w > target_w or img_h > target_h:\n", " img.thumbnail((target_w, target_h), Image.LANCZOS)\n", " new_img = Image.new(\"RGB\", (target_w, target_h), color=\"black\")\n", " paste_x = (target_w - img.width) // 2\n", " paste_y = (target_h - img.height) // 2\n", " new_img.paste(img, (paste_x, paste_y))\n", "\n", " return new_img\n", "\n", "def preprocess_crnn_image(image_path, config, return_tensor=True):\n", " \"\"\"\n", " Complete preprocessing pipeline for the text recognition model.\n", " \n", " Args:\n", " image_path: Path to image or PIL Image\n", " config: Model configuration object\n", " return_tensor: Whether to return PyTorch tensor or PIL Image\n", " \n", " Returns:\n", " Preprocessed image tensor or PIL Image\n", " \"\"\"\n", " input_channels = getattr(config, 'input_channel', 1)\n", " \n", " processed_pil = center_and_resize_image(image_path, (config.imgW, config.imgH))\n", " \n", " if not return_tensor:\n", " return processed_pil\n", " \n", " image_np = np.array(processed_pil)\n", " \n", " image_np = image_np.astype(np.float32) / 255.0\n", " \n", " if getattr(config, 'contrast_adjust', False):\n", " image_np = (image_np - np.mean(image_np)) / (np.std(image_np) + 1e-8)\n", " \n", " if input_channels == 1:\n", " image_tensor = torch.from_numpy(image_np).unsqueeze(0) \n", " image_tensor = (image_tensor - 0.5) / 0.5\n", " else:\n", " transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", " ])\n", " image_tensor = transform(processed_pil)\n", " \n", " # Add batch dimension: (1, C, H, W)\n", " image_tensor = image_tensor.unsqueeze(0)\n", " \n", " return image_tensor\n", "\n", "def visualize_crnn_preprocessing(image_path, config):\n", "\n", " \"\"\"Visualize the original and preprocessed image for text\"\"\"\n", " \n", " original_image = Image.open(image_path)\n", " if config.rgb:\n", " original_image = original_image.convert('RGB')\n", " else:\n", " original_image = original_image.convert('L')\n", " \n", " processed_tensor = preprocess_crnn_image(image_path, config)\n", " processed_image = processed_tensor.squeeze(0).permute(1, 2, 0).numpy()\n", " \n", " if config.rgb:\n", " # Denormalize for visualization\n", " mean = np.array([0.485, 0.456, 0.406])\n", " std = np.array([0.229, 0.224, 0.225])\n", " processed_image = processed_image * std + mean\n", " processed_image = np.clip(processed_image, 0, 1)\n", " else:\n", " processed_image = (processed_image * 0.5) + 0.5\n", " processed_image = np.clip(processed_image, 0, 1)\n", " \n", " fig, axes = plt.subplots(1, 2, figsize=(15, 6))\n", " \n", " # Original image\n", " axes[0].imshow(original_image, cmap='gray' if not config.rgb else None)\n", " axes[0].set_title(f'Original Image\\nSize: {original_image.size}')\n", " axes[0].axis('off')\n", " \n", " # Processed image\n", " if config.rgb:\n", " axes[1].imshow(processed_image)\n", " else:\n", " axes[1].imshow(processed_image[:, :, 0], cmap='gray')\n", " axes[1].set_title(f'Preprocessed Image\\nSize: {config.imgW}x{config.imgH}')\n", " axes[1].axis('off')\n", " \n", " plt.tight_layout()\n", " plt.show()\n", " \n", " #print(f\"Original image size: {original_image.size}\")\n", " #print(f\"Processed tensor shape: {processed_tensor.shape}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "7d85fa5b", "metadata": {}, "outputs": [], "source": [ "def inference_crnn_single_image(model, converter, image_path, config):\n", " \n", " image_tensor = preprocess_crnn_image(image_path, config)\n", " image_tensor = image_tensor.to(device)\n", " batch_size = image_tensor.size(0)\n", " text_for_pred = torch.LongTensor(batch_size, config.batch_max_length).fill_(0).to(device)\n", " with torch.no_grad():\n", " #preds = model(image_tensor, text_for_pred)\n", " preds = model(image_tensor, text_for_pred, is_train=False)\n", " # Greedy decoding for CTC\n", " if config.Prediction == \"CTC\":\n", " preds_size = torch.IntTensor([preds.size(1)] * batch_size)\n", " _, preds_index = preds.max(2)\n", " preds_index = preds_index.view(-1)\n", " preds_str = converter.decode_greedy(preds_index.data, preds_size.data)\n", " return preds_str\n", " elif config.Prediction == \"Attn\":\n", " # remove last timestep to align with loss logic\n", " preds = preds[:, :config.batch_max_length - 1, :]\n", "\n", " # greedy index selection\n", " _, preds_index = preds.max(2)\n", "\n", " # decode raw strings\n", " preds_str = converter.decode(preds_index, torch.IntTensor([config.batch_max_length] * batch_size))\n", "\n", " # softmax + max prob (same as your snippet)\n", " preds_prob = F.softmax(preds, dim=2)\n", " preds_max_prob, _ = preds_prob.max(dim=2)\n", "\n", " final_result = []\n", " for pred, pred_max_prob in zip(preds_str, preds_max_prob):\n", " eos_pos = pred.find('[s]')\n", " if eos_pos != -1:\n", " pred = pred[:eos_pos]\n", " pred_max_prob = pred_max_prob[:eos_pos]\n", "\n", " final_result.append(pred)\n", "\n", " return final_result\n", "\n", " #preds_str = converter.decode_beamsearch(preds, beamWidth=4)\n", " #preds_prob = F.softmax(preds, dim=2)\n", " #preds_max_prob, _ = preds_prob.max(dim=2)\n", " \n", "\n", " return None\n", " \n", "all_entries = os.listdir(\"crops\")\n", "files_only = [os.path.join(\"crops\", entry) for entry in all_entries if os.path.isfile(os.path.join(\"crops\", entry))]\n", "\n", "print(config.Prediction)\n", "for index, string in enumerate(files_only):\n", " if index == 1:\n", " break\n", " visualize_crnn_preprocessing(string, config)\n", " model, converter = load_crnn_model(\"models/iter_60000.pth\",config)\n", " print(inference_crnn_single_image(model,converter, string,config))" ] }, { "cell_type": "code", "execution_count": 11, "id": "35b4c6b9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "No Transformation module specified\n", "(16, 3, 80, 1220)\n", "Model exported to ONNX format: models/ResNetBiLSTMCTCv1Batch16.onnx\n", "✓ ONNX model validation passed!\n", "ONNX Model Information:\n", " Input shape: (16, 3, 80, 1220)\n", " Output classes: 133\n", " Opset version: 11\n" ] }, { "data": { "text/plain": [ "'models/ResNetBiLSTMCTCv1Batch16.onnx'" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch.onnx\n", "import onnx\n", "from onnx import checker\n", "#CTC\n", "def export_crnn_to_onnx(model, converter, config, output_path, input_shape=None,batch_size=1, opset_version=11):\n", " \"\"\"\n", " Export CRNN text recognition model to ONNX format.\n", " \n", " Args:\n", " model: The PyTorch CRNN model\n", " converter: The label converter (CTCLabelConverter or AttnLabelConverter)\n", " config: Model configuration object\n", " output_path (str): Path where to save the ONNX model\n", " input_shape (tuple): Input tensor shape (batch_size, channels, height, width)\n", " Default: (1, 3, 80, 1220) based on config\n", " opset_version (int): ONNX opset version (default: 11)\n", " \n", " Returns:\n", " str: Path to the exported ONNX model\n", " \"\"\"\n", " model.eval()\n", " \n", " if input_shape is None:\n", " input_shape = (batch_size, config.input_channel, config.imgH, config.imgW)\n", " print(input_shape)\n", " #dummy_input = torch.randn(input_shape, device=device)\n", " dummy_input = torch.randn(batch_size, config.input_channel, config.imgH, config.imgW).to(device)\n", " dummy_text = torch.LongTensor(batch_size, config.batch_max_length).fill_(0).to(device) #CTC\n", " #dummy_text = torch.LongTensor(1, config.batch_max_length +1).fill_(0).to(device)# Attn\n", "\n", " with torch.no_grad():\n", " torch.onnx.export(\n", " model,\n", " (dummy_input, dummy_text),\n", " output_path,\n", " export_params=True,\n", " opset_version=opset_version,\n", " #do_constant_folding=True,\n", " input_names=['input', 'text'],\n", " output_names=['output'],\n", " # dynamic_axes={\n", " # 'input': {0: 'batch_size'},\n", " # 'output': {0: 'batch_size'}\n", " # },\n", " verbose=False\n", " )\n", " \n", " print(f\"Model exported to ONNX format: {output_path}\")\n", " \n", " try:\n", " onnx_model = onnx.load(output_path)\n", " checker.check_model(onnx_model)\n", " print(\"✓ ONNX model validation passed!\")\n", " \n", " print(f\"ONNX Model Information:\")\n", " print(f\" Input shape: {input_shape}\")\n", " print(f\" Output classes: {len(converter.character)}\")\n", " print(f\" Opset version: {opset_version}\")\n", " \n", " return output_path\n", " \n", " except Exception as e:\n", " print(f\"❌ ONNX model validation failed: {e}\")\n", " return None\n", "\n", "model, converter = load_crnn_model(\"models/iter_60000.pth\",config)\n", "\n", "export_crnn_to_onnx(model, converter,config,\"models/ResNetBiLSTMCTCv1Batch16.onnx\", batch_size=16)" ] }, { "cell_type": "code", "execution_count": 16, "id": "dd4d6dc9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "No Transformation module specified\n", "(32, 3, 80, 1220)\n", "False\n", "Model exported to ONNX format: models/ResNetBiLSTMAttnv1Batch32.onnx\n", "✓ ONNX model validation passed!\n", "ONNX Model Information:\n", " Input shape: (32, 3, 80, 1220)\n", " Output classes: 134\n", " Opset version: 11\n" ] }, { "data": { "text/plain": [ "'models/ResNetBiLSTMAttnv1Batch32.onnx'" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch.onnx\n", "import onnx\n", "from onnx import checker\n", "#Attn\n", "def export_crnn_to_onnx(model, converter, config, output_path, input_shape=None, opset_version=11,batch_size=1):\n", " \"\"\"\n", " Export CRNN text recognition model to ONNX format.\n", " \n", " Args:\n", " model: The PyTorch CRNN model\n", " converter: The label converter (CTCLabelConverter or AttnLabelConverter)\n", " config: Model configuration object\n", " output_path (str): Path where to save the ONNX model\n", " input_shape (tuple): Input tensor shape (batch_size, channels, height, width)\n", " Default: (1, 3, 80, 1220) based on config\n", " opset_version (int): ONNX opset version (default: 11)\n", " \n", " Returns:\n", " str: Path to the exported ONNX model\n", " \"\"\"\n", " model.eval()\n", " \n", " if input_shape is None:\n", " input_shape = (batch_size, config.input_channel, config.imgH, config.imgW)\n", " print(input_shape)\n", " #dummy_input = torch.randn(input_shape, device=device)\n", " dummy_input = torch.randn(batch_size, config.input_channel, config.imgH, config.imgW).to(device)\n", " #dummy_text = torch.LongTensor(1, config.batch_max_length).fill_(0).to(device) #CTC\n", " dummy_text = torch.LongTensor(batch_size, config.batch_max_length +1).fill_(0).to(device)# Attn\n", "\n", " with torch.no_grad():\n", " torch.onnx.export(\n", " model,\n", " (dummy_input),\n", " output_path,\n", " export_params=True,\n", " opset_version=opset_version,\n", " #do_constant_folding=True,\n", " input_names=['input'],\n", " output_names=['output'],\n", " # dynamic_axes={\n", " # 'input': {0: 'batch_size'},\n", " # 'output': {0: 'batch_size', 1: 'sequence_length'}\n", " #},\n", " verbose=False\n", " )\n", " \n", " print(f\"Model exported to ONNX format: {output_path}\")\n", " \n", " try:\n", " onnx_model = onnx.load(output_path)\n", " checker.check_model(onnx_model)\n", " print(\"✓ ONNX model validation passed!\")\n", " \n", " print(f\"ONNX Model Information:\")\n", " print(f\" Input shape: {input_shape}\")\n", " print(f\" Output classes: {len(converter.character)}\")\n", " print(f\" Opset version: {opset_version}\")\n", " \n", " return output_path\n", " \n", " except Exception as e:\n", " print(f\"❌ ONNX model validation failed: {e}\")\n", " return None\n", "\n", "model, converter = load_crnn_model(\"models/iter_70000.pth\",config)\n", "\n", "export_crnn_to_onnx(model, converter,config,\"models/ResNetBiLSTMAttnv1Batch32.onnx\",batch_size=32)" ] }, { "cell_type": "code", "execution_count": null, "id": "a961d8f0", "metadata": {}, "outputs": [], "source": [ "\n", "import onnxruntime as ort\n", "import torch\n", "import torch.nn.functional as F\n", "import numpy as np\n", "import os\n", "\n", "def load_onnx_model(onnx_path):\n", " \"\"\"Load ONNX model as a runtime session.\"\"\"\n", " session = ort.InferenceSession(onnx_path)\n", " return session\n", "\n", "def inference_crnn_single_image_onnx(session, converter, image_path, config):\n", " # Preprocess image\n", " image_tensor = preprocess_crnn_image(image_path, config) # [1, C, H, W]\n", " image_np = image_tensor.cpu().numpy().astype(np.float32)\n", "\n", " batch_size = image_np.shape[0]\n", "\n", " # Run ONNX inference\n", " if config.Prediction == \"Attn\":\n", " text_input = np.zeros((batch_size, config.batch_max_length+1), dtype=np.int64)\n", " outputs = session.run(None, {\"input\": image_np })\n", " preds = outputs[0] # [batch, seq_len, num_classes]\n", " else:\n", " outputs = session.run(None, {\"input\": image_np })\n", " preds = outputs[0] # [batch, seq_len, num_classes]\n", "\n", " if config.Prediction == \"CTC\":\n", " preds_size = np.array([preds.shape[1]] * batch_size, dtype=np.int32)\n", " preds_index = preds.argmax(axis=2).flatten()\n", " preds_str = converter.decode_greedy(preds_index, preds_size)\n", " return preds_str\n", "\n", " elif config.Prediction == \"Attn\":\n", " preds = preds[:, :config.batch_max_length - 1, :]\n", " preds_index = preds.argmax(axis=2)\n", " preds_str = converter.decode(torch.from_numpy(preds_index), \n", " torch.IntTensor([config.batch_max_length] * batch_size))\n", "\n", " preds_prob = F.softmax(torch.from_numpy(preds), dim=2)\n", " preds_max_prob, _ = preds_prob.max(dim=2)\n", "\n", " final_result = []\n", " for pred, pred_max_prob in zip(preds_str, preds_max_prob):\n", " eos_pos = pred.find('[s]')\n", " if eos_pos != -1:\n", " pred = pred[:eos_pos]\n", " pred_max_prob = pred_max_prob[:eos_pos]\n", " final_result.append(pred)\n", " return final_result\n", "\n", " return None\n", "\n", "# =========================\n", "# Example usage\n", "# =========================\n", "model, converter = load_crnn_model(\"models/iter_60000.pth\",config)\n", "onnx_session = load_onnx_model(\"models/ResNetBiLSTMCTCv1.onnx\")\n", "\n", "all_entries = os.listdir(\"crops\")\n", "files_only = [os.path.join(\"crops\", entry) for entry in all_entries if os.path.isfile(os.path.join(\"crops\", entry))]\n", "\n", "for index, string in enumerate(files_only):\n", " if index == 5:\n", " break\n", " visualize_crnn_preprocessing(string, config)\n", " print(inference_crnn_single_image_onnx(onnx_session, converter, string, config))\n" ] }, { "cell_type": "code", "execution_count": null, "id": "86ac9a27", "metadata": {}, "outputs": [], "source": [ "\n", "import onnxruntime as ort\n", "import torch\n", "import torch.nn.functional as F\n", "import numpy as np\n", "import os\n", "\n", "def load_onnx_model(onnx_path):\n", " \"\"\"Load ONNX model as a runtime session.\"\"\"\n", " session = ort.InferenceSession(onnx_path)\n", " return session\n", "\n", "def inference_crnn_single_image_onnx(session, converter, image_path, config):\n", " # Preprocess image\n", " image_tensor = preprocess_crnn_image(image_path, config) # [1, C, H, W]\n", " image_np = image_tensor.cpu().numpy().astype(np.float32)\n", "\n", " batch_size = image_np.shape[0]\n", "\n", " # Run ONNX inference\n", " if config.Prediction == \"Attn\":\n", " text_input = np.zeros((batch_size, config.batch_max_length+1), dtype=np.int64)\n", " outputs = session.run(None, {\"input\": image_np})\n", " preds = outputs[0] # [batch, seq_len, num_classes]\n", " else:\n", " outputs = session.run(None, {\"input\": image_np })\n", " preds = outputs[0] # [batch, seq_len, num_classes]\n", "\n", " if config.Prediction == \"CTC\":\n", " preds_size = np.array([preds.shape[1]] * batch_size, dtype=np.int32)\n", " preds_index = preds.argmax(axis=2).flatten()\n", " preds_str = converter.decode_greedy(preds_index, preds_size)\n", " return preds_str\n", "\n", " elif config.Prediction == \"Attn\":\n", " preds = preds[:, :config.batch_max_length - 1, :]\n", " preds_index = preds.argmax(axis=2)\n", " preds_str = converter.decode(torch.from_numpy(preds_index), \n", " torch.IntTensor([config.batch_max_length] * batch_size))\n", "\n", " preds_prob = F.softmax(torch.from_numpy(preds), dim=2)\n", " preds_max_prob, _ = preds_prob.max(dim=2)\n", "\n", " final_result = []\n", " for pred, pred_max_prob in zip(preds_str, preds_max_prob):\n", " eos_pos = pred.find('[s]')\n", " if eos_pos != -1:\n", " pred = pred[:eos_pos]\n", " pred_max_prob = pred_max_prob[:eos_pos]\n", " final_result.append(pred)\n", " return final_result\n", "\n", " return None\n", "\n", "# =========================\n", "# Example usage\n", "# =========================\n", "model, converter = load_crnn_model(\"models/iter_70000.pth\",config)\n", "onnx_session = load_onnx_model(\"models/ResNetBiLSTMAttnv1Batch32.onnx\")\n", "\n", "all_entries = os.listdir(\"crops\")\n", "files_only = [os.path.join(\"crops\", entry) for entry in all_entries if os.path.isfile(os.path.join(\"crops\", entry))]\n", "\n", "for index, string in enumerate(files_only):\n", " if index == 5:\n", " break\n", " visualize_crnn_preprocessing(string, config)\n", " print(inference_crnn_single_image_onnx(onnx_session, converter, string, config))\n" ] }, { "cell_type": "code", "execution_count": null, "id": "24cc9f37", "metadata": {}, "outputs": [], "source": [ "\n", "import onnxruntime as ort\n", "import torch\n", "import torch.nn.functional as F\n", "import numpy as np\n", "import os\n", "#Batch\n", "def load_onnx_model(onnx_path):\n", " \"\"\"Load ONNX model as a runtime session.\"\"\"\n", " session = ort.InferenceSession(onnx_path)\n", " return session\n", "\n", "def preprocess_crnn_image_batch(image_paths, config):\n", " matrix = [preprocess_crnn_image(image_path,config)[0] for image_path in image_paths]\n", " image_tensors = torch.stack(matrix, dim=0)\n", " return image_tensors\n", "\n", "def inference_crnn_batch_image_onnx(session, converter, image_paths, config):\n", " # Preprocess image\n", " image_tensors = preprocess_crnn_image_batch(image_paths, config)\n", "\n", " # batch_size = image_tensors.shape[0]\n", "\n", " # if batch_size < 16:\n", " # pad = 16 - batch_size\n", " # padding = torch.empty(\n", " # pad,\n", " # config.input_channel,\n", " # config.imgH,\n", " # config.imgW,\n", " # device=image_tensors.device,\n", " # dtype=image_tensors.dtype\n", " # )\n", " # image_tensors = torch.cat([image_tensors, padding], dim=0)\n", "\n", " print(image_tensors.shape)\n", " image_np = image_tensors.cpu().numpy().astype(np.float32)\n", "\n", " batch_size = image_np.shape[0]\n", "\n", " # Run ONNX inference\n", " if config.Prediction == \"Attn\":\n", " text_input = np.zeros((batch_size, config.batch_max_length+1), dtype=np.int64)\n", " outputs = session.run(None, {\"input\": image_np })\n", " preds = outputs[0] # [batch, seq_len, num_classes]\n", " else:\n", " outputs = session.run(None, {\"input\": image_np })\n", " preds = outputs[0] # [batch, seq_len, num_classes]\n", "\n", " if config.Prediction == \"CTC\":\n", " preds_size = np.array([preds.shape[1]] * batch_size, dtype=np.int32)\n", " preds_index = preds.argmax(axis=2).flatten()\n", " preds_str = converter.decode_greedy(preds_index, preds_size)\n", " return preds_str\n", "\n", " elif config.Prediction == \"Attn\":\n", " preds = preds[:, :config.batch_max_length - 1, :]\n", " preds_index = preds.argmax(axis=2)\n", " preds_str = converter.decode(torch.from_numpy(preds_index), \n", " torch.IntTensor([config.batch_max_length] * batch_size))\n", "\n", " preds_prob = F.softmax(torch.from_numpy(preds), dim=2)\n", " preds_max_prob, _ = preds_prob.max(dim=2)\n", "\n", " final_result = []\n", " for pred, pred_max_prob in zip(preds_str, preds_max_prob):\n", " eos_pos = pred.find('[s]')\n", " if eos_pos != -1:\n", " pred = pred[:eos_pos]\n", " pred_max_prob = pred_max_prob[:eos_pos]\n", " final_result.append(pred)\n", " return final_result\n", "\n", " return None\n", "\n", "# =========================\n", "# Example usage\n", "# =========================\n", "model, converter = load_crnn_model(\"models/iter_70000.pth\",config)\n", "onnx_session = load_onnx_model(\"models/ResNetBiLSTMAttnv2Batch.onnx\")\n", "\n", "all_entries = os.listdir(\"crops\")\n", "files_only = [os.path.join(\"crops\", entry) for entry in all_entries if os.path.isfile(os.path.join(\"crops\", entry))]\n", "\n", "for index, string in enumerate(files_only):\n", " if index == 1:\n", " break\n", " string = [string, files_only[index+1]]\n", " #visualize_crnn_preprocessing(string, config)\n", " print(inference_crnn_batch_image_onnx(onnx_session, converter, string, config))\n" ] } ], "metadata": { "kernelspec": { "display_name": "OCRTrainer", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.19" } }, "nbformat": 4, "nbformat_minor": 5 }