donghyun
commited on
Commit
ยท
8672bad
1
Parent(s):
1a7b7d2
Add OCR code, modules, and weights
Browse files- .gitignore +22 -0
- README.md +2 -0
- ai_modules/__init__.py +60 -0
- ai_modules/models/HRCenterNet.py +204 -0
- ai_modules/models/__init__.py +5 -0
- ai_modules/models/modules.py +111 -0
- ai_modules/models/resnet.py +186 -0
- ai_modules/nlp/__init__.py +26 -0
- ai_modules/nlp/mlm_predictor.py +118 -0
- ai_modules/nlp/punctuation_restorer.py +326 -0
- ai_modules/nlp/utils.py +87 -0
- ai_modules/nlp_engine.py +321 -0
- ai_modules/ocr_engine.py +767 -0
- ai_modules/preprocessor_unified.py +605 -0
- dong_ocr.py +349 -0
- requirements.txt +8 -0
- weights/best.pth +3 -0
- weights/best_5000.pt +3 -0
.gitignore
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
*.pyo
|
| 5 |
+
|
| 6 |
+
# Environment & Secrets
|
| 7 |
+
.env
|
| 8 |
+
*.json
|
| 9 |
+
weights/*.json
|
| 10 |
+
|
| 11 |
+
# Logs & Temp
|
| 12 |
+
*.log
|
| 13 |
+
*.tmp
|
| 14 |
+
*.temp
|
| 15 |
+
|
| 16 |
+
# Output files
|
| 17 |
+
*_bbox.*
|
| 18 |
+
*_ocr_result.json
|
| 19 |
+
|
| 20 |
+
# OS files
|
| 21 |
+
.DS_Store
|
| 22 |
+
Thumbs.db
|
README.md
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
ai_modules/__init__.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
================================================================================
|
| 4 |
+
Epitext AI Unified Preprocessing Module
|
| 5 |
+
================================================================================
|
| 6 |
+
|
| 7 |
+
ํตํฉ ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ ํจํค์ง (Swin Gray + OCR ๋์ ์์ฑ)
|
| 8 |
+
|
| 9 |
+
ํ ๋ฒ์ ํจ์ ํธ์ถ๋ก ๋ ๊ฐ์ง ์ ์ฒ๋ฆฌ ์๋ฃ:
|
| 10 |
+
|
| 11 |
+
1๏ธโฃ Swin Gray: ๊ทธ๋ ์ด ๋น์ด์งํ (์ ๋ณด ์์ค ์ต์) โ JPG 3์ฑ๋
|
| 12 |
+
|
| 13 |
+
2๏ธโฃ OCR: ์ด์งํ (๋ช
ํํ ํ๋ฐฑ) โ PNG 1์ฑ๋
|
| 14 |
+
|
| 15 |
+
๋ฒ์ : 1.0.0
|
| 16 |
+
์ํ: โ
Production Ready
|
| 17 |
+
|
| 18 |
+
์ฃผ์ ํน์ง:
|
| 19 |
+
|
| 20 |
+
โ
ํจ์จ์ฑ: ์์ญ ๊ฒ์ถ 1ํ (๋ ๊ฐ์ง ๋ชจ๋ ์ฌ์ฉ)
|
| 21 |
+
|
| 22 |
+
โ
๋ฐฐ๊ฒฝ ๋ณด์ฅ: Swin (๋ฐ์) + OCR (ํ์์)
|
| 23 |
+
|
| 24 |
+
โ
ํ๋ณธ ์ง์: ์๋ ๊ฒ์ถ ์ต์
|
| 25 |
+
|
| 26 |
+
โ
์ค์ ๊ฐ๋ฅ: JSON ๊ธฐ๋ฐ ์ปค์คํฐ๋ง์ด์ง
|
| 27 |
+
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
from .preprocessor_unified import (
|
| 31 |
+
UnifiedImagePreprocessor,
|
| 32 |
+
get_preprocessor,
|
| 33 |
+
preprocess_image_unified
|
| 34 |
+
)
|
| 35 |
+
from .ocr_engine import (
|
| 36 |
+
get_ocr_engine,
|
| 37 |
+
OCREngine,
|
| 38 |
+
ocr_and_detect
|
| 39 |
+
)
|
| 40 |
+
from .nlp_engine import (
|
| 41 |
+
get_nlp_engine,
|
| 42 |
+
NLPEngine,
|
| 43 |
+
process_text_with_nlp
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
__version__ = "1.0.0"
|
| 47 |
+
__author__ = "Epitext Team"
|
| 48 |
+
|
| 49 |
+
__all__ = [
|
| 50 |
+
"UnifiedImagePreprocessor",
|
| 51 |
+
"get_preprocessor",
|
| 52 |
+
"preprocess_image_unified",
|
| 53 |
+
"get_ocr_engine",
|
| 54 |
+
"OCREngine",
|
| 55 |
+
"ocr_and_detect",
|
| 56 |
+
"get_nlp_engine",
|
| 57 |
+
"NLPEngine",
|
| 58 |
+
"process_text_with_nlp"
|
| 59 |
+
]
|
| 60 |
+
|
ai_modules/models/HRCenterNet.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from ai_modules.models.modules import BasicBlock, Bottleneck
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class StageModule(nn.Module):
|
| 7 |
+
def __init__(self, stage, output_branches, c, bn_momentum):
|
| 8 |
+
super(StageModule, self).__init__()
|
| 9 |
+
self.stage = stage
|
| 10 |
+
self.output_branches = output_branches
|
| 11 |
+
|
| 12 |
+
self.branches = nn.ModuleList()
|
| 13 |
+
for i in range(self.stage):
|
| 14 |
+
w = c * (2 ** i)
|
| 15 |
+
branch = nn.Sequential(
|
| 16 |
+
BasicBlock(w, w, bn_momentum=bn_momentum),
|
| 17 |
+
BasicBlock(w, w, bn_momentum=bn_momentum),
|
| 18 |
+
BasicBlock(w, w, bn_momentum=bn_momentum),
|
| 19 |
+
BasicBlock(w, w, bn_momentum=bn_momentum),
|
| 20 |
+
)
|
| 21 |
+
self.branches.append(branch)
|
| 22 |
+
|
| 23 |
+
self.fuse_layers = nn.ModuleList()
|
| 24 |
+
# for each output_branches (i.e. each branch in all cases but the very last one)
|
| 25 |
+
for i in range(self.output_branches):
|
| 26 |
+
self.fuse_layers.append(nn.ModuleList())
|
| 27 |
+
for j in range(self.stage): # for each branch
|
| 28 |
+
if i == j:
|
| 29 |
+
self.fuse_layers[-1].append(nn.Sequential()) # Used in place of "None" because it is callable
|
| 30 |
+
elif i < j:
|
| 31 |
+
self.fuse_layers[-1].append(nn.Sequential(
|
| 32 |
+
nn.Conv2d(c * (2 ** j), c * (2 ** i), kernel_size=(1, 1), stride=(1, 1), bias=False),
|
| 33 |
+
nn.BatchNorm2d(c * (2 ** i), eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
|
| 34 |
+
nn.Upsample(scale_factor=(2.0 ** (j - i)), mode='nearest'),
|
| 35 |
+
))
|
| 36 |
+
elif i > j:
|
| 37 |
+
ops = []
|
| 38 |
+
for k in range(i - j - 1):
|
| 39 |
+
ops.append(nn.Sequential(
|
| 40 |
+
nn.Conv2d(c * (2 ** j), c * (2 ** j), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1),
|
| 41 |
+
bias=False),
|
| 42 |
+
nn.BatchNorm2d(c * (2 ** j), eps=1e-05, momentum=0.1, affine=True,
|
| 43 |
+
track_running_stats=True),
|
| 44 |
+
nn.ReLU(inplace=True),
|
| 45 |
+
))
|
| 46 |
+
ops.append(nn.Sequential(
|
| 47 |
+
nn.Conv2d(c * (2 ** j), c * (2 ** i), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1),
|
| 48 |
+
bias=False),
|
| 49 |
+
nn.BatchNorm2d(c * (2 ** i), eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
|
| 50 |
+
))
|
| 51 |
+
self.fuse_layers[-1].append(nn.Sequential(*ops))
|
| 52 |
+
|
| 53 |
+
self.relu = nn.ReLU(inplace=True)
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
assert len(self.branches) == len(x)
|
| 57 |
+
|
| 58 |
+
x = [branch(b) for branch, b in zip(self.branches, x)]
|
| 59 |
+
|
| 60 |
+
x_fused = []
|
| 61 |
+
for i in range(len(self.fuse_layers)):
|
| 62 |
+
for j in range(0, len(self.branches)):
|
| 63 |
+
if j == 0:
|
| 64 |
+
x_fused.append(self.fuse_layers[i][0](x[0]))
|
| 65 |
+
else:
|
| 66 |
+
x_fused[i] = x_fused[i] + self.fuse_layers[i][j](x[j])
|
| 67 |
+
|
| 68 |
+
for i in range(len(x_fused)):
|
| 69 |
+
x_fused[i] = self.relu(x_fused[i])
|
| 70 |
+
|
| 71 |
+
return x_fused
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class _HRCenterNet(nn.Module):
|
| 75 |
+
def __init__(self, c=48, nof_joints=17, bn_momentum=0.1):
|
| 76 |
+
super(_HRCenterNet, self).__init__()
|
| 77 |
+
|
| 78 |
+
# Input (stem net)
|
| 79 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
|
| 80 |
+
self.bn1 = nn.BatchNorm2d(64, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True)
|
| 81 |
+
self.conv2 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
|
| 82 |
+
self.bn2 = nn.BatchNorm2d(64, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True)
|
| 83 |
+
self.relu = nn.ReLU(inplace=True)
|
| 84 |
+
|
| 85 |
+
# Stage 1 (layer1) - First group of bottleneck (resnet) modules
|
| 86 |
+
downsample = nn.Sequential(
|
| 87 |
+
nn.Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False),
|
| 88 |
+
nn.BatchNorm2d(256, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
|
| 89 |
+
)
|
| 90 |
+
self.layer1 = nn.Sequential(
|
| 91 |
+
Bottleneck(64, 64, downsample=downsample),
|
| 92 |
+
Bottleneck(256, 64),
|
| 93 |
+
Bottleneck(256, 64),
|
| 94 |
+
Bottleneck(256, 64),
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Fusion layer 1 (transition1) - Creation of the first two branches (one full and one half resolution)
|
| 98 |
+
self.transition1 = nn.ModuleList([
|
| 99 |
+
nn.Sequential(
|
| 100 |
+
nn.Conv2d(256, c, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
|
| 101 |
+
nn.BatchNorm2d(c, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
|
| 102 |
+
nn.ReLU(inplace=True),
|
| 103 |
+
),
|
| 104 |
+
nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights
|
| 105 |
+
nn.Conv2d(256, c * (2 ** 1), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
|
| 106 |
+
nn.BatchNorm2d(c * (2 ** 1), eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
|
| 107 |
+
nn.ReLU(inplace=True),
|
| 108 |
+
)),
|
| 109 |
+
])
|
| 110 |
+
|
| 111 |
+
# Stage 2 (stage2) - Second module with 1 group of bottleneck (resnet) modules. This has 2 branches
|
| 112 |
+
self.stage2 = nn.Sequential(
|
| 113 |
+
StageModule(stage=2, output_branches=2, c=c, bn_momentum=bn_momentum),
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Fusion layer 2 (transition2) - Creation of the third branch (1/4 resolution)
|
| 117 |
+
self.transition2 = nn.ModuleList([
|
| 118 |
+
nn.Sequential(), # None, - Used in place of "None" because it is callable
|
| 119 |
+
nn.Sequential(), # None, - Used in place of "None" because it is callable
|
| 120 |
+
nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights
|
| 121 |
+
nn.Conv2d(c * (2 ** 1), c * (2 ** 2), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
|
| 122 |
+
nn.BatchNorm2d(c * (2 ** 2), eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
|
| 123 |
+
nn.ReLU(inplace=True),
|
| 124 |
+
)), # ToDo Why the new branch derives from the "upper" branch only?
|
| 125 |
+
])
|
| 126 |
+
|
| 127 |
+
# Stage 3 (stage3) - Third module with 4 groups of bottleneck (resnet) modules. This has 3 branches
|
| 128 |
+
self.stage3 = nn.Sequential(
|
| 129 |
+
StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
|
| 130 |
+
StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
|
| 131 |
+
StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
|
| 132 |
+
StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# Fusion layer 3 (transition3) - Creation of the fourth branch (1/8 resolution)
|
| 136 |
+
self.transition3 = nn.ModuleList([
|
| 137 |
+
nn.Sequential(), # None, - Used in place of "None" because it is callable
|
| 138 |
+
nn.Sequential(), # None, - Used in place of "None" because it is callable
|
| 139 |
+
nn.Sequential(), # None, - Used in place of "None" because it is callable
|
| 140 |
+
nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights
|
| 141 |
+
nn.Conv2d(c * (2 ** 2), c * (2 ** 3), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
|
| 142 |
+
nn.BatchNorm2d(c * (2 ** 3), eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
|
| 143 |
+
nn.ReLU(inplace=True),
|
| 144 |
+
)), # ToDo Why the new branch derives from the "upper" branch only?
|
| 145 |
+
])
|
| 146 |
+
|
| 147 |
+
# Stage 4 (stage4) - Fourth module with 3 groups of bottleneck (resnet) modules. This has 4 branches
|
| 148 |
+
self.stage4 = nn.Sequential(
|
| 149 |
+
StageModule(stage=4, output_branches=4, c=c, bn_momentum=bn_momentum),
|
| 150 |
+
StageModule(stage=4, output_branches=4, c=c, bn_momentum=bn_momentum),
|
| 151 |
+
StageModule(stage=4, output_branches=1, c=c, bn_momentum=bn_momentum),
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Final layer (final_layer)
|
| 155 |
+
self.final_layer = nn.Sequential(
|
| 156 |
+
nn.Conv2d(c, 32, kernel_size=(1, 1), stride=(1, 1)),
|
| 157 |
+
nn.BatchNorm2d(32, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
|
| 158 |
+
nn.ReLU(inplace=True),
|
| 159 |
+
nn.Conv2d(32, nof_joints, kernel_size=(1, 1), stride=(1, 1)),
|
| 160 |
+
nn.Sigmoid()
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
def forward(self, x):
|
| 164 |
+
x = self.conv1(x)
|
| 165 |
+
x = self.bn1(x)
|
| 166 |
+
x = self.relu(x)
|
| 167 |
+
x = self.conv2(x)
|
| 168 |
+
x = self.bn2(x)
|
| 169 |
+
x = self.relu(x)
|
| 170 |
+
|
| 171 |
+
x = self.layer1(x)
|
| 172 |
+
x = [trans(x) for trans in self.transition1] # Since now, x is a list (# == nof branches)
|
| 173 |
+
|
| 174 |
+
x = self.stage2(x)
|
| 175 |
+
# x = [trans(x[-1]) for trans in self.transition2] # New branch derives from the "upper" branch only
|
| 176 |
+
x = [
|
| 177 |
+
self.transition2[0](x[0]),
|
| 178 |
+
self.transition2[1](x[1]),
|
| 179 |
+
self.transition2[2](x[-1])
|
| 180 |
+
] # New branch derives from the "upper" branch only
|
| 181 |
+
|
| 182 |
+
x = self.stage3(x)
|
| 183 |
+
# x = [trans(x) for trans in self.transition3] # New branch derives from the "upper" branch only
|
| 184 |
+
x = [
|
| 185 |
+
self.transition3[0](x[0]),
|
| 186 |
+
self.transition3[1](x[1]),
|
| 187 |
+
self.transition3[2](x[2]),
|
| 188 |
+
self.transition3[3](x[-1])
|
| 189 |
+
] # New branch derives from the "upper" branch only
|
| 190 |
+
|
| 191 |
+
x = self.stage4(x)
|
| 192 |
+
|
| 193 |
+
x = self.final_layer(x[0])
|
| 194 |
+
|
| 195 |
+
return x
|
| 196 |
+
|
| 197 |
+
def HRCenterNet(args):
|
| 198 |
+
|
| 199 |
+
model = _HRCenterNet(32, 5, 0.1)
|
| 200 |
+
|
| 201 |
+
if not (args.log_dir == None):
|
| 202 |
+
model.load_state_dict(torch.load(args.log_dir))
|
| 203 |
+
|
| 204 |
+
return model
|
ai_modules/models/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
OCR ๋ชจ๋ธ ๋ชจ๋ ํจํค์ง
|
| 4 |
+
"""
|
| 5 |
+
|
ai_modules/models/modules.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Bottleneck(nn.Module):
|
| 6 |
+
expansion = 4
|
| 7 |
+
|
| 8 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, bn_momentum=0.1):
|
| 9 |
+
super(Bottleneck, self).__init__()
|
| 10 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 11 |
+
self.bn1 = nn.BatchNorm2d(planes, momentum=bn_momentum)
|
| 12 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 13 |
+
self.bn2 = nn.BatchNorm2d(planes, momentum=bn_momentum)
|
| 14 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
|
| 15 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=bn_momentum)
|
| 16 |
+
self.relu = nn.ReLU(inplace=True)
|
| 17 |
+
self.downsample = downsample
|
| 18 |
+
self.stride = stride
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
residual = x
|
| 22 |
+
|
| 23 |
+
out = self.conv1(x)
|
| 24 |
+
out = self.bn1(out)
|
| 25 |
+
out = self.relu(out)
|
| 26 |
+
|
| 27 |
+
out = self.conv2(out)
|
| 28 |
+
out = self.bn2(out)
|
| 29 |
+
out = self.relu(out)
|
| 30 |
+
|
| 31 |
+
out = self.conv3(out)
|
| 32 |
+
out = self.bn3(out)
|
| 33 |
+
|
| 34 |
+
if self.downsample is not None:
|
| 35 |
+
residual = self.downsample(x)
|
| 36 |
+
|
| 37 |
+
out += residual
|
| 38 |
+
out = self.relu(out)
|
| 39 |
+
|
| 40 |
+
return out
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# class Bottleneck_Tranpose(nn.Module):
|
| 44 |
+
# expansion = 4
|
| 45 |
+
|
| 46 |
+
# def __init__(self, inplanes, planes, stride=1, downsample=None, bn_momentum=0.1):
|
| 47 |
+
# super(Bottleneck, self).__init__()
|
| 48 |
+
# nn.ConvTranspose2d(c, 64, (3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1)),
|
| 49 |
+
|
| 50 |
+
# self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 51 |
+
# self.bn1 = nn.BatchNorm2d(planes, momentum=bn_momentum)
|
| 52 |
+
# self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 53 |
+
# self.bn2 = nn.BatchNorm2d(planes, momentum=bn_momentum)
|
| 54 |
+
# self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
|
| 55 |
+
# self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=bn_momentum)
|
| 56 |
+
# self.relu = nn.ReLU(inplace=True)
|
| 57 |
+
# self.downsample = downsample
|
| 58 |
+
# self.stride = stride
|
| 59 |
+
|
| 60 |
+
# def forward(self, x):
|
| 61 |
+
# residual = x
|
| 62 |
+
|
| 63 |
+
# out = self.conv1(x)
|
| 64 |
+
# out = self.bn1(out)
|
| 65 |
+
# out = self.relu(out)
|
| 66 |
+
|
| 67 |
+
# out = self.conv2(out)
|
| 68 |
+
# out = self.bn2(out)
|
| 69 |
+
# out = self.relu(out)
|
| 70 |
+
|
| 71 |
+
# out = self.conv3(out)
|
| 72 |
+
# out = self.bn3(out)
|
| 73 |
+
|
| 74 |
+
# if self.downsample is not None:
|
| 75 |
+
# residual = self.downsample(x)
|
| 76 |
+
|
| 77 |
+
# out += residual
|
| 78 |
+
# out = self.relu(out)
|
| 79 |
+
|
| 80 |
+
# return out
|
| 81 |
+
|
| 82 |
+
class BasicBlock(nn.Module):
|
| 83 |
+
expansion = 1
|
| 84 |
+
|
| 85 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, bn_momentum=0.1):
|
| 86 |
+
super(BasicBlock, self).__init__()
|
| 87 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 88 |
+
self.bn1 = nn.BatchNorm2d(planes, momentum=bn_momentum)
|
| 89 |
+
self.relu = nn.ReLU(inplace=True)
|
| 90 |
+
self.conv2 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
| 91 |
+
self.bn2 = nn.BatchNorm2d(planes, momentum=bn_momentum)
|
| 92 |
+
self.downsample = downsample
|
| 93 |
+
self.stride = stride
|
| 94 |
+
|
| 95 |
+
def forward(self, x):
|
| 96 |
+
residual = x
|
| 97 |
+
|
| 98 |
+
out = self.conv1(x)
|
| 99 |
+
out = self.bn1(out)
|
| 100 |
+
out = self.relu(out)
|
| 101 |
+
|
| 102 |
+
out = self.conv2(out)
|
| 103 |
+
out = self.bn2(out)
|
| 104 |
+
|
| 105 |
+
if self.downsample is not None:
|
| 106 |
+
residual = self.downsample(x)
|
| 107 |
+
|
| 108 |
+
out += residual
|
| 109 |
+
out = self.relu(out)
|
| 110 |
+
|
| 111 |
+
return out
|
ai_modules/models/resnet.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import PIL
|
| 3 |
+
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
|
| 7 |
+
class BasicBlock(nn.Module):
|
| 8 |
+
expansion = 1
|
| 9 |
+
|
| 10 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
| 11 |
+
base_width=64, dilation=1, norm_layer=None):
|
| 12 |
+
super(BasicBlock, self).__init__()
|
| 13 |
+
if norm_layer is None:
|
| 14 |
+
norm_layer = nn.BatchNorm2d
|
| 15 |
+
if groups != 1 or base_width != 64:
|
| 16 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
| 17 |
+
if dilation > 1:
|
| 18 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
| 19 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
| 20 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 21 |
+
self.bn1 = norm_layer(planes)
|
| 22 |
+
self.relu = nn.ReLU(inplace=True)
|
| 23 |
+
self.conv2 = conv3x3(planes, planes)
|
| 24 |
+
self.bn2 = norm_layer(planes)
|
| 25 |
+
self.downsample = downsample
|
| 26 |
+
self.stride = stride
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
identity = x
|
| 30 |
+
|
| 31 |
+
out = self.conv1(x)
|
| 32 |
+
out = self.bn1(out)
|
| 33 |
+
out = self.relu(out)
|
| 34 |
+
|
| 35 |
+
out = self.conv2(out)
|
| 36 |
+
out = self.bn2(out)
|
| 37 |
+
|
| 38 |
+
if self.downsample is not None:
|
| 39 |
+
identity = self.downsample(x)
|
| 40 |
+
|
| 41 |
+
out += identity
|
| 42 |
+
out = self.relu(out)
|
| 43 |
+
|
| 44 |
+
return out
|
| 45 |
+
|
| 46 |
+
class ResNet(nn.Module):
|
| 47 |
+
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
| 48 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
| 49 |
+
norm_layer=None):
|
| 50 |
+
super(ResNet, self).__init__()
|
| 51 |
+
if norm_layer is None:
|
| 52 |
+
norm_layer = nn.BatchNorm2d
|
| 53 |
+
self._norm_layer = norm_layer
|
| 54 |
+
|
| 55 |
+
self.inplanes = 64
|
| 56 |
+
self.dilation = 1
|
| 57 |
+
if replace_stride_with_dilation is None:
|
| 58 |
+
# each element in the tuple indicates if we should replace
|
| 59 |
+
# the 2x2 stride with a dilated convolution instead
|
| 60 |
+
replace_stride_with_dilation = [False, False, False]
|
| 61 |
+
if len(replace_stride_with_dilation) != 3:
|
| 62 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
| 63 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
| 64 |
+
self.groups = groups
|
| 65 |
+
self.base_width = width_per_group
|
| 66 |
+
self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,
|
| 67 |
+
bias=False)
|
| 68 |
+
self.bn1 = norm_layer(self.inplanes)
|
| 69 |
+
self.relu = nn.ReLU(inplace=True)
|
| 70 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 71 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 72 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
| 73 |
+
dilate=replace_stride_with_dilation[0])
|
| 74 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
| 75 |
+
dilate=replace_stride_with_dilation[1])
|
| 76 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
| 77 |
+
dilate=replace_stride_with_dilation[2])
|
| 78 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 79 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
| 80 |
+
|
| 81 |
+
for m in self.modules():
|
| 82 |
+
if isinstance(m, nn.Conv2d):
|
| 83 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 84 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 85 |
+
nn.init.constant_(m.weight, 1)
|
| 86 |
+
nn.init.constant_(m.bias, 0)
|
| 87 |
+
|
| 88 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
| 89 |
+
norm_layer = self._norm_layer
|
| 90 |
+
downsample = None
|
| 91 |
+
previous_dilation = self.dilation
|
| 92 |
+
if dilate:
|
| 93 |
+
self.dilation *= stride
|
| 94 |
+
stride = 1
|
| 95 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 96 |
+
downsample = nn.Sequential(
|
| 97 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
| 98 |
+
norm_layer(planes * block.expansion),
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
layers = []
|
| 102 |
+
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
| 103 |
+
self.base_width, previous_dilation, norm_layer))
|
| 104 |
+
self.inplanes = planes * block.expansion
|
| 105 |
+
for _ in range(1, blocks):
|
| 106 |
+
layers.append(block(self.inplanes, planes, groups=self.groups,
|
| 107 |
+
base_width=self.base_width, dilation=self.dilation,
|
| 108 |
+
norm_layer=norm_layer))
|
| 109 |
+
|
| 110 |
+
return nn.Sequential(*layers)
|
| 111 |
+
|
| 112 |
+
def _forward_impl(self, x):
|
| 113 |
+
# See note [TorchScript super()]
|
| 114 |
+
x = self.conv1(x)
|
| 115 |
+
x = self.bn1(x)
|
| 116 |
+
x = self.relu(x)
|
| 117 |
+
x = self.maxpool(x)
|
| 118 |
+
|
| 119 |
+
x = self.layer1(x)
|
| 120 |
+
x = self.layer2(x)
|
| 121 |
+
x = self.layer3(x)
|
| 122 |
+
x = self.layer4(x)
|
| 123 |
+
|
| 124 |
+
x = self.avgpool(x)
|
| 125 |
+
x = torch.flatten(x, 1)
|
| 126 |
+
x = self.fc(x)
|
| 127 |
+
|
| 128 |
+
return x
|
| 129 |
+
|
| 130 |
+
def forward(self, x):
|
| 131 |
+
return self._forward_impl(x)
|
| 132 |
+
|
| 133 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
| 134 |
+
"""1x1 convolution"""
|
| 135 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
| 136 |
+
|
| 137 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
| 138 |
+
"""3x3 convolution with padding"""
|
| 139 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 140 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
| 141 |
+
|
| 142 |
+
class ResnetCustom(torch.nn.Module):
|
| 143 |
+
def __init__(self, weight_fn):
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 146 |
+
weight = torch.load(weight_fn, map_location=self.device)
|
| 147 |
+
self.id2charDict = weight['vocab']['id2char']
|
| 148 |
+
num_classes = len(self.id2charDict)
|
| 149 |
+
self.id2charDict[-1] = "โ " # unrecognized token
|
| 150 |
+
self.transform = transforms.Compose([transforms.Grayscale(),
|
| 151 |
+
transforms.Resize((64,64)),
|
| 152 |
+
transforms.ToTensor()])
|
| 153 |
+
|
| 154 |
+
self.net = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)
|
| 155 |
+
self.net.load_state_dict(weight['model'])
|
| 156 |
+
self.net = self.net.to(self.device)
|
| 157 |
+
self.net.eval()
|
| 158 |
+
#self.net(torch.rand((64,1,64,64)))
|
| 159 |
+
print(f'{weight_fn} loaded!')
|
| 160 |
+
|
| 161 |
+
def forward(self, images:PIL.Image, bs=256, conf_thres=0.5):
|
| 162 |
+
'''
|
| 163 |
+
input
|
| 164 |
+
images: list of PIL images
|
| 165 |
+
return
|
| 166 |
+
chars: list of recognized chars
|
| 167 |
+
'''
|
| 168 |
+
chars = []
|
| 169 |
+
for i in range(0, len(images), bs):
|
| 170 |
+
inp = []
|
| 171 |
+
for image in images[i: i+bs]:
|
| 172 |
+
inp.append(self.transform(image))
|
| 173 |
+
inp = torch.stack(inp, dim=0).to(self.device)
|
| 174 |
+
out = self.net(inp)
|
| 175 |
+
out = torch.nn.functional.softmax(out, dim=1)
|
| 176 |
+
conf, indice = torch.max(out, dim=1)
|
| 177 |
+
indice[conf<conf_thres] = -1
|
| 178 |
+
chars += [self.id2charDict[x] for x in indice.tolist()]
|
| 179 |
+
|
| 180 |
+
return chars
|
| 181 |
+
|
| 182 |
+
if __name__ == "__main__":
|
| 183 |
+
net = ResnetCustom(weight_fn="best_5000.pt")
|
| 184 |
+
inp = [PIL.Image.open('0.jpg'), PIL.Image.open('1.png')]
|
| 185 |
+
print(net(inp))
|
| 186 |
+
|
ai_modules/nlp/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Korean Historical Text Processor NLP Module
|
| 3 |
+
ํ๊ตญ์ด ๊ณ ์ ํ
์คํธ์ ๊ตฌ๋์ ๋ณต์ ๋ฐ MLM ์์ธก์ ์ํ ๋ชจ๋์
๋๋ค.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
__version__ = "1.0.0"
|
| 7 |
+
__author__ = "EPITEXT"
|
| 8 |
+
|
| 9 |
+
from .punctuation_restorer import PunctuationRestorer
|
| 10 |
+
from .mlm_predictor import MLMPredictor
|
| 11 |
+
from .utils import (
|
| 12 |
+
remove_punctuation,
|
| 13 |
+
extract_mask_info,
|
| 14 |
+
replace_mask_with_symbol,
|
| 15 |
+
normalize_mask_tokens,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"PunctuationRestorer",
|
| 20 |
+
"MLMPredictor",
|
| 21 |
+
"remove_punctuation",
|
| 22 |
+
"extract_mask_info",
|
| 23 |
+
"replace_mask_with_symbol",
|
| 24 |
+
"normalize_mask_tokens",
|
| 25 |
+
]
|
| 26 |
+
|
ai_modules/nlp/mlm_predictor.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MLM(Masked Language Model) ์์ธก ๋ชจ๋
|
| 3 |
+
BERT ๊ธฐ๋ฐ MLM์ ์ฌ์ฉํ์ฌ ๋ง์คํน๋ ํ ํฐ์ ์์ธกํฉ๋๋ค.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from typing import List, Dict
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
| 9 |
+
from .utils import normalize_mask_tokens
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MLMPredictor:
|
| 13 |
+
"""MLM ์์ธก์ ๋ด๋นํ๋ ํด๋์ค"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, config: Dict, device: str = "cpu"):
|
| 16 |
+
"""
|
| 17 |
+
MLM ์์ธก๊ธฐ๋ฅผ ์ด๊ธฐํํฉ๋๋ค.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
config: ์ค์ ๋์
๋๋ฆฌ (nlp_config.json์์ ๋ก๋)
|
| 21 |
+
device: ์ฐ์ฐ ๋๋ฐ์ด์ค ('cpu' ๋๋ 'cuda')
|
| 22 |
+
"""
|
| 23 |
+
mlm_cfg = config['mlm_model']
|
| 24 |
+
self.model_name = mlm_cfg['model_name']
|
| 25 |
+
self.top_k = mlm_cfg['top_k']
|
| 26 |
+
self.max_length = mlm_cfg['max_length']
|
| 27 |
+
self.device = device
|
| 28 |
+
self.tokenizer = None
|
| 29 |
+
self.model = None
|
| 30 |
+
|
| 31 |
+
def load_model(self) -> None:
|
| 32 |
+
"""๋ชจ๋ธ์ ๋ฉ๋ชจ๋ฆฌ์ ๋ก๋ํฉ๋๋ค."""
|
| 33 |
+
print(f"[MLM] ๋ชจ๋ธ ๋ก๋ ์ค: {self.model_name}")
|
| 34 |
+
|
| 35 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 36 |
+
self.model_name,
|
| 37 |
+
use_fast=False
|
| 38 |
+
)
|
| 39 |
+
self.model = AutoModelForMaskedLM.from_pretrained(self.model_name)
|
| 40 |
+
self.model.to(self.device)
|
| 41 |
+
self.model.eval()
|
| 42 |
+
|
| 43 |
+
print(f"[MLM] โ MLM ๋ชจ๋ธ ๋ก๋ ์๋ฃ")
|
| 44 |
+
|
| 45 |
+
def predict_masks(
|
| 46 |
+
self,
|
| 47 |
+
text: str
|
| 48 |
+
) -> List[List[Dict[str, any]]]:
|
| 49 |
+
"""
|
| 50 |
+
ํ
์คํธ ๋ด์ [MASK] ํ ํฐ์ ์์ธกํฉ๋๋ค.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
text: ๋ง์คํฌ๊ฐ ํฌํจ๋ ํ
์คํธ
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
๊ฐ ๋ง์คํฌ ์์น๋ณ top-k ์์ธก ๊ฒฐ๊ณผ ๋ฆฌ์คํธ
|
| 57 |
+
"""
|
| 58 |
+
# [MASK1], [MASK2] -> [MASK] ์ ๊ทํ
|
| 59 |
+
text_normalized = normalize_mask_tokens(text)
|
| 60 |
+
|
| 61 |
+
print(f"[MLM] ์
๋ ฅ ํ
์คํธ ์ํ: {text_normalized[:100]}...")
|
| 62 |
+
print(f"[MLM] [MASK] ํ ํฐ ๊ฐ์: {text_normalized.count('[MASK]')}")
|
| 63 |
+
|
| 64 |
+
# ํ ํฌ๋์ด์ฆ
|
| 65 |
+
inputs = self.tokenizer(
|
| 66 |
+
text_normalized,
|
| 67 |
+
return_tensors="pt",
|
| 68 |
+
truncation=True,
|
| 69 |
+
max_length=self.max_length
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# ๋๋ฐ์ด์ค๋ก ์ด๋
|
| 73 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 74 |
+
|
| 75 |
+
# [MASK] ์์น ์ฐพ๊ธฐ
|
| 76 |
+
mask_indices = torch.where(
|
| 77 |
+
inputs["input_ids"] == self.tokenizer.mask_token_id
|
| 78 |
+
)[1]
|
| 79 |
+
|
| 80 |
+
print(f"[MLM] ํ ํฌ๋์ด์ ๊ฐ ์ฐพ์ [MASK] ์์น ๊ฐ์: {len(mask_indices)}")
|
| 81 |
+
|
| 82 |
+
if len(mask_indices) == 0:
|
| 83 |
+
print("[MLM] โ ๏ธ ๊ฒฝ๊ณ : [MASK] ํ ํฐ์ ์ฐพ์ ์ ์์ต๋๋ค!")
|
| 84 |
+
sample_tokens = self.tokenizer.convert_ids_to_tokens(
|
| 85 |
+
inputs['input_ids'][0][:50]
|
| 86 |
+
)
|
| 87 |
+
print(f"[MLM] ํ ํฐํ๋ ์
๋ ฅ ์ํ: {sample_tokens}")
|
| 88 |
+
return []
|
| 89 |
+
|
| 90 |
+
# ์์ธก ์ํ
|
| 91 |
+
with torch.no_grad():
|
| 92 |
+
outputs = self.model(**inputs)
|
| 93 |
+
logits = outputs.logits
|
| 94 |
+
|
| 95 |
+
# ๊ฐ ๋ง์คํฌ ์์น๋ณ๋ก top-k ์์ธก
|
| 96 |
+
all_predictions = []
|
| 97 |
+
for mask_idx in mask_indices:
|
| 98 |
+
mask_logits = logits[0, mask_idx, :]
|
| 99 |
+
|
| 100 |
+
# ์ ์ฒด ์ดํ์ ๋ํด softmax ๊ณ์ฐ ํ top-k ์ ํ
|
| 101 |
+
all_probs = torch.nn.functional.softmax(mask_logits, dim=-1)
|
| 102 |
+
top_k_probs, top_k_indices = torch.topk(all_probs, self.top_k)
|
| 103 |
+
|
| 104 |
+
top_k_tokens = self.tokenizer.convert_ids_to_tokens(
|
| 105 |
+
top_k_indices.tolist()
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
predictions = [
|
| 109 |
+
{
|
| 110 |
+
"token": token,
|
| 111 |
+
"probability": float(prob)
|
| 112 |
+
}
|
| 113 |
+
for token, prob in zip(top_k_tokens, top_k_probs.tolist())
|
| 114 |
+
]
|
| 115 |
+
all_predictions.append(predictions)
|
| 116 |
+
|
| 117 |
+
return all_predictions
|
| 118 |
+
|
ai_modules/nlp/punctuation_restorer.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
๊ตฌ๋์ ๋ณต์ ๋ชจ๋
|
| 3 |
+
Hugging Face ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ํ๊ตญ์ด ๊ณ ์ ํ
์คํธ์ ๊ตฌ๋์ ์ ๋ณต์ํฉ๋๋ค.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import torch
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, List, Tuple
|
| 10 |
+
from collections import Counter
|
| 11 |
+
from huggingface_hub import snapshot_download
|
| 12 |
+
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class PunctuationRestorer:
|
| 16 |
+
"""๊ตฌ๋์ ๋ณต์์ ๋ด๋นํ๋ ํด๋์ค"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, config: Dict, cache_dir: str, device: str = "cpu"):
|
| 19 |
+
"""
|
| 20 |
+
๊ตฌ๋์ ๋ณต์๊ธฐ๋ฅผ ์ด๊ธฐํํฉ๋๋ค.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
config: ์ค์ ๋์
๋๋ฆฌ (nlp_config.json์์ ๋ก๋)
|
| 24 |
+
cache_dir: ๋ชจ๋ธ ์บ์ ๋๋ ํ ๋ฆฌ (๊ธฐ๋ณธ ๊ฒฝ๋ก)
|
| 25 |
+
device: ์ฐ์ฐ ๋๋ฐ์ด์ค ('cpu' ๋๋ 'cuda')
|
| 26 |
+
"""
|
| 27 |
+
punc_cfg = config['punc_model']
|
| 28 |
+
self.model_tag = punc_cfg['model_tag']
|
| 29 |
+
self.max_length = punc_cfg['max_length']
|
| 30 |
+
self.window_size = punc_cfg['window_size']
|
| 31 |
+
self.overlap = punc_cfg['overlap']
|
| 32 |
+
|
| 33 |
+
self.cache_dir = Path(cache_dir) / "punc"
|
| 34 |
+
self.device = device
|
| 35 |
+
self.model_info = None
|
| 36 |
+
|
| 37 |
+
def download_model(self) -> None:
|
| 38 |
+
"""Hugging Face์์ ๋ชจ๋ธ์ ๋ค์ด๋ก๋ํฉ๋๋ค."""
|
| 39 |
+
self.cache_dir.parent.mkdir(parents=True, exist_ok=True)
|
| 40 |
+
|
| 41 |
+
if not self.cache_dir.exists() or not any(self.cache_dir.iterdir()):
|
| 42 |
+
print(f"[PUNC] ๋ชจ๋ธ ๋ค์ด๋ก๋ ์ค: {self.model_tag}")
|
| 43 |
+
snapshot_download(
|
| 44 |
+
repo_id=self.model_tag,
|
| 45 |
+
repo_type="model",
|
| 46 |
+
local_dir=str(self.cache_dir),
|
| 47 |
+
local_dir_use_symlinks=False,
|
| 48 |
+
)
|
| 49 |
+
else:
|
| 50 |
+
print(f"[PUNC] ์บ์๋ ๋ชจ๋ธ ์ฌ์ฉ: {self.cache_dir}")
|
| 51 |
+
|
| 52 |
+
def load_model(self) -> None:
|
| 53 |
+
"""๋ชจ๋ธ์ ๋ฉ๋ชจ๋ฆฌ์ ๋ก๋ํฉ๋๋ค."""
|
| 54 |
+
torch_dtype = torch.float16 if "cuda" in self.device else torch.float32
|
| 55 |
+
|
| 56 |
+
# ๋ชจ๋ธ ํ์ผ ์ฐพ๊ธฐ
|
| 57 |
+
fnames = sorted(self.cache_dir.rglob("*.safetensors"))
|
| 58 |
+
if len(fnames) == 0:
|
| 59 |
+
# safetensors๊ฐ ์์ผ๋ฉด ๋ค๋ฅธ ํ์ ์๋
|
| 60 |
+
fnames = sorted(self.cache_dir.rglob("*.bin"))
|
| 61 |
+
|
| 62 |
+
if len(fnames) == 0:
|
| 63 |
+
raise FileNotFoundError(f"๋ชจ๋ธ ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค: {self.cache_dir}")
|
| 64 |
+
|
| 65 |
+
hface_path = fnames[0].parent
|
| 66 |
+
|
| 67 |
+
# ํ ํฌ๋์ด์ ๋ฐ ๋ชจ๋ธ ๋ก๋
|
| 68 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 69 |
+
str(hface_path),
|
| 70 |
+
model_max_length=self.max_length
|
| 71 |
+
)
|
| 72 |
+
model = AutoModelForTokenClassification.from_pretrained(
|
| 73 |
+
str(hface_path),
|
| 74 |
+
device_map=self.device if "cuda" in self.device else None,
|
| 75 |
+
torch_dtype=torch_dtype
|
| 76 |
+
)
|
| 77 |
+
if "cuda" not in self.device:
|
| 78 |
+
model = model.to(self.device)
|
| 79 |
+
model.eval()
|
| 80 |
+
|
| 81 |
+
# NER ํ์ดํ๋ผ์ธ ์์ฑ
|
| 82 |
+
ner_pipeline = pipeline(
|
| 83 |
+
task="ner",
|
| 84 |
+
model=model,
|
| 85 |
+
tokenizer=tokenizer,
|
| 86 |
+
device=0 if "cuda" in self.device else -1
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# ๋ ์ด๋ธ ๋งคํ ๋ก๋
|
| 90 |
+
label2id_path = hface_path / "label2id.json"
|
| 91 |
+
if not label2id_path.is_file():
|
| 92 |
+
label2id_path = hface_path.parent / "label2id.json"
|
| 93 |
+
if not label2id_path.is_file():
|
| 94 |
+
raise FileNotFoundError(f"label2id.json์ ์ฐพ์ ์ ์์ต๋๋ค: {hface_path}")
|
| 95 |
+
|
| 96 |
+
label2id = json.loads(label2id_path.read_text(encoding="utf-8"))
|
| 97 |
+
|
| 98 |
+
self.model_info = {
|
| 99 |
+
"model": model,
|
| 100 |
+
"tokenizer": tokenizer,
|
| 101 |
+
"pipe": ner_pipeline,
|
| 102 |
+
"label2id": label2id
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
print(f"[PUNC] โ ๊ตฌ๋์ ๋ณต์ ๋ชจ๋ธ ๋ก๋ ์๋ฃ")
|
| 106 |
+
|
| 107 |
+
def restore_punctuation(
|
| 108 |
+
self,
|
| 109 |
+
text: str,
|
| 110 |
+
add_space: bool = True,
|
| 111 |
+
reduce: bool = True,
|
| 112 |
+
) -> str:
|
| 113 |
+
"""
|
| 114 |
+
์ฌ๋ผ์ด๋ฉ ์๋์ฐ ๋ฐฉ์์ผ๋ก ๊ตฌ๋์ ์ ๋ณต์ํฉ๋๋ค.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
text: ์
๋ ฅ ํ
์คํธ
|
| 118 |
+
add_space: ๊ตฌ๋์ ๋ค ๊ณต๋ฐฑ ์ถ๊ฐ ์ฌ๋ถ
|
| 119 |
+
reduce: ๊ตฌ๋์ ๋จ์ํ ์ฌ๋ถ
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
๊ตฌ๋์ ์ด ๋ณต์๋ ํ
์คํธ
|
| 123 |
+
"""
|
| 124 |
+
if not text.strip():
|
| 125 |
+
return ""
|
| 126 |
+
|
| 127 |
+
# ๋ ์ด๋ธ -> ๊ตฌ๋์ ๋งคํ ์์ฑ
|
| 128 |
+
label2punc = self._build_label2punc(add_space, reduce)
|
| 129 |
+
|
| 130 |
+
# ์ฌ๋ผ์ด๋ฉ ์๋์ฐ๋ก ๋ ์ด๋ธ ์์ธก
|
| 131 |
+
labels = self._predict_labels_sliding(text, self.window_size, self.overlap)
|
| 132 |
+
|
| 133 |
+
# ๊ธธ์ด ์กฐ์
|
| 134 |
+
if len(labels) < len(text):
|
| 135 |
+
labels += ["O"] * (len(text) - len(labels))
|
| 136 |
+
elif len(labels) > len(text):
|
| 137 |
+
labels = labels[:len(text)]
|
| 138 |
+
|
| 139 |
+
# ๊ตฌ๋์ ์ฝ์
|
| 140 |
+
result = ""
|
| 141 |
+
for ch, label in zip(text, labels):
|
| 142 |
+
result += ch
|
| 143 |
+
punc = label2punc.get(label, "")
|
| 144 |
+
result += punc
|
| 145 |
+
|
| 146 |
+
return result.strip()
|
| 147 |
+
|
| 148 |
+
def _predict_labels_sliding(
|
| 149 |
+
self,
|
| 150 |
+
text: str,
|
| 151 |
+
window_size: int,
|
| 152 |
+
overlap: int
|
| 153 |
+
) -> List[str]:
|
| 154 |
+
"""
|
| 155 |
+
์ฌ๋ผ์ด๋ฉ ์๋์ฐ๋ก ๊ฐ ๋ฌธ์์ ๋ ์ด๋ธ์ ์์ธกํฉ๋๋ค.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
text: ์
๋ ฅ ํ
์คํธ
|
| 159 |
+
window_size: ์๋์ฐ ํฌ๊ธฐ
|
| 160 |
+
overlap: ์ค์ฒฉ ํฌ๊ธฐ
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
๊ฐ ๋ฌธ์์ ๋ํ ๋ ์ด๋ธ ๋ฆฌ์คํธ
|
| 164 |
+
"""
|
| 165 |
+
n = len(text)
|
| 166 |
+
if n == 0:
|
| 167 |
+
return []
|
| 168 |
+
|
| 169 |
+
# ๊ฐ ์์น๋ณ ํ๋ณด ๋ ์ด๋ธ ์ ์ฅ
|
| 170 |
+
labels_per_pos = [[] for _ in range(n)]
|
| 171 |
+
stride = max(1, window_size - overlap)
|
| 172 |
+
start = 0
|
| 173 |
+
|
| 174 |
+
while start < n:
|
| 175 |
+
end = min(start + window_size, n)
|
| 176 |
+
sub_text = text[start:end]
|
| 177 |
+
|
| 178 |
+
try:
|
| 179 |
+
# NER ์์ธก ์ํ
|
| 180 |
+
sub_preds = self.model_info["pipe"](sub_text)
|
| 181 |
+
_, sub_labels = self._align_predictions(sub_text, sub_preds)
|
| 182 |
+
except Exception as e:
|
| 183 |
+
# ์ค๋ฅ ๋ฐ์ ์ ๋ชจ๋ 'O' ๋ ์ด๋ธ
|
| 184 |
+
print(f"[PUNC] ์์ธก ์ค๋ฅ (start={start}): {e}")
|
| 185 |
+
sub_labels = ["O"] * len(sub_text)
|
| 186 |
+
|
| 187 |
+
# ์ ์ญ ์์น์ ๋ ์ด๋ธ ์ ์ฅ
|
| 188 |
+
for i, label in enumerate(sub_labels):
|
| 189 |
+
gidx = start + i
|
| 190 |
+
if gidx >= n:
|
| 191 |
+
break
|
| 192 |
+
if label != "O":
|
| 193 |
+
labels_per_pos[gidx].append(label)
|
| 194 |
+
|
| 195 |
+
if end == n:
|
| 196 |
+
break
|
| 197 |
+
start += stride
|
| 198 |
+
|
| 199 |
+
# ๋ค์๊ฒฐ ํฌํ๋ก ์ต์ข
๋ ์ด๋ธ ๊ฒฐ์
|
| 200 |
+
final_labels = []
|
| 201 |
+
for cand_list in labels_per_pos:
|
| 202 |
+
if not cand_list:
|
| 203 |
+
final_labels.append("O")
|
| 204 |
+
else:
|
| 205 |
+
c = Counter(cand_list)
|
| 206 |
+
label, _ = c.most_common(1)[0]
|
| 207 |
+
final_labels.append(label)
|
| 208 |
+
|
| 209 |
+
return final_labels
|
| 210 |
+
|
| 211 |
+
@staticmethod
|
| 212 |
+
def _align_predictions(text: str, predictions: List[dict]) -> Tuple[List[str], List[str]]:
|
| 213 |
+
"""
|
| 214 |
+
NER ์์ธก ๊ฒฐ๊ณผ๋ฅผ ๋ฌธ์ ๋จ์ ๋ ์ด๋ธ๋ก ์ ๋ ฌํฉ๋๋ค.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
text: ์๋ณธ ํ
์คํธ
|
| 218 |
+
predictions: NER ์์ธก ๊ฒฐ๊ณผ
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
(๋ฌธ์ ๋ฆฌ์คํธ, ๋ ์ด๋ธ ๋ฆฌ์คํธ) ํํ
|
| 222 |
+
"""
|
| 223 |
+
words = list(text)
|
| 224 |
+
labels = ["O" for _ in range(len(words))]
|
| 225 |
+
|
| 226 |
+
for pred in predictions:
|
| 227 |
+
idx = pred["end"] - 1
|
| 228 |
+
if 0 <= idx < len(labels):
|
| 229 |
+
labels[idx] = pred["entity"]
|
| 230 |
+
|
| 231 |
+
return words, labels
|
| 232 |
+
|
| 233 |
+
def _build_label2punc(self, add_space: bool, reduce: bool) -> Dict[str, str]:
|
| 234 |
+
"""
|
| 235 |
+
๋ ์ด๋ธ์ ๊ตฌ๋์ ์ผ๋ก ๋งคํํ๋ ๋์
๋๋ฆฌ๋ฅผ ์์ฑํฉ๋๋ค.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
add_space: ๊ตฌ๋์ ๋ค ๊ณต๋ฐฑ ์ถ๊ฐ ์ฌ๋ถ
|
| 239 |
+
reduce: ๊ตฌ๋์ ๋จ์ํ ์ฌ๋ถ
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
๋ ์ด๋ธ -> ๊ตฌ๋์ ๋งคํ ๋์
๋๋ฆฌ
|
| 243 |
+
"""
|
| 244 |
+
label2id = self.model_info["label2id"]
|
| 245 |
+
label2punc = {f"B-{v}": k for k, v in label2id.items()}
|
| 246 |
+
label2punc["O"] = ""
|
| 247 |
+
|
| 248 |
+
# ๊ตฌ๋์ ๋จ์ํ
|
| 249 |
+
if reduce:
|
| 250 |
+
new_label2punc = {}
|
| 251 |
+
for label, punc in label2punc.items():
|
| 252 |
+
if label == "O":
|
| 253 |
+
new_label2punc[label] = ""
|
| 254 |
+
else:
|
| 255 |
+
reduced = self._reduce_punc(punc)
|
| 256 |
+
new_label2punc[label] = reduced
|
| 257 |
+
label2punc = new_label2punc
|
| 258 |
+
|
| 259 |
+
# ๊ณต๋ฐฑ ์ถ๊ฐ
|
| 260 |
+
if add_space:
|
| 261 |
+
special_puncs = "!,:;?ใ"
|
| 262 |
+
label2punc = {
|
| 263 |
+
k: self._insert_space(v, special_puncs)
|
| 264 |
+
for k, v in label2punc.items()
|
| 265 |
+
}
|
| 266 |
+
label2punc["O"] = ""
|
| 267 |
+
|
| 268 |
+
return label2punc
|
| 269 |
+
|
| 270 |
+
@staticmethod
|
| 271 |
+
def _reduce_punc(text: str) -> str:
|
| 272 |
+
"""
|
| 273 |
+
๊ตฌ๋์ ์ ๋จ์ํํฉ๋๋ค (?, ใ, , ์ค ํ๋๋ก ๋ณํ).
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
text: ๊ตฌ๋์ ๋ฌธ์์ด
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
๋จ์ํ๋ ๊ตฌ๋์
|
| 280 |
+
"""
|
| 281 |
+
reduce_map = {
|
| 282 |
+
",": ",", "-": ",", "/": ",", ":": ",", "|": ",",
|
| 283 |
+
"ยท": ",", "ใ": ",",
|
| 284 |
+
"?": "?", "!": "ใ", ".": "ใ", ";": "ใ", "ใ": "ใ",
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
text = "".join([reduce_map.get(c, "") for c in text])
|
| 288 |
+
punc_order = "?ใ,,"
|
| 289 |
+
|
| 290 |
+
if len(set(text).intersection(punc_order)) == 0:
|
| 291 |
+
return ""
|
| 292 |
+
|
| 293 |
+
# ๊ฐ์ฅ ๋ง์ด ๋ฑ์ฅํ ๊ตฌ๋์ ์ ํ
|
| 294 |
+
counts = {c: text.count(c) for c in punc_order}
|
| 295 |
+
max_count = max(counts.values())
|
| 296 |
+
max_keys = {k for k, v in counts.items() if v == max_count}
|
| 297 |
+
|
| 298 |
+
if len(max_keys) == 1:
|
| 299 |
+
return max_keys.pop()
|
| 300 |
+
|
| 301 |
+
# ๋๋ฅ ์ผ ๊ฒฝ์ฐ ์ฐ์ ์์์ ๋ฐ๋ผ ์ ํ
|
| 302 |
+
for c in punc_order:
|
| 303 |
+
if c in max_keys:
|
| 304 |
+
return c
|
| 305 |
+
|
| 306 |
+
return ""
|
| 307 |
+
|
| 308 |
+
@staticmethod
|
| 309 |
+
def _insert_space(text: str, chars: str) -> str:
|
| 310 |
+
"""
|
| 311 |
+
ํน์ ๋ฌธ์ ๋ค์ ๊ณต๋ฐฑ์ ์ฝ์
ํฉ๋๋ค.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
text: ์๋ณธ ํ
์คํธ
|
| 315 |
+
chars: ๊ณต๋ฐฑ์ ์ถ๊ฐํ ๋ฌธ์๋ค
|
| 316 |
+
|
| 317 |
+
Returns:
|
| 318 |
+
๊ณต๋ฐฑ์ด ์ฝ์
๋ ํ
์คํธ
|
| 319 |
+
"""
|
| 320 |
+
result = ""
|
| 321 |
+
for c in text:
|
| 322 |
+
result += c
|
| 323 |
+
if c in chars:
|
| 324 |
+
result += " "
|
| 325 |
+
return result
|
| 326 |
+
|
ai_modules/nlp/utils.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
์ ํธ๋ฆฌํฐ ํจ์ ๋ชจ๋
|
| 3 |
+
ํ์ผ ์
์ถ๋ ฅ, ํ
์คํธ ์ ์ฒ๋ฆฌ ๋ฑ ๊ณตํต ๊ธฐ๋ฅ์ ์ ๊ณตํฉ๋๋ค.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import re
|
| 7 |
+
import unicodedata
|
| 8 |
+
from typing import Dict, Any
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def remove_punctuation(text: str) -> str:
|
| 12 |
+
"""
|
| 13 |
+
ํ
์คํธ์์ ๊ตฌ๋์ ๊ณผ ๊ณต๋ฐฑ์ ์ ๊ฑฐํฉ๋๋ค. [MASK] ํ ํฐ์ ๋ณด์กดํฉ๋๋ค.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
text: ์๋ณธ ํ
์คํธ
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
๊ตฌ๋์ ์ด ์ ๊ฑฐ๋ ํ
์คํธ
|
| 20 |
+
"""
|
| 21 |
+
result = []
|
| 22 |
+
i = 0
|
| 23 |
+
|
| 24 |
+
while i < len(text):
|
| 25 |
+
# [MASK...] ํํ์ ํ ํฐ ๋ณด์กด
|
| 26 |
+
if text[i:i+1] == '[' and 'MASK' in text[i:i+10]:
|
| 27 |
+
end = text.find(']', i)
|
| 28 |
+
if end != -1:
|
| 29 |
+
result.append(text[i:end+1])
|
| 30 |
+
i = end + 1
|
| 31 |
+
continue
|
| 32 |
+
|
| 33 |
+
# ์ผ๋ฐ ๋ฌธ์ ์ฒ๋ฆฌ (๊ตฌ๋์ ๊ณผ ๊ณต๋ฐฑ ์ ์ธ)
|
| 34 |
+
if unicodedata.category(text[i])[0] not in "PZ":
|
| 35 |
+
result.append(text[i])
|
| 36 |
+
i += 1
|
| 37 |
+
|
| 38 |
+
return "".join(result)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def replace_mask_with_symbol(text: str, symbol: str = "โก") -> str:
|
| 42 |
+
"""
|
| 43 |
+
[MASK1], [MASK2] ๋ฑ์ ๋ง์คํฌ ํ ํฐ์ ์ง์ ๋ ๊ธฐํธ๋ก ์นํํฉ๋๋ค.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
text: ์๋ณธ ํ
์คํธ
|
| 47 |
+
symbol: ์นํํ ๊ธฐํธ
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
๋ง์คํฌ๊ฐ ์นํ๋ ํ
์คํธ
|
| 51 |
+
"""
|
| 52 |
+
return re.sub(r'\[MASK\d+\]', symbol, text)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def normalize_mask_tokens(text: str) -> str:
|
| 56 |
+
"""
|
| 57 |
+
[MASK1], [MASK2] ๋ฑ์ [MASK]๋ก ์ ๊ทํํฉ๋๋ค.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
text: ์๋ณธ ํ
์คํธ
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
์ ๊ทํ๋ ํ
์คํธ
|
| 64 |
+
"""
|
| 65 |
+
return re.sub(r'\[MASK\d+\]', '[MASK]', text)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def extract_mask_info(json_data: Dict[str, Any]) -> list:
|
| 69 |
+
"""
|
| 70 |
+
JSON ๋ฐ์ดํฐ์์ ๋ง์คํฌ ์ ๋ณด๋ฅผ ์ถ์ถํฉ๋๋ค.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
json_data: ์
๋ ฅ JSON ๋ฐ์ดํฐ
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
๋ง์คํฌ ์ ๋ณด ๋ฆฌ์คํธ (order์ type ํฌํจ)
|
| 77 |
+
"""
|
| 78 |
+
mask_info = []
|
| 79 |
+
for item in json_data.get('results', []):
|
| 80 |
+
if 'MASK' in item.get('type', ''):
|
| 81 |
+
mask_info.append({
|
| 82 |
+
'order': item['order'],
|
| 83 |
+
'type': item['type']
|
| 84 |
+
})
|
| 85 |
+
mask_info.sort(key=lambda x: x['order'])
|
| 86 |
+
return mask_info
|
| 87 |
+
|
ai_modules/nlp_engine.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NLP ํตํฉ ์์ง
|
| 3 |
+
๊ตฌ๋์ ๋ณต์ ๋ฐ MLM ์์ธก์ ํตํฉ ๊ด๋ฆฌํ๋ ์์ง์
๋๋ค.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import torch
|
| 9 |
+
import logging
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Dict, Any, Optional, List
|
| 12 |
+
|
| 13 |
+
from .nlp.punctuation_restorer import PunctuationRestorer
|
| 14 |
+
from .nlp.mlm_predictor import MLMPredictor
|
| 15 |
+
from .nlp.utils import remove_punctuation, replace_mask_with_symbol
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def load_nlp_config(config_path: Optional[str] = None) -> Dict[str, Any]:
|
| 21 |
+
"""
|
| 22 |
+
NLP ์ค์ ํ์ผ์ ๋ก๋ํฉ๋๋ค.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
config_path: ์ค์ ํ์ผ ๊ฒฝ๋ก (None์ด๋ฉด ๊ธฐ๋ณธ ๊ฒฝ๋ก ์ฌ์ฉ)
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
์ค์ ๋์
๋๋ฆฌ
|
| 29 |
+
"""
|
| 30 |
+
if config_path is None:
|
| 31 |
+
config_path = Path(__file__).parent / "config" / "nlp_config.json"
|
| 32 |
+
else:
|
| 33 |
+
config_path = Path(config_path)
|
| 34 |
+
|
| 35 |
+
if not config_path.exists():
|
| 36 |
+
raise FileNotFoundError(f"NLP ์ค์ ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค: {config_path}")
|
| 37 |
+
|
| 38 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 39 |
+
return json.load(f)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class NLPEngine:
|
| 43 |
+
"""NLP ์ฒ๋ฆฌ ํตํฉ ์์ง ํด๋์ค"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, config_path: Optional[str] = None):
|
| 46 |
+
"""
|
| 47 |
+
NLP ์์ง์ ์ด๊ธฐํํฉ๋๋ค.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
config_path: ์ค์ ํ์ผ ๊ฒฝ๋ก (None์ด๋ฉด ๊ธฐ๋ณธ ๊ฒฝ๋ก ์ฌ์ฉ)
|
| 51 |
+
"""
|
| 52 |
+
self.config = load_nlp_config(config_path)
|
| 53 |
+
|
| 54 |
+
# ๋๋ฐ์ด์ค ์ค์
|
| 55 |
+
dev_cfg = self.config.get('device', 'auto')
|
| 56 |
+
if dev_cfg == 'auto':
|
| 57 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 58 |
+
else:
|
| 59 |
+
self.device = dev_cfg
|
| 60 |
+
|
| 61 |
+
logger.info(f"[NLP] Device: {self.device}")
|
| 62 |
+
|
| 63 |
+
# ๋ชจ๋ธ ์บ์ ๊ฒฝ๋ก (ํ๊ฒฝ ๋ณ์ ๋๋ ๊ธฐ๋ณธ๊ฐ)
|
| 64 |
+
self.base_model_dir = os.getenv(
|
| 65 |
+
'AI_MODEL_DIR',
|
| 66 |
+
str(Path(__file__).parent.parent / "models")
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# ์๋ธ ๋ชจ๋ ์ด๊ธฐํ (์ง์ฐ ๋ก๋ฉ)
|
| 70 |
+
self.punc_restorer = None
|
| 71 |
+
self.mlm_predictor = None
|
| 72 |
+
|
| 73 |
+
def _load_models(self):
|
| 74 |
+
"""ํ์ํ ๋ ๋ชจ๋ธ์ ๋ฉ๋ชจ๋ฆฌ์ ๋ก๋"""
|
| 75 |
+
if self.punc_restorer is None:
|
| 76 |
+
logger.info("[NLP] ๊ตฌ๋์ ๋ณต์ ๋ชจ๋ธ ๋ก๋ ์ค...")
|
| 77 |
+
self.punc_restorer = PunctuationRestorer(
|
| 78 |
+
self.config,
|
| 79 |
+
self.base_model_dir,
|
| 80 |
+
self.device
|
| 81 |
+
)
|
| 82 |
+
self.punc_restorer.download_model()
|
| 83 |
+
self.punc_restorer.load_model()
|
| 84 |
+
|
| 85 |
+
if self.mlm_predictor is None:
|
| 86 |
+
logger.info("[NLP] MLM ๋ชจ๋ธ ๋ก๋ ์ค...")
|
| 87 |
+
self.mlm_predictor = MLMPredictor(self.config, self.device)
|
| 88 |
+
self.mlm_predictor.load_model()
|
| 89 |
+
|
| 90 |
+
def process_text(
|
| 91 |
+
self,
|
| 92 |
+
raw_text: str,
|
| 93 |
+
ocr_results: Optional[List[Dict]] = None,
|
| 94 |
+
add_space: bool = True,
|
| 95 |
+
reduce_punc: bool = True
|
| 96 |
+
) -> Dict[str, Any]:
|
| 97 |
+
"""
|
| 98 |
+
ํ
์คํธ ์ฒ๋ฆฌ ํ์ดํ๋ผ์ธ:
|
| 99 |
+
1. ๊ตฌ๋์ ์ ๊ฑฐ (์ ์ฒ๋ฆฌ)
|
| 100 |
+
2. ๊ตฌ๋์ ๋ณต์
|
| 101 |
+
3. [MASK] ์์ธก
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
raw_text: ์๋ณธ ํ
์คํธ (๊ตฌ๋์ ํฌํจ ๊ฐ๋ฅ)
|
| 105 |
+
add_space: ๊ตฌ๋์ ๋ค ๊ณต๋ฐฑ ์ถ๊ฐ ์ฌ๋ถ
|
| 106 |
+
reduce_punc: ๊ตฌ๋์ ๋จ์ํ ์ฌ๋ถ
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
์ฒ๋ฆฌ ๊ฒฐ๊ณผ ๋์
๋๋ฆฌ
|
| 110 |
+
"""
|
| 111 |
+
self._load_models()
|
| 112 |
+
|
| 113 |
+
try:
|
| 114 |
+
# 1. ์ ์ฒ๋ฆฌ (๊ตฌ๋์ ์ ๊ฑฐ, [MASK] ๋ณด์กด)
|
| 115 |
+
clean_text = remove_punctuation(raw_text)
|
| 116 |
+
logger.info(f"[NLP] ๊ตฌ๋์ ์ ๊ฑฐ ์๋ฃ: {len(clean_text)} ๊ธ์")
|
| 117 |
+
|
| 118 |
+
# 2. ๊ตฌ๋์ ๋ณต์
|
| 119 |
+
punctuated_text = self.punc_restorer.restore_punctuation(
|
| 120 |
+
clean_text,
|
| 121 |
+
add_space=add_space,
|
| 122 |
+
reduce=reduce_punc
|
| 123 |
+
)
|
| 124 |
+
logger.info(f"[NLP] ๊ตฌ๋์ ๋ณต์ ์๋ฃ: {len(punctuated_text)} ๊ธ์")
|
| 125 |
+
|
| 126 |
+
# 3. MLM ์์ธก
|
| 127 |
+
mask_predictions = self.mlm_predictor.predict_masks(punctuated_text)
|
| 128 |
+
logger.info(f"[NLP] MLM ์์ธก ์๋ฃ: {len(mask_predictions)}๊ฐ ๋ง์คํฌ")
|
| 129 |
+
|
| 130 |
+
# 4. ์ถ๋ ฅ์ฉ ํ
์คํธ ์์ฑ ([MASK] -> โก)
|
| 131 |
+
mask_replacement = self.config['tokens']['mask_replacement']
|
| 132 |
+
final_text = replace_mask_with_symbol(
|
| 133 |
+
punctuated_text,
|
| 134 |
+
mask_replacement
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# Extract mask info from OCR results or original text
|
| 138 |
+
mask_info_list = []
|
| 139 |
+
if ocr_results:
|
| 140 |
+
# Use OCR results to get order and type
|
| 141 |
+
for item in ocr_results:
|
| 142 |
+
if 'MASK' in item.get('type', ''):
|
| 143 |
+
mask_info_list.append({
|
| 144 |
+
'order': item.get('order', 0),
|
| 145 |
+
'type': item.get('type', 'MASK2'),
|
| 146 |
+
'text': item.get('text', '')
|
| 147 |
+
})
|
| 148 |
+
else:
|
| 149 |
+
# Fallback: extract from text
|
| 150 |
+
i = 0
|
| 151 |
+
while i < len(raw_text):
|
| 152 |
+
if raw_text[i] == '[' and 'MASK' in raw_text[i:i+10]:
|
| 153 |
+
end = raw_text.find(']', i)
|
| 154 |
+
if end != -1:
|
| 155 |
+
mask_text = raw_text[i:end+1]
|
| 156 |
+
mask_type = 'MASK1' if 'MASK1' in mask_text else 'MASK2'
|
| 157 |
+
mask_info_list.append({
|
| 158 |
+
'order': len(mask_info_list), # Sequential order
|
| 159 |
+
'type': mask_type,
|
| 160 |
+
'text': mask_text
|
| 161 |
+
})
|
| 162 |
+
i = end + 1
|
| 163 |
+
continue
|
| 164 |
+
i += 1
|
| 165 |
+
|
| 166 |
+
# Format results according to specification
|
| 167 |
+
formatted_results = []
|
| 168 |
+
for idx, pred_list in enumerate(mask_predictions):
|
| 169 |
+
if idx < len(mask_info_list):
|
| 170 |
+
mask_info = mask_info_list[idx]
|
| 171 |
+
formatted_results.append({
|
| 172 |
+
"order": mask_info['order'],
|
| 173 |
+
"type": mask_info['type'],
|
| 174 |
+
"top_10": pred_list[:10] # Top-10 predictions
|
| 175 |
+
})
|
| 176 |
+
else:
|
| 177 |
+
# Fallback if mask_info_list is shorter
|
| 178 |
+
formatted_results.append({
|
| 179 |
+
"order": idx,
|
| 180 |
+
"type": "MASK2",
|
| 181 |
+
"top_10": pred_list[:10]
|
| 182 |
+
})
|
| 183 |
+
|
| 184 |
+
# Calculate statistics
|
| 185 |
+
top1_probs = [preds[0]['probability'] for preds in mask_predictions if preds]
|
| 186 |
+
statistics = {
|
| 187 |
+
"top1_probability_avg": float(sum(top1_probs) / len(top1_probs)) if top1_probs else 0.0,
|
| 188 |
+
"top1_probability_min": float(min(top1_probs)) if top1_probs else 0.0,
|
| 189 |
+
"top1_probability_max": float(max(top1_probs)) if top1_probs else 0.0,
|
| 190 |
+
"total_masks": len(mask_predictions)
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
return {
|
| 194 |
+
"punctuated_text_with_masks": final_text,
|
| 195 |
+
"results": formatted_results,
|
| 196 |
+
"statistics": statistics
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
except Exception as e:
|
| 200 |
+
logger.error(f"[NLP] ์ฒ๋ฆฌ ์ค ์ค๋ฅ: {e}", exc_info=True)
|
| 201 |
+
return {
|
| 202 |
+
"success": False,
|
| 203 |
+
"error": str(e)
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
def restore_punctuation_only(
|
| 207 |
+
self,
|
| 208 |
+
text: str,
|
| 209 |
+
add_space: bool = True,
|
| 210 |
+
reduce_punc: bool = True
|
| 211 |
+
) -> Dict[str, Any]:
|
| 212 |
+
"""
|
| 213 |
+
๊ตฌ๋์ ๋ณต์๋ง ์ํํฉ๋๋ค (MLM ์์ธก ์ ์ธ).
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
text: ์
๋ ฅ ํ
์คํธ
|
| 217 |
+
add_space: ๊ตฌ๋์ ๋ค ๊ณต๋ฐฑ ์ถ๊ฐ ์ฌ๋ถ
|
| 218 |
+
reduce_punc: ๊ตฌ๋์ ๋จ์ํ ์ฌ๋ถ
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
๊ตฌ๋์ ๋ณต์ ๊ฒฐ๊ณผ
|
| 222 |
+
"""
|
| 223 |
+
self._load_models()
|
| 224 |
+
|
| 225 |
+
try:
|
| 226 |
+
clean_text = remove_punctuation(text)
|
| 227 |
+
punctuated_text = self.punc_restorer.restore_punctuation(
|
| 228 |
+
clean_text,
|
| 229 |
+
add_space=add_space,
|
| 230 |
+
reduce=reduce_punc
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
return {
|
| 234 |
+
"success": True,
|
| 235 |
+
"original_text": text,
|
| 236 |
+
"clean_text": clean_text,
|
| 237 |
+
"punctuated_text": punctuated_text
|
| 238 |
+
}
|
| 239 |
+
except Exception as e:
|
| 240 |
+
logger.error(f"[NLP] ๊ตฌ๋์ ๋ณต์ ์ค ์ค๋ฅ: {e}", exc_info=True)
|
| 241 |
+
return {
|
| 242 |
+
"success": False,
|
| 243 |
+
"error": str(e)
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
def predict_masks_only(
|
| 247 |
+
self,
|
| 248 |
+
text: str
|
| 249 |
+
) -> Dict[str, Any]:
|
| 250 |
+
"""
|
| 251 |
+
MLM ์์ธก๋ง ์ํํฉ๋๋ค (๊ตฌ๋์ ๋ณต์ ์ ์ธ).
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
text: ๋ง์คํฌ๊ฐ ํฌํจ๋ ํ
์คํธ
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
MLM ์์ธก ๊ฒฐ๊ณผ
|
| 258 |
+
"""
|
| 259 |
+
self._load_models()
|
| 260 |
+
|
| 261 |
+
try:
|
| 262 |
+
mask_predictions = self.mlm_predictor.predict_masks(text)
|
| 263 |
+
|
| 264 |
+
return {
|
| 265 |
+
"success": True,
|
| 266 |
+
"predictions": mask_predictions,
|
| 267 |
+
"mask_count": len(mask_predictions)
|
| 268 |
+
}
|
| 269 |
+
except Exception as e:
|
| 270 |
+
logger.error(f"[NLP] MLM ์์ธก ์ค ์ค๋ฅ: {e}", exc_info=True)
|
| 271 |
+
return {
|
| 272 |
+
"success": False,
|
| 273 |
+
"error": str(e)
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
# ================================================================================
|
| 278 |
+
# Global Accessor
|
| 279 |
+
# ================================================================================
|
| 280 |
+
_nlp_engine = None
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def get_nlp_engine(config_path: Optional[str] = None) -> NLPEngine:
|
| 284 |
+
"""
|
| 285 |
+
์ ์ญ NLP ์์ง ์ธ์คํด์ค๋ฅผ ๋ฐํํฉ๋๋ค (์ฑ๊ธํค ํจํด).
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
config_path: ์ค์ ํ์ผ ๊ฒฝ๏ฟฝ๏ฟฝ๏ฟฝ (None์ด๋ฉด ๊ธฐ๋ณธ ๊ฒฝ๋ก ์ฌ์ฉ)
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
NLPEngine ์ธ์คํด์ค
|
| 292 |
+
"""
|
| 293 |
+
global _nlp_engine
|
| 294 |
+
if _nlp_engine is None:
|
| 295 |
+
_nlp_engine = NLPEngine(config_path)
|
| 296 |
+
return _nlp_engine
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def process_text_with_nlp(
|
| 300 |
+
text: str,
|
| 301 |
+
ocr_results: Optional[List[Dict]] = None,
|
| 302 |
+
config_path: Optional[str] = None,
|
| 303 |
+
add_space: bool = True,
|
| 304 |
+
reduce_punc: bool = True
|
| 305 |
+
) -> Dict[str, Any]:
|
| 306 |
+
"""
|
| 307 |
+
ํธ์ ํจ์: ํ
์คํธ๋ฅผ NLP ํ์ดํ๋ผ์ธ์ผ๋ก ์ฒ๋ฆฌํฉ๋๋ค.
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
text: ์
๋ ฅ ํ
์คํธ
|
| 311 |
+
ocr_results: OCR ๊ฒฐ๊ณผ ๋ฆฌ์คํธ (order, type ์ ๋ณด ํฌํจ)
|
| 312 |
+
config_path: ์ค์ ํ์ผ ๊ฒฝ๋ก
|
| 313 |
+
add_space: ๊ตฌ๋์ ๋ค ๊ณต๋ฐฑ ์ถ๊ฐ ์ฌ๋ถ
|
| 314 |
+
reduce_punc: ๊ตฌ๋์ ๋จ์ํ ์ฌ๋ถ
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
์ฒ๋ฆฌ ๊ฒฐ๊ณผ ๋์
๋๋ฆฌ
|
| 318 |
+
"""
|
| 319 |
+
engine = get_nlp_engine(config_path)
|
| 320 |
+
return engine.process_text(text, ocr_results=ocr_results, add_space=add_space, reduce_punc=reduce_punc)
|
| 321 |
+
|
ai_modules/ocr_engine.py
ADDED
|
@@ -0,0 +1,767 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
================================================================================
|
| 4 |
+
OCR Ensemble Module for Epitext AI Project
|
| 5 |
+
================================================================================
|
| 6 |
+
๋ชจ๋๋ช
: ocr_engine.py (v12.0.0 - Production Ready)
|
| 7 |
+
์์ฑ์ผ: 2025-12-03
|
| 8 |
+
๋ชฉ์ : Google Vision API + HRCenterNet ์์๋ธ ๊ธฐ๋ฐ ํ์ OCR ๋ฐ ์์ ์์ญ ํ์ง
|
| 9 |
+
์ํ: Production Ready
|
| 10 |
+
================================================================================
|
| 11 |
+
"""
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import io
|
| 15 |
+
import cv2
|
| 16 |
+
import json
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
import torchvision
|
| 20 |
+
import re
|
| 21 |
+
import logging
|
| 22 |
+
from torch.autograd import Variable
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from PIL import Image
|
| 25 |
+
from typing import Dict, List, Optional, Tuple, Any
|
| 26 |
+
|
| 27 |
+
# ================================================================================
|
| 28 |
+
# Logging Configuration
|
| 29 |
+
# ================================================================================
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
# ================================================================================
|
| 33 |
+
# External Model Imports
|
| 34 |
+
# ================================================================================
|
| 35 |
+
try:
|
| 36 |
+
from ai_modules.models.resnet import ResnetCustom
|
| 37 |
+
from ai_modules.models.HRCenterNet import _HRCenterNet
|
| 38 |
+
logger.info("[INIT] ์ธ๋ถ ๋ชจ๋ธ ์ํฌํธ ์๋ฃ: ResnetCustom, HRCenterNet")
|
| 39 |
+
except ImportError as e:
|
| 40 |
+
logger.error(f"[INIT] ๋ชจ๋ธ ์ํฌํธ ์คํจ: {e}")
|
| 41 |
+
raise
|
| 42 |
+
|
| 43 |
+
# ================================================================================
|
| 44 |
+
# Google Vision API Import
|
| 45 |
+
# ================================================================================
|
| 46 |
+
try:
|
| 47 |
+
from google.cloud import vision
|
| 48 |
+
HAS_GOOGLE_VISION = True
|
| 49 |
+
except ImportError:
|
| 50 |
+
HAS_GOOGLE_VISION = False
|
| 51 |
+
logger.warning("[INIT] google-cloud-vision ํจํค์ง๊ฐ ์ค์น๋์ง ์์์ต๋๋ค.")
|
| 52 |
+
|
| 53 |
+
# ================================================================================
|
| 54 |
+
# Utility Functions
|
| 55 |
+
# ================================================================================
|
| 56 |
+
def is_hanja(text: str) -> bool:
|
| 57 |
+
if not text: return False
|
| 58 |
+
return re.match(r'[\u4e00-\u9fff]', text) is not None
|
| 59 |
+
|
| 60 |
+
def calculate_pixel_density(binary_img: np.ndarray, box: Dict) -> float:
|
| 61 |
+
x1, y1 = int(box['min_x']), int(box['min_y'])
|
| 62 |
+
x2, y2 = int(box['max_x']), int(box['max_y'])
|
| 63 |
+
h, w = binary_img.shape
|
| 64 |
+
x1, y1 = max(0, x1), max(0, y1)
|
| 65 |
+
x2, y2 = min(w, x2), min(h, y2)
|
| 66 |
+
if x2 <= x1 or y2 <= y1: return 0.0
|
| 67 |
+
roi = binary_img[y1:y2, x1:x2]
|
| 68 |
+
return cv2.countNonZero(roi) / ((x2 - x1) * (y2 - y1))
|
| 69 |
+
|
| 70 |
+
def load_ocr_config(config_path: Optional[str] = None) -> Dict:
|
| 71 |
+
"""์ค์ ํ์ผ ๋ก๋"""
|
| 72 |
+
if config_path is None:
|
| 73 |
+
config_path = str(Path(__file__).parent / "config" / "ocr_config.json")
|
| 74 |
+
|
| 75 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 76 |
+
return json.load(f)
|
| 77 |
+
|
| 78 |
+
# ================================================================================
|
| 79 |
+
# Text Detection Class
|
| 80 |
+
# ================================================================================
|
| 81 |
+
class TextDetector:
|
| 82 |
+
def __init__(self, device: torch.device, det_ckpt: str, config: Dict):
|
| 83 |
+
self.device = device
|
| 84 |
+
self.config = config
|
| 85 |
+
self.input_size = config['model_config']['input_size']
|
| 86 |
+
self.output_size = config['model_config']['output_size']
|
| 87 |
+
|
| 88 |
+
self.model = _HRCenterNet(32, 5, 0.1)
|
| 89 |
+
if not os.path.exists(det_ckpt):
|
| 90 |
+
raise FileNotFoundError(f"์ฒดํฌํฌ์ธํธ ํ์ผ ์์: {det_ckpt}")
|
| 91 |
+
|
| 92 |
+
state = torch.load(det_ckpt, map_location=self.device)
|
| 93 |
+
self.model.load_state_dict(state)
|
| 94 |
+
self.model = self.model.to(self.device)
|
| 95 |
+
self.model.eval()
|
| 96 |
+
|
| 97 |
+
self.transform = torchvision.transforms.Compose([
|
| 98 |
+
torchvision.transforms.Resize((self.input_size, self.input_size)),
|
| 99 |
+
torchvision.transforms.ToTensor()
|
| 100 |
+
])
|
| 101 |
+
|
| 102 |
+
@torch.no_grad()
|
| 103 |
+
def detect(self, image) -> Tuple[List, List]:
|
| 104 |
+
if isinstance(image, str): img = Image.open(image).convert("RGB")
|
| 105 |
+
elif isinstance(image, np.ndarray): img = Image.fromarray(image).convert("RGB")
|
| 106 |
+
else: img = image.convert("RGB")
|
| 107 |
+
|
| 108 |
+
image_tensor = self.transform(img).unsqueeze_(0)
|
| 109 |
+
inp = Variable(image_tensor).to(self.device, dtype=torch.float)
|
| 110 |
+
|
| 111 |
+
predict = self.model(inp)
|
| 112 |
+
predict_np = predict.data.cpu().numpy()
|
| 113 |
+
heatmap, offset_y, offset_x, width_map, height_map = predict_np[0]
|
| 114 |
+
|
| 115 |
+
bbox, score_list = [], []
|
| 116 |
+
Hc, Wc = img.size[1] / self.output_size, img.size[0] / self.output_size
|
| 117 |
+
|
| 118 |
+
# Config์์ NMS ์๊ณ๊ฐ ๋ก๋
|
| 119 |
+
nms_cfg = self.config.get('nms_config', {})
|
| 120 |
+
nms_score = nms_cfg.get('primary_threshold', 0.12)
|
| 121 |
+
|
| 122 |
+
idxs = np.where(heatmap.reshape(-1, 1) >= nms_score)[0]
|
| 123 |
+
if len(idxs) == 0:
|
| 124 |
+
nms_score = nms_cfg.get('fallback_threshold', 0.08)
|
| 125 |
+
idxs = np.where(heatmap.reshape(-1, 1) >= nms_score)[0]
|
| 126 |
+
|
| 127 |
+
for j in idxs:
|
| 128 |
+
row = j // self.output_size
|
| 129 |
+
col = j - row * self.output_size
|
| 130 |
+
bias_x = offset_x[row, col] * Hc
|
| 131 |
+
bias_y = offset_y[row, col] * Wc
|
| 132 |
+
width = width_map[row, col] * self.output_size * Hc
|
| 133 |
+
height = height_map[row, col] * self.output_size * Wc
|
| 134 |
+
|
| 135 |
+
score_list.append(float(heatmap[row, col]))
|
| 136 |
+
row = row * Hc + bias_y
|
| 137 |
+
col = col * Wc + bias_x
|
| 138 |
+
|
| 139 |
+
top = row - width / 2.0
|
| 140 |
+
left = col - height / 2.0
|
| 141 |
+
bottom = row + width / 2.0
|
| 142 |
+
right = col + height / 2.0
|
| 143 |
+
bbox.append([left, top, max(0.0, right - left), max(0.0, bottom - top)])
|
| 144 |
+
|
| 145 |
+
if not bbox: return [], []
|
| 146 |
+
|
| 147 |
+
xyxy = [[x, y, x+w, y+h] for x, y, w, h in bbox]
|
| 148 |
+
keep = torchvision.ops.nms(
|
| 149 |
+
torch.tensor(xyxy, dtype=torch.float32),
|
| 150 |
+
scores=torch.tensor(score_list, dtype=torch.float32),
|
| 151 |
+
iou_threshold=nms_cfg.get('iou_threshold', 0.05)
|
| 152 |
+
).cpu().numpy().tolist()
|
| 153 |
+
|
| 154 |
+
res_boxes, res_scores = [], []
|
| 155 |
+
W, H = img.size
|
| 156 |
+
for k in keep:
|
| 157 |
+
idx = int(k)
|
| 158 |
+
x, y, w, h = bbox[idx]
|
| 159 |
+
x = max(0.0, min(x, W - 1.0))
|
| 160 |
+
y = max(0.0, min(y, H - 1.0))
|
| 161 |
+
w = max(0.0, min(w, W - x))
|
| 162 |
+
h = max(0.0, min(h, H - y))
|
| 163 |
+
if w > 1 and h > 1:
|
| 164 |
+
res_boxes.append([x, y, w, h])
|
| 165 |
+
res_scores.append(score_list[idx])
|
| 166 |
+
|
| 167 |
+
return res_boxes, res_scores
|
| 168 |
+
|
| 169 |
+
# ================================================================================
|
| 170 |
+
# Merging Logics (Config ์ ์ฉ)
|
| 171 |
+
# ================================================================================
|
| 172 |
+
def merge_vertical_fragments(boxes, scores, config):
|
| 173 |
+
if not boxes: return [], []
|
| 174 |
+
rects = [{'x': b[0], 'y': b[1], 'w': b[2], 'h': b[3],
|
| 175 |
+
'x2': b[0]+b[2], 'y2': b[1]+b[3],
|
| 176 |
+
'cx': b[0]+b[2]/2, 'cy': b[1]+b[3]/2, 'score': s}
|
| 177 |
+
for b, s in zip(boxes, scores)]
|
| 178 |
+
|
| 179 |
+
cfg = config['merge_config']['vertical_fragments']
|
| 180 |
+
|
| 181 |
+
while True:
|
| 182 |
+
rects.sort(key=lambda r: r['y'])
|
| 183 |
+
merged = False
|
| 184 |
+
new_rects, skip_indices = [], set()
|
| 185 |
+
|
| 186 |
+
for i in range(len(rects)):
|
| 187 |
+
if i in skip_indices: continue
|
| 188 |
+
current = rects[i]
|
| 189 |
+
best_cand_idx = -1
|
| 190 |
+
|
| 191 |
+
for j in range(i + 1, min(i + 5, len(rects))):
|
| 192 |
+
if j in skip_indices: continue
|
| 193 |
+
candidate = rects[j]
|
| 194 |
+
|
| 195 |
+
avg_w = (current['w'] + candidate['w']) / 2
|
| 196 |
+
if abs(current['cx'] - candidate['cx']) > avg_w * cfg['horizontal_center_ratio']: continue
|
| 197 |
+
if (candidate['y'] - current['y2']) > avg_w * cfg['vertical_gap_ratio']: continue
|
| 198 |
+
|
| 199 |
+
new_h = max(current['y2'], candidate['y2']) - min(current['y'], candidate['y'])
|
| 200 |
+
new_w = max(current['x2'], candidate['x2']) - min(current['x'], candidate['x'])
|
| 201 |
+
|
| 202 |
+
is_safe_ratio = (new_h / new_w) < cfg['aspect_ratio_limit']
|
| 203 |
+
cur_square = (current['h'] / current['w']) > 0.85
|
| 204 |
+
cand_square = (candidate['h'] / candidate['w']) > 0.85
|
| 205 |
+
is_overlapped = (candidate['y'] - current['y2']) < -avg_w * 0.2
|
| 206 |
+
|
| 207 |
+
if is_safe_ratio and (not (cur_square and cand_square) or is_overlapped):
|
| 208 |
+
best_cand_idx = j
|
| 209 |
+
break
|
| 210 |
+
|
| 211 |
+
if best_cand_idx != -1:
|
| 212 |
+
cand = rects[best_cand_idx]
|
| 213 |
+
nx, ny = min(current['x'], cand['x']), min(current['y'], cand['y'])
|
| 214 |
+
nx2, ny2 = max(current['x2'], cand['x2']), max(current['y2'], cand['y2'])
|
| 215 |
+
new_rects.append({
|
| 216 |
+
'x': nx, 'y': ny, 'w': nx2-nx, 'h': ny2-ny,
|
| 217 |
+
'x2': nx2, 'y2': ny2, 'cx': (nx+nx2)/2, 'cy': (ny+ny2)/2,
|
| 218 |
+
'score': max(current['score'], cand['score'])
|
| 219 |
+
})
|
| 220 |
+
skip_indices.add(best_cand_idx)
|
| 221 |
+
merged = True
|
| 222 |
+
else:
|
| 223 |
+
new_rects.append(current)
|
| 224 |
+
rects = new_rects
|
| 225 |
+
if not merged: break
|
| 226 |
+
|
| 227 |
+
return [[r['x'], r['y'], r['w'], r['h']] for r in rects], [r['score'] for r in rects]
|
| 228 |
+
|
| 229 |
+
def merge_google_symbols(symbols, config):
|
| 230 |
+
if not symbols: return []
|
| 231 |
+
cfg = config['merge_config']['google_symbols']
|
| 232 |
+
|
| 233 |
+
while True:
|
| 234 |
+
symbols.sort(key=lambda s: s['min_y'])
|
| 235 |
+
merged = False
|
| 236 |
+
new_symbols, skip_indices = [], set()
|
| 237 |
+
|
| 238 |
+
for i in range(len(symbols)):
|
| 239 |
+
if i in skip_indices: continue
|
| 240 |
+
curr = symbols[i]
|
| 241 |
+
best_cand_idx = -1
|
| 242 |
+
|
| 243 |
+
for j in range(i + 1, min(i + 5, len(symbols))):
|
| 244 |
+
if j in skip_indices: continue
|
| 245 |
+
cand = symbols[j]
|
| 246 |
+
|
| 247 |
+
avg_w = (curr['width'] + cand['width']) / 2
|
| 248 |
+
if abs(curr['center_x'] - cand['center_x']) > avg_w * cfg['horizontal_center_ratio']: continue
|
| 249 |
+
|
| 250 |
+
gap = cand['min_y'] - curr['max_y']
|
| 251 |
+
is_touching = gap < (avg_w * cfg['vertical_gap_ratio'])
|
| 252 |
+
|
| 253 |
+
new_h = max(curr['max_y'], cand['max_y']) - min(curr['min_y'], cand['min_y'])
|
| 254 |
+
new_w = max(curr['max_x'], cand['max_x']) - min(curr['min_x'], cand['min_x'])
|
| 255 |
+
|
| 256 |
+
is_both_square = (curr['height']/curr['width'] > 0.85) and (cand['height']/cand['width'] > 0.85)
|
| 257 |
+
is_safe_ratio = (new_h / new_w) < cfg['aspect_ratio_limit']
|
| 258 |
+
is_duplicate = (curr['text'] == cand['text'])
|
| 259 |
+
|
| 260 |
+
if (is_touching and is_safe_ratio and not is_both_square) or is_duplicate:
|
| 261 |
+
best_cand_idx = j
|
| 262 |
+
break
|
| 263 |
+
|
| 264 |
+
if best_cand_idx != -1:
|
| 265 |
+
cand = symbols[best_cand_idx]
|
| 266 |
+
merged_sym = {
|
| 267 |
+
'text': curr['text'],
|
| 268 |
+
'min_x': min(curr['min_x'], cand['min_x']), 'min_y': min(curr['min_y'], cand['min_y']),
|
| 269 |
+
'max_x': max(curr['max_x'], cand['max_x']), 'max_y': max(curr['max_y'], cand['max_y']),
|
| 270 |
+
'confidence': max(curr['confidence'], cand['confidence']),
|
| 271 |
+
'source': 'Google'
|
| 272 |
+
}
|
| 273 |
+
merged_sym['width'] = merged_sym['max_x'] - merged_sym['min_x']
|
| 274 |
+
merged_sym['height'] = merged_sym['max_y'] - merged_sym['min_y']
|
| 275 |
+
merged_sym['center_x'] = (merged_sym['min_x'] + merged_sym['max_x']) / 2
|
| 276 |
+
merged_sym['center_y'] = (merged_sym['min_y'] + merged_sym['max_y']) / 2
|
| 277 |
+
new_symbols.append(merged_sym)
|
| 278 |
+
skip_indices.add(best_cand_idx)
|
| 279 |
+
merged = True
|
| 280 |
+
else:
|
| 281 |
+
new_symbols.append(curr)
|
| 282 |
+
symbols = new_symbols
|
| 283 |
+
if not merged: break
|
| 284 |
+
return symbols
|
| 285 |
+
|
| 286 |
+
# ================================================================================
|
| 287 |
+
# Models Execution
|
| 288 |
+
# ================================================================================
|
| 289 |
+
def get_google_ocr(content: bytes, config: Dict, google_json_path: Optional[str] = None) -> List[Dict]:
|
| 290 |
+
if not HAS_GOOGLE_VISION: return []
|
| 291 |
+
if google_json_path and os.path.exists(google_json_path):
|
| 292 |
+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = google_json_path
|
| 293 |
+
|
| 294 |
+
try:
|
| 295 |
+
client = vision.ImageAnnotatorClient()
|
| 296 |
+
image = vision.Image(content=content)
|
| 297 |
+
context = vision.ImageContext(language_hints=["zh-Hant"])
|
| 298 |
+
response = client.document_text_detection(image=image, image_context=context)
|
| 299 |
+
|
| 300 |
+
if not response.full_text_annotation: return []
|
| 301 |
+
|
| 302 |
+
symbols = []
|
| 303 |
+
for page in response.full_text_annotation.pages:
|
| 304 |
+
for block in page.blocks:
|
| 305 |
+
for paragraph in block.paragraphs:
|
| 306 |
+
for word in paragraph.words:
|
| 307 |
+
for s in word.symbols:
|
| 308 |
+
if not is_hanja(s.text): continue
|
| 309 |
+
v = s.bounding_box.vertices
|
| 310 |
+
x, y = [p.x for p in v], [p.y for p in v]
|
| 311 |
+
symbols.append({
|
| 312 |
+
'text': s.text,
|
| 313 |
+
'center_x': (min(x)+max(x))/2, 'center_y': (min(y)+max(y))/2,
|
| 314 |
+
'min_x': min(x), 'max_x': max(x), 'min_y': min(y), 'max_y': max(y),
|
| 315 |
+
'width': max(x)-min(x), 'height': max(y)-min(y),
|
| 316 |
+
'confidence': s.confidence, 'source': 'Google'
|
| 317 |
+
})
|
| 318 |
+
|
| 319 |
+
original_count = len(symbols)
|
| 320 |
+
symbols = merge_google_symbols(symbols, config)
|
| 321 |
+
if len(symbols) < original_count:
|
| 322 |
+
logger.info(f"[OCR] Google ๋ณํฉ: {original_count} -> {len(symbols)}๊ฐ")
|
| 323 |
+
return symbols
|
| 324 |
+
except Exception as e:
|
| 325 |
+
logger.error(f"[OCR] Google Vision Error: {e}")
|
| 326 |
+
return []
|
| 327 |
+
|
| 328 |
+
def get_custom_model_ocr(image_path, binary_img, detector, recognizer, config):
|
| 329 |
+
try:
|
| 330 |
+
pil_img = Image.open(image_path).convert("RGB")
|
| 331 |
+
boxes, scores = detector.detect(pil_img)
|
| 332 |
+
if not boxes: return []
|
| 333 |
+
|
| 334 |
+
# Merge
|
| 335 |
+
original_count = len(boxes)
|
| 336 |
+
boxes, scores = merge_vertical_fragments(boxes, scores, config)
|
| 337 |
+
if len(boxes) < original_count:
|
| 338 |
+
logger.info(f"[OCR] Custom ๋ณํฉ: {original_count} -> {len(boxes)}๊ฐ")
|
| 339 |
+
|
| 340 |
+
# Stats
|
| 341 |
+
all_heights = [b[3] for b in boxes]
|
| 342 |
+
all_widths = [b[2] for b in boxes]
|
| 343 |
+
median_h = np.median(all_heights) if all_heights else 0
|
| 344 |
+
median_w = np.median(all_widths) if all_widths else 0
|
| 345 |
+
|
| 346 |
+
# Recognize
|
| 347 |
+
crops = [pil_img.crop((int(b[0]), int(b[1]), int(b[0]+b[2]), int(b[1]+b[3]))) for b in boxes]
|
| 348 |
+
chars = recognizer(crops) if crops else []
|
| 349 |
+
|
| 350 |
+
# Filter & Mask (Config values)
|
| 351 |
+
symbols = []
|
| 352 |
+
img_h, _ = binary_img.shape
|
| 353 |
+
ft = config['filtering_thresholds']
|
| 354 |
+
it = config['ink_detection_thresholds']
|
| 355 |
+
|
| 356 |
+
for char, (x, y, w, h), score in zip(chars, boxes, scores):
|
| 357 |
+
if not char or char == "โ ": continue
|
| 358 |
+
|
| 359 |
+
box_dict = {'min_x': x, 'min_y': y, 'max_x': x+w, 'max_y': y+h}
|
| 360 |
+
density = calculate_pixel_density(binary_img, box_dict)
|
| 361 |
+
|
| 362 |
+
# Hard Filters
|
| 363 |
+
if score < ft['min_score_hard'] or density < ft['density_min_hard']: continue
|
| 364 |
+
# Smart Filters
|
| 365 |
+
if score < ft['smart_score_threshold'] and density < ft['smart_density_threshold']: continue
|
| 366 |
+
|
| 367 |
+
# Title Removal
|
| 368 |
+
is_huge = (h > median_h * 3.5) if median_h > 0 else False
|
| 369 |
+
is_top = (y < img_h * 0.15) and (h > median_h * 2.5 or w > median_w * 2.5) if median_h > 0 else False
|
| 370 |
+
if median_h > 0 and (is_huge or is_top): continue
|
| 371 |
+
|
| 372 |
+
# Masking
|
| 373 |
+
final_text, final_type = char, 'TEXT'
|
| 374 |
+
if density >= it['density_ink_heavy']:
|
| 375 |
+
final_text, final_type = '[MASK1]', 'MASK1'
|
| 376 |
+
elif density >= it['density_ink_partial']:
|
| 377 |
+
final_text, final_type = '[MASK2]', 'MASK2'
|
| 378 |
+
else:
|
| 379 |
+
if not is_hanja(char): continue
|
| 380 |
+
|
| 381 |
+
symbols.append({
|
| 382 |
+
'text': final_text, 'type': final_type,
|
| 383 |
+
'center_x': x+w/2, 'center_y': y+h/2,
|
| 384 |
+
'min_x': x, 'max_x': x+w, 'min_y': y, 'max_y': y+h,
|
| 385 |
+
'width': w, 'height': h,
|
| 386 |
+
'confidence': float(score), 'source': 'Custom', 'density': density
|
| 387 |
+
})
|
| 388 |
+
|
| 389 |
+
logger.info(f"[OCR] Custom Model ์๋ฃ: {len(symbols)}๊ฐ")
|
| 390 |
+
return symbols
|
| 391 |
+
except Exception as e:
|
| 392 |
+
logger.error(f"[OCR] Custom Model Error: {e}")
|
| 393 |
+
return []
|
| 394 |
+
|
| 395 |
+
# ================================================================================
|
| 396 |
+
# Ensemble Reconstruction (Full Logic from Script)
|
| 397 |
+
# ================================================================================
|
| 398 |
+
def ensemble_reconstruction(google_syms, custom_syms, binary_img, config):
|
| 399 |
+
logger.info("[ENSEMBLE] ์์๋ธ ์ฌ๊ตฌ์ฑ ์์...")
|
| 400 |
+
img_h, img_w = binary_img.shape
|
| 401 |
+
ec = config['ensemble_config']
|
| 402 |
+
ft = config['filtering_thresholds']
|
| 403 |
+
it = config['ink_detection_thresholds']
|
| 404 |
+
|
| 405 |
+
# --- Helper Functions ---
|
| 406 |
+
def filter_excessive_masks(nodes):
|
| 407 |
+
filtered, buffer = [], []
|
| 408 |
+
threshold = ec['excessive_mask_threshold']
|
| 409 |
+
for node in nodes:
|
| 410 |
+
if 'MASK' in node.get('type', 'TEXT'): buffer.append(node)
|
| 411 |
+
else:
|
| 412 |
+
if buffer:
|
| 413 |
+
if len(buffer) < threshold: filtered.extend(buffer)
|
| 414 |
+
buffer = []
|
| 415 |
+
filtered.append(node)
|
| 416 |
+
if buffer and len(buffer) < threshold: filtered.extend(buffer)
|
| 417 |
+
return filtered
|
| 418 |
+
|
| 419 |
+
def merge_split_masks(nodes, avg_h):
|
| 420 |
+
if not nodes: return []
|
| 421 |
+
merged, skip = [], False
|
| 422 |
+
for i in range(len(nodes)):
|
| 423 |
+
if skip: skip = False; continue
|
| 424 |
+
curr = nodes[i]
|
| 425 |
+
if i == len(nodes)-1: merged.append(curr); break
|
| 426 |
+
|
| 427 |
+
next_node = nodes[i+1]
|
| 428 |
+
if 'MASK' in curr.get('type','TEXT') and 'MASK' in next_node.get('type','TEXT'):
|
| 429 |
+
combined_h = next_node['max_y'] - curr['min_y']
|
| 430 |
+
if combined_h < avg_h * 1.8:
|
| 431 |
+
new_node = curr.copy()
|
| 432 |
+
new_node.update({'max_y': next_node['max_y'], 'height': next_node['max_y'] - curr['min_y']})
|
| 433 |
+
density = calculate_pixel_density(binary_img, new_node)
|
| 434 |
+
new_node['density'] = density
|
| 435 |
+
|
| 436 |
+
if density < ft['density_min_hard']:
|
| 437 |
+
skip = True; continue
|
| 438 |
+
|
| 439 |
+
m_type = 'MASK1' if density >= it['density_ink_heavy'] else 'MASK2'
|
| 440 |
+
new_node.update({'type': m_type, 'text': f'[{m_type}]'})
|
| 441 |
+
merged.append(new_node)
|
| 442 |
+
skip = True
|
| 443 |
+
continue
|
| 444 |
+
merged.append(curr)
|
| 445 |
+
return merged
|
| 446 |
+
|
| 447 |
+
def resolve_overlaps(boxes):
|
| 448 |
+
if not boxes: return []
|
| 449 |
+
boxes.sort(key=lambda x: x['min_y'])
|
| 450 |
+
for i in range(len(boxes)-1):
|
| 451 |
+
curr, next_box = boxes[i], boxes[i+1]
|
| 452 |
+
if min(curr['max_x'], next_box['max_x']) - max(curr['min_x'], next_box['min_x']) <= 0: continue
|
| 453 |
+
|
| 454 |
+
if curr['max_y'] > next_box['min_y']:
|
| 455 |
+
mid_y = (curr['max_y'] + next_box['min_y']) / 2
|
| 456 |
+
curr['max_y'], curr['height'] = mid_y, mid_y - curr['min_y']
|
| 457 |
+
next_box['min_y'], next_box['height'] = mid_y, next_box['max_y'] - mid_y
|
| 458 |
+
return boxes
|
| 459 |
+
|
| 460 |
+
def filter_google_overlaps(g_boxes, c_boxes):
|
| 461 |
+
if not g_boxes: return c_boxes
|
| 462 |
+
filtered = []
|
| 463 |
+
for c in c_boxes:
|
| 464 |
+
is_dup = False
|
| 465 |
+
for g in g_boxes:
|
| 466 |
+
dx = abs(c['center_x'] - g['center_x'])
|
| 467 |
+
dy = abs(c['center_y'] - g['center_y'])
|
| 468 |
+
# MASK is preserved even if overlapping
|
| 469 |
+
if 'MASK' in c.get('type', 'TEXT'): pass
|
| 470 |
+
elif (min(c['max_x'], g['max_x']) > max(c['min_x'], g['min_x']) and
|
| 471 |
+
min(c['max_y'], g['max_y']) > max(c['min_y'], g['min_y'])) or \
|
| 472 |
+
(dx < g['width']*0.4 and dy < g['height']*0.4):
|
| 473 |
+
is_dup = True; break
|
| 474 |
+
if not is_dup: filtered.append(c)
|
| 475 |
+
return filtered
|
| 476 |
+
|
| 477 |
+
def infer_gaps(col, step_y, avg_w):
|
| 478 |
+
if not col: return []
|
| 479 |
+
col.sort(key=lambda s: s['center_y'])
|
| 480 |
+
filled = []
|
| 481 |
+
for i, curr in enumerate(col):
|
| 482 |
+
if i > 0:
|
| 483 |
+
prev = col[i-1]
|
| 484 |
+
gap = curr['center_y'] - prev['center_y']
|
| 485 |
+
if gap > step_y * ec['gap_inference_ratio']:
|
| 486 |
+
missing = int(round(gap/step_y)) - 1
|
| 487 |
+
if missing > 0:
|
| 488 |
+
step = gap / (missing + 1)
|
| 489 |
+
for k in range(1, missing + 1):
|
| 490 |
+
ny = prev['center_y'] + k*step
|
| 491 |
+
nb = {'min_x': curr['center_x'] - avg_w/2, 'max_x': curr['center_x'] + avg_w/2,
|
| 492 |
+
'min_y': max(0, ny - step_y*0.4), 'max_y': min(img_h, ny + step_y*0.4)}
|
| 493 |
+
nb.update({'height': nb['max_y']-nb['min_y'], 'width': nb['max_x']-nb['min_x'],
|
| 494 |
+
'center_x': (nb['min_x']+nb['max_x'])/2, 'center_y': (nb['min_y']+nb['max_y'])/2})
|
| 495 |
+
|
| 496 |
+
d = calculate_pixel_density(binary_img, nb)
|
| 497 |
+
if d < ft['density_min_hard']: continue
|
| 498 |
+
|
| 499 |
+
mt = 'MASK1' if d >= it['density_ink_heavy'] else 'MASK2'
|
| 500 |
+
nb.update({'text': f'[{mt}]', 'type': mt, 'density': d, 'confidence': 0.0, 'source': 'Inferred'})
|
| 501 |
+
filled.append(nb)
|
| 502 |
+
filled.append(curr)
|
| 503 |
+
return filled
|
| 504 |
+
|
| 505 |
+
def check_ink_on_google(g_syms):
|
| 506 |
+
filtered = []
|
| 507 |
+
for s in g_syms:
|
| 508 |
+
d = calculate_pixel_density(binary_img, s)
|
| 509 |
+
s['density'] = d
|
| 510 |
+
if d >= it['density_ink_heavy']: s.update({'type': 'MASK1', 'text': '[MASK1]'})
|
| 511 |
+
elif d >= it['density_ink_partial']: s.update({'type': 'MASK2', 'text': '[MASK2]'})
|
| 512 |
+
elif d < ft['density_min_hard']: continue # Hallucination check
|
| 513 |
+
else: s['type'] = 'TEXT'
|
| 514 |
+
filtered.append(s)
|
| 515 |
+
return filtered
|
| 516 |
+
|
| 517 |
+
# --- Preprocessing ---
|
| 518 |
+
all_h = ([s['height'] for s in google_syms] + [s['height'] for s in custom_syms])
|
| 519 |
+
median_h = np.median(all_h) if all_h else 30.0
|
| 520 |
+
|
| 521 |
+
# Filter Height & Check Ink
|
| 522 |
+
def global_remove_tall_and_top(boxes, median_h, threshold=2.0):
|
| 523 |
+
if not boxes: return []
|
| 524 |
+
filtered = []
|
| 525 |
+
for b in boxes:
|
| 526 |
+
if b['height'] > median_h * threshold: continue
|
| 527 |
+
if b['min_y'] < img_h * 0.15 and b['height'] > median_h * 2.5: continue
|
| 528 |
+
filtered.append(b)
|
| 529 |
+
return filtered
|
| 530 |
+
|
| 531 |
+
if google_syms:
|
| 532 |
+
google_syms = global_remove_tall_and_top(google_syms, median_h, threshold=2.0)
|
| 533 |
+
google_syms = check_ink_on_google(google_syms)
|
| 534 |
+
if custom_syms:
|
| 535 |
+
custom_syms = global_remove_tall_and_top(custom_syms, median_h, threshold=3.5)
|
| 536 |
+
|
| 537 |
+
# Resize & Filter Custom
|
| 538 |
+
avg_w = np.mean([s['width'] for s in google_syms]) if google_syms else 0
|
| 539 |
+
median_w = np.median([s['width'] for s in google_syms]) if google_syms else 0
|
| 540 |
+
|
| 541 |
+
processed_custom = []
|
| 542 |
+
for s in custom_syms:
|
| 543 |
+
if 'MASK' in s.get('type', 'TEXT'):
|
| 544 |
+
processed_custom.append(s); continue
|
| 545 |
+
|
| 546 |
+
if (s['width']*s['height'] > (median_w*median_h)*0.2 and
|
| 547 |
+
s['width'] > median_w*0.3 and s['height'] > median_h*0.3):
|
| 548 |
+
|
| 549 |
+
# Resize logic
|
| 550 |
+
if s['width'] < median_w*0.8 or s['height'] < median_h*0.8:
|
| 551 |
+
tw = max(s['width'], median_w*0.9)
|
| 552 |
+
th = max(s['height'], median_h*0.9)
|
| 553 |
+
cx, cy = s['center_x'], s['center_y']
|
| 554 |
+
s.update({'min_x': max(0, cx-tw/2), 'max_x': min(img_w, cx+tw/2),
|
| 555 |
+
'min_y': max(0, cy-th/2), 'max_y': min(img_h, cy+th/2)})
|
| 556 |
+
s.update({'width': s['max_x']-s['min_x'], 'height': s['max_y']-s['min_y']})
|
| 557 |
+
processed_custom.append(s)
|
| 558 |
+
|
| 559 |
+
custom_syms = filter_google_overlaps(google_syms, processed_custom)
|
| 560 |
+
|
| 561 |
+
if not google_syms and not custom_syms: return [], []
|
| 562 |
+
|
| 563 |
+
# --- Column Grouping ---
|
| 564 |
+
all_syms = google_syms + custom_syms
|
| 565 |
+
columns = []
|
| 566 |
+
if all_syms:
|
| 567 |
+
for s in sorted(all_syms, key=lambda x: -x['center_x']):
|
| 568 |
+
found = False
|
| 569 |
+
for col in columns:
|
| 570 |
+
cx = sum(c['center_x'] for c in col) / len(col)
|
| 571 |
+
if abs(s['center_x'] - cx) < (avg_w if avg_w else s['width']) * ec['column_grouping_ratio']:
|
| 572 |
+
col.append(s); found = True; break
|
| 573 |
+
if not found: columns.append([s])
|
| 574 |
+
|
| 575 |
+
# Vertical Step Calculation
|
| 576 |
+
global_steps = []
|
| 577 |
+
for col in columns:
|
| 578 |
+
col.sort(key=lambda s: s['center_y'])
|
| 579 |
+
for k in range(len(col)-1):
|
| 580 |
+
step = col[k+1]['center_y'] - col[k]['center_y']
|
| 581 |
+
if median_h * 0.8 < step < median_h * 1.5: global_steps.append(step)
|
| 582 |
+
global_step = np.median(global_steps) if global_steps else median_h * 1.1
|
| 583 |
+
|
| 584 |
+
# --- Reconstruction ---
|
| 585 |
+
final_boxes, lines = [], []
|
| 586 |
+
for col in columns:
|
| 587 |
+
col.sort(key=lambda s: s['center_y'])
|
| 588 |
+
local_steps = [col[k+1]['center_y'] - col[k]['center_y'] for k in range(len(col)-1)
|
| 589 |
+
if median_h*0.8 < (col[k+1]['center_y'] - col[k]['center_y']) < median_h*1.5]
|
| 590 |
+
step_y = np.median(local_steps) if local_steps else global_step
|
| 591 |
+
|
| 592 |
+
# Deduplication in column
|
| 593 |
+
unique_col = []
|
| 594 |
+
if col:
|
| 595 |
+
prev = col[0]
|
| 596 |
+
unique_col.append(prev)
|
| 597 |
+
for k in range(1, len(col)):
|
| 598 |
+
curr = col[k]
|
| 599 |
+
dist_y = abs(curr['center_y'] - prev['center_y'])
|
| 600 |
+
is_same_text = (curr.get('text') == prev.get('text'))
|
| 601 |
+
is_close = (dist_y < median_h * 0.6)
|
| 602 |
+
|
| 603 |
+
if is_close:
|
| 604 |
+
prev_is_mask = 'MASK' in prev.get('type', 'TEXT')
|
| 605 |
+
curr_is_mask = 'MASK' in curr.get('type', 'TEXT')
|
| 606 |
+
|
| 607 |
+
if prev_is_mask and curr_is_mask:
|
| 608 |
+
if prev['density'] < curr['density']:
|
| 609 |
+
unique_col.pop()
|
| 610 |
+
unique_col.append(curr)
|
| 611 |
+
prev = curr
|
| 612 |
+
continue
|
| 613 |
+
elif prev_is_mask and not curr_is_mask:
|
| 614 |
+
continue
|
| 615 |
+
elif not prev_is_mask and curr_is_mask:
|
| 616 |
+
unique_col.pop()
|
| 617 |
+
unique_col.append(curr)
|
| 618 |
+
prev = curr
|
| 619 |
+
continue
|
| 620 |
+
|
| 621 |
+
if is_same_text and is_close:
|
| 622 |
+
if prev.get('source') == 'Google':
|
| 623 |
+
continue
|
| 624 |
+
elif curr.get('source') == 'Google':
|
| 625 |
+
unique_col.pop()
|
| 626 |
+
unique_col.append(curr)
|
| 627 |
+
prev = curr
|
| 628 |
+
else:
|
| 629 |
+
continue
|
| 630 |
+
else:
|
| 631 |
+
unique_col.append(curr)
|
| 632 |
+
prev = curr
|
| 633 |
+
|
| 634 |
+
col = infer_gaps(unique_col, step_y, avg_w if avg_w else median_h)
|
| 635 |
+
|
| 636 |
+
# Gap Filling with Masks
|
| 637 |
+
filled_col, cy = [], col[0]['min_y'] if col else 0
|
| 638 |
+
for item in col:
|
| 639 |
+
gap = item['min_y'] - cy
|
| 640 |
+
if gap > step_y * 1.2:
|
| 641 |
+
mb = {'min_x': item['center_x'] - (avg_w if avg_w else median_h)/2,
|
| 642 |
+
'max_x': item['center_x'] + (avg_w if avg_w else median_h)/2,
|
| 643 |
+
'min_y': max(0, cy + gap*0.1), 'max_y': min(img_h, item['min_y'] - gap*0.1)}
|
| 644 |
+
d = calculate_pixel_density(binary_img, mb)
|
| 645 |
+
if d >= ft['density_min_hard']:
|
| 646 |
+
mt = 'MASK1' if d >= it['density_ink_heavy'] else 'MASK2'
|
| 647 |
+
if d >= it['density_ink_partial']:
|
| 648 |
+
filled_col.append({'text': f'[{mt}]', 'type': mt, 'density': d,
|
| 649 |
+
'min_x': mb['min_x'], 'max_x': mb['max_x'],
|
| 650 |
+
'min_y': mb['min_y'], 'max_y': mb['max_y'],
|
| 651 |
+
'confidence': 0.0, 'source': 'GapFill'})
|
| 652 |
+
|
| 653 |
+
if item.get('density', 0) < ft['density_min_hard'] and 'MASK' not in item.get('type','TEXT'):
|
| 654 |
+
cy = item['max_y']; continue
|
| 655 |
+
|
| 656 |
+
filled_col.append(item)
|
| 657 |
+
cy = item['max_y']
|
| 658 |
+
|
| 659 |
+
filled_col = merge_split_masks(filled_col, median_h)
|
| 660 |
+
filled_col = filter_excessive_masks(filled_col)
|
| 661 |
+
filled_col = resolve_overlaps(filled_col)
|
| 662 |
+
|
| 663 |
+
final_boxes.extend(filled_col)
|
| 664 |
+
lines.append("".join([s['text'] for s in filled_col]))
|
| 665 |
+
|
| 666 |
+
logger.info(f"[ENSEMBLE] ์๋ฃ: {len(final_boxes)}๊ฐ ๋ฐ์ค, {len(lines)}๊ฐ ์ด")
|
| 667 |
+
return final_boxes, lines
|
| 668 |
+
|
| 669 |
+
# ================================================================================
|
| 670 |
+
# OCREngine Class
|
| 671 |
+
# ================================================================================
|
| 672 |
+
class OCREngine:
|
| 673 |
+
def __init__(self, config_path: Optional[str] = None):
|
| 674 |
+
self.config = load_ocr_config(config_path)
|
| 675 |
+
|
| 676 |
+
# Load paths from env
|
| 677 |
+
base_path = os.getenv('OCR_WEIGHTS_BASE_PATH')
|
| 678 |
+
if not base_path:
|
| 679 |
+
raise ValueError("OCR_WEIGHTS_BASE_PATH environment variable is required. Please set it in your .env file.")
|
| 680 |
+
|
| 681 |
+
self.det_ckpt = os.path.join(base_path, os.getenv('OCR_DETECTION_MODEL', 'best.pth'))
|
| 682 |
+
self.rec_ckpt = os.path.join(base_path, os.getenv('OCR_RECOGNITION_MODEL', 'best_5000.pt'))
|
| 683 |
+
self.google_json = os.path.join(base_path, os.getenv('GOOGLE_CREDENTIALS_JSON'))
|
| 684 |
+
|
| 685 |
+
if not self.google_json or not os.path.exists(self.google_json):
|
| 686 |
+
raise ValueError(f"GOOGLE_CREDENTIALS_JSON environment variable is required and file must exist. Please set it in your .env file.")
|
| 687 |
+
|
| 688 |
+
if os.path.exists(self.google_json):
|
| 689 |
+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.google_json
|
| 690 |
+
|
| 691 |
+
# Device
|
| 692 |
+
dev_cfg = self.config['model_config']['device']
|
| 693 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if dev_cfg == 'auto' else torch.device(dev_cfg)
|
| 694 |
+
self.detector = None
|
| 695 |
+
self.recognizer = None
|
| 696 |
+
|
| 697 |
+
def _load_models(self):
|
| 698 |
+
if not self.detector:
|
| 699 |
+
self.detector = TextDetector(self.device, self.det_ckpt, self.config)
|
| 700 |
+
if not self.recognizer:
|
| 701 |
+
self.recognizer = ResnetCustom(weight_fn=self.rec_ckpt)
|
| 702 |
+
self.recognizer.to(self.device)
|
| 703 |
+
|
| 704 |
+
def run_ocr(self, image_path: str) -> Dict:
|
| 705 |
+
try:
|
| 706 |
+
self._load_models()
|
| 707 |
+
|
| 708 |
+
# 1. Preprocessing (Exact Match to v12 Script)
|
| 709 |
+
img_bgr = cv2.imread(image_path)
|
| 710 |
+
if img_bgr is None: raise ValueError(f"Image not found: {image_path}")
|
| 711 |
+
|
| 712 |
+
img_gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
| 713 |
+
img_blur = cv2.medianBlur(img_gray, 3)
|
| 714 |
+
_, img_binary = cv2.threshold(img_blur, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
| 715 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
|
| 716 |
+
img_binary = cv2.morphologyEx(img_binary, cv2.MORPH_CLOSE, kernel)
|
| 717 |
+
|
| 718 |
+
# 2. Google Vision
|
| 719 |
+
with io.open(image_path, 'rb') as f: content = f.read()
|
| 720 |
+
google_syms = get_google_ocr(content, self.config, self.google_json)
|
| 721 |
+
|
| 722 |
+
# 3. Custom Model
|
| 723 |
+
custom_syms = get_custom_model_ocr(image_path, img_binary, self.detector, self.recognizer, self.config)
|
| 724 |
+
|
| 725 |
+
# 4. Ensemble
|
| 726 |
+
final_boxes, result_lines = ensemble_reconstruction(google_syms, custom_syms, img_binary, self.config)
|
| 727 |
+
|
| 728 |
+
# Format results according to specification
|
| 729 |
+
formatted_results = []
|
| 730 |
+
for order, box in enumerate(final_boxes):
|
| 731 |
+
formatted_results.append({
|
| 732 |
+
"order": order,
|
| 733 |
+
"text": box.get('text', ''),
|
| 734 |
+
"type": box.get('type', 'TEXT'),
|
| 735 |
+
"box": [
|
| 736 |
+
float(box.get('min_x', 0)),
|
| 737 |
+
float(box.get('min_y', 0)),
|
| 738 |
+
float(box.get('max_x', 0)),
|
| 739 |
+
float(box.get('max_y', 0))
|
| 740 |
+
],
|
| 741 |
+
"confidence": float(box.get('confidence', 0.0)),
|
| 742 |
+
"source": box.get('source', 'Unknown')
|
| 743 |
+
})
|
| 744 |
+
|
| 745 |
+
# Extract image filename
|
| 746 |
+
image_filename = os.path.basename(image_path)
|
| 747 |
+
|
| 748 |
+
return {
|
| 749 |
+
"image": image_filename,
|
| 750 |
+
"results": formatted_results
|
| 751 |
+
}
|
| 752 |
+
except Exception as e:
|
| 753 |
+
logger.error(f"[OCR] Execution Failed: {e}", exc_info=True)
|
| 754 |
+
return {"success": False, "error": str(e)}
|
| 755 |
+
|
| 756 |
+
# ================================================================================
|
| 757 |
+
# Global Accessor
|
| 758 |
+
# ================================================================================
|
| 759 |
+
_engine = None
|
| 760 |
+
|
| 761 |
+
def get_ocr_engine(config_path: Optional[str] = None) -> OCREngine:
|
| 762 |
+
global _engine
|
| 763 |
+
if _engine is None: _engine = OCREngine(config_path)
|
| 764 |
+
return _engine
|
| 765 |
+
|
| 766 |
+
def ocr_and_detect(image_path: str, config_path: Optional[str] = None, bbox: Optional[Tuple[int, int, int, int]] = None, device: str = "cuda") -> Dict:
|
| 767 |
+
return get_ocr_engine(config_path).run_ocr(image_path)
|
ai_modules/preprocessor_unified.py
ADDED
|
@@ -0,0 +1,605 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Epitext_Back/ai_modules/preprocessor_unified.py
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
================================================================================
|
| 5 |
+
Unified Image Preprocessing Module for Epitext AI Project
|
| 6 |
+
================================================================================
|
| 7 |
+
|
| 8 |
+
๋ชจ๋๋ช
: preprocessor_unified.py (v1.0.0 - Production Ready)
|
| 9 |
+
์์ฑ์ผ: 2025-12-02
|
| 10 |
+
๋ชฉ์ : ํ์ ์ด๋ฏธ์ง๋ฅผ Swin Gray์ OCR์ฉ์ผ๋ก ๋์์ ์ ์ฒ๋ฆฌ
|
| 11 |
+
์ํ: Production Ready
|
| 12 |
+
|
| 13 |
+
ํต์ฌ ๊ธฐ๋ฅ:
|
| 14 |
+
ํ ๋ฒ์ ๋ ๊ฐ์ง ์ ์ฒ๋ฆฌ ์๋ฃ:
|
| 15 |
+
1. Swin Gray: ๊ทธ๋ ์ด ๋น์ด์งํ -> 3์ฑ๋ (์ ๋ณด ์์ค ์ต์)
|
| 16 |
+
2. OCR: ์ด์งํ -> 1์ฑ๋ (๋ช
ํํ ํ๋ฐฑ)
|
| 17 |
+
|
| 18 |
+
์๋ ๋ฐฐ๊ฒฝ ๋ณด์ฅ:
|
| 19 |
+
- Swin: ๋ฐ์๋ฐฐ๊ฒฝ (>=127)
|
| 20 |
+
- OCR: ํฐ๋ฐฐ๊ฒฝ + ๊ฒ์ ๊ธ์ (255/0)
|
| 21 |
+
|
| 22 |
+
ํ๋ณธ ์๋ ๊ฒ์ถ: ํฐ ์ด๋์ด ์์ญ ์๋ณ
|
| 23 |
+
์์ญ ๊ฒ์ถ 1ํ: ํจ์จ์ฑ
|
| 24 |
+
์ค์ ํ์ผ ์ง์: JSON ๊ธฐ๋ฐ ์ปค์คํฐ๋ง์ด์ง
|
| 25 |
+
๋ก๊น
์ง์: DEBUG, INFO, WARNING, ERROR
|
| 26 |
+
|
| 27 |
+
์์กด์ฑ:
|
| 28 |
+
- opencv-python >= 4.8.0
|
| 29 |
+
- numpy >= 1.24.0
|
| 30 |
+
|
| 31 |
+
๋จ์ผ ํจ์:
|
| 32 |
+
preprocess_image_unified(input_path, output_swin_path, output_ocr_path, ...)
|
| 33 |
+
|
| 34 |
+
์ฌ์ฉ ์์:
|
| 35 |
+
>>> from ai_modules.preprocessor_unified import preprocess_image_unified
|
| 36 |
+
>>> result = preprocess_image_unified(
|
| 37 |
+
... "input.jpg",
|
| 38 |
+
... "swin.jpg",
|
| 39 |
+
... "ocr.png"
|
| 40 |
+
... )
|
| 41 |
+
|
| 42 |
+
================================================================================
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
import cv2
|
| 47 |
+
import numpy as np
|
| 48 |
+
from pathlib import Path
|
| 49 |
+
import json
|
| 50 |
+
import logging
|
| 51 |
+
from typing import Dict, Optional, Tuple
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ================================================================================
|
| 55 |
+
# Logging Configuration
|
| 56 |
+
# ================================================================================
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
logging.basicConfig(
|
| 60 |
+
level=logging.INFO,
|
| 61 |
+
format='%(asctime)s - [%(levelname)s] %(message)s'
|
| 62 |
+
)
|
| 63 |
+
logger = logging.getLogger(__name__)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# ================================================================================
|
| 67 |
+
# Constants
|
| 68 |
+
# ================================================================================
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ๊ธฐ๋ณธ ์ค์ ๊ฐ
|
| 72 |
+
DEFAULT_MARGIN = 10
|
| 73 |
+
DEFAULT_BRIGHTNESS_THRESHOLD = 127
|
| 74 |
+
DEFAULT_RUBBING_MIN_AREA_RATIO = 0.1
|
| 75 |
+
DEFAULT_TEXT_MIN_AREA = 16
|
| 76 |
+
DEFAULT_TEXT_AREA_RATIO = 0.00005
|
| 77 |
+
DEFAULT_MORPHOLOGY_KERNEL_SIZE = (2, 2)
|
| 78 |
+
DEFAULT_MORPHOLOGY_CLOSE_ITERATIONS = 3
|
| 79 |
+
DEFAULT_MORPHOLOGY_OPEN_ITERATIONS = 2
|
| 80 |
+
DEFAULT_RUBBING_KERNEL_SIZE = (5, 5)
|
| 81 |
+
DEFAULT_RUBBING_CLOSE_ITERATIONS = 10
|
| 82 |
+
DEFAULT_RUBBING_OPEN_ITERATIONS = 5
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# ================================================================================
|
| 86 |
+
# Main Preprocessing Class
|
| 87 |
+
# ================================================================================
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class UnifiedImagePreprocessor:
|
| 91 |
+
"""
|
| 92 |
+
ํตํฉ ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ ํด๋์ค (Swin + OCR)
|
| 93 |
+
|
| 94 |
+
ํ ๋ฒ์ ์ฒ๋ฆฌ๋ก Swin Gray์ OCR์ฉ ์ด๋ฏธ์ง๋ฅผ ๋ชจ๋ ์์ฑํฉ๋๋ค.
|
| 95 |
+
|
| 96 |
+
Attributes:
|
| 97 |
+
config (dict): ์ ์ฒ๋ฆฌ ์ค์ ํ๋ผ๋ฏธํฐ
|
| 98 |
+
|
| 99 |
+
Example:
|
| 100 |
+
>>> prep = UnifiedImagePreprocessor()
|
| 101 |
+
>>> result = prep.preprocess_unified("input.jpg", "swin.jpg", "ocr.png")
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __init__(self, config_path: Optional[str] = None) -> None:
|
| 105 |
+
"""
|
| 106 |
+
UnifiedImagePreprocessor ์ด๊ธฐํ
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
config_path (str, optional): ์ค์ ํ์ผ ๊ฒฝ๋ก (JSON)
|
| 110 |
+
"""
|
| 111 |
+
self.config = self._load_config(config_path)
|
| 112 |
+
logger.info("[INIT] UnifiedImagePreprocessor v1.0.0 ์ด๊ธฐํ ์๋ฃ")
|
| 113 |
+
|
| 114 |
+
def _load_config(self, config_path: Optional[str]) -> Dict:
|
| 115 |
+
"""์ค์ ํ์ผ ๋ก๋"""
|
| 116 |
+
default_config = {
|
| 117 |
+
"margin": DEFAULT_MARGIN,
|
| 118 |
+
"brightness_threshold": DEFAULT_BRIGHTNESS_THRESHOLD,
|
| 119 |
+
"rubbing_min_area_ratio": DEFAULT_RUBBING_MIN_AREA_RATIO,
|
| 120 |
+
"text_min_area": DEFAULT_TEXT_MIN_AREA,
|
| 121 |
+
"text_area_ratio": DEFAULT_TEXT_AREA_RATIO,
|
| 122 |
+
"morphology_kernel_size": DEFAULT_MORPHOLOGY_KERNEL_SIZE,
|
| 123 |
+
"morphology_close_iterations": DEFAULT_MORPHOLOGY_CLOSE_ITERATIONS,
|
| 124 |
+
"morphology_open_iterations": DEFAULT_MORPHOLOGY_OPEN_ITERATIONS,
|
| 125 |
+
"rubbing_kernel_size": DEFAULT_RUBBING_KERNEL_SIZE,
|
| 126 |
+
"rubbing_close_iterations": DEFAULT_RUBBING_CLOSE_ITERATIONS,
|
| 127 |
+
"rubbing_open_iterations": DEFAULT_RUBBING_OPEN_ITERATIONS,
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
# ๊ธฐ๋ณธ ์ค์ ํ์ผ ๊ฒฝ๋ก (config_path๊ฐ ์์ ๋)
|
| 131 |
+
if config_path is None:
|
| 132 |
+
default_config_path = Path(__file__).parent / "config" / "preprocess_config.json"
|
| 133 |
+
if default_config_path.exists():
|
| 134 |
+
config_path = str(default_config_path)
|
| 135 |
+
|
| 136 |
+
if config_path and Path(config_path).exists():
|
| 137 |
+
try:
|
| 138 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 139 |
+
user_config = json.load(f)
|
| 140 |
+
# _description ํ๋๋ ์ ์ธํ๊ณ ์
๋ฐ์ดํธ
|
| 141 |
+
user_config_clean = {k: v for k, v in user_config.items() if not k.startswith('_')}
|
| 142 |
+
default_config.update(user_config_clean)
|
| 143 |
+
logger.info(f"[CONFIG] ์ค์ ํ์ผ ๋ก๋: {config_path}")
|
| 144 |
+
except Exception as e:
|
| 145 |
+
logger.warning(f"[CONFIG] ์ค์ ํ์ผ ๋ก๋ ์คํจ: {e} - ๊ธฐ๋ณธ ์ค์ ์ฌ์ฉ")
|
| 146 |
+
|
| 147 |
+
return default_config
|
| 148 |
+
|
| 149 |
+
def _find_rubbing_bbox(self, gray_image: np.ndarray) -> Optional[Tuple[int, int, int, int]]:
|
| 150 |
+
"""
|
| 151 |
+
ํ๋ณธ ์์ญ ๊ฒ์ถ (ํฐ ์ด๋์ด ์ฌ๊ฐํ ์ฐพ๊ธฐ)
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
gray_image (np.ndarray): ๊ทธ๋ ์ด์ค์ผ์ผ ์ด๋ฏธ์ง
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
tuple: (x, y, w, h) ๋๋ None
|
| 158 |
+
"""
|
| 159 |
+
H_img, W_img = gray_image.shape
|
| 160 |
+
|
| 161 |
+
# Step 1: ์ด๋์ด ์์ญ ์ถ์ถ
|
| 162 |
+
_, dark_mask = cv2.threshold(gray_image, 127, 255, cv2.THRESH_BINARY_INV)
|
| 163 |
+
|
| 164 |
+
# Step 2: ๋ชจํด๋ก์ง ์ฐ์ฐ
|
| 165 |
+
kernel_rub = np.ones(self.config["rubbing_kernel_size"], np.uint8)
|
| 166 |
+
dark_mask = cv2.morphologyEx(
|
| 167 |
+
dark_mask, cv2.MORPH_CLOSE, kernel_rub,
|
| 168 |
+
iterations=self.config["rubbing_close_iterations"]
|
| 169 |
+
)
|
| 170 |
+
dark_mask = cv2.morphologyEx(
|
| 171 |
+
dark_mask, cv2.MORPH_OPEN, kernel_rub,
|
| 172 |
+
iterations=self.config["rubbing_open_iterations"]
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# Step 3: ์ปจํฌ์ด ๊ฒ์ถ
|
| 176 |
+
contours, _ = cv2.findContours(dark_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 177 |
+
|
| 178 |
+
if not contours:
|
| 179 |
+
return None
|
| 180 |
+
|
| 181 |
+
# Step 4: ๊ฐ์ฅ ํฐ ์ปจํฌ์ด
|
| 182 |
+
largest = max(contours, key=cv2.contourArea)
|
| 183 |
+
area = cv2.contourArea(largest)
|
| 184 |
+
|
| 185 |
+
# Step 5: ๋ฉด์ ๊ฒ์ฆ
|
| 186 |
+
min_area = (H_img * W_img) * self.config["rubbing_min_area_ratio"]
|
| 187 |
+
if area < min_area:
|
| 188 |
+
return None
|
| 189 |
+
|
| 190 |
+
return cv2.boundingRect(largest)
|
| 191 |
+
|
| 192 |
+
def _find_text_bbox(self, gray_image: np.ndarray) -> Tuple[int, int, int, int]:
|
| 193 |
+
"""
|
| 194 |
+
ํ
์คํธ ์์ญ ๊ฒ์ถ
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
gray_image (np.ndarray): ๊ทธ๋ ์ด์ค์ผ์ผ ์ด๋ฏธ์ง
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
tuple: (x, y, w, h)
|
| 201 |
+
"""
|
| 202 |
+
H_img, W_img = gray_image.shape
|
| 203 |
+
|
| 204 |
+
# Step 1: Otsu ์ด์งํ
|
| 205 |
+
_, binary = cv2.threshold(
|
| 206 |
+
gray_image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# Step 2: ๋ชจํด๋ก์ง ์ฐ์ฐ
|
| 210 |
+
kernel_morph = np.ones(self.config["morphology_kernel_size"], np.uint8)
|
| 211 |
+
binary = cv2.morphologyEx(
|
| 212 |
+
binary, cv2.MORPH_CLOSE, kernel_morph,
|
| 213 |
+
iterations=self.config["morphology_close_iterations"]
|
| 214 |
+
)
|
| 215 |
+
binary = cv2.morphologyEx(
|
| 216 |
+
binary, cv2.MORPH_OPEN, kernel_morph,
|
| 217 |
+
iterations=self.config["morphology_open_iterations"]
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
# Step 3: ์ปจํฌ์ด ๊ฒ์ถ
|
| 221 |
+
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 222 |
+
|
| 223 |
+
# Step 4: ์ต์ ๋ฉด์ ์ค์
|
| 224 |
+
min_area = max(
|
| 225 |
+
self.config["text_min_area"],
|
| 226 |
+
int((H_img * W_img) * self.config["text_area_ratio"])
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Step 5: ์ ํจํ ์ปจํฌ์ด ํํฐ๋ง
|
| 230 |
+
valid_contours = [
|
| 231 |
+
cnt for cnt in contours
|
| 232 |
+
if cv2.contourArea(cv2.boundingRect(cnt)) >= min_area
|
| 233 |
+
]
|
| 234 |
+
|
| 235 |
+
# Step 6: ๊ฒฝ๊ณ๋ฐ์ค ๊ณ์ฐ
|
| 236 |
+
if valid_contours:
|
| 237 |
+
all_points = np.vstack(valid_contours)
|
| 238 |
+
return cv2.boundingRect(all_points)
|
| 239 |
+
else:
|
| 240 |
+
return (0, 0, W_img, H_img)
|
| 241 |
+
|
| 242 |
+
def _apply_margin(
|
| 243 |
+
self,
|
| 244 |
+
bbox: Tuple[int, int, int, int],
|
| 245 |
+
gray_image: np.ndarray,
|
| 246 |
+
margin_val: int
|
| 247 |
+
) -> Tuple[int, int, int, int]:
|
| 248 |
+
"""์ฌ๋ฐฑ ์ถ๊ฐ"""
|
| 249 |
+
x, y, w, h = bbox
|
| 250 |
+
H_img, W_img = gray_image.shape
|
| 251 |
+
|
| 252 |
+
x_new = max(0, x - margin_val)
|
| 253 |
+
y_new = max(0, y - margin_val)
|
| 254 |
+
w_new = min(W_img - x_new, w + 2 * margin_val)
|
| 255 |
+
h_new = min(H_img - y_new, h + 2 * margin_val)
|
| 256 |
+
|
| 257 |
+
return (x_new, y_new, w_new, h_new)
|
| 258 |
+
|
| 259 |
+
def _ensure_bright_background(
|
| 260 |
+
self,
|
| 261 |
+
gray_cropped: np.ndarray
|
| 262 |
+
) -> Tuple[np.ndarray, Dict]:
|
| 263 |
+
"""
|
| 264 |
+
๋ฐ์๋ฐฐ๊ฒฝ ๋ณด์ฅ (Swin์ฉ)
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
tuple: (์ฒ๋ฆฌ๋ ๊ทธ๋ ์ด ์ด๋ฏธ์ง, ์ฒ๋ฆฌ ์ ๋ณด)
|
| 268 |
+
"""
|
| 269 |
+
mean_brightness = np.mean(gray_cropped)
|
| 270 |
+
is_inverted = False
|
| 271 |
+
|
| 272 |
+
if mean_brightness < self.config["brightness_threshold"]:
|
| 273 |
+
gray_bright = cv2.bitwise_not(gray_cropped)
|
| 274 |
+
is_inverted = True
|
| 275 |
+
else:
|
| 276 |
+
gray_bright = gray_cropped.copy()
|
| 277 |
+
|
| 278 |
+
# ์ฌํ์ธ
|
| 279 |
+
final_brightness = np.mean(gray_bright)
|
| 280 |
+
if final_brightness < self.config["brightness_threshold"]:
|
| 281 |
+
gray_bright = cv2.bitwise_not(gray_bright)
|
| 282 |
+
is_inverted = not is_inverted
|
| 283 |
+
final_brightness = np.mean(gray_bright)
|
| 284 |
+
|
| 285 |
+
return gray_bright, {
|
| 286 |
+
"mean_brightness_before": float(mean_brightness),
|
| 287 |
+
"mean_brightness_after": float(final_brightness),
|
| 288 |
+
"is_inverted": is_inverted,
|
| 289 |
+
"is_bright_bg": final_brightness >= self.config["brightness_threshold"]
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
def _ensure_white_background(
|
| 293 |
+
self,
|
| 294 |
+
gray_cropped: np.ndarray
|
| 295 |
+
) -> Tuple[np.ndarray, Dict]:
|
| 296 |
+
"""
|
| 297 |
+
ํฐ๋ฐฐ๊ฒฝ ๋ณด์ฅ (OCR์ฉ)
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
tuple: (์ฒ๋ฆฌ๋ ์ด์ง ์ด๋ฏธ์ง, ์ฒ๋ฆฌ ์ ๋ณด)
|
| 301 |
+
"""
|
| 302 |
+
# Step 1: ์ด์งํ
|
| 303 |
+
_, binary = cv2.threshold(
|
| 304 |
+
gray_cropped, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# Step 2: ํด๋ผ๋ฆฌํฐ ํ๋จ
|
| 308 |
+
mean_brightness = np.mean(binary)
|
| 309 |
+
|
| 310 |
+
# Step 3: ํ์์ ๋ฐ์
|
| 311 |
+
if mean_brightness < self.config["brightness_threshold"]:
|
| 312 |
+
binary_final = cv2.bitwise_not(binary)
|
| 313 |
+
polarity = "inverted"
|
| 314 |
+
else:
|
| 315 |
+
binary_final = binary
|
| 316 |
+
polarity = "normal"
|
| 317 |
+
|
| 318 |
+
final_brightness = np.mean(binary_final)
|
| 319 |
+
|
| 320 |
+
return binary_final, {
|
| 321 |
+
"mean_brightness_before": float(mean_brightness),
|
| 322 |
+
"mean_brightness_after": float(final_brightness),
|
| 323 |
+
"polarity": polarity,
|
| 324 |
+
"is_white_bg": final_brightness > self.config["brightness_threshold"]
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
def preprocess_unified(
|
| 328 |
+
self,
|
| 329 |
+
input_image_path: str,
|
| 330 |
+
output_swin_path: str,
|
| 331 |
+
output_ocr_path: str,
|
| 332 |
+
margin: Optional[int] = None,
|
| 333 |
+
use_rubbing: bool = False
|
| 334 |
+
) -> Dict:
|
| 335 |
+
"""
|
| 336 |
+
ํตํฉ ์ ์ฒ๋ฆฌ (Swin Gray + OCR ๋์ ์์ฑ)
|
| 337 |
+
|
| 338 |
+
ํ ๋ฒ์ ํจ์ ํธ์ถ๋ก Swin Gray์ OCR์ฉ ์ด๋ฏธ์ง๋ฅผ ๋ชจ๋ ์์ฑํฉ๋๋ค.
|
| 339 |
+
ํ๋ณธ ๋ฐ ํ
์คํธ ์์ญ ๊ฒ์ถ์ 1ํ๋ง ์ํ๋์ด ํจ์จ์ฑ์ ๋ณด์ฅํฉ๋๋ค.
|
| 340 |
+
|
| 341 |
+
Args:
|
| 342 |
+
input_image_path (str): ์
๋ ฅ ์ด๋ฏธ์ง ๊ฒฝ๋ก
|
| 343 |
+
output_swin_path (str): Swin Gray ์ถ๋ ฅ ๊ฒฝ๋ก (JPG)
|
| 344 |
+
output_ocr_path (str): OCR ์ถ๋ ฅ ๊ฒฝ๋ก (PNG)
|
| 345 |
+
margin (int, optional): ํฌ๋กญ ์ฌ๋ฐฑ (ํฝ์
)
|
| 346 |
+
use_rubbing (bool): ํ๋ณธ ๊ฒ์ถ ์ฌ๋ถ (๊ธฐ๋ณธ: False)
|
| 347 |
+
|
| 348 |
+
Returns:
|
| 349 |
+
dict: ์ฒ๋ฆฌ ๊ฒฐ๊ณผ
|
| 350 |
+
์ฑ๊ณต ์: {
|
| 351 |
+
"success": True,
|
| 352 |
+
"original_shape": (H, W, C),
|
| 353 |
+
"bbox": (x, y, w, h),
|
| 354 |
+
"region_type": "text" or "rubbing",
|
| 355 |
+
"region_detected": bool,
|
| 356 |
+
|
| 357 |
+
"swin": {
|
| 358 |
+
"output_path": str,
|
| 359 |
+
"output_shape": (H, W, 3),
|
| 360 |
+
"is_bright_bg": bool,
|
| 361 |
+
...
|
| 362 |
+
},
|
| 363 |
+
|
| 364 |
+
"ocr": {
|
| 365 |
+
"output_path": str,
|
| 366 |
+
"output_shape": (H, W),
|
| 367 |
+
"is_white_bg": bool,
|
| 368 |
+
...
|
| 369 |
+
}
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
์คํจ ์: {
|
| 373 |
+
"success": False,
|
| 374 |
+
"message": str
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
Processing Steps:
|
| 378 |
+
1. ์ด๋ฏธ์ง ๋ก๋
|
| 379 |
+
2. ๊ทธ๋ ์ด์ค์ผ์ผ ๋ณํ
|
| 380 |
+
3. ์์ญ ๊ฒ์ถ (ํ๋ณธ ๋๋ ํ
์คํธ, 1ํ๋ง)
|
| 381 |
+
4. ํฌ๋กญ + ์ฌ๋ฐฑ
|
| 382 |
+
5. Swin Gray ์ฒ๋ฆฌ (๋ฐ์๋ฐฐ๊ฒฝ ๋ณด์ฅ)
|
| 383 |
+
6. OCR ์ฒ๋ฆฌ (์ด์งํ + ํฐ๋ฐฐ๊ฒฝ ๋ณด์ฅ)
|
| 384 |
+
7. ๋์ ์ ์ฅ
|
| 385 |
+
|
| 386 |
+
Output:
|
| 387 |
+
- Swin: JPG 3์ฑ๋ (๋น์ด์งํ 256๋จ๊ณ)
|
| 388 |
+
- OCR: PNG 1์ฑ๋ (์ด์งํ)
|
| 389 |
+
|
| 390 |
+
Example:
|
| 391 |
+
>>> prep = UnifiedImagePreprocessor()
|
| 392 |
+
>>> result = prep.preprocess_unified(
|
| 393 |
+
... "input.jpg",
|
| 394 |
+
... "swin.jpg",
|
| 395 |
+
... "ocr.png"
|
| 396 |
+
... )
|
| 397 |
+
>>> if result["success"]:
|
| 398 |
+
... swin_output = result["swin"]["output_path"]
|
| 399 |
+
... ocr_output = result["ocr"]["output_path"]
|
| 400 |
+
"""
|
| 401 |
+
margin_val = margin or self.config["margin"]
|
| 402 |
+
|
| 403 |
+
try:
|
| 404 |
+
# ====================================================================
|
| 405 |
+
# Step 1: ์ด๋ฏธ์ง ๋ก๋
|
| 406 |
+
# ====================================================================
|
| 407 |
+
img_bgr = cv2.imread(str(input_image_path), cv2.IMREAD_COLOR)
|
| 408 |
+
if img_bgr is None:
|
| 409 |
+
raise ValueError(f"์ด๋ฏธ์ง ๋ก๋ ์คํจ: {input_image_path}")
|
| 410 |
+
|
| 411 |
+
original_shape = img_bgr.shape
|
| 412 |
+
logger.info(f"[LOAD] ์ด๋ฏธ์ง ๋ก๋: {input_image_path} {original_shape}")
|
| 413 |
+
|
| 414 |
+
# ====================================================================
|
| 415 |
+
# Step 2: ๊ทธ๋ ์ด์ค์ผ์ผ ๋ณํ
|
| 416 |
+
# ====================================================================
|
| 417 |
+
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
| 418 |
+
|
| 419 |
+
# ====================================================================
|
| 420 |
+
# Step 3: ์์ญ ๊ฒ์ถ (ํ๋ณธ ๋๋ ํ
์คํธ)
|
| 421 |
+
# ====================================================================
|
| 422 |
+
if use_rubbing:
|
| 423 |
+
detected_bbox = self._find_rubbing_bbox(gray)
|
| 424 |
+
region_type = "rubbing"
|
| 425 |
+
logger.info("[DETECT] ํ๋ณธ ์์ญ ๊ฒ์ถ ๋ชจ๋")
|
| 426 |
+
else:
|
| 427 |
+
detected_bbox = None
|
| 428 |
+
region_type = "text"
|
| 429 |
+
logger.info("[DETECT] ํ
์คํธ ์์ญ ๊ฒ์ถ ๋ชจ๋")
|
| 430 |
+
|
| 431 |
+
H_img, W_img = gray.shape
|
| 432 |
+
|
| 433 |
+
# ====================================================================
|
| 434 |
+
# Step 4: ํฌ๋กญ + ์ฌ๋ฐฑ
|
| 435 |
+
# ====================================================================
|
| 436 |
+
if detected_bbox is not None:
|
| 437 |
+
bbox_final = self._apply_margin(detected_bbox, gray, margin_val)
|
| 438 |
+
logger.info(f"[DETECT] {region_type} ์์ญ ๊ฒ์ถ: {bbox_final}")
|
| 439 |
+
else:
|
| 440 |
+
# ํ๋ณธ ๋ฏธ๊ฒ์ถ ๋๋ ํ
์คํธ ๋ชจ๋ -> ํ
์คํธ ๊ฒ์ถ
|
| 441 |
+
if use_rubbing:
|
| 442 |
+
bbox_final = (0, 0, W_img, H_img)
|
| 443 |
+
logger.warning("[DETECT] ํ๋ณธ ๋ฏธ๊ฒ์ถ - ์ ์ฒด ์ด๋ฏธ์ง ์ฌ์ฉ")
|
| 444 |
+
else:
|
| 445 |
+
bbox_text = self._find_text_bbox(gray)
|
| 446 |
+
bbox_final = self._apply_margin(bbox_text, gray, margin_val)
|
| 447 |
+
logger.info(f"[DETECT] ํ
์คํธ ์์ญ ๊ฒ์ถ: {bbox_final}")
|
| 448 |
+
|
| 449 |
+
x, y, w, h = bbox_final
|
| 450 |
+
gray_cropped = gray[y:y+h, x:x+w]
|
| 451 |
+
|
| 452 |
+
logger.info(f"[CROP] ํฌ๋กญ ์๋ฃ: {gray_cropped.shape}")
|
| 453 |
+
|
| 454 |
+
# ====================================================================
|
| 455 |
+
# Step 5: Swin Gray ์ฒ๋ฆฌ
|
| 456 |
+
# ====================================================================
|
| 457 |
+
gray_bright, info_swin = self._ensure_bright_background(gray_cropped)
|
| 458 |
+
swin_output_3ch = cv2.cvtColor(gray_bright, cv2.COLOR_GRAY2BGR)
|
| 459 |
+
|
| 460 |
+
# ====================================================================
|
| 461 |
+
# Step 6: OCR ์ฒ๋ฆฌ
|
| 462 |
+
# ====================================================================
|
| 463 |
+
binary_final, info_ocr = self._ensure_white_background(gray_cropped)
|
| 464 |
+
|
| 465 |
+
# ====================================================================
|
| 466 |
+
# Step 7: ๋์ ์ ์ฅ
|
| 467 |
+
# ====================================================================
|
| 468 |
+
output_swin_path_obj = Path(output_swin_path)
|
| 469 |
+
output_swin_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
| 470 |
+
swin_success = cv2.imwrite(str(output_swin_path_obj), swin_output_3ch)
|
| 471 |
+
|
| 472 |
+
output_ocr_path_obj = Path(output_ocr_path)
|
| 473 |
+
output_ocr_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
| 474 |
+
ocr_success = cv2.imwrite(str(output_ocr_path_obj), binary_final)
|
| 475 |
+
|
| 476 |
+
if not swin_success or not ocr_success:
|
| 477 |
+
raise ValueError("์ด๋ฏธ์ง ์ ์ฅ ์คํจ")
|
| 478 |
+
|
| 479 |
+
logger.info(f"[SAVE] Swin ์ ์ฅ: {output_swin_path_obj}")
|
| 480 |
+
logger.info(f"[SAVE] OCR ์ ์ฅ: {output_ocr_path_obj}")
|
| 481 |
+
|
| 482 |
+
# ====================================================================
|
| 483 |
+
# ๊ฒฐ๊ณผ ๋ฐํ
|
| 484 |
+
# ====================================================================
|
| 485 |
+
return {
|
| 486 |
+
"success": True,
|
| 487 |
+
"version": "Unified Swin Gray + OCR (v1.0.0)",
|
| 488 |
+
"original_shape": original_shape,
|
| 489 |
+
"bbox": bbox_final,
|
| 490 |
+
"region_type": region_type,
|
| 491 |
+
"region_detected": detected_bbox is not None,
|
| 492 |
+
|
| 493 |
+
# Swin ๋ถ๋ถ
|
| 494 |
+
"swin": {
|
| 495 |
+
"output_path": str(output_swin_path_obj).replace("\\", "/"),
|
| 496 |
+
"output_shape": swin_output_3ch.shape,
|
| 497 |
+
"color_type": "Grayscale 3์ฑ๋ (B=G=R, ๋น์ด์งํ 256๋จ๊ณ)",
|
| 498 |
+
"is_inverted": info_swin["is_inverted"],
|
| 499 |
+
"mean_brightness_before": info_swin["mean_brightness_before"],
|
| 500 |
+
"mean_brightness_after": info_swin["mean_brightness_after"],
|
| 501 |
+
"is_bright_bg": info_swin["is_bright_bg"]
|
| 502 |
+
},
|
| 503 |
+
|
| 504 |
+
# OCR ๋ถ๋ถ
|
| 505 |
+
"ocr": {
|
| 506 |
+
"output_path": str(output_ocr_path_obj).replace("\\", "/"),
|
| 507 |
+
"output_shape": binary_final.shape,
|
| 508 |
+
"polarity": info_ocr["polarity"],
|
| 509 |
+
"mean_brightness_before": info_ocr["mean_brightness_before"],
|
| 510 |
+
"mean_brightness_after": info_ocr["mean_brightness_after"],
|
| 511 |
+
"is_white_bg": info_ocr["is_white_bg"]
|
| 512 |
+
},
|
| 513 |
+
|
| 514 |
+
"message": "[DONE] ํตํฉ ์ ์ฒ๋ฆฌ ์๋ฃ (Swin + OCR)"
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
except Exception as e:
|
| 518 |
+
logger.error(f"[ERROR] ํตํฉ ์ ์ฒ๋ฆฌ ์คํจ: {e}")
|
| 519 |
+
return {
|
| 520 |
+
"success": False,
|
| 521 |
+
"message": str(e)
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
# ================================================================================
|
| 526 |
+
# Global Instance & Convenience Functions
|
| 527 |
+
# ================================================================================
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
_global_preprocessor = None
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def get_preprocessor(config_path: Optional[str] = None) -> UnifiedImagePreprocessor:
|
| 534 |
+
"""์ ์ญ ์ ์ฒ๋ฆฌ๊ธฐ ์ธ์คํด์ค ๋ฐํ"""
|
| 535 |
+
global _global_preprocessor
|
| 536 |
+
if _global_preprocessor is None:
|
| 537 |
+
_global_preprocessor = UnifiedImagePreprocessor(config_path)
|
| 538 |
+
return _global_preprocessor
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
def preprocess_image_unified(
|
| 542 |
+
input_path: str,
|
| 543 |
+
output_swin_path: str,
|
| 544 |
+
output_ocr_path: str,
|
| 545 |
+
margin: Optional[int] = None,
|
| 546 |
+
use_rubbing: bool = False
|
| 547 |
+
) -> Dict:
|
| 548 |
+
"""
|
| 549 |
+
ํธ์ ํจ์: ํตํฉ ์ ์ฒ๋ฆฌ
|
| 550 |
+
|
| 551 |
+
Args:
|
| 552 |
+
input_path (str): ์
๋ ฅ ์ด๋ฏธ์ง ๊ฒฝ๋ก
|
| 553 |
+
output_swin_path (str): Swin ์ถ๋ ฅ ๊ฒฝ๋ก
|
| 554 |
+
output_ocr_path (str): OCR ์ถ๋ ฅ ๊ฒฝ๋ก
|
| 555 |
+
margin (int, optional): ์ฌ๋ฐฑ
|
| 556 |
+
use_rubbing (bool): ํ๋ณธ ๋ชจ๋
|
| 557 |
+
|
| 558 |
+
Returns:
|
| 559 |
+
dict: ์ฒ๋ฆฌ ๊ฒฐ๊ณผ
|
| 560 |
+
"""
|
| 561 |
+
prep = get_preprocessor()
|
| 562 |
+
return prep.preprocess_unified(
|
| 563 |
+
input_path,
|
| 564 |
+
output_swin_path,
|
| 565 |
+
output_ocr_path,
|
| 566 |
+
margin,
|
| 567 |
+
use_rubbing
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
# ================================================================================
|
| 572 |
+
# Usage Example
|
| 573 |
+
# ================================================================================
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
if __name__ == "__main__":
|
| 577 |
+
"""
|
| 578 |
+
ํ
์คํธ ์์
|
| 579 |
+
"""
|
| 580 |
+
logger.info("=" * 80)
|
| 581 |
+
logger.info("[TEST] Unified Image Preprocessor v1.0.0 - ํ
์คํธ ์์")
|
| 582 |
+
logger.info("=" * 80)
|
| 583 |
+
|
| 584 |
+
try:
|
| 585 |
+
prep = UnifiedImagePreprocessor()
|
| 586 |
+
|
| 587 |
+
result = prep.preprocess_unified(
|
| 588 |
+
"test_input.jpg",
|
| 589 |
+
"test_swin.jpg",
|
| 590 |
+
"test_ocr.png"
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
if result["success"]:
|
| 594 |
+
logger.info("[TEST] ํตํฉ ์ ์ฒ๋ฆฌ ์ฑ๊ณต!")
|
| 595 |
+
logger.info(f"[TEST] Swin: {result['swin']['output_path']}")
|
| 596 |
+
logger.info(f"[TEST] OCR: {result['ocr']['output_path']}")
|
| 597 |
+
logger.info(f"[TEST] Swin ๋ฐ์๋ฐฐ๊ฒฝ: {'Yes' if result['swin']['is_bright_bg'] else 'No'}")
|
| 598 |
+
logger.info(f"[TEST] OCR ํฐ๋ฐฐ๊ฒฝ: {'Yes' if result['ocr']['is_white_bg'] else 'No'}")
|
| 599 |
+
else:
|
| 600 |
+
logger.error(f"[TEST] ์คํจ: {result['message']}")
|
| 601 |
+
|
| 602 |
+
except Exception as e:
|
| 603 |
+
logger.error(f"[TEST] ์์ธ: {e}")
|
| 604 |
+
|
| 605 |
+
logger.info("=" * 80)
|
dong_ocr.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
๋
๋ฆฝ ์คํ ๊ฐ๋ฅํ OCR ์คํฌ๋ฆฝํธ
|
| 4 |
+
Google Vision API + HRCenterNet ์์๋ธ ๊ธฐ๋ฐ ํ์ OCR ๋ฐ ์์ ์์ญ ํ์ง
|
| 5 |
+
|
| 6 |
+
์์ ์ฌํญ:
|
| 7 |
+
1. ์ขํ(X๊ฐ) ๋ณํ๋ฅผ ๊ฐ์งํ์ฌ ์๋์ผ๋ก ์ด(Column)์ ๊ตฌ๋ถํ์ฌ ์ถ๋ ฅํ๋ ๋ก์ง ์ถ๊ฐ
|
| 8 |
+
2. [MASK] ์ขํ ๋ฑ ์์์ ์์ญ ์์ค ๋ฐฉ์ง๋ฅผ ์ํ Safe Crop(๋ด๋ฆผ/์ฌ๋ฆผ) ์ ์ฉ
|
| 9 |
+
-> ์๊ฐํ ๋ฟ๋ง ์๋๋ผ JSON ๊ฒฐ๊ณผ ๋ฐ์ดํฐ ์์ฒด์๋ ์ ์ฉํ์ฌ ์์์ ์ ๊ฑฐ
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import json
|
| 15 |
+
import logging
|
| 16 |
+
import cv2
|
| 17 |
+
import math
|
| 18 |
+
import numpy as np
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from dotenv import load_dotenv
|
| 21 |
+
|
| 22 |
+
# ํ์ฌ ์คํฌ๋ฆฝํธ์ ๋๋ ํ ๋ฆฌ๋ฅผ Python ๊ฒฝ๋ก์ ์ถ๊ฐ
|
| 23 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 24 |
+
if current_dir not in sys.path:
|
| 25 |
+
sys.path.insert(0, current_dir)
|
| 26 |
+
|
| 27 |
+
# ํ๊ฒฝ ๋ณ์ ๋ก๋
|
| 28 |
+
load_dotenv()
|
| 29 |
+
|
| 30 |
+
# ๋ก๊น
์ค์
|
| 31 |
+
logging.basicConfig(
|
| 32 |
+
level=logging.INFO,
|
| 33 |
+
format='%(asctime)s - [%(levelname)s] %(message)s'
|
| 34 |
+
)
|
| 35 |
+
logger = logging.getLogger("DONG_OCR")
|
| 36 |
+
|
| 37 |
+
# OCR ์์ง ๋ฐ ์ ์ฒ๋ฆฌ ๋ชจ๋ import
|
| 38 |
+
try:
|
| 39 |
+
from ai_modules.ocr_engine import get_ocr_engine
|
| 40 |
+
from ai_modules.preprocessor_unified import preprocess_image_unified
|
| 41 |
+
except ImportError as e:
|
| 42 |
+
logger.error(f"โ ๋ชจ๋ import ์คํจ: {e}")
|
| 43 |
+
sys.exit(1)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def format_ocr_results(raw_results, image_filename):
|
| 47 |
+
"""
|
| 48 |
+
OCR ๊ฒฐ๊ณผ๋ฅผ ์์ฒญํ์ JSON ํฌ๋งท์ผ๋ก ๋ณํํ๋ ํจ์
|
| 49 |
+
์์ : JSON ์ ์ฅ ์์๋ Safe Crop(๋ด๋ฆผ/์ฌ๋ฆผ)์ ์ ์ฉํ์ฌ ์ ์๋ก ๋ณํ
|
| 50 |
+
"""
|
| 51 |
+
formatted_list = []
|
| 52 |
+
|
| 53 |
+
if raw_results is None:
|
| 54 |
+
raw_results = []
|
| 55 |
+
|
| 56 |
+
if not raw_results:
|
| 57 |
+
return {"image": image_filename, "results": []}
|
| 58 |
+
|
| 59 |
+
order_counter = 0
|
| 60 |
+
for idx, item in enumerate(raw_results):
|
| 61 |
+
if not isinstance(item, dict): continue
|
| 62 |
+
|
| 63 |
+
min_x, min_y, max_x, max_y = 0.0, 0.0, 0.0, 0.0
|
| 64 |
+
|
| 65 |
+
# 1. ์ด๋ฏธ 'box' ๋ฆฌ์คํธ๊ฐ ์๋ ๊ฒฝ์ฐ
|
| 66 |
+
if 'box' in item and isinstance(item['box'], list) and len(item['box']) == 4:
|
| 67 |
+
try:
|
| 68 |
+
min_x, min_y, max_x, max_y = map(float, item['box'])
|
| 69 |
+
except: pass
|
| 70 |
+
|
| 71 |
+
# 2. 'box'๊ฐ ์์ผ๋ฉด ๊ฐ๋ณ ์ขํ ํค ์ฌ์ฉ
|
| 72 |
+
if min_x == 0 and max_x == 0:
|
| 73 |
+
mx = item.get('min_x')
|
| 74 |
+
my = item.get('min_y')
|
| 75 |
+
Mx = item.get('max_x')
|
| 76 |
+
My = item.get('max_y')
|
| 77 |
+
|
| 78 |
+
if mx is None: mx = item.get('x', 0)
|
| 79 |
+
if my is None: my = item.get('y', 0)
|
| 80 |
+
if Mx is None:
|
| 81 |
+
Mx = item.get('x2')
|
| 82 |
+
if Mx is None:
|
| 83 |
+
width = item.get('width', 0)
|
| 84 |
+
Mx = mx + width if width > 0 else 0
|
| 85 |
+
if My is None:
|
| 86 |
+
My = item.get('y2')
|
| 87 |
+
if My is None:
|
| 88 |
+
height = item.get('height', 0)
|
| 89 |
+
My = my + height if height > 0 else 0
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
min_x, min_y, max_x, max_y = float(mx), float(my), float(Mx), float(My)
|
| 93 |
+
except: continue
|
| 94 |
+
|
| 95 |
+
if min_x == 0 and min_y == 0 and max_x == 0 and max_y == 0:
|
| 96 |
+
width = item.get('width', 0)
|
| 97 |
+
height = item.get('height', 0)
|
| 98 |
+
if width > 0 and height > 0:
|
| 99 |
+
cx, cy = item.get('center_x', width/2), item.get('center_y', height/2)
|
| 100 |
+
min_x, min_y = cx - width/2, cy - height/2
|
| 101 |
+
max_x, max_y = cx + width/2, cy + height/2
|
| 102 |
+
else: continue
|
| 103 |
+
|
| 104 |
+
if max_x <= min_x or max_y <= min_y: continue
|
| 105 |
+
|
| 106 |
+
# === [์ถ๊ฐ๋จ] JSON ๋ฐ์ดํฐ ์์ฒด์ Safe Crop ์ ์ฉ (์์์ ์ ๊ฑฐ) ===
|
| 107 |
+
# min ์ขํ๋ ๋ด๋ฆผ(floor), max ์ขํ๋ ์ฌ๋ฆผ(ceil)ํ์ฌ ์์ญ ํ๋ณด ํ ์ ์ ๋ณํ
|
| 108 |
+
min_x = int(math.floor(min_x))
|
| 109 |
+
min_y = int(math.floor(min_y))
|
| 110 |
+
max_x = int(math.ceil(max_x))
|
| 111 |
+
max_y = int(math.ceil(max_y))
|
| 112 |
+
|
| 113 |
+
# ์์ ์ขํ ๋ฐฉ์ง (์ต์ 0)
|
| 114 |
+
min_x = max(0, min_x)
|
| 115 |
+
min_y = max(0, min_y)
|
| 116 |
+
# ==========================================================
|
| 117 |
+
|
| 118 |
+
new_item = {
|
| 119 |
+
"order": order_counter,
|
| 120 |
+
"text": item.get('text', ''),
|
| 121 |
+
"type": item.get('type', 'TEXT'),
|
| 122 |
+
"box": [min_x, min_y, max_x, max_y],
|
| 123 |
+
"confidence": float(item.get('confidence', 0.0)),
|
| 124 |
+
"source": item.get('source', 'Unknown')
|
| 125 |
+
}
|
| 126 |
+
formatted_list.append(new_item)
|
| 127 |
+
order_counter += 1
|
| 128 |
+
|
| 129 |
+
return {"image": image_filename, "results": formatted_list}
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def draw_bboxes(image_path, results, output_path):
|
| 133 |
+
"""์ด๋ฏธ์ง์ Bounding Box ๊ทธ๋ฆฌ๊ธฐ (Safe Crop ์ ์ฉ)"""
|
| 134 |
+
try:
|
| 135 |
+
img_array = np.fromfile(image_path, np.uint8)
|
| 136 |
+
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
|
| 137 |
+
if img is None:
|
| 138 |
+
img = cv2.imread(image_path)
|
| 139 |
+
if img is None: return
|
| 140 |
+
|
| 141 |
+
box_count = 0
|
| 142 |
+
colors = {
|
| 143 |
+
'Google': (0, 255, 0), 'Custom': (255, 0, 255),
|
| 144 |
+
'MASK1': (255, 0, 0), 'MASK2': (0, 0, 255), 'Default': (0, 255, 255)
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
for item in results:
|
| 148 |
+
box = item.get('box', [])
|
| 149 |
+
if len(box) != 4: continue
|
| 150 |
+
try:
|
| 151 |
+
# format_ocr_results์์ ์ด๋ฏธ ์ ์๋ก ๋ณํ๋์ด ์ค์ง๋ง,
|
| 152 |
+
# ์์ ์ ์ํด ํ ๋ฒ ๋ ์ฒ๋ฆฌ (float๋ก ๋ค์ด์๋ ์ฒ๋ฆฌ ๊ฐ๋ฅํ๋๋ก ์ ์ง)
|
| 153 |
+
x1 = int(math.floor(float(box[0])))
|
| 154 |
+
y1 = int(math.floor(float(box[1])))
|
| 155 |
+
x2 = int(math.ceil(float(box[2])))
|
| 156 |
+
y2 = int(math.ceil(float(box[3])))
|
| 157 |
+
except: continue
|
| 158 |
+
|
| 159 |
+
h, w = img.shape[:2]
|
| 160 |
+
# ์ด๋ฏธ์ง ๋ฒ์ ๋ฒ์ด๋์ง ์๊ฒ ํด๋ฆฌํ
|
| 161 |
+
x1 = max(0, min(x1, w-1))
|
| 162 |
+
y1 = max(0, min(y1, h-1))
|
| 163 |
+
x2 = max(x1+1, min(x2, w))
|
| 164 |
+
y2 = max(y1+1, min(y2, h))
|
| 165 |
+
|
| 166 |
+
text = item.get('text', '')
|
| 167 |
+
source = item.get('source', '')
|
| 168 |
+
itype = item.get('type', 'TEXT')
|
| 169 |
+
|
| 170 |
+
if 'MASK1' in itype or '[MASK1]' in text: color = colors['MASK1']
|
| 171 |
+
elif 'MASK2' in itype or '[MASK2]' in text: color = colors['MASK2']
|
| 172 |
+
elif source in colors: color = colors[source]
|
| 173 |
+
else: color = colors['Default']
|
| 174 |
+
|
| 175 |
+
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
|
| 176 |
+
|
| 177 |
+
if itype == 'TEXT' and len(text) <= 2:
|
| 178 |
+
cv2.putText(img, text, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
| 179 |
+
elif 'MASK' in itype:
|
| 180 |
+
label = '[M1]' if itype == 'MASK1' else '[M2]'
|
| 181 |
+
cv2.putText(img, label, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1)
|
| 182 |
+
|
| 183 |
+
box_count += 1
|
| 184 |
+
|
| 185 |
+
ext = os.path.splitext(output_path)[1].lower()
|
| 186 |
+
params = [int(cv2.IMWRITE_JPEG_QUALITY), 95] if ext in ['.jpg', '.jpeg'] else [int(cv2.IMWRITE_PNG_COMPRESSION), 3]
|
| 187 |
+
result, encoded_img = cv2.imencode(ext, img, params)
|
| 188 |
+
if result:
|
| 189 |
+
with open(output_path, mode='wb') as f: encoded_img.tofile(f)
|
| 190 |
+
logger.info(f"๐ผ๏ธ B-Box ์ด๋ฏธ์ง ์ ์ฅ๋จ: {output_path} ({box_count}๊ฐ ๋ฐ์ค)")
|
| 191 |
+
|
| 192 |
+
except Exception as e:
|
| 193 |
+
logger.error(f"โ ์๊ฐํ ์ค ์ค๋ฅ: {e}")
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def run_ocr(image_path, use_preprocessing=True):
|
| 197 |
+
"""OCR ์คํ, ๊ฒฐ๊ณผ ์ถ๋ ฅ ๋ฐ ์ ์ฅ"""
|
| 198 |
+
if not os.path.exists(image_path):
|
| 199 |
+
logger.error(f"โ ์ด๋ฏธ์ง ์์: {image_path}")
|
| 200 |
+
return False
|
| 201 |
+
|
| 202 |
+
logger.info(f"๐ OCR ๋ถ์ ์์: {image_path}")
|
| 203 |
+
|
| 204 |
+
try:
|
| 205 |
+
# 1. ์ ์ฒ๋ฆฌ
|
| 206 |
+
ocr_image_path = image_path
|
| 207 |
+
preprocess_result = {'success': False}
|
| 208 |
+
if use_preprocessing:
|
| 209 |
+
logger.info("๐ธ ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ ์ค...")
|
| 210 |
+
base_dir = os.path.dirname(os.path.abspath(image_path))
|
| 211 |
+
base_name = os.path.splitext(os.path.basename(image_path))[0]
|
| 212 |
+
swin_path = os.path.join(base_dir, f"{base_name}_swin_temp.jpg")
|
| 213 |
+
ocr_preprocessed_path = os.path.join(base_dir, f"{base_name}_ocr_temp.png")
|
| 214 |
+
|
| 215 |
+
preprocess_result = preprocess_image_unified(
|
| 216 |
+
input_path=image_path, output_swin_path=swin_path,
|
| 217 |
+
output_ocr_path=ocr_preprocessed_path, use_rubbing=True
|
| 218 |
+
)
|
| 219 |
+
if preprocess_result.get('success'):
|
| 220 |
+
ocr_image_path = ocr_preprocessed_path
|
| 221 |
+
logger.info(f"โ
์ ์ฒ๋ฆฌ ์๋ฃ: {ocr_preprocessed_path}")
|
| 222 |
+
else:
|
| 223 |
+
logger.warning(f"โ ๏ธ ์ ์ฒ๋ฆฌ ์คํจ: {preprocess_result.get('message')}")
|
| 224 |
+
|
| 225 |
+
# 2. ์์ง ์คํ
|
| 226 |
+
engine = get_ocr_engine()
|
| 227 |
+
logger.info("โ
OCR ์์ง ๋ก๋ ์๋ฃ")
|
| 228 |
+
|
| 229 |
+
try:
|
| 230 |
+
raw_result = engine.run_ocr(ocr_image_path)
|
| 231 |
+
except Exception as e:
|
| 232 |
+
logger.error(f"โ OCR ์คํ ์์ธ: {e}")
|
| 233 |
+
return False
|
| 234 |
+
|
| 235 |
+
if not raw_result: return False
|
| 236 |
+
|
| 237 |
+
is_success = raw_result.get('success', False)
|
| 238 |
+
if not is_success and 'results' in raw_result and isinstance(raw_result['results'], list):
|
| 239 |
+
is_success = True
|
| 240 |
+
|
| 241 |
+
if not is_success:
|
| 242 |
+
logger.error(f"โ OCR ์คํจ: {raw_result.get('error')}")
|
| 243 |
+
return False
|
| 244 |
+
|
| 245 |
+
logger.info("\n" + "="*60)
|
| 246 |
+
logger.info("โ
OCR ๋ถ์ ์๋ฃ")
|
| 247 |
+
|
| 248 |
+
# 3. ๋ฐ์ดํฐ ํฌ๋งทํ
|
| 249 |
+
formatted_result = format_ocr_results(raw_result.get('results', []), os.path.basename(image_path))
|
| 250 |
+
results_list = formatted_result.get('results', [])
|
| 251 |
+
|
| 252 |
+
# 4. [์ด ๊ตฌ๋ถ ์ถ๋ ฅ ๋ก์ง] ์ขํ ๊ธฐ๋ฐ์ผ๋ก ์ด์ ๊ณ์ฐํ์ฌ ์ถ๋ ฅ
|
| 253 |
+
logger.info("\n" + "๐ [ ์ธ์๋ ํ
์คํธ ๊ฒฐ๊ณผ (์๋ ์ด ๊ตฌ๋ถ) ] " + "-"*25)
|
| 254 |
+
|
| 255 |
+
if not results_list:
|
| 256 |
+
logger.info(" (๊ฒฐ๊ณผ ์์)")
|
| 257 |
+
else:
|
| 258 |
+
columns = []
|
| 259 |
+
current_col_text = []
|
| 260 |
+
|
| 261 |
+
# ์ฒซ ๋ฒ์งธ ๊ธ์์ X ์ค์ฌ์ ๊ณ์ฐ
|
| 262 |
+
first_box = results_list[0]['box']
|
| 263 |
+
prev_cx = (first_box[0] + first_box[2]) / 2
|
| 264 |
+
|
| 265 |
+
for item in results_list:
|
| 266 |
+
box = item['box']
|
| 267 |
+
curr_cx = (box[0] + box[2]) / 2
|
| 268 |
+
|
| 269 |
+
# ํ
์คํธ ์ถ์ถ (MASK ์ฒ๋ฆฌ)
|
| 270 |
+
text = item.get('text', '')
|
| 271 |
+
if item.get('type') in ['MASK1', 'MASK2']:
|
| 272 |
+
text = f"[{item.get('type')}]"
|
| 273 |
+
|
| 274 |
+
# === ์ด ๊ตฌ๋ถ ํต์ฌ ๋ก์ง ===
|
| 275 |
+
# ์ด์ ๊ธ์์ X์ขํ ์ค์ฌ์ด 50ํฝ์
์ด์ ์ฐจ์ด๋๋ฉด ์๋ก์ด ์ด๋ก ๊ฐ์ฃผ
|
| 276 |
+
# (์ผ๋ฐ์ ์ผ๋ก ์ธ๋ก์ฐ๊ธฐ์์ ์ค๋ฐ๊ฟ ์ X์ขํ๊ฐ ํฌ๊ฒ ๋ณํจ)
|
| 277 |
+
if abs(curr_cx - prev_cx) > 50:
|
| 278 |
+
if current_col_text:
|
| 279 |
+
columns.append("".join(current_col_text))
|
| 280 |
+
current_col_text = []
|
| 281 |
+
prev_cx = curr_cx # ์๋ก์ด ์ด์ ๊ธฐ์ค์ผ๋ก ๊ฐฑ์
|
| 282 |
+
|
| 283 |
+
current_col_text.append(text)
|
| 284 |
+
# ๊ฐ์ ์ด ๋ด์์๋ ๋ฏธ์ธํ X ํ๋ค๋ฆผ์ด ์์ ์ ์์ผ๋ฏ๋ก prev_cx๋ฅผ ๊ณ์ ๊ฐฑ์ ํ์ง ์๊ณ
|
| 285 |
+
# ํด๋น ์ด์ '๋ํ' X๊ฐ์ ์ ์งํ๊ฑฐ๋, ํน์ ๊ธ์๋ง๋ค ๊ฐฑ์ ํ ์ ์์.
|
| 286 |
+
# ์ฌ๊ธฐ์๋ ๊ธ์๊ฐ ๋น์ค๋ฌํ ์ ์์ผ๋ฏ๋ก ๋งค๋ฒ ๊ฐฑ์ ํ๋ ๋ฐฉ์์ ์
|
| 287 |
+
prev_cx = curr_cx
|
| 288 |
+
|
| 289 |
+
# ๋ง์ง๋ง ์ด ์ถ๊ฐ
|
| 290 |
+
if current_col_text:
|
| 291 |
+
columns.append("".join(current_col_text))
|
| 292 |
+
|
| 293 |
+
# ์ถ๋ ฅ
|
| 294 |
+
for idx, col_text in enumerate(columns, 1):
|
| 295 |
+
logger.info(f" [์ด {idx:02d}] {col_text}")
|
| 296 |
+
|
| 297 |
+
logger.info("-" * 60 + "\n")
|
| 298 |
+
|
| 299 |
+
# 5. ๊ฒฐ๊ณผ ์ ์ฅ
|
| 300 |
+
json_path = os.path.splitext(image_path)[0] + "_ocr_result.json"
|
| 301 |
+
with open(json_path, 'w', encoding='utf-8') as f:
|
| 302 |
+
json.dump(formatted_result, f, ensure_ascii=False, indent=2)
|
| 303 |
+
logger.info(f"๐พ JSON ๊ฒฐ๊ณผ ์ ์ฅ๋จ: {json_path}")
|
| 304 |
+
|
| 305 |
+
# 6. ์๊ฐํ ์ ์ฅ
|
| 306 |
+
output_img_path = os.path.splitext(image_path)[0] + "_bbox.jpg"
|
| 307 |
+
bbox_image_path = ocr_image_path if use_preprocessing and preprocess_result.get('success') else image_path
|
| 308 |
+
draw_bboxes(bbox_image_path, results_list, output_img_path)
|
| 309 |
+
|
| 310 |
+
# 7. ํต๊ณ
|
| 311 |
+
counts = {'Google':0, 'Custom':0, 'MASK1':0, 'MASK2':0, 'TEXT':0}
|
| 312 |
+
for r in results_list:
|
| 313 |
+
if r['source'] in counts: counts[r['source']] += 1
|
| 314 |
+
if r['type'] in counts: counts[r['type']] += 1
|
| 315 |
+
|
| 316 |
+
logger.info("๐ ์ต์ข
ํต๊ณ")
|
| 317 |
+
logger.info(f" - ๐ข Google: {counts['Google']}๊ฐ")
|
| 318 |
+
logger.info(f" - ๐ฃ Custom: {counts['Custom']}๊ฐ")
|
| 319 |
+
logger.info(f" - ๐ต MASK1: {counts['MASK1']}๊ฐ")
|
| 320 |
+
logger.info(f" - ๐ด MASK2: {counts['MASK2']}๊ฐ")
|
| 321 |
+
logger.info(f" - ๐ TEXT: {counts['TEXT']}๊ฐ")
|
| 322 |
+
logger.info("="*60)
|
| 323 |
+
|
| 324 |
+
return True
|
| 325 |
+
|
| 326 |
+
except Exception as e:
|
| 327 |
+
logger.error(f"โ ์ค๋ฅ ๋ฐ์: {e}", exc_info=True)
|
| 328 |
+
return False
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def main():
|
| 332 |
+
if len(sys.argv) < 2:
|
| 333 |
+
print("์ฌ์ฉ๋ฒ: python dong_ocr.py <์ด๋ฏธ์ง>")
|
| 334 |
+
sys.exit(1)
|
| 335 |
+
|
| 336 |
+
if not os.getenv('OCR_WEIGHTS_BASE_PATH') or not os.getenv('GOOGLE_CREDENTIALS_JSON'):
|
| 337 |
+
logger.error("โ ํ๊ฒฝ๋ณ์ ๋ฏธ์ค์ ")
|
| 338 |
+
sys.exit(1)
|
| 339 |
+
|
| 340 |
+
if run_ocr(sys.argv[1]):
|
| 341 |
+
logger.info("โ
์์
์๋ฃ!")
|
| 342 |
+
sys.exit(0)
|
| 343 |
+
else:
|
| 344 |
+
logger.error("โ ์์
์คํจ")
|
| 345 |
+
sys.exit(1)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
if __name__ == "__main__":
|
| 349 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
opencv-python
|
| 2 |
+
numpy
|
| 3 |
+
torch
|
| 4 |
+
torchvision
|
| 5 |
+
python-dotenv
|
| 6 |
+
Pillow
|
| 7 |
+
google-cloud-vision
|
| 8 |
+
huggingface-hub>=0.34.0,<1.0
|
weights/best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:31db0eec96515dc820475df31245d7fe51ffcc56c76dc70df4e2bf83ff21d7e6
|
| 3 |
+
size 115004284
|
weights/best_5000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d3180e120cb82a5dc2474448a3b61ea72b53e2507a2dc367bda76dc222a35ec6
|
| 3 |
+
size 62505977
|