donghyun commited on
Commit ·
8672bad
1
Parent(s): 1a7b7d2
Add OCR code, modules, and weights
Browse files- .gitignore +22 -0
- README.md +2 -0
- ai_modules/__init__.py +60 -0
- ai_modules/models/HRCenterNet.py +204 -0
- ai_modules/models/__init__.py +5 -0
- ai_modules/models/modules.py +111 -0
- ai_modules/models/resnet.py +186 -0
- ai_modules/nlp/__init__.py +26 -0
- ai_modules/nlp/mlm_predictor.py +118 -0
- ai_modules/nlp/punctuation_restorer.py +326 -0
- ai_modules/nlp/utils.py +87 -0
- ai_modules/nlp_engine.py +321 -0
- ai_modules/ocr_engine.py +767 -0
- ai_modules/preprocessor_unified.py +605 -0
- dong_ocr.py +349 -0
- requirements.txt +8 -0
- weights/best.pth +3 -0
- weights/best_5000.pt +3 -0
.gitignore
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
*.pyo
|
| 5 |
+
|
| 6 |
+
# Environment & Secrets
|
| 7 |
+
.env
|
| 8 |
+
*.json
|
| 9 |
+
weights/*.json
|
| 10 |
+
|
| 11 |
+
# Logs & Temp
|
| 12 |
+
*.log
|
| 13 |
+
*.tmp
|
| 14 |
+
*.temp
|
| 15 |
+
|
| 16 |
+
# Output files
|
| 17 |
+
*_bbox.*
|
| 18 |
+
*_ocr_result.json
|
| 19 |
+
|
| 20 |
+
# OS files
|
| 21 |
+
.DS_Store
|
| 22 |
+
Thumbs.db
|
README.md
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
ai_modules/__init__.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
================================================================================
|
| 4 |
+
Epitext AI Unified Preprocessing Module
|
| 5 |
+
================================================================================
|
| 6 |
+
|
| 7 |
+
통합 이미지 전처리 패키지 (Swin Gray + OCR 동시 생성)
|
| 8 |
+
|
| 9 |
+
한 번의 함수 호출로 두 가지 전처리 완료:
|
| 10 |
+
|
| 11 |
+
1️⃣ Swin Gray: 그레이 비이진화 (정보 손실 최소) → JPG 3채널
|
| 12 |
+
|
| 13 |
+
2️⃣ OCR: 이진화 (명확한 흑백) → PNG 1채널
|
| 14 |
+
|
| 15 |
+
버전: 1.0.0
|
| 16 |
+
상태: ✅ Production Ready
|
| 17 |
+
|
| 18 |
+
주요 특징:
|
| 19 |
+
|
| 20 |
+
✅ 효율성: 영역 검출 1회 (두 가지 모두 사용)
|
| 21 |
+
|
| 22 |
+
✅ 배경 보장: Swin (밝음) + OCR (하얀색)
|
| 23 |
+
|
| 24 |
+
✅ 탁본 지원: 자동 검출 옵션
|
| 25 |
+
|
| 26 |
+
✅ 설정 가능: JSON 기반 커스터마이징
|
| 27 |
+
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
from .preprocessor_unified import (
|
| 31 |
+
UnifiedImagePreprocessor,
|
| 32 |
+
get_preprocessor,
|
| 33 |
+
preprocess_image_unified
|
| 34 |
+
)
|
| 35 |
+
from .ocr_engine import (
|
| 36 |
+
get_ocr_engine,
|
| 37 |
+
OCREngine,
|
| 38 |
+
ocr_and_detect
|
| 39 |
+
)
|
| 40 |
+
from .nlp_engine import (
|
| 41 |
+
get_nlp_engine,
|
| 42 |
+
NLPEngine,
|
| 43 |
+
process_text_with_nlp
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
__version__ = "1.0.0"
|
| 47 |
+
__author__ = "Epitext Team"
|
| 48 |
+
|
| 49 |
+
__all__ = [
|
| 50 |
+
"UnifiedImagePreprocessor",
|
| 51 |
+
"get_preprocessor",
|
| 52 |
+
"preprocess_image_unified",
|
| 53 |
+
"get_ocr_engine",
|
| 54 |
+
"OCREngine",
|
| 55 |
+
"ocr_and_detect",
|
| 56 |
+
"get_nlp_engine",
|
| 57 |
+
"NLPEngine",
|
| 58 |
+
"process_text_with_nlp"
|
| 59 |
+
]
|
| 60 |
+
|
ai_modules/models/HRCenterNet.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from ai_modules.models.modules import BasicBlock, Bottleneck
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class StageModule(nn.Module):
|
| 7 |
+
def __init__(self, stage, output_branches, c, bn_momentum):
|
| 8 |
+
super(StageModule, self).__init__()
|
| 9 |
+
self.stage = stage
|
| 10 |
+
self.output_branches = output_branches
|
| 11 |
+
|
| 12 |
+
self.branches = nn.ModuleList()
|
| 13 |
+
for i in range(self.stage):
|
| 14 |
+
w = c * (2 ** i)
|
| 15 |
+
branch = nn.Sequential(
|
| 16 |
+
BasicBlock(w, w, bn_momentum=bn_momentum),
|
| 17 |
+
BasicBlock(w, w, bn_momentum=bn_momentum),
|
| 18 |
+
BasicBlock(w, w, bn_momentum=bn_momentum),
|
| 19 |
+
BasicBlock(w, w, bn_momentum=bn_momentum),
|
| 20 |
+
)
|
| 21 |
+
self.branches.append(branch)
|
| 22 |
+
|
| 23 |
+
self.fuse_layers = nn.ModuleList()
|
| 24 |
+
# for each output_branches (i.e. each branch in all cases but the very last one)
|
| 25 |
+
for i in range(self.output_branches):
|
| 26 |
+
self.fuse_layers.append(nn.ModuleList())
|
| 27 |
+
for j in range(self.stage): # for each branch
|
| 28 |
+
if i == j:
|
| 29 |
+
self.fuse_layers[-1].append(nn.Sequential()) # Used in place of "None" because it is callable
|
| 30 |
+
elif i < j:
|
| 31 |
+
self.fuse_layers[-1].append(nn.Sequential(
|
| 32 |
+
nn.Conv2d(c * (2 ** j), c * (2 ** i), kernel_size=(1, 1), stride=(1, 1), bias=False),
|
| 33 |
+
nn.BatchNorm2d(c * (2 ** i), eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
|
| 34 |
+
nn.Upsample(scale_factor=(2.0 ** (j - i)), mode='nearest'),
|
| 35 |
+
))
|
| 36 |
+
elif i > j:
|
| 37 |
+
ops = []
|
| 38 |
+
for k in range(i - j - 1):
|
| 39 |
+
ops.append(nn.Sequential(
|
| 40 |
+
nn.Conv2d(c * (2 ** j), c * (2 ** j), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1),
|
| 41 |
+
bias=False),
|
| 42 |
+
nn.BatchNorm2d(c * (2 ** j), eps=1e-05, momentum=0.1, affine=True,
|
| 43 |
+
track_running_stats=True),
|
| 44 |
+
nn.ReLU(inplace=True),
|
| 45 |
+
))
|
| 46 |
+
ops.append(nn.Sequential(
|
| 47 |
+
nn.Conv2d(c * (2 ** j), c * (2 ** i), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1),
|
| 48 |
+
bias=False),
|
| 49 |
+
nn.BatchNorm2d(c * (2 ** i), eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
|
| 50 |
+
))
|
| 51 |
+
self.fuse_layers[-1].append(nn.Sequential(*ops))
|
| 52 |
+
|
| 53 |
+
self.relu = nn.ReLU(inplace=True)
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
assert len(self.branches) == len(x)
|
| 57 |
+
|
| 58 |
+
x = [branch(b) for branch, b in zip(self.branches, x)]
|
| 59 |
+
|
| 60 |
+
x_fused = []
|
| 61 |
+
for i in range(len(self.fuse_layers)):
|
| 62 |
+
for j in range(0, len(self.branches)):
|
| 63 |
+
if j == 0:
|
| 64 |
+
x_fused.append(self.fuse_layers[i][0](x[0]))
|
| 65 |
+
else:
|
| 66 |
+
x_fused[i] = x_fused[i] + self.fuse_layers[i][j](x[j])
|
| 67 |
+
|
| 68 |
+
for i in range(len(x_fused)):
|
| 69 |
+
x_fused[i] = self.relu(x_fused[i])
|
| 70 |
+
|
| 71 |
+
return x_fused
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class _HRCenterNet(nn.Module):
|
| 75 |
+
def __init__(self, c=48, nof_joints=17, bn_momentum=0.1):
|
| 76 |
+
super(_HRCenterNet, self).__init__()
|
| 77 |
+
|
| 78 |
+
# Input (stem net)
|
| 79 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
|
| 80 |
+
self.bn1 = nn.BatchNorm2d(64, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True)
|
| 81 |
+
self.conv2 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
|
| 82 |
+
self.bn2 = nn.BatchNorm2d(64, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True)
|
| 83 |
+
self.relu = nn.ReLU(inplace=True)
|
| 84 |
+
|
| 85 |
+
# Stage 1 (layer1) - First group of bottleneck (resnet) modules
|
| 86 |
+
downsample = nn.Sequential(
|
| 87 |
+
nn.Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False),
|
| 88 |
+
nn.BatchNorm2d(256, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
|
| 89 |
+
)
|
| 90 |
+
self.layer1 = nn.Sequential(
|
| 91 |
+
Bottleneck(64, 64, downsample=downsample),
|
| 92 |
+
Bottleneck(256, 64),
|
| 93 |
+
Bottleneck(256, 64),
|
| 94 |
+
Bottleneck(256, 64),
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Fusion layer 1 (transition1) - Creation of the first two branches (one full and one half resolution)
|
| 98 |
+
self.transition1 = nn.ModuleList([
|
| 99 |
+
nn.Sequential(
|
| 100 |
+
nn.Conv2d(256, c, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
|
| 101 |
+
nn.BatchNorm2d(c, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
|
| 102 |
+
nn.ReLU(inplace=True),
|
| 103 |
+
),
|
| 104 |
+
nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights
|
| 105 |
+
nn.Conv2d(256, c * (2 ** 1), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
|
| 106 |
+
nn.BatchNorm2d(c * (2 ** 1), eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
|
| 107 |
+
nn.ReLU(inplace=True),
|
| 108 |
+
)),
|
| 109 |
+
])
|
| 110 |
+
|
| 111 |
+
# Stage 2 (stage2) - Second module with 1 group of bottleneck (resnet) modules. This has 2 branches
|
| 112 |
+
self.stage2 = nn.Sequential(
|
| 113 |
+
StageModule(stage=2, output_branches=2, c=c, bn_momentum=bn_momentum),
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Fusion layer 2 (transition2) - Creation of the third branch (1/4 resolution)
|
| 117 |
+
self.transition2 = nn.ModuleList([
|
| 118 |
+
nn.Sequential(), # None, - Used in place of "None" because it is callable
|
| 119 |
+
nn.Sequential(), # None, - Used in place of "None" because it is callable
|
| 120 |
+
nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights
|
| 121 |
+
nn.Conv2d(c * (2 ** 1), c * (2 ** 2), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
|
| 122 |
+
nn.BatchNorm2d(c * (2 ** 2), eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
|
| 123 |
+
nn.ReLU(inplace=True),
|
| 124 |
+
)), # ToDo Why the new branch derives from the "upper" branch only?
|
| 125 |
+
])
|
| 126 |
+
|
| 127 |
+
# Stage 3 (stage3) - Third module with 4 groups of bottleneck (resnet) modules. This has 3 branches
|
| 128 |
+
self.stage3 = nn.Sequential(
|
| 129 |
+
StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
|
| 130 |
+
StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
|
| 131 |
+
StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
|
| 132 |
+
StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# Fusion layer 3 (transition3) - Creation of the fourth branch (1/8 resolution)
|
| 136 |
+
self.transition3 = nn.ModuleList([
|
| 137 |
+
nn.Sequential(), # None, - Used in place of "None" because it is callable
|
| 138 |
+
nn.Sequential(), # None, - Used in place of "None" because it is callable
|
| 139 |
+
nn.Sequential(), # None, - Used in place of "None" because it is callable
|
| 140 |
+
nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights
|
| 141 |
+
nn.Conv2d(c * (2 ** 2), c * (2 ** 3), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
|
| 142 |
+
nn.BatchNorm2d(c * (2 ** 3), eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
|
| 143 |
+
nn.ReLU(inplace=True),
|
| 144 |
+
)), # ToDo Why the new branch derives from the "upper" branch only?
|
| 145 |
+
])
|
| 146 |
+
|
| 147 |
+
# Stage 4 (stage4) - Fourth module with 3 groups of bottleneck (resnet) modules. This has 4 branches
|
| 148 |
+
self.stage4 = nn.Sequential(
|
| 149 |
+
StageModule(stage=4, output_branches=4, c=c, bn_momentum=bn_momentum),
|
| 150 |
+
StageModule(stage=4, output_branches=4, c=c, bn_momentum=bn_momentum),
|
| 151 |
+
StageModule(stage=4, output_branches=1, c=c, bn_momentum=bn_momentum),
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Final layer (final_layer)
|
| 155 |
+
self.final_layer = nn.Sequential(
|
| 156 |
+
nn.Conv2d(c, 32, kernel_size=(1, 1), stride=(1, 1)),
|
| 157 |
+
nn.BatchNorm2d(32, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
|
| 158 |
+
nn.ReLU(inplace=True),
|
| 159 |
+
nn.Conv2d(32, nof_joints, kernel_size=(1, 1), stride=(1, 1)),
|
| 160 |
+
nn.Sigmoid()
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
def forward(self, x):
|
| 164 |
+
x = self.conv1(x)
|
| 165 |
+
x = self.bn1(x)
|
| 166 |
+
x = self.relu(x)
|
| 167 |
+
x = self.conv2(x)
|
| 168 |
+
x = self.bn2(x)
|
| 169 |
+
x = self.relu(x)
|
| 170 |
+
|
| 171 |
+
x = self.layer1(x)
|
| 172 |
+
x = [trans(x) for trans in self.transition1] # Since now, x is a list (# == nof branches)
|
| 173 |
+
|
| 174 |
+
x = self.stage2(x)
|
| 175 |
+
# x = [trans(x[-1]) for trans in self.transition2] # New branch derives from the "upper" branch only
|
| 176 |
+
x = [
|
| 177 |
+
self.transition2[0](x[0]),
|
| 178 |
+
self.transition2[1](x[1]),
|
| 179 |
+
self.transition2[2](x[-1])
|
| 180 |
+
] # New branch derives from the "upper" branch only
|
| 181 |
+
|
| 182 |
+
x = self.stage3(x)
|
| 183 |
+
# x = [trans(x) for trans in self.transition3] # New branch derives from the "upper" branch only
|
| 184 |
+
x = [
|
| 185 |
+
self.transition3[0](x[0]),
|
| 186 |
+
self.transition3[1](x[1]),
|
| 187 |
+
self.transition3[2](x[2]),
|
| 188 |
+
self.transition3[3](x[-1])
|
| 189 |
+
] # New branch derives from the "upper" branch only
|
| 190 |
+
|
| 191 |
+
x = self.stage4(x)
|
| 192 |
+
|
| 193 |
+
x = self.final_layer(x[0])
|
| 194 |
+
|
| 195 |
+
return x
|
| 196 |
+
|
| 197 |
+
def HRCenterNet(args):
|
| 198 |
+
|
| 199 |
+
model = _HRCenterNet(32, 5, 0.1)
|
| 200 |
+
|
| 201 |
+
if not (args.log_dir == None):
|
| 202 |
+
model.load_state_dict(torch.load(args.log_dir))
|
| 203 |
+
|
| 204 |
+
return model
|
ai_modules/models/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
OCR 모델 모듈 패키지
|
| 4 |
+
"""
|
| 5 |
+
|
ai_modules/models/modules.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Bottleneck(nn.Module):
|
| 6 |
+
expansion = 4
|
| 7 |
+
|
| 8 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, bn_momentum=0.1):
|
| 9 |
+
super(Bottleneck, self).__init__()
|
| 10 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 11 |
+
self.bn1 = nn.BatchNorm2d(planes, momentum=bn_momentum)
|
| 12 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 13 |
+
self.bn2 = nn.BatchNorm2d(planes, momentum=bn_momentum)
|
| 14 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
|
| 15 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=bn_momentum)
|
| 16 |
+
self.relu = nn.ReLU(inplace=True)
|
| 17 |
+
self.downsample = downsample
|
| 18 |
+
self.stride = stride
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
residual = x
|
| 22 |
+
|
| 23 |
+
out = self.conv1(x)
|
| 24 |
+
out = self.bn1(out)
|
| 25 |
+
out = self.relu(out)
|
| 26 |
+
|
| 27 |
+
out = self.conv2(out)
|
| 28 |
+
out = self.bn2(out)
|
| 29 |
+
out = self.relu(out)
|
| 30 |
+
|
| 31 |
+
out = self.conv3(out)
|
| 32 |
+
out = self.bn3(out)
|
| 33 |
+
|
| 34 |
+
if self.downsample is not None:
|
| 35 |
+
residual = self.downsample(x)
|
| 36 |
+
|
| 37 |
+
out += residual
|
| 38 |
+
out = self.relu(out)
|
| 39 |
+
|
| 40 |
+
return out
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# class Bottleneck_Tranpose(nn.Module):
|
| 44 |
+
# expansion = 4
|
| 45 |
+
|
| 46 |
+
# def __init__(self, inplanes, planes, stride=1, downsample=None, bn_momentum=0.1):
|
| 47 |
+
# super(Bottleneck, self).__init__()
|
| 48 |
+
# nn.ConvTranspose2d(c, 64, (3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1)),
|
| 49 |
+
|
| 50 |
+
# self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 51 |
+
# self.bn1 = nn.BatchNorm2d(planes, momentum=bn_momentum)
|
| 52 |
+
# self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 53 |
+
# self.bn2 = nn.BatchNorm2d(planes, momentum=bn_momentum)
|
| 54 |
+
# self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
|
| 55 |
+
# self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=bn_momentum)
|
| 56 |
+
# self.relu = nn.ReLU(inplace=True)
|
| 57 |
+
# self.downsample = downsample
|
| 58 |
+
# self.stride = stride
|
| 59 |
+
|
| 60 |
+
# def forward(self, x):
|
| 61 |
+
# residual = x
|
| 62 |
+
|
| 63 |
+
# out = self.conv1(x)
|
| 64 |
+
# out = self.bn1(out)
|
| 65 |
+
# out = self.relu(out)
|
| 66 |
+
|
| 67 |
+
# out = self.conv2(out)
|
| 68 |
+
# out = self.bn2(out)
|
| 69 |
+
# out = self.relu(out)
|
| 70 |
+
|
| 71 |
+
# out = self.conv3(out)
|
| 72 |
+
# out = self.bn3(out)
|
| 73 |
+
|
| 74 |
+
# if self.downsample is not None:
|
| 75 |
+
# residual = self.downsample(x)
|
| 76 |
+
|
| 77 |
+
# out += residual
|
| 78 |
+
# out = self.relu(out)
|
| 79 |
+
|
| 80 |
+
# return out
|
| 81 |
+
|
| 82 |
+
class BasicBlock(nn.Module):
|
| 83 |
+
expansion = 1
|
| 84 |
+
|
| 85 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, bn_momentum=0.1):
|
| 86 |
+
super(BasicBlock, self).__init__()
|
| 87 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 88 |
+
self.bn1 = nn.BatchNorm2d(planes, momentum=bn_momentum)
|
| 89 |
+
self.relu = nn.ReLU(inplace=True)
|
| 90 |
+
self.conv2 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
| 91 |
+
self.bn2 = nn.BatchNorm2d(planes, momentum=bn_momentum)
|
| 92 |
+
self.downsample = downsample
|
| 93 |
+
self.stride = stride
|
| 94 |
+
|
| 95 |
+
def forward(self, x):
|
| 96 |
+
residual = x
|
| 97 |
+
|
| 98 |
+
out = self.conv1(x)
|
| 99 |
+
out = self.bn1(out)
|
| 100 |
+
out = self.relu(out)
|
| 101 |
+
|
| 102 |
+
out = self.conv2(out)
|
| 103 |
+
out = self.bn2(out)
|
| 104 |
+
|
| 105 |
+
if self.downsample is not None:
|
| 106 |
+
residual = self.downsample(x)
|
| 107 |
+
|
| 108 |
+
out += residual
|
| 109 |
+
out = self.relu(out)
|
| 110 |
+
|
| 111 |
+
return out
|
ai_modules/models/resnet.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import PIL
|
| 3 |
+
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
|
| 7 |
+
class BasicBlock(nn.Module):
|
| 8 |
+
expansion = 1
|
| 9 |
+
|
| 10 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
| 11 |
+
base_width=64, dilation=1, norm_layer=None):
|
| 12 |
+
super(BasicBlock, self).__init__()
|
| 13 |
+
if norm_layer is None:
|
| 14 |
+
norm_layer = nn.BatchNorm2d
|
| 15 |
+
if groups != 1 or base_width != 64:
|
| 16 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
| 17 |
+
if dilation > 1:
|
| 18 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
| 19 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
| 20 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 21 |
+
self.bn1 = norm_layer(planes)
|
| 22 |
+
self.relu = nn.ReLU(inplace=True)
|
| 23 |
+
self.conv2 = conv3x3(planes, planes)
|
| 24 |
+
self.bn2 = norm_layer(planes)
|
| 25 |
+
self.downsample = downsample
|
| 26 |
+
self.stride = stride
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
identity = x
|
| 30 |
+
|
| 31 |
+
out = self.conv1(x)
|
| 32 |
+
out = self.bn1(out)
|
| 33 |
+
out = self.relu(out)
|
| 34 |
+
|
| 35 |
+
out = self.conv2(out)
|
| 36 |
+
out = self.bn2(out)
|
| 37 |
+
|
| 38 |
+
if self.downsample is not None:
|
| 39 |
+
identity = self.downsample(x)
|
| 40 |
+
|
| 41 |
+
out += identity
|
| 42 |
+
out = self.relu(out)
|
| 43 |
+
|
| 44 |
+
return out
|
| 45 |
+
|
| 46 |
+
class ResNet(nn.Module):
|
| 47 |
+
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
| 48 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
| 49 |
+
norm_layer=None):
|
| 50 |
+
super(ResNet, self).__init__()
|
| 51 |
+
if norm_layer is None:
|
| 52 |
+
norm_layer = nn.BatchNorm2d
|
| 53 |
+
self._norm_layer = norm_layer
|
| 54 |
+
|
| 55 |
+
self.inplanes = 64
|
| 56 |
+
self.dilation = 1
|
| 57 |
+
if replace_stride_with_dilation is None:
|
| 58 |
+
# each element in the tuple indicates if we should replace
|
| 59 |
+
# the 2x2 stride with a dilated convolution instead
|
| 60 |
+
replace_stride_with_dilation = [False, False, False]
|
| 61 |
+
if len(replace_stride_with_dilation) != 3:
|
| 62 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
| 63 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
| 64 |
+
self.groups = groups
|
| 65 |
+
self.base_width = width_per_group
|
| 66 |
+
self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,
|
| 67 |
+
bias=False)
|
| 68 |
+
self.bn1 = norm_layer(self.inplanes)
|
| 69 |
+
self.relu = nn.ReLU(inplace=True)
|
| 70 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 71 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 72 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
| 73 |
+
dilate=replace_stride_with_dilation[0])
|
| 74 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
| 75 |
+
dilate=replace_stride_with_dilation[1])
|
| 76 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
| 77 |
+
dilate=replace_stride_with_dilation[2])
|
| 78 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 79 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
| 80 |
+
|
| 81 |
+
for m in self.modules():
|
| 82 |
+
if isinstance(m, nn.Conv2d):
|
| 83 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 84 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 85 |
+
nn.init.constant_(m.weight, 1)
|
| 86 |
+
nn.init.constant_(m.bias, 0)
|
| 87 |
+
|
| 88 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
| 89 |
+
norm_layer = self._norm_layer
|
| 90 |
+
downsample = None
|
| 91 |
+
previous_dilation = self.dilation
|
| 92 |
+
if dilate:
|
| 93 |
+
self.dilation *= stride
|
| 94 |
+
stride = 1
|
| 95 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 96 |
+
downsample = nn.Sequential(
|
| 97 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
| 98 |
+
norm_layer(planes * block.expansion),
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
layers = []
|
| 102 |
+
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
| 103 |
+
self.base_width, previous_dilation, norm_layer))
|
| 104 |
+
self.inplanes = planes * block.expansion
|
| 105 |
+
for _ in range(1, blocks):
|
| 106 |
+
layers.append(block(self.inplanes, planes, groups=self.groups,
|
| 107 |
+
base_width=self.base_width, dilation=self.dilation,
|
| 108 |
+
norm_layer=norm_layer))
|
| 109 |
+
|
| 110 |
+
return nn.Sequential(*layers)
|
| 111 |
+
|
| 112 |
+
def _forward_impl(self, x):
|
| 113 |
+
# See note [TorchScript super()]
|
| 114 |
+
x = self.conv1(x)
|
| 115 |
+
x = self.bn1(x)
|
| 116 |
+
x = self.relu(x)
|
| 117 |
+
x = self.maxpool(x)
|
| 118 |
+
|
| 119 |
+
x = self.layer1(x)
|
| 120 |
+
x = self.layer2(x)
|
| 121 |
+
x = self.layer3(x)
|
| 122 |
+
x = self.layer4(x)
|
| 123 |
+
|
| 124 |
+
x = self.avgpool(x)
|
| 125 |
+
x = torch.flatten(x, 1)
|
| 126 |
+
x = self.fc(x)
|
| 127 |
+
|
| 128 |
+
return x
|
| 129 |
+
|
| 130 |
+
def forward(self, x):
|
| 131 |
+
return self._forward_impl(x)
|
| 132 |
+
|
| 133 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
| 134 |
+
"""1x1 convolution"""
|
| 135 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
| 136 |
+
|
| 137 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
| 138 |
+
"""3x3 convolution with padding"""
|
| 139 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 140 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
| 141 |
+
|
| 142 |
+
class ResnetCustom(torch.nn.Module):
|
| 143 |
+
def __init__(self, weight_fn):
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 146 |
+
weight = torch.load(weight_fn, map_location=self.device)
|
| 147 |
+
self.id2charDict = weight['vocab']['id2char']
|
| 148 |
+
num_classes = len(self.id2charDict)
|
| 149 |
+
self.id2charDict[-1] = "■" # unrecognized token
|
| 150 |
+
self.transform = transforms.Compose([transforms.Grayscale(),
|
| 151 |
+
transforms.Resize((64,64)),
|
| 152 |
+
transforms.ToTensor()])
|
| 153 |
+
|
| 154 |
+
self.net = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)
|
| 155 |
+
self.net.load_state_dict(weight['model'])
|
| 156 |
+
self.net = self.net.to(self.device)
|
| 157 |
+
self.net.eval()
|
| 158 |
+
#self.net(torch.rand((64,1,64,64)))
|
| 159 |
+
print(f'{weight_fn} loaded!')
|
| 160 |
+
|
| 161 |
+
def forward(self, images:PIL.Image, bs=256, conf_thres=0.5):
|
| 162 |
+
'''
|
| 163 |
+
input
|
| 164 |
+
images: list of PIL images
|
| 165 |
+
return
|
| 166 |
+
chars: list of recognized chars
|
| 167 |
+
'''
|
| 168 |
+
chars = []
|
| 169 |
+
for i in range(0, len(images), bs):
|
| 170 |
+
inp = []
|
| 171 |
+
for image in images[i: i+bs]:
|
| 172 |
+
inp.append(self.transform(image))
|
| 173 |
+
inp = torch.stack(inp, dim=0).to(self.device)
|
| 174 |
+
out = self.net(inp)
|
| 175 |
+
out = torch.nn.functional.softmax(out, dim=1)
|
| 176 |
+
conf, indice = torch.max(out, dim=1)
|
| 177 |
+
indice[conf<conf_thres] = -1
|
| 178 |
+
chars += [self.id2charDict[x] for x in indice.tolist()]
|
| 179 |
+
|
| 180 |
+
return chars
|
| 181 |
+
|
| 182 |
+
if __name__ == "__main__":
|
| 183 |
+
net = ResnetCustom(weight_fn="best_5000.pt")
|
| 184 |
+
inp = [PIL.Image.open('0.jpg'), PIL.Image.open('1.png')]
|
| 185 |
+
print(net(inp))
|
| 186 |
+
|
ai_modules/nlp/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Korean Historical Text Processor NLP Module
|
| 3 |
+
한국어 고전 텍스트의 구두점 복원 및 MLM 예측을 위한 모듈입니다.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
__version__ = "1.0.0"
|
| 7 |
+
__author__ = "EPITEXT"
|
| 8 |
+
|
| 9 |
+
from .punctuation_restorer import PunctuationRestorer
|
| 10 |
+
from .mlm_predictor import MLMPredictor
|
| 11 |
+
from .utils import (
|
| 12 |
+
remove_punctuation,
|
| 13 |
+
extract_mask_info,
|
| 14 |
+
replace_mask_with_symbol,
|
| 15 |
+
normalize_mask_tokens,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"PunctuationRestorer",
|
| 20 |
+
"MLMPredictor",
|
| 21 |
+
"remove_punctuation",
|
| 22 |
+
"extract_mask_info",
|
| 23 |
+
"replace_mask_with_symbol",
|
| 24 |
+
"normalize_mask_tokens",
|
| 25 |
+
]
|
| 26 |
+
|
ai_modules/nlp/mlm_predictor.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MLM(Masked Language Model) 예측 모듈
|
| 3 |
+
BERT 기반 MLM을 사용하여 마스킹된 토큰을 예측합니다.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from typing import List, Dict
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
| 9 |
+
from .utils import normalize_mask_tokens
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MLMPredictor:
|
| 13 |
+
"""MLM 예측을 담당하는 클래스"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, config: Dict, device: str = "cpu"):
|
| 16 |
+
"""
|
| 17 |
+
MLM 예측기를 초기화합니다.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
config: 설정 딕셔너리 (nlp_config.json에서 로드)
|
| 21 |
+
device: 연산 디바이스 ('cpu' 또는 'cuda')
|
| 22 |
+
"""
|
| 23 |
+
mlm_cfg = config['mlm_model']
|
| 24 |
+
self.model_name = mlm_cfg['model_name']
|
| 25 |
+
self.top_k = mlm_cfg['top_k']
|
| 26 |
+
self.max_length = mlm_cfg['max_length']
|
| 27 |
+
self.device = device
|
| 28 |
+
self.tokenizer = None
|
| 29 |
+
self.model = None
|
| 30 |
+
|
| 31 |
+
def load_model(self) -> None:
|
| 32 |
+
"""모델을 메모리에 로드합니다."""
|
| 33 |
+
print(f"[MLM] 모델 로드 중: {self.model_name}")
|
| 34 |
+
|
| 35 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 36 |
+
self.model_name,
|
| 37 |
+
use_fast=False
|
| 38 |
+
)
|
| 39 |
+
self.model = AutoModelForMaskedLM.from_pretrained(self.model_name)
|
| 40 |
+
self.model.to(self.device)
|
| 41 |
+
self.model.eval()
|
| 42 |
+
|
| 43 |
+
print(f"[MLM] ✓ MLM 모델 로드 완료")
|
| 44 |
+
|
| 45 |
+
def predict_masks(
|
| 46 |
+
self,
|
| 47 |
+
text: str
|
| 48 |
+
) -> List[List[Dict[str, any]]]:
|
| 49 |
+
"""
|
| 50 |
+
텍스트 내의 [MASK] 토큰을 예측합니다.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
text: 마스크가 포함된 텍스트
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
각 마스크 위치별 top-k 예측 결과 리스트
|
| 57 |
+
"""
|
| 58 |
+
# [MASK1], [MASK2] -> [MASK] 정규화
|
| 59 |
+
text_normalized = normalize_mask_tokens(text)
|
| 60 |
+
|
| 61 |
+
print(f"[MLM] 입력 텍스트 샘플: {text_normalized[:100]}...")
|
| 62 |
+
print(f"[MLM] [MASK] 토큰 개수: {text_normalized.count('[MASK]')}")
|
| 63 |
+
|
| 64 |
+
# 토크나이즈
|
| 65 |
+
inputs = self.tokenizer(
|
| 66 |
+
text_normalized,
|
| 67 |
+
return_tensors="pt",
|
| 68 |
+
truncation=True,
|
| 69 |
+
max_length=self.max_length
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# 디바이스로 이동
|
| 73 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 74 |
+
|
| 75 |
+
# [MASK] 위치 찾기
|
| 76 |
+
mask_indices = torch.where(
|
| 77 |
+
inputs["input_ids"] == self.tokenizer.mask_token_id
|
| 78 |
+
)[1]
|
| 79 |
+
|
| 80 |
+
print(f"[MLM] 토크나이저가 찾은 [MASK] 위치 개수: {len(mask_indices)}")
|
| 81 |
+
|
| 82 |
+
if len(mask_indices) == 0:
|
| 83 |
+
print("[MLM] ⚠️ 경고: [MASK] 토큰을 찾을 수 없습니다!")
|
| 84 |
+
sample_tokens = self.tokenizer.convert_ids_to_tokens(
|
| 85 |
+
inputs['input_ids'][0][:50]
|
| 86 |
+
)
|
| 87 |
+
print(f"[MLM] 토큰화된 입력 샘플: {sample_tokens}")
|
| 88 |
+
return []
|
| 89 |
+
|
| 90 |
+
# 예측 수행
|
| 91 |
+
with torch.no_grad():
|
| 92 |
+
outputs = self.model(**inputs)
|
| 93 |
+
logits = outputs.logits
|
| 94 |
+
|
| 95 |
+
# 각 마스크 위치별로 top-k 예측
|
| 96 |
+
all_predictions = []
|
| 97 |
+
for mask_idx in mask_indices:
|
| 98 |
+
mask_logits = logits[0, mask_idx, :]
|
| 99 |
+
|
| 100 |
+
# 전체 어휘에 대해 softmax 계산 후 top-k 선택
|
| 101 |
+
all_probs = torch.nn.functional.softmax(mask_logits, dim=-1)
|
| 102 |
+
top_k_probs, top_k_indices = torch.topk(all_probs, self.top_k)
|
| 103 |
+
|
| 104 |
+
top_k_tokens = self.tokenizer.convert_ids_to_tokens(
|
| 105 |
+
top_k_indices.tolist()
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
predictions = [
|
| 109 |
+
{
|
| 110 |
+
"token": token,
|
| 111 |
+
"probability": float(prob)
|
| 112 |
+
}
|
| 113 |
+
for token, prob in zip(top_k_tokens, top_k_probs.tolist())
|
| 114 |
+
]
|
| 115 |
+
all_predictions.append(predictions)
|
| 116 |
+
|
| 117 |
+
return all_predictions
|
| 118 |
+
|
ai_modules/nlp/punctuation_restorer.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
구두점 복원 모듈
|
| 3 |
+
Hugging Face 모델을 사용하여 한국어 고전 텍스트의 구두점을 복원합니다.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import torch
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, List, Tuple
|
| 10 |
+
from collections import Counter
|
| 11 |
+
from huggingface_hub import snapshot_download
|
| 12 |
+
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class PunctuationRestorer:
|
| 16 |
+
"""구두점 복원을 담당하는 클래스"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, config: Dict, cache_dir: str, device: str = "cpu"):
|
| 19 |
+
"""
|
| 20 |
+
구두점 복원기를 초기화합니다.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
config: 설정 딕셔너리 (nlp_config.json에서 로드)
|
| 24 |
+
cache_dir: 모델 캐시 디렉토리 (기본 경로)
|
| 25 |
+
device: 연산 디바이스 ('cpu' 또는 'cuda')
|
| 26 |
+
"""
|
| 27 |
+
punc_cfg = config['punc_model']
|
| 28 |
+
self.model_tag = punc_cfg['model_tag']
|
| 29 |
+
self.max_length = punc_cfg['max_length']
|
| 30 |
+
self.window_size = punc_cfg['window_size']
|
| 31 |
+
self.overlap = punc_cfg['overlap']
|
| 32 |
+
|
| 33 |
+
self.cache_dir = Path(cache_dir) / "punc"
|
| 34 |
+
self.device = device
|
| 35 |
+
self.model_info = None
|
| 36 |
+
|
| 37 |
+
def download_model(self) -> None:
|
| 38 |
+
"""Hugging Face에서 모델을 다운로드합니다."""
|
| 39 |
+
self.cache_dir.parent.mkdir(parents=True, exist_ok=True)
|
| 40 |
+
|
| 41 |
+
if not self.cache_dir.exists() or not any(self.cache_dir.iterdir()):
|
| 42 |
+
print(f"[PUNC] 모델 다운로드 중: {self.model_tag}")
|
| 43 |
+
snapshot_download(
|
| 44 |
+
repo_id=self.model_tag,
|
| 45 |
+
repo_type="model",
|
| 46 |
+
local_dir=str(self.cache_dir),
|
| 47 |
+
local_dir_use_symlinks=False,
|
| 48 |
+
)
|
| 49 |
+
else:
|
| 50 |
+
print(f"[PUNC] 캐시된 모델 사용: {self.cache_dir}")
|
| 51 |
+
|
| 52 |
+
def load_model(self) -> None:
|
| 53 |
+
"""모델을 메모리에 로드합니다."""
|
| 54 |
+
torch_dtype = torch.float16 if "cuda" in self.device else torch.float32
|
| 55 |
+
|
| 56 |
+
# 모델 파일 찾기
|
| 57 |
+
fnames = sorted(self.cache_dir.rglob("*.safetensors"))
|
| 58 |
+
if len(fnames) == 0:
|
| 59 |
+
# safetensors가 없으면 다른 형식 시도
|
| 60 |
+
fnames = sorted(self.cache_dir.rglob("*.bin"))
|
| 61 |
+
|
| 62 |
+
if len(fnames) == 0:
|
| 63 |
+
raise FileNotFoundError(f"모델 파일을 찾을 수 없습니다: {self.cache_dir}")
|
| 64 |
+
|
| 65 |
+
hface_path = fnames[0].parent
|
| 66 |
+
|
| 67 |
+
# 토크나이저 및 모델 로드
|
| 68 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 69 |
+
str(hface_path),
|
| 70 |
+
model_max_length=self.max_length
|
| 71 |
+
)
|
| 72 |
+
model = AutoModelForTokenClassification.from_pretrained(
|
| 73 |
+
str(hface_path),
|
| 74 |
+
device_map=self.device if "cuda" in self.device else None,
|
| 75 |
+
torch_dtype=torch_dtype
|
| 76 |
+
)
|
| 77 |
+
if "cuda" not in self.device:
|
| 78 |
+
model = model.to(self.device)
|
| 79 |
+
model.eval()
|
| 80 |
+
|
| 81 |
+
# NER 파이프라인 생성
|
| 82 |
+
ner_pipeline = pipeline(
|
| 83 |
+
task="ner",
|
| 84 |
+
model=model,
|
| 85 |
+
tokenizer=tokenizer,
|
| 86 |
+
device=0 if "cuda" in self.device else -1
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# 레이블 매핑 로드
|
| 90 |
+
label2id_path = hface_path / "label2id.json"
|
| 91 |
+
if not label2id_path.is_file():
|
| 92 |
+
label2id_path = hface_path.parent / "label2id.json"
|
| 93 |
+
if not label2id_path.is_file():
|
| 94 |
+
raise FileNotFoundError(f"label2id.json을 찾을 수 없습니다: {hface_path}")
|
| 95 |
+
|
| 96 |
+
label2id = json.loads(label2id_path.read_text(encoding="utf-8"))
|
| 97 |
+
|
| 98 |
+
self.model_info = {
|
| 99 |
+
"model": model,
|
| 100 |
+
"tokenizer": tokenizer,
|
| 101 |
+
"pipe": ner_pipeline,
|
| 102 |
+
"label2id": label2id
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
print(f"[PUNC] ✓ 구두점 복원 모델 로드 완료")
|
| 106 |
+
|
| 107 |
+
def restore_punctuation(
|
| 108 |
+
self,
|
| 109 |
+
text: str,
|
| 110 |
+
add_space: bool = True,
|
| 111 |
+
reduce: bool = True,
|
| 112 |
+
) -> str:
|
| 113 |
+
"""
|
| 114 |
+
슬라이딩 윈도우 방식으로 구두점을 복원합니다.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
text: 입력 텍스트
|
| 118 |
+
add_space: 구두점 뒤 공백 추가 여부
|
| 119 |
+
reduce: 구두점 단순화 여부
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
구두점이 복원된 텍스트
|
| 123 |
+
"""
|
| 124 |
+
if not text.strip():
|
| 125 |
+
return ""
|
| 126 |
+
|
| 127 |
+
# 레이블 -> 구두점 매핑 생성
|
| 128 |
+
label2punc = self._build_label2punc(add_space, reduce)
|
| 129 |
+
|
| 130 |
+
# 슬라이딩 윈도우로 레이블 예측
|
| 131 |
+
labels = self._predict_labels_sliding(text, self.window_size, self.overlap)
|
| 132 |
+
|
| 133 |
+
# 길이 조정
|
| 134 |
+
if len(labels) < len(text):
|
| 135 |
+
labels += ["O"] * (len(text) - len(labels))
|
| 136 |
+
elif len(labels) > len(text):
|
| 137 |
+
labels = labels[:len(text)]
|
| 138 |
+
|
| 139 |
+
# 구두점 삽입
|
| 140 |
+
result = ""
|
| 141 |
+
for ch, label in zip(text, labels):
|
| 142 |
+
result += ch
|
| 143 |
+
punc = label2punc.get(label, "")
|
| 144 |
+
result += punc
|
| 145 |
+
|
| 146 |
+
return result.strip()
|
| 147 |
+
|
| 148 |
+
def _predict_labels_sliding(
|
| 149 |
+
self,
|
| 150 |
+
text: str,
|
| 151 |
+
window_size: int,
|
| 152 |
+
overlap: int
|
| 153 |
+
) -> List[str]:
|
| 154 |
+
"""
|
| 155 |
+
슬라이딩 윈도우로 각 문자의 레이블을 예측합니다.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
text: 입력 텍스트
|
| 159 |
+
window_size: 윈도우 크기
|
| 160 |
+
overlap: 중첩 크기
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
각 문자에 대한 레이블 리스트
|
| 164 |
+
"""
|
| 165 |
+
n = len(text)
|
| 166 |
+
if n == 0:
|
| 167 |
+
return []
|
| 168 |
+
|
| 169 |
+
# 각 위치별 후보 레이블 저장
|
| 170 |
+
labels_per_pos = [[] for _ in range(n)]
|
| 171 |
+
stride = max(1, window_size - overlap)
|
| 172 |
+
start = 0
|
| 173 |
+
|
| 174 |
+
while start < n:
|
| 175 |
+
end = min(start + window_size, n)
|
| 176 |
+
sub_text = text[start:end]
|
| 177 |
+
|
| 178 |
+
try:
|
| 179 |
+
# NER 예측 수행
|
| 180 |
+
sub_preds = self.model_info["pipe"](sub_text)
|
| 181 |
+
_, sub_labels = self._align_predictions(sub_text, sub_preds)
|
| 182 |
+
except Exception as e:
|
| 183 |
+
# 오류 발생 시 모두 'O' 레이블
|
| 184 |
+
print(f"[PUNC] 예측 오류 (start={start}): {e}")
|
| 185 |
+
sub_labels = ["O"] * len(sub_text)
|
| 186 |
+
|
| 187 |
+
# 전역 위치에 레이블 저장
|
| 188 |
+
for i, label in enumerate(sub_labels):
|
| 189 |
+
gidx = start + i
|
| 190 |
+
if gidx >= n:
|
| 191 |
+
break
|
| 192 |
+
if label != "O":
|
| 193 |
+
labels_per_pos[gidx].append(label)
|
| 194 |
+
|
| 195 |
+
if end == n:
|
| 196 |
+
break
|
| 197 |
+
start += stride
|
| 198 |
+
|
| 199 |
+
# 다수결 투표로 최종 레이블 결정
|
| 200 |
+
final_labels = []
|
| 201 |
+
for cand_list in labels_per_pos:
|
| 202 |
+
if not cand_list:
|
| 203 |
+
final_labels.append("O")
|
| 204 |
+
else:
|
| 205 |
+
c = Counter(cand_list)
|
| 206 |
+
label, _ = c.most_common(1)[0]
|
| 207 |
+
final_labels.append(label)
|
| 208 |
+
|
| 209 |
+
return final_labels
|
| 210 |
+
|
| 211 |
+
@staticmethod
|
| 212 |
+
def _align_predictions(text: str, predictions: List[dict]) -> Tuple[List[str], List[str]]:
|
| 213 |
+
"""
|
| 214 |
+
NER 예측 결과를 문자 단위 레이블로 정렬합니다.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
text: 원본 텍스트
|
| 218 |
+
predictions: NER 예측 결과
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
(문자 리스트, 레이블 리스트) 튜플
|
| 222 |
+
"""
|
| 223 |
+
words = list(text)
|
| 224 |
+
labels = ["O" for _ in range(len(words))]
|
| 225 |
+
|
| 226 |
+
for pred in predictions:
|
| 227 |
+
idx = pred["end"] - 1
|
| 228 |
+
if 0 <= idx < len(labels):
|
| 229 |
+
labels[idx] = pred["entity"]
|
| 230 |
+
|
| 231 |
+
return words, labels
|
| 232 |
+
|
| 233 |
+
def _build_label2punc(self, add_space: bool, reduce: bool) -> Dict[str, str]:
|
| 234 |
+
"""
|
| 235 |
+
레이블을 구두점으로 매핑하는 딕셔너리를 생성합니다.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
add_space: 구두점 뒤 공백 추가 여부
|
| 239 |
+
reduce: 구두점 단순화 여부
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
레이블 -> 구두점 매핑 딕셔너리
|
| 243 |
+
"""
|
| 244 |
+
label2id = self.model_info["label2id"]
|
| 245 |
+
label2punc = {f"B-{v}": k for k, v in label2id.items()}
|
| 246 |
+
label2punc["O"] = ""
|
| 247 |
+
|
| 248 |
+
# 구두점 단순화
|
| 249 |
+
if reduce:
|
| 250 |
+
new_label2punc = {}
|
| 251 |
+
for label, punc in label2punc.items():
|
| 252 |
+
if label == "O":
|
| 253 |
+
new_label2punc[label] = ""
|
| 254 |
+
else:
|
| 255 |
+
reduced = self._reduce_punc(punc)
|
| 256 |
+
new_label2punc[label] = reduced
|
| 257 |
+
label2punc = new_label2punc
|
| 258 |
+
|
| 259 |
+
# 공백 추가
|
| 260 |
+
if add_space:
|
| 261 |
+
special_puncs = "!,:;?。"
|
| 262 |
+
label2punc = {
|
| 263 |
+
k: self._insert_space(v, special_puncs)
|
| 264 |
+
for k, v in label2punc.items()
|
| 265 |
+
}
|
| 266 |
+
label2punc["O"] = ""
|
| 267 |
+
|
| 268 |
+
return label2punc
|
| 269 |
+
|
| 270 |
+
@staticmethod
|
| 271 |
+
def _reduce_punc(text: str) -> str:
|
| 272 |
+
"""
|
| 273 |
+
구두점을 단순화합니다 (?, 。, , 중 하나로 변환).
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
text: 구두점 문자열
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
단순화된 구두점
|
| 280 |
+
"""
|
| 281 |
+
reduce_map = {
|
| 282 |
+
",": ",", "-": ",", "/": ",", ":": ",", "|": ",",
|
| 283 |
+
"·": ",", "、": ",",
|
| 284 |
+
"?": "?", "!": "。", ".": "。", ";": "。", "。": "。",
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
text = "".join([reduce_map.get(c, "") for c in text])
|
| 288 |
+
punc_order = "?。,,"
|
| 289 |
+
|
| 290 |
+
if len(set(text).intersection(punc_order)) == 0:
|
| 291 |
+
return ""
|
| 292 |
+
|
| 293 |
+
# 가장 많이 등장한 구두점 선택
|
| 294 |
+
counts = {c: text.count(c) for c in punc_order}
|
| 295 |
+
max_count = max(counts.values())
|
| 296 |
+
max_keys = {k for k, v in counts.items() if v == max_count}
|
| 297 |
+
|
| 298 |
+
if len(max_keys) == 1:
|
| 299 |
+
return max_keys.pop()
|
| 300 |
+
|
| 301 |
+
# 동률일 경우 우선순위에 따라 선택
|
| 302 |
+
for c in punc_order:
|
| 303 |
+
if c in max_keys:
|
| 304 |
+
return c
|
| 305 |
+
|
| 306 |
+
return ""
|
| 307 |
+
|
| 308 |
+
@staticmethod
|
| 309 |
+
def _insert_space(text: str, chars: str) -> str:
|
| 310 |
+
"""
|
| 311 |
+
특정 문자 뒤에 공백을 삽입합니다.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
text: 원본 텍스트
|
| 315 |
+
chars: 공백을 추가할 문자들
|
| 316 |
+
|
| 317 |
+
Returns:
|
| 318 |
+
공백이 삽입된 텍스트
|
| 319 |
+
"""
|
| 320 |
+
result = ""
|
| 321 |
+
for c in text:
|
| 322 |
+
result += c
|
| 323 |
+
if c in chars:
|
| 324 |
+
result += " "
|
| 325 |
+
return result
|
| 326 |
+
|
ai_modules/nlp/utils.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
유틸리티 함수 모듈
|
| 3 |
+
파일 입출력, 텍스트 전처리 등 공통 기능을 제공합니다.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import re
|
| 7 |
+
import unicodedata
|
| 8 |
+
from typing import Dict, Any
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def remove_punctuation(text: str) -> str:
|
| 12 |
+
"""
|
| 13 |
+
텍스트에서 구두점과 공백을 제거합니다. [MASK] 토큰은 보존합니다.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
text: 원본 텍스트
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
구두점이 제거된 텍스트
|
| 20 |
+
"""
|
| 21 |
+
result = []
|
| 22 |
+
i = 0
|
| 23 |
+
|
| 24 |
+
while i < len(text):
|
| 25 |
+
# [MASK...] 형태의 토큰 보존
|
| 26 |
+
if text[i:i+1] == '[' and 'MASK' in text[i:i+10]:
|
| 27 |
+
end = text.find(']', i)
|
| 28 |
+
if end != -1:
|
| 29 |
+
result.append(text[i:end+1])
|
| 30 |
+
i = end + 1
|
| 31 |
+
continue
|
| 32 |
+
|
| 33 |
+
# 일반 문자 처리 (구두점과 공백 제외)
|
| 34 |
+
if unicodedata.category(text[i])[0] not in "PZ":
|
| 35 |
+
result.append(text[i])
|
| 36 |
+
i += 1
|
| 37 |
+
|
| 38 |
+
return "".join(result)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def replace_mask_with_symbol(text: str, symbol: str = "□") -> str:
|
| 42 |
+
"""
|
| 43 |
+
[MASK1], [MASK2] 등의 마스크 토큰을 지정된 기호로 치환합니다.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
text: 원본 텍스트
|
| 47 |
+
symbol: 치환할 기호
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
마스크가 치환된 텍스트
|
| 51 |
+
"""
|
| 52 |
+
return re.sub(r'\[MASK\d+\]', symbol, text)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def normalize_mask_tokens(text: str) -> str:
|
| 56 |
+
"""
|
| 57 |
+
[MASK1], [MASK2] 등을 [MASK]로 정규화합니다.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
text: 원본 텍스트
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
정규화된 텍스트
|
| 64 |
+
"""
|
| 65 |
+
return re.sub(r'\[MASK\d+\]', '[MASK]', text)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def extract_mask_info(json_data: Dict[str, Any]) -> list:
|
| 69 |
+
"""
|
| 70 |
+
JSON 데이터에서 마스크 정보를 추출합니다.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
json_data: 입력 JSON 데이터
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
마스크 정보 리스트 (order와 type 포함)
|
| 77 |
+
"""
|
| 78 |
+
mask_info = []
|
| 79 |
+
for item in json_data.get('results', []):
|
| 80 |
+
if 'MASK' in item.get('type', ''):
|
| 81 |
+
mask_info.append({
|
| 82 |
+
'order': item['order'],
|
| 83 |
+
'type': item['type']
|
| 84 |
+
})
|
| 85 |
+
mask_info.sort(key=lambda x: x['order'])
|
| 86 |
+
return mask_info
|
| 87 |
+
|
ai_modules/nlp_engine.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NLP 통합 엔진
|
| 3 |
+
구두점 복원 및 MLM 예측을 통합 관리하는 엔진입니다.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import torch
|
| 9 |
+
import logging
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Dict, Any, Optional, List
|
| 12 |
+
|
| 13 |
+
from .nlp.punctuation_restorer import PunctuationRestorer
|
| 14 |
+
from .nlp.mlm_predictor import MLMPredictor
|
| 15 |
+
from .nlp.utils import remove_punctuation, replace_mask_with_symbol
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def load_nlp_config(config_path: Optional[str] = None) -> Dict[str, Any]:
|
| 21 |
+
"""
|
| 22 |
+
NLP 설정 파일을 로드합니다.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
config_path: 설정 파일 경로 (None이면 기본 경로 사용)
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
설정 딕셔너리
|
| 29 |
+
"""
|
| 30 |
+
if config_path is None:
|
| 31 |
+
config_path = Path(__file__).parent / "config" / "nlp_config.json"
|
| 32 |
+
else:
|
| 33 |
+
config_path = Path(config_path)
|
| 34 |
+
|
| 35 |
+
if not config_path.exists():
|
| 36 |
+
raise FileNotFoundError(f"NLP 설정 파일을 찾을 수 없습니다: {config_path}")
|
| 37 |
+
|
| 38 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 39 |
+
return json.load(f)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class NLPEngine:
|
| 43 |
+
"""NLP 처리 통합 엔진 클래스"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, config_path: Optional[str] = None):
|
| 46 |
+
"""
|
| 47 |
+
NLP 엔진을 초기화합니다.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
config_path: 설정 파일 경로 (None이면 기본 경로 사용)
|
| 51 |
+
"""
|
| 52 |
+
self.config = load_nlp_config(config_path)
|
| 53 |
+
|
| 54 |
+
# 디바이스 설정
|
| 55 |
+
dev_cfg = self.config.get('device', 'auto')
|
| 56 |
+
if dev_cfg == 'auto':
|
| 57 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 58 |
+
else:
|
| 59 |
+
self.device = dev_cfg
|
| 60 |
+
|
| 61 |
+
logger.info(f"[NLP] Device: {self.device}")
|
| 62 |
+
|
| 63 |
+
# 모델 캐시 경로 (환경 변수 또는 기본값)
|
| 64 |
+
self.base_model_dir = os.getenv(
|
| 65 |
+
'AI_MODEL_DIR',
|
| 66 |
+
str(Path(__file__).parent.parent / "models")
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# 서브 모듈 초기화 (지연 로딩)
|
| 70 |
+
self.punc_restorer = None
|
| 71 |
+
self.mlm_predictor = None
|
| 72 |
+
|
| 73 |
+
def _load_models(self):
|
| 74 |
+
"""필요할 때 모델을 메모리에 로드"""
|
| 75 |
+
if self.punc_restorer is None:
|
| 76 |
+
logger.info("[NLP] 구두점 복원 모델 로드 중...")
|
| 77 |
+
self.punc_restorer = PunctuationRestorer(
|
| 78 |
+
self.config,
|
| 79 |
+
self.base_model_dir,
|
| 80 |
+
self.device
|
| 81 |
+
)
|
| 82 |
+
self.punc_restorer.download_model()
|
| 83 |
+
self.punc_restorer.load_model()
|
| 84 |
+
|
| 85 |
+
if self.mlm_predictor is None:
|
| 86 |
+
logger.info("[NLP] MLM 모델 로드 중...")
|
| 87 |
+
self.mlm_predictor = MLMPredictor(self.config, self.device)
|
| 88 |
+
self.mlm_predictor.load_model()
|
| 89 |
+
|
| 90 |
+
def process_text(
|
| 91 |
+
self,
|
| 92 |
+
raw_text: str,
|
| 93 |
+
ocr_results: Optional[List[Dict]] = None,
|
| 94 |
+
add_space: bool = True,
|
| 95 |
+
reduce_punc: bool = True
|
| 96 |
+
) -> Dict[str, Any]:
|
| 97 |
+
"""
|
| 98 |
+
텍스트 처리 파이프라인:
|
| 99 |
+
1. 구두점 제거 (전처리)
|
| 100 |
+
2. 구두점 복원
|
| 101 |
+
3. [MASK] 예측
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
raw_text: 원본 텍스트 (구두점 포함 가능)
|
| 105 |
+
add_space: 구두점 뒤 공백 추가 여부
|
| 106 |
+
reduce_punc: 구두점 단순화 여부
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
처리 결과 딕셔너리
|
| 110 |
+
"""
|
| 111 |
+
self._load_models()
|
| 112 |
+
|
| 113 |
+
try:
|
| 114 |
+
# 1. 전처리 (구두점 제거, [MASK] 보존)
|
| 115 |
+
clean_text = remove_punctuation(raw_text)
|
| 116 |
+
logger.info(f"[NLP] 구두점 제거 완료: {len(clean_text)} 글자")
|
| 117 |
+
|
| 118 |
+
# 2. 구두점 복원
|
| 119 |
+
punctuated_text = self.punc_restorer.restore_punctuation(
|
| 120 |
+
clean_text,
|
| 121 |
+
add_space=add_space,
|
| 122 |
+
reduce=reduce_punc
|
| 123 |
+
)
|
| 124 |
+
logger.info(f"[NLP] 구두점 복원 완료: {len(punctuated_text)} 글자")
|
| 125 |
+
|
| 126 |
+
# 3. MLM 예측
|
| 127 |
+
mask_predictions = self.mlm_predictor.predict_masks(punctuated_text)
|
| 128 |
+
logger.info(f"[NLP] MLM 예측 완료: {len(mask_predictions)}개 마스크")
|
| 129 |
+
|
| 130 |
+
# 4. 출력용 텍스트 생성 ([MASK] -> □)
|
| 131 |
+
mask_replacement = self.config['tokens']['mask_replacement']
|
| 132 |
+
final_text = replace_mask_with_symbol(
|
| 133 |
+
punctuated_text,
|
| 134 |
+
mask_replacement
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# Extract mask info from OCR results or original text
|
| 138 |
+
mask_info_list = []
|
| 139 |
+
if ocr_results:
|
| 140 |
+
# Use OCR results to get order and type
|
| 141 |
+
for item in ocr_results:
|
| 142 |
+
if 'MASK' in item.get('type', ''):
|
| 143 |
+
mask_info_list.append({
|
| 144 |
+
'order': item.get('order', 0),
|
| 145 |
+
'type': item.get('type', 'MASK2'),
|
| 146 |
+
'text': item.get('text', '')
|
| 147 |
+
})
|
| 148 |
+
else:
|
| 149 |
+
# Fallback: extract from text
|
| 150 |
+
i = 0
|
| 151 |
+
while i < len(raw_text):
|
| 152 |
+
if raw_text[i] == '[' and 'MASK' in raw_text[i:i+10]:
|
| 153 |
+
end = raw_text.find(']', i)
|
| 154 |
+
if end != -1:
|
| 155 |
+
mask_text = raw_text[i:end+1]
|
| 156 |
+
mask_type = 'MASK1' if 'MASK1' in mask_text else 'MASK2'
|
| 157 |
+
mask_info_list.append({
|
| 158 |
+
'order': len(mask_info_list), # Sequential order
|
| 159 |
+
'type': mask_type,
|
| 160 |
+
'text': mask_text
|
| 161 |
+
})
|
| 162 |
+
i = end + 1
|
| 163 |
+
continue
|
| 164 |
+
i += 1
|
| 165 |
+
|
| 166 |
+
# Format results according to specification
|
| 167 |
+
formatted_results = []
|
| 168 |
+
for idx, pred_list in enumerate(mask_predictions):
|
| 169 |
+
if idx < len(mask_info_list):
|
| 170 |
+
mask_info = mask_info_list[idx]
|
| 171 |
+
formatted_results.append({
|
| 172 |
+
"order": mask_info['order'],
|
| 173 |
+
"type": mask_info['type'],
|
| 174 |
+
"top_10": pred_list[:10] # Top-10 predictions
|
| 175 |
+
})
|
| 176 |
+
else:
|
| 177 |
+
# Fallback if mask_info_list is shorter
|
| 178 |
+
formatted_results.append({
|
| 179 |
+
"order": idx,
|
| 180 |
+
"type": "MASK2",
|
| 181 |
+
"top_10": pred_list[:10]
|
| 182 |
+
})
|
| 183 |
+
|
| 184 |
+
# Calculate statistics
|
| 185 |
+
top1_probs = [preds[0]['probability'] for preds in mask_predictions if preds]
|
| 186 |
+
statistics = {
|
| 187 |
+
"top1_probability_avg": float(sum(top1_probs) / len(top1_probs)) if top1_probs else 0.0,
|
| 188 |
+
"top1_probability_min": float(min(top1_probs)) if top1_probs else 0.0,
|
| 189 |
+
"top1_probability_max": float(max(top1_probs)) if top1_probs else 0.0,
|
| 190 |
+
"total_masks": len(mask_predictions)
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
return {
|
| 194 |
+
"punctuated_text_with_masks": final_text,
|
| 195 |
+
"results": formatted_results,
|
| 196 |
+
"statistics": statistics
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
except Exception as e:
|
| 200 |
+
logger.error(f"[NLP] 처리 중 오류: {e}", exc_info=True)
|
| 201 |
+
return {
|
| 202 |
+
"success": False,
|
| 203 |
+
"error": str(e)
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
def restore_punctuation_only(
|
| 207 |
+
self,
|
| 208 |
+
text: str,
|
| 209 |
+
add_space: bool = True,
|
| 210 |
+
reduce_punc: bool = True
|
| 211 |
+
) -> Dict[str, Any]:
|
| 212 |
+
"""
|
| 213 |
+
구두점 복원만 수행합니다 (MLM 예측 제외).
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
text: 입력 텍스트
|
| 217 |
+
add_space: 구두점 뒤 공백 추가 여부
|
| 218 |
+
reduce_punc: 구두점 단순화 여부
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
구두점 복원 결과
|
| 222 |
+
"""
|
| 223 |
+
self._load_models()
|
| 224 |
+
|
| 225 |
+
try:
|
| 226 |
+
clean_text = remove_punctuation(text)
|
| 227 |
+
punctuated_text = self.punc_restorer.restore_punctuation(
|
| 228 |
+
clean_text,
|
| 229 |
+
add_space=add_space,
|
| 230 |
+
reduce=reduce_punc
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
return {
|
| 234 |
+
"success": True,
|
| 235 |
+
"original_text": text,
|
| 236 |
+
"clean_text": clean_text,
|
| 237 |
+
"punctuated_text": punctuated_text
|
| 238 |
+
}
|
| 239 |
+
except Exception as e:
|
| 240 |
+
logger.error(f"[NLP] 구두점 복원 중 오류: {e}", exc_info=True)
|
| 241 |
+
return {
|
| 242 |
+
"success": False,
|
| 243 |
+
"error": str(e)
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
def predict_masks_only(
|
| 247 |
+
self,
|
| 248 |
+
text: str
|
| 249 |
+
) -> Dict[str, Any]:
|
| 250 |
+
"""
|
| 251 |
+
MLM 예측만 수행합니다 (구두점 복원 제외).
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
text: 마스크가 포함된 텍스트
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
MLM 예측 결과
|
| 258 |
+
"""
|
| 259 |
+
self._load_models()
|
| 260 |
+
|
| 261 |
+
try:
|
| 262 |
+
mask_predictions = self.mlm_predictor.predict_masks(text)
|
| 263 |
+
|
| 264 |
+
return {
|
| 265 |
+
"success": True,
|
| 266 |
+
"predictions": mask_predictions,
|
| 267 |
+
"mask_count": len(mask_predictions)
|
| 268 |
+
}
|
| 269 |
+
except Exception as e:
|
| 270 |
+
logger.error(f"[NLP] MLM 예측 중 오류: {e}", exc_info=True)
|
| 271 |
+
return {
|
| 272 |
+
"success": False,
|
| 273 |
+
"error": str(e)
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
# ================================================================================
|
| 278 |
+
# Global Accessor
|
| 279 |
+
# ================================================================================
|
| 280 |
+
_nlp_engine = None
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def get_nlp_engine(config_path: Optional[str] = None) -> NLPEngine:
|
| 284 |
+
"""
|
| 285 |
+
전역 NLP 엔진 인스턴스를 반환합니다 (싱글톤 패턴).
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
config_path: 설정 파일 경��� (None이면 기본 경로 사용)
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
NLPEngine 인스턴스
|
| 292 |
+
"""
|
| 293 |
+
global _nlp_engine
|
| 294 |
+
if _nlp_engine is None:
|
| 295 |
+
_nlp_engine = NLPEngine(config_path)
|
| 296 |
+
return _nlp_engine
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def process_text_with_nlp(
|
| 300 |
+
text: str,
|
| 301 |
+
ocr_results: Optional[List[Dict]] = None,
|
| 302 |
+
config_path: Optional[str] = None,
|
| 303 |
+
add_space: bool = True,
|
| 304 |
+
reduce_punc: bool = True
|
| 305 |
+
) -> Dict[str, Any]:
|
| 306 |
+
"""
|
| 307 |
+
편의 함수: 텍스트를 NLP 파이프라인으로 처리합니다.
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
text: 입력 텍스트
|
| 311 |
+
ocr_results: OCR 결과 리스트 (order, type 정보 포함)
|
| 312 |
+
config_path: 설정 파일 경로
|
| 313 |
+
add_space: 구두점 뒤 공백 추가 여부
|
| 314 |
+
reduce_punc: 구두점 단순화 여부
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
처리 결과 딕셔너리
|
| 318 |
+
"""
|
| 319 |
+
engine = get_nlp_engine(config_path)
|
| 320 |
+
return engine.process_text(text, ocr_results=ocr_results, add_space=add_space, reduce_punc=reduce_punc)
|
| 321 |
+
|
ai_modules/ocr_engine.py
ADDED
|
@@ -0,0 +1,767 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
================================================================================
|
| 4 |
+
OCR Ensemble Module for Epitext AI Project
|
| 5 |
+
================================================================================
|
| 6 |
+
모듈명: ocr_engine.py (v12.0.0 - Production Ready)
|
| 7 |
+
작성일: 2025-12-03
|
| 8 |
+
목적: Google Vision API + HRCenterNet 앙상블 기반 한자 OCR 및 손상 영역 탐지
|
| 9 |
+
상태: Production Ready
|
| 10 |
+
================================================================================
|
| 11 |
+
"""
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import io
|
| 15 |
+
import cv2
|
| 16 |
+
import json
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
import torchvision
|
| 20 |
+
import re
|
| 21 |
+
import logging
|
| 22 |
+
from torch.autograd import Variable
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from PIL import Image
|
| 25 |
+
from typing import Dict, List, Optional, Tuple, Any
|
| 26 |
+
|
| 27 |
+
# ================================================================================
|
| 28 |
+
# Logging Configuration
|
| 29 |
+
# ================================================================================
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
# ================================================================================
|
| 33 |
+
# External Model Imports
|
| 34 |
+
# ================================================================================
|
| 35 |
+
try:
|
| 36 |
+
from ai_modules.models.resnet import ResnetCustom
|
| 37 |
+
from ai_modules.models.HRCenterNet import _HRCenterNet
|
| 38 |
+
logger.info("[INIT] 외부 모델 임포트 완료: ResnetCustom, HRCenterNet")
|
| 39 |
+
except ImportError as e:
|
| 40 |
+
logger.error(f"[INIT] 모델 임포트 실패: {e}")
|
| 41 |
+
raise
|
| 42 |
+
|
| 43 |
+
# ================================================================================
|
| 44 |
+
# Google Vision API Import
|
| 45 |
+
# ================================================================================
|
| 46 |
+
try:
|
| 47 |
+
from google.cloud import vision
|
| 48 |
+
HAS_GOOGLE_VISION = True
|
| 49 |
+
except ImportError:
|
| 50 |
+
HAS_GOOGLE_VISION = False
|
| 51 |
+
logger.warning("[INIT] google-cloud-vision 패키지가 설치되지 않았습니다.")
|
| 52 |
+
|
| 53 |
+
# ================================================================================
|
| 54 |
+
# Utility Functions
|
| 55 |
+
# ================================================================================
|
| 56 |
+
def is_hanja(text: str) -> bool:
|
| 57 |
+
if not text: return False
|
| 58 |
+
return re.match(r'[\u4e00-\u9fff]', text) is not None
|
| 59 |
+
|
| 60 |
+
def calculate_pixel_density(binary_img: np.ndarray, box: Dict) -> float:
|
| 61 |
+
x1, y1 = int(box['min_x']), int(box['min_y'])
|
| 62 |
+
x2, y2 = int(box['max_x']), int(box['max_y'])
|
| 63 |
+
h, w = binary_img.shape
|
| 64 |
+
x1, y1 = max(0, x1), max(0, y1)
|
| 65 |
+
x2, y2 = min(w, x2), min(h, y2)
|
| 66 |
+
if x2 <= x1 or y2 <= y1: return 0.0
|
| 67 |
+
roi = binary_img[y1:y2, x1:x2]
|
| 68 |
+
return cv2.countNonZero(roi) / ((x2 - x1) * (y2 - y1))
|
| 69 |
+
|
| 70 |
+
def load_ocr_config(config_path: Optional[str] = None) -> Dict:
|
| 71 |
+
"""설정 파일 로드"""
|
| 72 |
+
if config_path is None:
|
| 73 |
+
config_path = str(Path(__file__).parent / "config" / "ocr_config.json")
|
| 74 |
+
|
| 75 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 76 |
+
return json.load(f)
|
| 77 |
+
|
| 78 |
+
# ================================================================================
|
| 79 |
+
# Text Detection Class
|
| 80 |
+
# ================================================================================
|
| 81 |
+
class TextDetector:
|
| 82 |
+
def __init__(self, device: torch.device, det_ckpt: str, config: Dict):
|
| 83 |
+
self.device = device
|
| 84 |
+
self.config = config
|
| 85 |
+
self.input_size = config['model_config']['input_size']
|
| 86 |
+
self.output_size = config['model_config']['output_size']
|
| 87 |
+
|
| 88 |
+
self.model = _HRCenterNet(32, 5, 0.1)
|
| 89 |
+
if not os.path.exists(det_ckpt):
|
| 90 |
+
raise FileNotFoundError(f"체크포인트 파일 없음: {det_ckpt}")
|
| 91 |
+
|
| 92 |
+
state = torch.load(det_ckpt, map_location=self.device)
|
| 93 |
+
self.model.load_state_dict(state)
|
| 94 |
+
self.model = self.model.to(self.device)
|
| 95 |
+
self.model.eval()
|
| 96 |
+
|
| 97 |
+
self.transform = torchvision.transforms.Compose([
|
| 98 |
+
torchvision.transforms.Resize((self.input_size, self.input_size)),
|
| 99 |
+
torchvision.transforms.ToTensor()
|
| 100 |
+
])
|
| 101 |
+
|
| 102 |
+
@torch.no_grad()
|
| 103 |
+
def detect(self, image) -> Tuple[List, List]:
|
| 104 |
+
if isinstance(image, str): img = Image.open(image).convert("RGB")
|
| 105 |
+
elif isinstance(image, np.ndarray): img = Image.fromarray(image).convert("RGB")
|
| 106 |
+
else: img = image.convert("RGB")
|
| 107 |
+
|
| 108 |
+
image_tensor = self.transform(img).unsqueeze_(0)
|
| 109 |
+
inp = Variable(image_tensor).to(self.device, dtype=torch.float)
|
| 110 |
+
|
| 111 |
+
predict = self.model(inp)
|
| 112 |
+
predict_np = predict.data.cpu().numpy()
|
| 113 |
+
heatmap, offset_y, offset_x, width_map, height_map = predict_np[0]
|
| 114 |
+
|
| 115 |
+
bbox, score_list = [], []
|
| 116 |
+
Hc, Wc = img.size[1] / self.output_size, img.size[0] / self.output_size
|
| 117 |
+
|
| 118 |
+
# Config에서 NMS 임계값 로드
|
| 119 |
+
nms_cfg = self.config.get('nms_config', {})
|
| 120 |
+
nms_score = nms_cfg.get('primary_threshold', 0.12)
|
| 121 |
+
|
| 122 |
+
idxs = np.where(heatmap.reshape(-1, 1) >= nms_score)[0]
|
| 123 |
+
if len(idxs) == 0:
|
| 124 |
+
nms_score = nms_cfg.get('fallback_threshold', 0.08)
|
| 125 |
+
idxs = np.where(heatmap.reshape(-1, 1) >= nms_score)[0]
|
| 126 |
+
|
| 127 |
+
for j in idxs:
|
| 128 |
+
row = j // self.output_size
|
| 129 |
+
col = j - row * self.output_size
|
| 130 |
+
bias_x = offset_x[row, col] * Hc
|
| 131 |
+
bias_y = offset_y[row, col] * Wc
|
| 132 |
+
width = width_map[row, col] * self.output_size * Hc
|
| 133 |
+
height = height_map[row, col] * self.output_size * Wc
|
| 134 |
+
|
| 135 |
+
score_list.append(float(heatmap[row, col]))
|
| 136 |
+
row = row * Hc + bias_y
|
| 137 |
+
col = col * Wc + bias_x
|
| 138 |
+
|
| 139 |
+
top = row - width / 2.0
|
| 140 |
+
left = col - height / 2.0
|
| 141 |
+
bottom = row + width / 2.0
|
| 142 |
+
right = col + height / 2.0
|
| 143 |
+
bbox.append([left, top, max(0.0, right - left), max(0.0, bottom - top)])
|
| 144 |
+
|
| 145 |
+
if not bbox: return [], []
|
| 146 |
+
|
| 147 |
+
xyxy = [[x, y, x+w, y+h] for x, y, w, h in bbox]
|
| 148 |
+
keep = torchvision.ops.nms(
|
| 149 |
+
torch.tensor(xyxy, dtype=torch.float32),
|
| 150 |
+
scores=torch.tensor(score_list, dtype=torch.float32),
|
| 151 |
+
iou_threshold=nms_cfg.get('iou_threshold', 0.05)
|
| 152 |
+
).cpu().numpy().tolist()
|
| 153 |
+
|
| 154 |
+
res_boxes, res_scores = [], []
|
| 155 |
+
W, H = img.size
|
| 156 |
+
for k in keep:
|
| 157 |
+
idx = int(k)
|
| 158 |
+
x, y, w, h = bbox[idx]
|
| 159 |
+
x = max(0.0, min(x, W - 1.0))
|
| 160 |
+
y = max(0.0, min(y, H - 1.0))
|
| 161 |
+
w = max(0.0, min(w, W - x))
|
| 162 |
+
h = max(0.0, min(h, H - y))
|
| 163 |
+
if w > 1 and h > 1:
|
| 164 |
+
res_boxes.append([x, y, w, h])
|
| 165 |
+
res_scores.append(score_list[idx])
|
| 166 |
+
|
| 167 |
+
return res_boxes, res_scores
|
| 168 |
+
|
| 169 |
+
# ================================================================================
|
| 170 |
+
# Merging Logics (Config 적용)
|
| 171 |
+
# ================================================================================
|
| 172 |
+
def merge_vertical_fragments(boxes, scores, config):
|
| 173 |
+
if not boxes: return [], []
|
| 174 |
+
rects = [{'x': b[0], 'y': b[1], 'w': b[2], 'h': b[3],
|
| 175 |
+
'x2': b[0]+b[2], 'y2': b[1]+b[3],
|
| 176 |
+
'cx': b[0]+b[2]/2, 'cy': b[1]+b[3]/2, 'score': s}
|
| 177 |
+
for b, s in zip(boxes, scores)]
|
| 178 |
+
|
| 179 |
+
cfg = config['merge_config']['vertical_fragments']
|
| 180 |
+
|
| 181 |
+
while True:
|
| 182 |
+
rects.sort(key=lambda r: r['y'])
|
| 183 |
+
merged = False
|
| 184 |
+
new_rects, skip_indices = [], set()
|
| 185 |
+
|
| 186 |
+
for i in range(len(rects)):
|
| 187 |
+
if i in skip_indices: continue
|
| 188 |
+
current = rects[i]
|
| 189 |
+
best_cand_idx = -1
|
| 190 |
+
|
| 191 |
+
for j in range(i + 1, min(i + 5, len(rects))):
|
| 192 |
+
if j in skip_indices: continue
|
| 193 |
+
candidate = rects[j]
|
| 194 |
+
|
| 195 |
+
avg_w = (current['w'] + candidate['w']) / 2
|
| 196 |
+
if abs(current['cx'] - candidate['cx']) > avg_w * cfg['horizontal_center_ratio']: continue
|
| 197 |
+
if (candidate['y'] - current['y2']) > avg_w * cfg['vertical_gap_ratio']: continue
|
| 198 |
+
|
| 199 |
+
new_h = max(current['y2'], candidate['y2']) - min(current['y'], candidate['y'])
|
| 200 |
+
new_w = max(current['x2'], candidate['x2']) - min(current['x'], candidate['x'])
|
| 201 |
+
|
| 202 |
+
is_safe_ratio = (new_h / new_w) < cfg['aspect_ratio_limit']
|
| 203 |
+
cur_square = (current['h'] / current['w']) > 0.85
|
| 204 |
+
cand_square = (candidate['h'] / candidate['w']) > 0.85
|
| 205 |
+
is_overlapped = (candidate['y'] - current['y2']) < -avg_w * 0.2
|
| 206 |
+
|
| 207 |
+
if is_safe_ratio and (not (cur_square and cand_square) or is_overlapped):
|
| 208 |
+
best_cand_idx = j
|
| 209 |
+
break
|
| 210 |
+
|
| 211 |
+
if best_cand_idx != -1:
|
| 212 |
+
cand = rects[best_cand_idx]
|
| 213 |
+
nx, ny = min(current['x'], cand['x']), min(current['y'], cand['y'])
|
| 214 |
+
nx2, ny2 = max(current['x2'], cand['x2']), max(current['y2'], cand['y2'])
|
| 215 |
+
new_rects.append({
|
| 216 |
+
'x': nx, 'y': ny, 'w': nx2-nx, 'h': ny2-ny,
|
| 217 |
+
'x2': nx2, 'y2': ny2, 'cx': (nx+nx2)/2, 'cy': (ny+ny2)/2,
|
| 218 |
+
'score': max(current['score'], cand['score'])
|
| 219 |
+
})
|
| 220 |
+
skip_indices.add(best_cand_idx)
|
| 221 |
+
merged = True
|
| 222 |
+
else:
|
| 223 |
+
new_rects.append(current)
|
| 224 |
+
rects = new_rects
|
| 225 |
+
if not merged: break
|
| 226 |
+
|
| 227 |
+
return [[r['x'], r['y'], r['w'], r['h']] for r in rects], [r['score'] for r in rects]
|
| 228 |
+
|
| 229 |
+
def merge_google_symbols(symbols, config):
|
| 230 |
+
if not symbols: return []
|
| 231 |
+
cfg = config['merge_config']['google_symbols']
|
| 232 |
+
|
| 233 |
+
while True:
|
| 234 |
+
symbols.sort(key=lambda s: s['min_y'])
|
| 235 |
+
merged = False
|
| 236 |
+
new_symbols, skip_indices = [], set()
|
| 237 |
+
|
| 238 |
+
for i in range(len(symbols)):
|
| 239 |
+
if i in skip_indices: continue
|
| 240 |
+
curr = symbols[i]
|
| 241 |
+
best_cand_idx = -1
|
| 242 |
+
|
| 243 |
+
for j in range(i + 1, min(i + 5, len(symbols))):
|
| 244 |
+
if j in skip_indices: continue
|
| 245 |
+
cand = symbols[j]
|
| 246 |
+
|
| 247 |
+
avg_w = (curr['width'] + cand['width']) / 2
|
| 248 |
+
if abs(curr['center_x'] - cand['center_x']) > avg_w * cfg['horizontal_center_ratio']: continue
|
| 249 |
+
|
| 250 |
+
gap = cand['min_y'] - curr['max_y']
|
| 251 |
+
is_touching = gap < (avg_w * cfg['vertical_gap_ratio'])
|
| 252 |
+
|
| 253 |
+
new_h = max(curr['max_y'], cand['max_y']) - min(curr['min_y'], cand['min_y'])
|
| 254 |
+
new_w = max(curr['max_x'], cand['max_x']) - min(curr['min_x'], cand['min_x'])
|
| 255 |
+
|
| 256 |
+
is_both_square = (curr['height']/curr['width'] > 0.85) and (cand['height']/cand['width'] > 0.85)
|
| 257 |
+
is_safe_ratio = (new_h / new_w) < cfg['aspect_ratio_limit']
|
| 258 |
+
is_duplicate = (curr['text'] == cand['text'])
|
| 259 |
+
|
| 260 |
+
if (is_touching and is_safe_ratio and not is_both_square) or is_duplicate:
|
| 261 |
+
best_cand_idx = j
|
| 262 |
+
break
|
| 263 |
+
|
| 264 |
+
if best_cand_idx != -1:
|
| 265 |
+
cand = symbols[best_cand_idx]
|
| 266 |
+
merged_sym = {
|
| 267 |
+
'text': curr['text'],
|
| 268 |
+
'min_x': min(curr['min_x'], cand['min_x']), 'min_y': min(curr['min_y'], cand['min_y']),
|
| 269 |
+
'max_x': max(curr['max_x'], cand['max_x']), 'max_y': max(curr['max_y'], cand['max_y']),
|
| 270 |
+
'confidence': max(curr['confidence'], cand['confidence']),
|
| 271 |
+
'source': 'Google'
|
| 272 |
+
}
|
| 273 |
+
merged_sym['width'] = merged_sym['max_x'] - merged_sym['min_x']
|
| 274 |
+
merged_sym['height'] = merged_sym['max_y'] - merged_sym['min_y']
|
| 275 |
+
merged_sym['center_x'] = (merged_sym['min_x'] + merged_sym['max_x']) / 2
|
| 276 |
+
merged_sym['center_y'] = (merged_sym['min_y'] + merged_sym['max_y']) / 2
|
| 277 |
+
new_symbols.append(merged_sym)
|
| 278 |
+
skip_indices.add(best_cand_idx)
|
| 279 |
+
merged = True
|
| 280 |
+
else:
|
| 281 |
+
new_symbols.append(curr)
|
| 282 |
+
symbols = new_symbols
|
| 283 |
+
if not merged: break
|
| 284 |
+
return symbols
|
| 285 |
+
|
| 286 |
+
# ================================================================================
|
| 287 |
+
# Models Execution
|
| 288 |
+
# ================================================================================
|
| 289 |
+
def get_google_ocr(content: bytes, config: Dict, google_json_path: Optional[str] = None) -> List[Dict]:
|
| 290 |
+
if not HAS_GOOGLE_VISION: return []
|
| 291 |
+
if google_json_path and os.path.exists(google_json_path):
|
| 292 |
+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = google_json_path
|
| 293 |
+
|
| 294 |
+
try:
|
| 295 |
+
client = vision.ImageAnnotatorClient()
|
| 296 |
+
image = vision.Image(content=content)
|
| 297 |
+
context = vision.ImageContext(language_hints=["zh-Hant"])
|
| 298 |
+
response = client.document_text_detection(image=image, image_context=context)
|
| 299 |
+
|
| 300 |
+
if not response.full_text_annotation: return []
|
| 301 |
+
|
| 302 |
+
symbols = []
|
| 303 |
+
for page in response.full_text_annotation.pages:
|
| 304 |
+
for block in page.blocks:
|
| 305 |
+
for paragraph in block.paragraphs:
|
| 306 |
+
for word in paragraph.words:
|
| 307 |
+
for s in word.symbols:
|
| 308 |
+
if not is_hanja(s.text): continue
|
| 309 |
+
v = s.bounding_box.vertices
|
| 310 |
+
x, y = [p.x for p in v], [p.y for p in v]
|
| 311 |
+
symbols.append({
|
| 312 |
+
'text': s.text,
|
| 313 |
+
'center_x': (min(x)+max(x))/2, 'center_y': (min(y)+max(y))/2,
|
| 314 |
+
'min_x': min(x), 'max_x': max(x), 'min_y': min(y), 'max_y': max(y),
|
| 315 |
+
'width': max(x)-min(x), 'height': max(y)-min(y),
|
| 316 |
+
'confidence': s.confidence, 'source': 'Google'
|
| 317 |
+
})
|
| 318 |
+
|
| 319 |
+
original_count = len(symbols)
|
| 320 |
+
symbols = merge_google_symbols(symbols, config)
|
| 321 |
+
if len(symbols) < original_count:
|
| 322 |
+
logger.info(f"[OCR] Google 병합: {original_count} -> {len(symbols)}개")
|
| 323 |
+
return symbols
|
| 324 |
+
except Exception as e:
|
| 325 |
+
logger.error(f"[OCR] Google Vision Error: {e}")
|
| 326 |
+
return []
|
| 327 |
+
|
| 328 |
+
def get_custom_model_ocr(image_path, binary_img, detector, recognizer, config):
|
| 329 |
+
try:
|
| 330 |
+
pil_img = Image.open(image_path).convert("RGB")
|
| 331 |
+
boxes, scores = detector.detect(pil_img)
|
| 332 |
+
if not boxes: return []
|
| 333 |
+
|
| 334 |
+
# Merge
|
| 335 |
+
original_count = len(boxes)
|
| 336 |
+
boxes, scores = merge_vertical_fragments(boxes, scores, config)
|
| 337 |
+
if len(boxes) < original_count:
|
| 338 |
+
logger.info(f"[OCR] Custom 병합: {original_count} -> {len(boxes)}개")
|
| 339 |
+
|
| 340 |
+
# Stats
|
| 341 |
+
all_heights = [b[3] for b in boxes]
|
| 342 |
+
all_widths = [b[2] for b in boxes]
|
| 343 |
+
median_h = np.median(all_heights) if all_heights else 0
|
| 344 |
+
median_w = np.median(all_widths) if all_widths else 0
|
| 345 |
+
|
| 346 |
+
# Recognize
|
| 347 |
+
crops = [pil_img.crop((int(b[0]), int(b[1]), int(b[0]+b[2]), int(b[1]+b[3]))) for b in boxes]
|
| 348 |
+
chars = recognizer(crops) if crops else []
|
| 349 |
+
|
| 350 |
+
# Filter & Mask (Config values)
|
| 351 |
+
symbols = []
|
| 352 |
+
img_h, _ = binary_img.shape
|
| 353 |
+
ft = config['filtering_thresholds']
|
| 354 |
+
it = config['ink_detection_thresholds']
|
| 355 |
+
|
| 356 |
+
for char, (x, y, w, h), score in zip(chars, boxes, scores):
|
| 357 |
+
if not char or char == "■": continue
|
| 358 |
+
|
| 359 |
+
box_dict = {'min_x': x, 'min_y': y, 'max_x': x+w, 'max_y': y+h}
|
| 360 |
+
density = calculate_pixel_density(binary_img, box_dict)
|
| 361 |
+
|
| 362 |
+
# Hard Filters
|
| 363 |
+
if score < ft['min_score_hard'] or density < ft['density_min_hard']: continue
|
| 364 |
+
# Smart Filters
|
| 365 |
+
if score < ft['smart_score_threshold'] and density < ft['smart_density_threshold']: continue
|
| 366 |
+
|
| 367 |
+
# Title Removal
|
| 368 |
+
is_huge = (h > median_h * 3.5) if median_h > 0 else False
|
| 369 |
+
is_top = (y < img_h * 0.15) and (h > median_h * 2.5 or w > median_w * 2.5) if median_h > 0 else False
|
| 370 |
+
if median_h > 0 and (is_huge or is_top): continue
|
| 371 |
+
|
| 372 |
+
# Masking
|
| 373 |
+
final_text, final_type = char, 'TEXT'
|
| 374 |
+
if density >= it['density_ink_heavy']:
|
| 375 |
+
final_text, final_type = '[MASK1]', 'MASK1'
|
| 376 |
+
elif density >= it['density_ink_partial']:
|
| 377 |
+
final_text, final_type = '[MASK2]', 'MASK2'
|
| 378 |
+
else:
|
| 379 |
+
if not is_hanja(char): continue
|
| 380 |
+
|
| 381 |
+
symbols.append({
|
| 382 |
+
'text': final_text, 'type': final_type,
|
| 383 |
+
'center_x': x+w/2, 'center_y': y+h/2,
|
| 384 |
+
'min_x': x, 'max_x': x+w, 'min_y': y, 'max_y': y+h,
|
| 385 |
+
'width': w, 'height': h,
|
| 386 |
+
'confidence': float(score), 'source': 'Custom', 'density': density
|
| 387 |
+
})
|
| 388 |
+
|
| 389 |
+
logger.info(f"[OCR] Custom Model 완료: {len(symbols)}개")
|
| 390 |
+
return symbols
|
| 391 |
+
except Exception as e:
|
| 392 |
+
logger.error(f"[OCR] Custom Model Error: {e}")
|
| 393 |
+
return []
|
| 394 |
+
|
| 395 |
+
# ================================================================================
|
| 396 |
+
# Ensemble Reconstruction (Full Logic from Script)
|
| 397 |
+
# ================================================================================
|
| 398 |
+
def ensemble_reconstruction(google_syms, custom_syms, binary_img, config):
|
| 399 |
+
logger.info("[ENSEMBLE] 앙상블 재구성 시작...")
|
| 400 |
+
img_h, img_w = binary_img.shape
|
| 401 |
+
ec = config['ensemble_config']
|
| 402 |
+
ft = config['filtering_thresholds']
|
| 403 |
+
it = config['ink_detection_thresholds']
|
| 404 |
+
|
| 405 |
+
# --- Helper Functions ---
|
| 406 |
+
def filter_excessive_masks(nodes):
|
| 407 |
+
filtered, buffer = [], []
|
| 408 |
+
threshold = ec['excessive_mask_threshold']
|
| 409 |
+
for node in nodes:
|
| 410 |
+
if 'MASK' in node.get('type', 'TEXT'): buffer.append(node)
|
| 411 |
+
else:
|
| 412 |
+
if buffer:
|
| 413 |
+
if len(buffer) < threshold: filtered.extend(buffer)
|
| 414 |
+
buffer = []
|
| 415 |
+
filtered.append(node)
|
| 416 |
+
if buffer and len(buffer) < threshold: filtered.extend(buffer)
|
| 417 |
+
return filtered
|
| 418 |
+
|
| 419 |
+
def merge_split_masks(nodes, avg_h):
|
| 420 |
+
if not nodes: return []
|
| 421 |
+
merged, skip = [], False
|
| 422 |
+
for i in range(len(nodes)):
|
| 423 |
+
if skip: skip = False; continue
|
| 424 |
+
curr = nodes[i]
|
| 425 |
+
if i == len(nodes)-1: merged.append(curr); break
|
| 426 |
+
|
| 427 |
+
next_node = nodes[i+1]
|
| 428 |
+
if 'MASK' in curr.get('type','TEXT') and 'MASK' in next_node.get('type','TEXT'):
|
| 429 |
+
combined_h = next_node['max_y'] - curr['min_y']
|
| 430 |
+
if combined_h < avg_h * 1.8:
|
| 431 |
+
new_node = curr.copy()
|
| 432 |
+
new_node.update({'max_y': next_node['max_y'], 'height': next_node['max_y'] - curr['min_y']})
|
| 433 |
+
density = calculate_pixel_density(binary_img, new_node)
|
| 434 |
+
new_node['density'] = density
|
| 435 |
+
|
| 436 |
+
if density < ft['density_min_hard']:
|
| 437 |
+
skip = True; continue
|
| 438 |
+
|
| 439 |
+
m_type = 'MASK1' if density >= it['density_ink_heavy'] else 'MASK2'
|
| 440 |
+
new_node.update({'type': m_type, 'text': f'[{m_type}]'})
|
| 441 |
+
merged.append(new_node)
|
| 442 |
+
skip = True
|
| 443 |
+
continue
|
| 444 |
+
merged.append(curr)
|
| 445 |
+
return merged
|
| 446 |
+
|
| 447 |
+
def resolve_overlaps(boxes):
|
| 448 |
+
if not boxes: return []
|
| 449 |
+
boxes.sort(key=lambda x: x['min_y'])
|
| 450 |
+
for i in range(len(boxes)-1):
|
| 451 |
+
curr, next_box = boxes[i], boxes[i+1]
|
| 452 |
+
if min(curr['max_x'], next_box['max_x']) - max(curr['min_x'], next_box['min_x']) <= 0: continue
|
| 453 |
+
|
| 454 |
+
if curr['max_y'] > next_box['min_y']:
|
| 455 |
+
mid_y = (curr['max_y'] + next_box['min_y']) / 2
|
| 456 |
+
curr['max_y'], curr['height'] = mid_y, mid_y - curr['min_y']
|
| 457 |
+
next_box['min_y'], next_box['height'] = mid_y, next_box['max_y'] - mid_y
|
| 458 |
+
return boxes
|
| 459 |
+
|
| 460 |
+
def filter_google_overlaps(g_boxes, c_boxes):
|
| 461 |
+
if not g_boxes: return c_boxes
|
| 462 |
+
filtered = []
|
| 463 |
+
for c in c_boxes:
|
| 464 |
+
is_dup = False
|
| 465 |
+
for g in g_boxes:
|
| 466 |
+
dx = abs(c['center_x'] - g['center_x'])
|
| 467 |
+
dy = abs(c['center_y'] - g['center_y'])
|
| 468 |
+
# MASK is preserved even if overlapping
|
| 469 |
+
if 'MASK' in c.get('type', 'TEXT'): pass
|
| 470 |
+
elif (min(c['max_x'], g['max_x']) > max(c['min_x'], g['min_x']) and
|
| 471 |
+
min(c['max_y'], g['max_y']) > max(c['min_y'], g['min_y'])) or \
|
| 472 |
+
(dx < g['width']*0.4 and dy < g['height']*0.4):
|
| 473 |
+
is_dup = True; break
|
| 474 |
+
if not is_dup: filtered.append(c)
|
| 475 |
+
return filtered
|
| 476 |
+
|
| 477 |
+
def infer_gaps(col, step_y, avg_w):
|
| 478 |
+
if not col: return []
|
| 479 |
+
col.sort(key=lambda s: s['center_y'])
|
| 480 |
+
filled = []
|
| 481 |
+
for i, curr in enumerate(col):
|
| 482 |
+
if i > 0:
|
| 483 |
+
prev = col[i-1]
|
| 484 |
+
gap = curr['center_y'] - prev['center_y']
|
| 485 |
+
if gap > step_y * ec['gap_inference_ratio']:
|
| 486 |
+
missing = int(round(gap/step_y)) - 1
|
| 487 |
+
if missing > 0:
|
| 488 |
+
step = gap / (missing + 1)
|
| 489 |
+
for k in range(1, missing + 1):
|
| 490 |
+
ny = prev['center_y'] + k*step
|
| 491 |
+
nb = {'min_x': curr['center_x'] - avg_w/2, 'max_x': curr['center_x'] + avg_w/2,
|
| 492 |
+
'min_y': max(0, ny - step_y*0.4), 'max_y': min(img_h, ny + step_y*0.4)}
|
| 493 |
+
nb.update({'height': nb['max_y']-nb['min_y'], 'width': nb['max_x']-nb['min_x'],
|
| 494 |
+
'center_x': (nb['min_x']+nb['max_x'])/2, 'center_y': (nb['min_y']+nb['max_y'])/2})
|
| 495 |
+
|
| 496 |
+
d = calculate_pixel_density(binary_img, nb)
|
| 497 |
+
if d < ft['density_min_hard']: continue
|
| 498 |
+
|
| 499 |
+
mt = 'MASK1' if d >= it['density_ink_heavy'] else 'MASK2'
|
| 500 |
+
nb.update({'text': f'[{mt}]', 'type': mt, 'density': d, 'confidence': 0.0, 'source': 'Inferred'})
|
| 501 |
+
filled.append(nb)
|
| 502 |
+
filled.append(curr)
|
| 503 |
+
return filled
|
| 504 |
+
|
| 505 |
+
def check_ink_on_google(g_syms):
|
| 506 |
+
filtered = []
|
| 507 |
+
for s in g_syms:
|
| 508 |
+
d = calculate_pixel_density(binary_img, s)
|
| 509 |
+
s['density'] = d
|
| 510 |
+
if d >= it['density_ink_heavy']: s.update({'type': 'MASK1', 'text': '[MASK1]'})
|
| 511 |
+
elif d >= it['density_ink_partial']: s.update({'type': 'MASK2', 'text': '[MASK2]'})
|
| 512 |
+
elif d < ft['density_min_hard']: continue # Hallucination check
|
| 513 |
+
else: s['type'] = 'TEXT'
|
| 514 |
+
filtered.append(s)
|
| 515 |
+
return filtered
|
| 516 |
+
|
| 517 |
+
# --- Preprocessing ---
|
| 518 |
+
all_h = ([s['height'] for s in google_syms] + [s['height'] for s in custom_syms])
|
| 519 |
+
median_h = np.median(all_h) if all_h else 30.0
|
| 520 |
+
|
| 521 |
+
# Filter Height & Check Ink
|
| 522 |
+
def global_remove_tall_and_top(boxes, median_h, threshold=2.0):
|
| 523 |
+
if not boxes: return []
|
| 524 |
+
filtered = []
|
| 525 |
+
for b in boxes:
|
| 526 |
+
if b['height'] > median_h * threshold: continue
|
| 527 |
+
if b['min_y'] < img_h * 0.15 and b['height'] > median_h * 2.5: continue
|
| 528 |
+
filtered.append(b)
|
| 529 |
+
return filtered
|
| 530 |
+
|
| 531 |
+
if google_syms:
|
| 532 |
+
google_syms = global_remove_tall_and_top(google_syms, median_h, threshold=2.0)
|
| 533 |
+
google_syms = check_ink_on_google(google_syms)
|
| 534 |
+
if custom_syms:
|
| 535 |
+
custom_syms = global_remove_tall_and_top(custom_syms, median_h, threshold=3.5)
|
| 536 |
+
|
| 537 |
+
# Resize & Filter Custom
|
| 538 |
+
avg_w = np.mean([s['width'] for s in google_syms]) if google_syms else 0
|
| 539 |
+
median_w = np.median([s['width'] for s in google_syms]) if google_syms else 0
|
| 540 |
+
|
| 541 |
+
processed_custom = []
|
| 542 |
+
for s in custom_syms:
|
| 543 |
+
if 'MASK' in s.get('type', 'TEXT'):
|
| 544 |
+
processed_custom.append(s); continue
|
| 545 |
+
|
| 546 |
+
if (s['width']*s['height'] > (median_w*median_h)*0.2 and
|
| 547 |
+
s['width'] > median_w*0.3 and s['height'] > median_h*0.3):
|
| 548 |
+
|
| 549 |
+
# Resize logic
|
| 550 |
+
if s['width'] < median_w*0.8 or s['height'] < median_h*0.8:
|
| 551 |
+
tw = max(s['width'], median_w*0.9)
|
| 552 |
+
th = max(s['height'], median_h*0.9)
|
| 553 |
+
cx, cy = s['center_x'], s['center_y']
|
| 554 |
+
s.update({'min_x': max(0, cx-tw/2), 'max_x': min(img_w, cx+tw/2),
|
| 555 |
+
'min_y': max(0, cy-th/2), 'max_y': min(img_h, cy+th/2)})
|
| 556 |
+
s.update({'width': s['max_x']-s['min_x'], 'height': s['max_y']-s['min_y']})
|
| 557 |
+
processed_custom.append(s)
|
| 558 |
+
|
| 559 |
+
custom_syms = filter_google_overlaps(google_syms, processed_custom)
|
| 560 |
+
|
| 561 |
+
if not google_syms and not custom_syms: return [], []
|
| 562 |
+
|
| 563 |
+
# --- Column Grouping ---
|
| 564 |
+
all_syms = google_syms + custom_syms
|
| 565 |
+
columns = []
|
| 566 |
+
if all_syms:
|
| 567 |
+
for s in sorted(all_syms, key=lambda x: -x['center_x']):
|
| 568 |
+
found = False
|
| 569 |
+
for col in columns:
|
| 570 |
+
cx = sum(c['center_x'] for c in col) / len(col)
|
| 571 |
+
if abs(s['center_x'] - cx) < (avg_w if avg_w else s['width']) * ec['column_grouping_ratio']:
|
| 572 |
+
col.append(s); found = True; break
|
| 573 |
+
if not found: columns.append([s])
|
| 574 |
+
|
| 575 |
+
# Vertical Step Calculation
|
| 576 |
+
global_steps = []
|
| 577 |
+
for col in columns:
|
| 578 |
+
col.sort(key=lambda s: s['center_y'])
|
| 579 |
+
for k in range(len(col)-1):
|
| 580 |
+
step = col[k+1]['center_y'] - col[k]['center_y']
|
| 581 |
+
if median_h * 0.8 < step < median_h * 1.5: global_steps.append(step)
|
| 582 |
+
global_step = np.median(global_steps) if global_steps else median_h * 1.1
|
| 583 |
+
|
| 584 |
+
# --- Reconstruction ---
|
| 585 |
+
final_boxes, lines = [], []
|
| 586 |
+
for col in columns:
|
| 587 |
+
col.sort(key=lambda s: s['center_y'])
|
| 588 |
+
local_steps = [col[k+1]['center_y'] - col[k]['center_y'] for k in range(len(col)-1)
|
| 589 |
+
if median_h*0.8 < (col[k+1]['center_y'] - col[k]['center_y']) < median_h*1.5]
|
| 590 |
+
step_y = np.median(local_steps) if local_steps else global_step
|
| 591 |
+
|
| 592 |
+
# Deduplication in column
|
| 593 |
+
unique_col = []
|
| 594 |
+
if col:
|
| 595 |
+
prev = col[0]
|
| 596 |
+
unique_col.append(prev)
|
| 597 |
+
for k in range(1, len(col)):
|
| 598 |
+
curr = col[k]
|
| 599 |
+
dist_y = abs(curr['center_y'] - prev['center_y'])
|
| 600 |
+
is_same_text = (curr.get('text') == prev.get('text'))
|
| 601 |
+
is_close = (dist_y < median_h * 0.6)
|
| 602 |
+
|
| 603 |
+
if is_close:
|
| 604 |
+
prev_is_mask = 'MASK' in prev.get('type', 'TEXT')
|
| 605 |
+
curr_is_mask = 'MASK' in curr.get('type', 'TEXT')
|
| 606 |
+
|
| 607 |
+
if prev_is_mask and curr_is_mask:
|
| 608 |
+
if prev['density'] < curr['density']:
|
| 609 |
+
unique_col.pop()
|
| 610 |
+
unique_col.append(curr)
|
| 611 |
+
prev = curr
|
| 612 |
+
continue
|
| 613 |
+
elif prev_is_mask and not curr_is_mask:
|
| 614 |
+
continue
|
| 615 |
+
elif not prev_is_mask and curr_is_mask:
|
| 616 |
+
unique_col.pop()
|
| 617 |
+
unique_col.append(curr)
|
| 618 |
+
prev = curr
|
| 619 |
+
continue
|
| 620 |
+
|
| 621 |
+
if is_same_text and is_close:
|
| 622 |
+
if prev.get('source') == 'Google':
|
| 623 |
+
continue
|
| 624 |
+
elif curr.get('source') == 'Google':
|
| 625 |
+
unique_col.pop()
|
| 626 |
+
unique_col.append(curr)
|
| 627 |
+
prev = curr
|
| 628 |
+
else:
|
| 629 |
+
continue
|
| 630 |
+
else:
|
| 631 |
+
unique_col.append(curr)
|
| 632 |
+
prev = curr
|
| 633 |
+
|
| 634 |
+
col = infer_gaps(unique_col, step_y, avg_w if avg_w else median_h)
|
| 635 |
+
|
| 636 |
+
# Gap Filling with Masks
|
| 637 |
+
filled_col, cy = [], col[0]['min_y'] if col else 0
|
| 638 |
+
for item in col:
|
| 639 |
+
gap = item['min_y'] - cy
|
| 640 |
+
if gap > step_y * 1.2:
|
| 641 |
+
mb = {'min_x': item['center_x'] - (avg_w if avg_w else median_h)/2,
|
| 642 |
+
'max_x': item['center_x'] + (avg_w if avg_w else median_h)/2,
|
| 643 |
+
'min_y': max(0, cy + gap*0.1), 'max_y': min(img_h, item['min_y'] - gap*0.1)}
|
| 644 |
+
d = calculate_pixel_density(binary_img, mb)
|
| 645 |
+
if d >= ft['density_min_hard']:
|
| 646 |
+
mt = 'MASK1' if d >= it['density_ink_heavy'] else 'MASK2'
|
| 647 |
+
if d >= it['density_ink_partial']:
|
| 648 |
+
filled_col.append({'text': f'[{mt}]', 'type': mt, 'density': d,
|
| 649 |
+
'min_x': mb['min_x'], 'max_x': mb['max_x'],
|
| 650 |
+
'min_y': mb['min_y'], 'max_y': mb['max_y'],
|
| 651 |
+
'confidence': 0.0, 'source': 'GapFill'})
|
| 652 |
+
|
| 653 |
+
if item.get('density', 0) < ft['density_min_hard'] and 'MASK' not in item.get('type','TEXT'):
|
| 654 |
+
cy = item['max_y']; continue
|
| 655 |
+
|
| 656 |
+
filled_col.append(item)
|
| 657 |
+
cy = item['max_y']
|
| 658 |
+
|
| 659 |
+
filled_col = merge_split_masks(filled_col, median_h)
|
| 660 |
+
filled_col = filter_excessive_masks(filled_col)
|
| 661 |
+
filled_col = resolve_overlaps(filled_col)
|
| 662 |
+
|
| 663 |
+
final_boxes.extend(filled_col)
|
| 664 |
+
lines.append("".join([s['text'] for s in filled_col]))
|
| 665 |
+
|
| 666 |
+
logger.info(f"[ENSEMBLE] 완료: {len(final_boxes)}개 박스, {len(lines)}개 열")
|
| 667 |
+
return final_boxes, lines
|
| 668 |
+
|
| 669 |
+
# ================================================================================
|
| 670 |
+
# OCREngine Class
|
| 671 |
+
# ================================================================================
|
| 672 |
+
class OCREngine:
|
| 673 |
+
def __init__(self, config_path: Optional[str] = None):
|
| 674 |
+
self.config = load_ocr_config(config_path)
|
| 675 |
+
|
| 676 |
+
# Load paths from env
|
| 677 |
+
base_path = os.getenv('OCR_WEIGHTS_BASE_PATH')
|
| 678 |
+
if not base_path:
|
| 679 |
+
raise ValueError("OCR_WEIGHTS_BASE_PATH environment variable is required. Please set it in your .env file.")
|
| 680 |
+
|
| 681 |
+
self.det_ckpt = os.path.join(base_path, os.getenv('OCR_DETECTION_MODEL', 'best.pth'))
|
| 682 |
+
self.rec_ckpt = os.path.join(base_path, os.getenv('OCR_RECOGNITION_MODEL', 'best_5000.pt'))
|
| 683 |
+
self.google_json = os.path.join(base_path, os.getenv('GOOGLE_CREDENTIALS_JSON'))
|
| 684 |
+
|
| 685 |
+
if not self.google_json or not os.path.exists(self.google_json):
|
| 686 |
+
raise ValueError(f"GOOGLE_CREDENTIALS_JSON environment variable is required and file must exist. Please set it in your .env file.")
|
| 687 |
+
|
| 688 |
+
if os.path.exists(self.google_json):
|
| 689 |
+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.google_json
|
| 690 |
+
|
| 691 |
+
# Device
|
| 692 |
+
dev_cfg = self.config['model_config']['device']
|
| 693 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if dev_cfg == 'auto' else torch.device(dev_cfg)
|
| 694 |
+
self.detector = None
|
| 695 |
+
self.recognizer = None
|
| 696 |
+
|
| 697 |
+
def _load_models(self):
|
| 698 |
+
if not self.detector:
|
| 699 |
+
self.detector = TextDetector(self.device, self.det_ckpt, self.config)
|
| 700 |
+
if not self.recognizer:
|
| 701 |
+
self.recognizer = ResnetCustom(weight_fn=self.rec_ckpt)
|
| 702 |
+
self.recognizer.to(self.device)
|
| 703 |
+
|
| 704 |
+
def run_ocr(self, image_path: str) -> Dict:
|
| 705 |
+
try:
|
| 706 |
+
self._load_models()
|
| 707 |
+
|
| 708 |
+
# 1. Preprocessing (Exact Match to v12 Script)
|
| 709 |
+
img_bgr = cv2.imread(image_path)
|
| 710 |
+
if img_bgr is None: raise ValueError(f"Image not found: {image_path}")
|
| 711 |
+
|
| 712 |
+
img_gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
| 713 |
+
img_blur = cv2.medianBlur(img_gray, 3)
|
| 714 |
+
_, img_binary = cv2.threshold(img_blur, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
| 715 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
|
| 716 |
+
img_binary = cv2.morphologyEx(img_binary, cv2.MORPH_CLOSE, kernel)
|
| 717 |
+
|
| 718 |
+
# 2. Google Vision
|
| 719 |
+
with io.open(image_path, 'rb') as f: content = f.read()
|
| 720 |
+
google_syms = get_google_ocr(content, self.config, self.google_json)
|
| 721 |
+
|
| 722 |
+
# 3. Custom Model
|
| 723 |
+
custom_syms = get_custom_model_ocr(image_path, img_binary, self.detector, self.recognizer, self.config)
|
| 724 |
+
|
| 725 |
+
# 4. Ensemble
|
| 726 |
+
final_boxes, result_lines = ensemble_reconstruction(google_syms, custom_syms, img_binary, self.config)
|
| 727 |
+
|
| 728 |
+
# Format results according to specification
|
| 729 |
+
formatted_results = []
|
| 730 |
+
for order, box in enumerate(final_boxes):
|
| 731 |
+
formatted_results.append({
|
| 732 |
+
"order": order,
|
| 733 |
+
"text": box.get('text', ''),
|
| 734 |
+
"type": box.get('type', 'TEXT'),
|
| 735 |
+
"box": [
|
| 736 |
+
float(box.get('min_x', 0)),
|
| 737 |
+
float(box.get('min_y', 0)),
|
| 738 |
+
float(box.get('max_x', 0)),
|
| 739 |
+
float(box.get('max_y', 0))
|
| 740 |
+
],
|
| 741 |
+
"confidence": float(box.get('confidence', 0.0)),
|
| 742 |
+
"source": box.get('source', 'Unknown')
|
| 743 |
+
})
|
| 744 |
+
|
| 745 |
+
# Extract image filename
|
| 746 |
+
image_filename = os.path.basename(image_path)
|
| 747 |
+
|
| 748 |
+
return {
|
| 749 |
+
"image": image_filename,
|
| 750 |
+
"results": formatted_results
|
| 751 |
+
}
|
| 752 |
+
except Exception as e:
|
| 753 |
+
logger.error(f"[OCR] Execution Failed: {e}", exc_info=True)
|
| 754 |
+
return {"success": False, "error": str(e)}
|
| 755 |
+
|
| 756 |
+
# ================================================================================
|
| 757 |
+
# Global Accessor
|
| 758 |
+
# ================================================================================
|
| 759 |
+
_engine = None
|
| 760 |
+
|
| 761 |
+
def get_ocr_engine(config_path: Optional[str] = None) -> OCREngine:
|
| 762 |
+
global _engine
|
| 763 |
+
if _engine is None: _engine = OCREngine(config_path)
|
| 764 |
+
return _engine
|
| 765 |
+
|
| 766 |
+
def ocr_and_detect(image_path: str, config_path: Optional[str] = None, bbox: Optional[Tuple[int, int, int, int]] = None, device: str = "cuda") -> Dict:
|
| 767 |
+
return get_ocr_engine(config_path).run_ocr(image_path)
|
ai_modules/preprocessor_unified.py
ADDED
|
@@ -0,0 +1,605 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Epitext_Back/ai_modules/preprocessor_unified.py
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
================================================================================
|
| 5 |
+
Unified Image Preprocessing Module for Epitext AI Project
|
| 6 |
+
================================================================================
|
| 7 |
+
|
| 8 |
+
모듈명: preprocessor_unified.py (v1.0.0 - Production Ready)
|
| 9 |
+
작성일: 2025-12-02
|
| 10 |
+
목적: 한자 이미지를 Swin Gray와 OCR용으로 동시에 전처리
|
| 11 |
+
상태: Production Ready
|
| 12 |
+
|
| 13 |
+
핵심 기능:
|
| 14 |
+
한 번에 두 가지 전처리 완료:
|
| 15 |
+
1. Swin Gray: 그레이 비이진화 -> 3채널 (정보 손실 최소)
|
| 16 |
+
2. OCR: 이진화 -> 1채널 (명확한 흑백)
|
| 17 |
+
|
| 18 |
+
자동 배경 보장:
|
| 19 |
+
- Swin: 밝은배경 (>=127)
|
| 20 |
+
- OCR: 흰배경 + 검정글자 (255/0)
|
| 21 |
+
|
| 22 |
+
탁본 자동 검출: 큰 어두운 영역 식별
|
| 23 |
+
영역 검출 1회: 효율성
|
| 24 |
+
설정 파일 지원: JSON 기반 커스터마이징
|
| 25 |
+
로깅 지원: DEBUG, INFO, WARNING, ERROR
|
| 26 |
+
|
| 27 |
+
의존성:
|
| 28 |
+
- opencv-python >= 4.8.0
|
| 29 |
+
- numpy >= 1.24.0
|
| 30 |
+
|
| 31 |
+
단일 함수:
|
| 32 |
+
preprocess_image_unified(input_path, output_swin_path, output_ocr_path, ...)
|
| 33 |
+
|
| 34 |
+
사용 예시:
|
| 35 |
+
>>> from ai_modules.preprocessor_unified import preprocess_image_unified
|
| 36 |
+
>>> result = preprocess_image_unified(
|
| 37 |
+
... "input.jpg",
|
| 38 |
+
... "swin.jpg",
|
| 39 |
+
... "ocr.png"
|
| 40 |
+
... )
|
| 41 |
+
|
| 42 |
+
================================================================================
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
import cv2
|
| 47 |
+
import numpy as np
|
| 48 |
+
from pathlib import Path
|
| 49 |
+
import json
|
| 50 |
+
import logging
|
| 51 |
+
from typing import Dict, Optional, Tuple
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ================================================================================
|
| 55 |
+
# Logging Configuration
|
| 56 |
+
# ================================================================================
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
logging.basicConfig(
|
| 60 |
+
level=logging.INFO,
|
| 61 |
+
format='%(asctime)s - [%(levelname)s] %(message)s'
|
| 62 |
+
)
|
| 63 |
+
logger = logging.getLogger(__name__)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# ================================================================================
|
| 67 |
+
# Constants
|
| 68 |
+
# ================================================================================
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# 기본 설정값
|
| 72 |
+
DEFAULT_MARGIN = 10
|
| 73 |
+
DEFAULT_BRIGHTNESS_THRESHOLD = 127
|
| 74 |
+
DEFAULT_RUBBING_MIN_AREA_RATIO = 0.1
|
| 75 |
+
DEFAULT_TEXT_MIN_AREA = 16
|
| 76 |
+
DEFAULT_TEXT_AREA_RATIO = 0.00005
|
| 77 |
+
DEFAULT_MORPHOLOGY_KERNEL_SIZE = (2, 2)
|
| 78 |
+
DEFAULT_MORPHOLOGY_CLOSE_ITERATIONS = 3
|
| 79 |
+
DEFAULT_MORPHOLOGY_OPEN_ITERATIONS = 2
|
| 80 |
+
DEFAULT_RUBBING_KERNEL_SIZE = (5, 5)
|
| 81 |
+
DEFAULT_RUBBING_CLOSE_ITERATIONS = 10
|
| 82 |
+
DEFAULT_RUBBING_OPEN_ITERATIONS = 5
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# ================================================================================
|
| 86 |
+
# Main Preprocessing Class
|
| 87 |
+
# ================================================================================
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class UnifiedImagePreprocessor:
|
| 91 |
+
"""
|
| 92 |
+
통합 이미지 전처리 클래스 (Swin + OCR)
|
| 93 |
+
|
| 94 |
+
한 번의 처리로 Swin Gray와 OCR용 이미지를 모두 생성합니다.
|
| 95 |
+
|
| 96 |
+
Attributes:
|
| 97 |
+
config (dict): 전처리 설정 파라미터
|
| 98 |
+
|
| 99 |
+
Example:
|
| 100 |
+
>>> prep = UnifiedImagePreprocessor()
|
| 101 |
+
>>> result = prep.preprocess_unified("input.jpg", "swin.jpg", "ocr.png")
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __init__(self, config_path: Optional[str] = None) -> None:
|
| 105 |
+
"""
|
| 106 |
+
UnifiedImagePreprocessor 초기화
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
config_path (str, optional): 설정 파일 경로 (JSON)
|
| 110 |
+
"""
|
| 111 |
+
self.config = self._load_config(config_path)
|
| 112 |
+
logger.info("[INIT] UnifiedImagePreprocessor v1.0.0 초기화 완료")
|
| 113 |
+
|
| 114 |
+
def _load_config(self, config_path: Optional[str]) -> Dict:
|
| 115 |
+
"""설정 파일 로드"""
|
| 116 |
+
default_config = {
|
| 117 |
+
"margin": DEFAULT_MARGIN,
|
| 118 |
+
"brightness_threshold": DEFAULT_BRIGHTNESS_THRESHOLD,
|
| 119 |
+
"rubbing_min_area_ratio": DEFAULT_RUBBING_MIN_AREA_RATIO,
|
| 120 |
+
"text_min_area": DEFAULT_TEXT_MIN_AREA,
|
| 121 |
+
"text_area_ratio": DEFAULT_TEXT_AREA_RATIO,
|
| 122 |
+
"morphology_kernel_size": DEFAULT_MORPHOLOGY_KERNEL_SIZE,
|
| 123 |
+
"morphology_close_iterations": DEFAULT_MORPHOLOGY_CLOSE_ITERATIONS,
|
| 124 |
+
"morphology_open_iterations": DEFAULT_MORPHOLOGY_OPEN_ITERATIONS,
|
| 125 |
+
"rubbing_kernel_size": DEFAULT_RUBBING_KERNEL_SIZE,
|
| 126 |
+
"rubbing_close_iterations": DEFAULT_RUBBING_CLOSE_ITERATIONS,
|
| 127 |
+
"rubbing_open_iterations": DEFAULT_RUBBING_OPEN_ITERATIONS,
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
# 기본 설정 파일 경로 (config_path가 없을 때)
|
| 131 |
+
if config_path is None:
|
| 132 |
+
default_config_path = Path(__file__).parent / "config" / "preprocess_config.json"
|
| 133 |
+
if default_config_path.exists():
|
| 134 |
+
config_path = str(default_config_path)
|
| 135 |
+
|
| 136 |
+
if config_path and Path(config_path).exists():
|
| 137 |
+
try:
|
| 138 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 139 |
+
user_config = json.load(f)
|
| 140 |
+
# _description 필드는 제외하고 업데이트
|
| 141 |
+
user_config_clean = {k: v for k, v in user_config.items() if not k.startswith('_')}
|
| 142 |
+
default_config.update(user_config_clean)
|
| 143 |
+
logger.info(f"[CONFIG] 설정 파일 로드: {config_path}")
|
| 144 |
+
except Exception as e:
|
| 145 |
+
logger.warning(f"[CONFIG] 설정 파일 로드 실패: {e} - 기본 설정 사용")
|
| 146 |
+
|
| 147 |
+
return default_config
|
| 148 |
+
|
| 149 |
+
def _find_rubbing_bbox(self, gray_image: np.ndarray) -> Optional[Tuple[int, int, int, int]]:
|
| 150 |
+
"""
|
| 151 |
+
탁본 영역 검출 (큰 어두운 사각형 찾기)
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
gray_image (np.ndarray): 그레이스케일 이미지
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
tuple: (x, y, w, h) 또는 None
|
| 158 |
+
"""
|
| 159 |
+
H_img, W_img = gray_image.shape
|
| 160 |
+
|
| 161 |
+
# Step 1: 어두운 영역 추출
|
| 162 |
+
_, dark_mask = cv2.threshold(gray_image, 127, 255, cv2.THRESH_BINARY_INV)
|
| 163 |
+
|
| 164 |
+
# Step 2: 모폴로지 연산
|
| 165 |
+
kernel_rub = np.ones(self.config["rubbing_kernel_size"], np.uint8)
|
| 166 |
+
dark_mask = cv2.morphologyEx(
|
| 167 |
+
dark_mask, cv2.MORPH_CLOSE, kernel_rub,
|
| 168 |
+
iterations=self.config["rubbing_close_iterations"]
|
| 169 |
+
)
|
| 170 |
+
dark_mask = cv2.morphologyEx(
|
| 171 |
+
dark_mask, cv2.MORPH_OPEN, kernel_rub,
|
| 172 |
+
iterations=self.config["rubbing_open_iterations"]
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# Step 3: 컨투어 검출
|
| 176 |
+
contours, _ = cv2.findContours(dark_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 177 |
+
|
| 178 |
+
if not contours:
|
| 179 |
+
return None
|
| 180 |
+
|
| 181 |
+
# Step 4: 가장 큰 컨투어
|
| 182 |
+
largest = max(contours, key=cv2.contourArea)
|
| 183 |
+
area = cv2.contourArea(largest)
|
| 184 |
+
|
| 185 |
+
# Step 5: 면적 검증
|
| 186 |
+
min_area = (H_img * W_img) * self.config["rubbing_min_area_ratio"]
|
| 187 |
+
if area < min_area:
|
| 188 |
+
return None
|
| 189 |
+
|
| 190 |
+
return cv2.boundingRect(largest)
|
| 191 |
+
|
| 192 |
+
def _find_text_bbox(self, gray_image: np.ndarray) -> Tuple[int, int, int, int]:
|
| 193 |
+
"""
|
| 194 |
+
텍스트 영역 검출
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
gray_image (np.ndarray): 그레이스케일 이미지
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
tuple: (x, y, w, h)
|
| 201 |
+
"""
|
| 202 |
+
H_img, W_img = gray_image.shape
|
| 203 |
+
|
| 204 |
+
# Step 1: Otsu 이진화
|
| 205 |
+
_, binary = cv2.threshold(
|
| 206 |
+
gray_image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# Step 2: 모폴로지 연산
|
| 210 |
+
kernel_morph = np.ones(self.config["morphology_kernel_size"], np.uint8)
|
| 211 |
+
binary = cv2.morphologyEx(
|
| 212 |
+
binary, cv2.MORPH_CLOSE, kernel_morph,
|
| 213 |
+
iterations=self.config["morphology_close_iterations"]
|
| 214 |
+
)
|
| 215 |
+
binary = cv2.morphologyEx(
|
| 216 |
+
binary, cv2.MORPH_OPEN, kernel_morph,
|
| 217 |
+
iterations=self.config["morphology_open_iterations"]
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
# Step 3: 컨투어 검출
|
| 221 |
+
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 222 |
+
|
| 223 |
+
# Step 4: 최소 면적 설정
|
| 224 |
+
min_area = max(
|
| 225 |
+
self.config["text_min_area"],
|
| 226 |
+
int((H_img * W_img) * self.config["text_area_ratio"])
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Step 5: 유효한 컨투어 필터링
|
| 230 |
+
valid_contours = [
|
| 231 |
+
cnt for cnt in contours
|
| 232 |
+
if cv2.contourArea(cv2.boundingRect(cnt)) >= min_area
|
| 233 |
+
]
|
| 234 |
+
|
| 235 |
+
# Step 6: 경계박스 계산
|
| 236 |
+
if valid_contours:
|
| 237 |
+
all_points = np.vstack(valid_contours)
|
| 238 |
+
return cv2.boundingRect(all_points)
|
| 239 |
+
else:
|
| 240 |
+
return (0, 0, W_img, H_img)
|
| 241 |
+
|
| 242 |
+
def _apply_margin(
|
| 243 |
+
self,
|
| 244 |
+
bbox: Tuple[int, int, int, int],
|
| 245 |
+
gray_image: np.ndarray,
|
| 246 |
+
margin_val: int
|
| 247 |
+
) -> Tuple[int, int, int, int]:
|
| 248 |
+
"""여백 추가"""
|
| 249 |
+
x, y, w, h = bbox
|
| 250 |
+
H_img, W_img = gray_image.shape
|
| 251 |
+
|
| 252 |
+
x_new = max(0, x - margin_val)
|
| 253 |
+
y_new = max(0, y - margin_val)
|
| 254 |
+
w_new = min(W_img - x_new, w + 2 * margin_val)
|
| 255 |
+
h_new = min(H_img - y_new, h + 2 * margin_val)
|
| 256 |
+
|
| 257 |
+
return (x_new, y_new, w_new, h_new)
|
| 258 |
+
|
| 259 |
+
def _ensure_bright_background(
|
| 260 |
+
self,
|
| 261 |
+
gray_cropped: np.ndarray
|
| 262 |
+
) -> Tuple[np.ndarray, Dict]:
|
| 263 |
+
"""
|
| 264 |
+
밝은배경 보장 (Swin용)
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
tuple: (처리된 그레이 이미지, 처리 정보)
|
| 268 |
+
"""
|
| 269 |
+
mean_brightness = np.mean(gray_cropped)
|
| 270 |
+
is_inverted = False
|
| 271 |
+
|
| 272 |
+
if mean_brightness < self.config["brightness_threshold"]:
|
| 273 |
+
gray_bright = cv2.bitwise_not(gray_cropped)
|
| 274 |
+
is_inverted = True
|
| 275 |
+
else:
|
| 276 |
+
gray_bright = gray_cropped.copy()
|
| 277 |
+
|
| 278 |
+
# 재확인
|
| 279 |
+
final_brightness = np.mean(gray_bright)
|
| 280 |
+
if final_brightness < self.config["brightness_threshold"]:
|
| 281 |
+
gray_bright = cv2.bitwise_not(gray_bright)
|
| 282 |
+
is_inverted = not is_inverted
|
| 283 |
+
final_brightness = np.mean(gray_bright)
|
| 284 |
+
|
| 285 |
+
return gray_bright, {
|
| 286 |
+
"mean_brightness_before": float(mean_brightness),
|
| 287 |
+
"mean_brightness_after": float(final_brightness),
|
| 288 |
+
"is_inverted": is_inverted,
|
| 289 |
+
"is_bright_bg": final_brightness >= self.config["brightness_threshold"]
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
def _ensure_white_background(
|
| 293 |
+
self,
|
| 294 |
+
gray_cropped: np.ndarray
|
| 295 |
+
) -> Tuple[np.ndarray, Dict]:
|
| 296 |
+
"""
|
| 297 |
+
흰배경 보장 (OCR용)
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
tuple: (처리된 이진 이미지, 처리 정보)
|
| 301 |
+
"""
|
| 302 |
+
# Step 1: 이진화
|
| 303 |
+
_, binary = cv2.threshold(
|
| 304 |
+
gray_cropped, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# Step 2: 폴라리티 판단
|
| 308 |
+
mean_brightness = np.mean(binary)
|
| 309 |
+
|
| 310 |
+
# Step 3: 필요시 반전
|
| 311 |
+
if mean_brightness < self.config["brightness_threshold"]:
|
| 312 |
+
binary_final = cv2.bitwise_not(binary)
|
| 313 |
+
polarity = "inverted"
|
| 314 |
+
else:
|
| 315 |
+
binary_final = binary
|
| 316 |
+
polarity = "normal"
|
| 317 |
+
|
| 318 |
+
final_brightness = np.mean(binary_final)
|
| 319 |
+
|
| 320 |
+
return binary_final, {
|
| 321 |
+
"mean_brightness_before": float(mean_brightness),
|
| 322 |
+
"mean_brightness_after": float(final_brightness),
|
| 323 |
+
"polarity": polarity,
|
| 324 |
+
"is_white_bg": final_brightness > self.config["brightness_threshold"]
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
def preprocess_unified(
|
| 328 |
+
self,
|
| 329 |
+
input_image_path: str,
|
| 330 |
+
output_swin_path: str,
|
| 331 |
+
output_ocr_path: str,
|
| 332 |
+
margin: Optional[int] = None,
|
| 333 |
+
use_rubbing: bool = False
|
| 334 |
+
) -> Dict:
|
| 335 |
+
"""
|
| 336 |
+
통합 전처리 (Swin Gray + OCR 동시 생성)
|
| 337 |
+
|
| 338 |
+
한 번의 함수 호출로 Swin Gray와 OCR용 이미지를 모두 생성합니다.
|
| 339 |
+
탁본 및 텍스트 영역 검출은 1회만 수행되어 효율성을 보장합니다.
|
| 340 |
+
|
| 341 |
+
Args:
|
| 342 |
+
input_image_path (str): 입력 이미지 경로
|
| 343 |
+
output_swin_path (str): Swin Gray 출력 경로 (JPG)
|
| 344 |
+
output_ocr_path (str): OCR 출력 경로 (PNG)
|
| 345 |
+
margin (int, optional): 크롭 여백 (픽셀)
|
| 346 |
+
use_rubbing (bool): 탁본 검출 여부 (기본: False)
|
| 347 |
+
|
| 348 |
+
Returns:
|
| 349 |
+
dict: 처리 결과
|
| 350 |
+
성공 시: {
|
| 351 |
+
"success": True,
|
| 352 |
+
"original_shape": (H, W, C),
|
| 353 |
+
"bbox": (x, y, w, h),
|
| 354 |
+
"region_type": "text" or "rubbing",
|
| 355 |
+
"region_detected": bool,
|
| 356 |
+
|
| 357 |
+
"swin": {
|
| 358 |
+
"output_path": str,
|
| 359 |
+
"output_shape": (H, W, 3),
|
| 360 |
+
"is_bright_bg": bool,
|
| 361 |
+
...
|
| 362 |
+
},
|
| 363 |
+
|
| 364 |
+
"ocr": {
|
| 365 |
+
"output_path": str,
|
| 366 |
+
"output_shape": (H, W),
|
| 367 |
+
"is_white_bg": bool,
|
| 368 |
+
...
|
| 369 |
+
}
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
실패 시: {
|
| 373 |
+
"success": False,
|
| 374 |
+
"message": str
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
Processing Steps:
|
| 378 |
+
1. 이미지 로드
|
| 379 |
+
2. 그레이스케일 변환
|
| 380 |
+
3. 영역 검출 (탁본 또는 텍스트, 1회만)
|
| 381 |
+
4. 크롭 + 여백
|
| 382 |
+
5. Swin Gray 처리 (밝은배경 보장)
|
| 383 |
+
6. OCR 처리 (이진화 + 흰배경 보장)
|
| 384 |
+
7. 동시 저장
|
| 385 |
+
|
| 386 |
+
Output:
|
| 387 |
+
- Swin: JPG 3채널 (비이진화 256단계)
|
| 388 |
+
- OCR: PNG 1채널 (이진화)
|
| 389 |
+
|
| 390 |
+
Example:
|
| 391 |
+
>>> prep = UnifiedImagePreprocessor()
|
| 392 |
+
>>> result = prep.preprocess_unified(
|
| 393 |
+
... "input.jpg",
|
| 394 |
+
... "swin.jpg",
|
| 395 |
+
... "ocr.png"
|
| 396 |
+
... )
|
| 397 |
+
>>> if result["success"]:
|
| 398 |
+
... swin_output = result["swin"]["output_path"]
|
| 399 |
+
... ocr_output = result["ocr"]["output_path"]
|
| 400 |
+
"""
|
| 401 |
+
margin_val = margin or self.config["margin"]
|
| 402 |
+
|
| 403 |
+
try:
|
| 404 |
+
# ====================================================================
|
| 405 |
+
# Step 1: 이미지 로드
|
| 406 |
+
# ====================================================================
|
| 407 |
+
img_bgr = cv2.imread(str(input_image_path), cv2.IMREAD_COLOR)
|
| 408 |
+
if img_bgr is None:
|
| 409 |
+
raise ValueError(f"이미지 로드 실패: {input_image_path}")
|
| 410 |
+
|
| 411 |
+
original_shape = img_bgr.shape
|
| 412 |
+
logger.info(f"[LOAD] 이미지 로드: {input_image_path} {original_shape}")
|
| 413 |
+
|
| 414 |
+
# ====================================================================
|
| 415 |
+
# Step 2: 그레이스케일 변환
|
| 416 |
+
# ====================================================================
|
| 417 |
+
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
| 418 |
+
|
| 419 |
+
# ====================================================================
|
| 420 |
+
# Step 3: 영역 검출 (탁본 또는 텍스트)
|
| 421 |
+
# ====================================================================
|
| 422 |
+
if use_rubbing:
|
| 423 |
+
detected_bbox = self._find_rubbing_bbox(gray)
|
| 424 |
+
region_type = "rubbing"
|
| 425 |
+
logger.info("[DETECT] 탁본 영역 검출 모드")
|
| 426 |
+
else:
|
| 427 |
+
detected_bbox = None
|
| 428 |
+
region_type = "text"
|
| 429 |
+
logger.info("[DETECT] 텍스트 영역 검출 모드")
|
| 430 |
+
|
| 431 |
+
H_img, W_img = gray.shape
|
| 432 |
+
|
| 433 |
+
# ====================================================================
|
| 434 |
+
# Step 4: 크롭 + 여백
|
| 435 |
+
# ====================================================================
|
| 436 |
+
if detected_bbox is not None:
|
| 437 |
+
bbox_final = self._apply_margin(detected_bbox, gray, margin_val)
|
| 438 |
+
logger.info(f"[DETECT] {region_type} 영역 검출: {bbox_final}")
|
| 439 |
+
else:
|
| 440 |
+
# 탁본 미검출 또는 텍스트 모드 -> 텍스트 검출
|
| 441 |
+
if use_rubbing:
|
| 442 |
+
bbox_final = (0, 0, W_img, H_img)
|
| 443 |
+
logger.warning("[DETECT] 탁본 미검출 - 전체 이미지 사용")
|
| 444 |
+
else:
|
| 445 |
+
bbox_text = self._find_text_bbox(gray)
|
| 446 |
+
bbox_final = self._apply_margin(bbox_text, gray, margin_val)
|
| 447 |
+
logger.info(f"[DETECT] 텍스트 영역 검출: {bbox_final}")
|
| 448 |
+
|
| 449 |
+
x, y, w, h = bbox_final
|
| 450 |
+
gray_cropped = gray[y:y+h, x:x+w]
|
| 451 |
+
|
| 452 |
+
logger.info(f"[CROP] 크롭 완료: {gray_cropped.shape}")
|
| 453 |
+
|
| 454 |
+
# ====================================================================
|
| 455 |
+
# Step 5: Swin Gray 처리
|
| 456 |
+
# ====================================================================
|
| 457 |
+
gray_bright, info_swin = self._ensure_bright_background(gray_cropped)
|
| 458 |
+
swin_output_3ch = cv2.cvtColor(gray_bright, cv2.COLOR_GRAY2BGR)
|
| 459 |
+
|
| 460 |
+
# ====================================================================
|
| 461 |
+
# Step 6: OCR 처리
|
| 462 |
+
# ====================================================================
|
| 463 |
+
binary_final, info_ocr = self._ensure_white_background(gray_cropped)
|
| 464 |
+
|
| 465 |
+
# ====================================================================
|
| 466 |
+
# Step 7: 동시 저장
|
| 467 |
+
# ====================================================================
|
| 468 |
+
output_swin_path_obj = Path(output_swin_path)
|
| 469 |
+
output_swin_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
| 470 |
+
swin_success = cv2.imwrite(str(output_swin_path_obj), swin_output_3ch)
|
| 471 |
+
|
| 472 |
+
output_ocr_path_obj = Path(output_ocr_path)
|
| 473 |
+
output_ocr_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
| 474 |
+
ocr_success = cv2.imwrite(str(output_ocr_path_obj), binary_final)
|
| 475 |
+
|
| 476 |
+
if not swin_success or not ocr_success:
|
| 477 |
+
raise ValueError("이미지 저장 실패")
|
| 478 |
+
|
| 479 |
+
logger.info(f"[SAVE] Swin 저장: {output_swin_path_obj}")
|
| 480 |
+
logger.info(f"[SAVE] OCR 저장: {output_ocr_path_obj}")
|
| 481 |
+
|
| 482 |
+
# ====================================================================
|
| 483 |
+
# 결과 반환
|
| 484 |
+
# ====================================================================
|
| 485 |
+
return {
|
| 486 |
+
"success": True,
|
| 487 |
+
"version": "Unified Swin Gray + OCR (v1.0.0)",
|
| 488 |
+
"original_shape": original_shape,
|
| 489 |
+
"bbox": bbox_final,
|
| 490 |
+
"region_type": region_type,
|
| 491 |
+
"region_detected": detected_bbox is not None,
|
| 492 |
+
|
| 493 |
+
# Swin 부분
|
| 494 |
+
"swin": {
|
| 495 |
+
"output_path": str(output_swin_path_obj).replace("\\", "/"),
|
| 496 |
+
"output_shape": swin_output_3ch.shape,
|
| 497 |
+
"color_type": "Grayscale 3채널 (B=G=R, 비이진화 256단계)",
|
| 498 |
+
"is_inverted": info_swin["is_inverted"],
|
| 499 |
+
"mean_brightness_before": info_swin["mean_brightness_before"],
|
| 500 |
+
"mean_brightness_after": info_swin["mean_brightness_after"],
|
| 501 |
+
"is_bright_bg": info_swin["is_bright_bg"]
|
| 502 |
+
},
|
| 503 |
+
|
| 504 |
+
# OCR 부분
|
| 505 |
+
"ocr": {
|
| 506 |
+
"output_path": str(output_ocr_path_obj).replace("\\", "/"),
|
| 507 |
+
"output_shape": binary_final.shape,
|
| 508 |
+
"polarity": info_ocr["polarity"],
|
| 509 |
+
"mean_brightness_before": info_ocr["mean_brightness_before"],
|
| 510 |
+
"mean_brightness_after": info_ocr["mean_brightness_after"],
|
| 511 |
+
"is_white_bg": info_ocr["is_white_bg"]
|
| 512 |
+
},
|
| 513 |
+
|
| 514 |
+
"message": "[DONE] 통합 전처리 완료 (Swin + OCR)"
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
except Exception as e:
|
| 518 |
+
logger.error(f"[ERROR] 통합 전처리 실패: {e}")
|
| 519 |
+
return {
|
| 520 |
+
"success": False,
|
| 521 |
+
"message": str(e)
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
# ================================================================================
|
| 526 |
+
# Global Instance & Convenience Functions
|
| 527 |
+
# ================================================================================
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
_global_preprocessor = None
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def get_preprocessor(config_path: Optional[str] = None) -> UnifiedImagePreprocessor:
|
| 534 |
+
"""전역 전처리기 인스턴스 반환"""
|
| 535 |
+
global _global_preprocessor
|
| 536 |
+
if _global_preprocessor is None:
|
| 537 |
+
_global_preprocessor = UnifiedImagePreprocessor(config_path)
|
| 538 |
+
return _global_preprocessor
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
def preprocess_image_unified(
|
| 542 |
+
input_path: str,
|
| 543 |
+
output_swin_path: str,
|
| 544 |
+
output_ocr_path: str,
|
| 545 |
+
margin: Optional[int] = None,
|
| 546 |
+
use_rubbing: bool = False
|
| 547 |
+
) -> Dict:
|
| 548 |
+
"""
|
| 549 |
+
편의 함수: 통합 전처리
|
| 550 |
+
|
| 551 |
+
Args:
|
| 552 |
+
input_path (str): 입력 이미지 경로
|
| 553 |
+
output_swin_path (str): Swin 출력 경로
|
| 554 |
+
output_ocr_path (str): OCR 출력 경로
|
| 555 |
+
margin (int, optional): 여백
|
| 556 |
+
use_rubbing (bool): 탁본 모드
|
| 557 |
+
|
| 558 |
+
Returns:
|
| 559 |
+
dict: 처리 결과
|
| 560 |
+
"""
|
| 561 |
+
prep = get_preprocessor()
|
| 562 |
+
return prep.preprocess_unified(
|
| 563 |
+
input_path,
|
| 564 |
+
output_swin_path,
|
| 565 |
+
output_ocr_path,
|
| 566 |
+
margin,
|
| 567 |
+
use_rubbing
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
# ================================================================================
|
| 572 |
+
# Usage Example
|
| 573 |
+
# ================================================================================
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
if __name__ == "__main__":
|
| 577 |
+
"""
|
| 578 |
+
테스트 예시
|
| 579 |
+
"""
|
| 580 |
+
logger.info("=" * 80)
|
| 581 |
+
logger.info("[TEST] Unified Image Preprocessor v1.0.0 - 테스트 시작")
|
| 582 |
+
logger.info("=" * 80)
|
| 583 |
+
|
| 584 |
+
try:
|
| 585 |
+
prep = UnifiedImagePreprocessor()
|
| 586 |
+
|
| 587 |
+
result = prep.preprocess_unified(
|
| 588 |
+
"test_input.jpg",
|
| 589 |
+
"test_swin.jpg",
|
| 590 |
+
"test_ocr.png"
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
if result["success"]:
|
| 594 |
+
logger.info("[TEST] 통합 전처리 성공!")
|
| 595 |
+
logger.info(f"[TEST] Swin: {result['swin']['output_path']}")
|
| 596 |
+
logger.info(f"[TEST] OCR: {result['ocr']['output_path']}")
|
| 597 |
+
logger.info(f"[TEST] Swin 밝은배경: {'Yes' if result['swin']['is_bright_bg'] else 'No'}")
|
| 598 |
+
logger.info(f"[TEST] OCR 흰배경: {'Yes' if result['ocr']['is_white_bg'] else 'No'}")
|
| 599 |
+
else:
|
| 600 |
+
logger.error(f"[TEST] 실패: {result['message']}")
|
| 601 |
+
|
| 602 |
+
except Exception as e:
|
| 603 |
+
logger.error(f"[TEST] 예외: {e}")
|
| 604 |
+
|
| 605 |
+
logger.info("=" * 80)
|
dong_ocr.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
독립 실행 가능한 OCR 스크립트
|
| 4 |
+
Google Vision API + HRCenterNet 앙상블 기반 한자 OCR 및 손상 영역 탐지
|
| 5 |
+
|
| 6 |
+
수정사항:
|
| 7 |
+
1. 좌표(X값) 변화를 감지하여 자동으로 열(Column)을 구분하여 출력하는 로직 추가
|
| 8 |
+
2. [MASK] 좌표 등 소수점 영역 손실 방지를 위한 Safe Crop(내림/올림) 적용
|
| 9 |
+
-> 시각화 뿐만 아니라 JSON 결과 데이터 자체에도 적용하여 소수점 제거
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import json
|
| 15 |
+
import logging
|
| 16 |
+
import cv2
|
| 17 |
+
import math
|
| 18 |
+
import numpy as np
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from dotenv import load_dotenv
|
| 21 |
+
|
| 22 |
+
# 현재 스크립트의 디렉토리를 Python 경로에 추가
|
| 23 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 24 |
+
if current_dir not in sys.path:
|
| 25 |
+
sys.path.insert(0, current_dir)
|
| 26 |
+
|
| 27 |
+
# 환경 변수 로드
|
| 28 |
+
load_dotenv()
|
| 29 |
+
|
| 30 |
+
# 로깅 설정
|
| 31 |
+
logging.basicConfig(
|
| 32 |
+
level=logging.INFO,
|
| 33 |
+
format='%(asctime)s - [%(levelname)s] %(message)s'
|
| 34 |
+
)
|
| 35 |
+
logger = logging.getLogger("DONG_OCR")
|
| 36 |
+
|
| 37 |
+
# OCR 엔진 및 전처리 모듈 import
|
| 38 |
+
try:
|
| 39 |
+
from ai_modules.ocr_engine import get_ocr_engine
|
| 40 |
+
from ai_modules.preprocessor_unified import preprocess_image_unified
|
| 41 |
+
except ImportError as e:
|
| 42 |
+
logger.error(f"❌ 모듈 import 실패: {e}")
|
| 43 |
+
sys.exit(1)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def format_ocr_results(raw_results, image_filename):
|
| 47 |
+
"""
|
| 48 |
+
OCR 결과를 요청하신 JSON 포맷으로 변환하는 함수
|
| 49 |
+
수정: JSON 저장 시에도 Safe Crop(내림/올림)을 적용하여 정수로 변환
|
| 50 |
+
"""
|
| 51 |
+
formatted_list = []
|
| 52 |
+
|
| 53 |
+
if raw_results is None:
|
| 54 |
+
raw_results = []
|
| 55 |
+
|
| 56 |
+
if not raw_results:
|
| 57 |
+
return {"image": image_filename, "results": []}
|
| 58 |
+
|
| 59 |
+
order_counter = 0
|
| 60 |
+
for idx, item in enumerate(raw_results):
|
| 61 |
+
if not isinstance(item, dict): continue
|
| 62 |
+
|
| 63 |
+
min_x, min_y, max_x, max_y = 0.0, 0.0, 0.0, 0.0
|
| 64 |
+
|
| 65 |
+
# 1. 이미 'box' 리스트가 있는 경우
|
| 66 |
+
if 'box' in item and isinstance(item['box'], list) and len(item['box']) == 4:
|
| 67 |
+
try:
|
| 68 |
+
min_x, min_y, max_x, max_y = map(float, item['box'])
|
| 69 |
+
except: pass
|
| 70 |
+
|
| 71 |
+
# 2. 'box'가 없으면 개별 좌표 키 사용
|
| 72 |
+
if min_x == 0 and max_x == 0:
|
| 73 |
+
mx = item.get('min_x')
|
| 74 |
+
my = item.get('min_y')
|
| 75 |
+
Mx = item.get('max_x')
|
| 76 |
+
My = item.get('max_y')
|
| 77 |
+
|
| 78 |
+
if mx is None: mx = item.get('x', 0)
|
| 79 |
+
if my is None: my = item.get('y', 0)
|
| 80 |
+
if Mx is None:
|
| 81 |
+
Mx = item.get('x2')
|
| 82 |
+
if Mx is None:
|
| 83 |
+
width = item.get('width', 0)
|
| 84 |
+
Mx = mx + width if width > 0 else 0
|
| 85 |
+
if My is None:
|
| 86 |
+
My = item.get('y2')
|
| 87 |
+
if My is None:
|
| 88 |
+
height = item.get('height', 0)
|
| 89 |
+
My = my + height if height > 0 else 0
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
min_x, min_y, max_x, max_y = float(mx), float(my), float(Mx), float(My)
|
| 93 |
+
except: continue
|
| 94 |
+
|
| 95 |
+
if min_x == 0 and min_y == 0 and max_x == 0 and max_y == 0:
|
| 96 |
+
width = item.get('width', 0)
|
| 97 |
+
height = item.get('height', 0)
|
| 98 |
+
if width > 0 and height > 0:
|
| 99 |
+
cx, cy = item.get('center_x', width/2), item.get('center_y', height/2)
|
| 100 |
+
min_x, min_y = cx - width/2, cy - height/2
|
| 101 |
+
max_x, max_y = cx + width/2, cy + height/2
|
| 102 |
+
else: continue
|
| 103 |
+
|
| 104 |
+
if max_x <= min_x or max_y <= min_y: continue
|
| 105 |
+
|
| 106 |
+
# === [추가됨] JSON 데이터 자체에 Safe Crop 적용 (소수점 제거) ===
|
| 107 |
+
# min 좌표는 내림(floor), max 좌표는 올림(ceil)하여 영역 확보 후 정수 변환
|
| 108 |
+
min_x = int(math.floor(min_x))
|
| 109 |
+
min_y = int(math.floor(min_y))
|
| 110 |
+
max_x = int(math.ceil(max_x))
|
| 111 |
+
max_y = int(math.ceil(max_y))
|
| 112 |
+
|
| 113 |
+
# 음수 좌표 방지 (최소 0)
|
| 114 |
+
min_x = max(0, min_x)
|
| 115 |
+
min_y = max(0, min_y)
|
| 116 |
+
# ==========================================================
|
| 117 |
+
|
| 118 |
+
new_item = {
|
| 119 |
+
"order": order_counter,
|
| 120 |
+
"text": item.get('text', ''),
|
| 121 |
+
"type": item.get('type', 'TEXT'),
|
| 122 |
+
"box": [min_x, min_y, max_x, max_y],
|
| 123 |
+
"confidence": float(item.get('confidence', 0.0)),
|
| 124 |
+
"source": item.get('source', 'Unknown')
|
| 125 |
+
}
|
| 126 |
+
formatted_list.append(new_item)
|
| 127 |
+
order_counter += 1
|
| 128 |
+
|
| 129 |
+
return {"image": image_filename, "results": formatted_list}
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def draw_bboxes(image_path, results, output_path):
|
| 133 |
+
"""이미지에 Bounding Box 그리기 (Safe Crop 적용)"""
|
| 134 |
+
try:
|
| 135 |
+
img_array = np.fromfile(image_path, np.uint8)
|
| 136 |
+
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
|
| 137 |
+
if img is None:
|
| 138 |
+
img = cv2.imread(image_path)
|
| 139 |
+
if img is None: return
|
| 140 |
+
|
| 141 |
+
box_count = 0
|
| 142 |
+
colors = {
|
| 143 |
+
'Google': (0, 255, 0), 'Custom': (255, 0, 255),
|
| 144 |
+
'MASK1': (255, 0, 0), 'MASK2': (0, 0, 255), 'Default': (0, 255, 255)
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
for item in results:
|
| 148 |
+
box = item.get('box', [])
|
| 149 |
+
if len(box) != 4: continue
|
| 150 |
+
try:
|
| 151 |
+
# format_ocr_results에서 이미 정수로 변환되어 오지만,
|
| 152 |
+
# 안전을 위해 한 번 더 처리 (float로 들어와도 처리 가능하도록 유지)
|
| 153 |
+
x1 = int(math.floor(float(box[0])))
|
| 154 |
+
y1 = int(math.floor(float(box[1])))
|
| 155 |
+
x2 = int(math.ceil(float(box[2])))
|
| 156 |
+
y2 = int(math.ceil(float(box[3])))
|
| 157 |
+
except: continue
|
| 158 |
+
|
| 159 |
+
h, w = img.shape[:2]
|
| 160 |
+
# 이미지 범위 벗어나지 않게 클리핑
|
| 161 |
+
x1 = max(0, min(x1, w-1))
|
| 162 |
+
y1 = max(0, min(y1, h-1))
|
| 163 |
+
x2 = max(x1+1, min(x2, w))
|
| 164 |
+
y2 = max(y1+1, min(y2, h))
|
| 165 |
+
|
| 166 |
+
text = item.get('text', '')
|
| 167 |
+
source = item.get('source', '')
|
| 168 |
+
itype = item.get('type', 'TEXT')
|
| 169 |
+
|
| 170 |
+
if 'MASK1' in itype or '[MASK1]' in text: color = colors['MASK1']
|
| 171 |
+
elif 'MASK2' in itype or '[MASK2]' in text: color = colors['MASK2']
|
| 172 |
+
elif source in colors: color = colors[source]
|
| 173 |
+
else: color = colors['Default']
|
| 174 |
+
|
| 175 |
+
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
|
| 176 |
+
|
| 177 |
+
if itype == 'TEXT' and len(text) <= 2:
|
| 178 |
+
cv2.putText(img, text, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
| 179 |
+
elif 'MASK' in itype:
|
| 180 |
+
label = '[M1]' if itype == 'MASK1' else '[M2]'
|
| 181 |
+
cv2.putText(img, label, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1)
|
| 182 |
+
|
| 183 |
+
box_count += 1
|
| 184 |
+
|
| 185 |
+
ext = os.path.splitext(output_path)[1].lower()
|
| 186 |
+
params = [int(cv2.IMWRITE_JPEG_QUALITY), 95] if ext in ['.jpg', '.jpeg'] else [int(cv2.IMWRITE_PNG_COMPRESSION), 3]
|
| 187 |
+
result, encoded_img = cv2.imencode(ext, img, params)
|
| 188 |
+
if result:
|
| 189 |
+
with open(output_path, mode='wb') as f: encoded_img.tofile(f)
|
| 190 |
+
logger.info(f"🖼️ B-Box 이미지 저장됨: {output_path} ({box_count}개 박스)")
|
| 191 |
+
|
| 192 |
+
except Exception as e:
|
| 193 |
+
logger.error(f"❌ 시각화 중 오류: {e}")
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def run_ocr(image_path, use_preprocessing=True):
|
| 197 |
+
"""OCR 실행, 결과 출력 및 저장"""
|
| 198 |
+
if not os.path.exists(image_path):
|
| 199 |
+
logger.error(f"❌ 이미지 없음: {image_path}")
|
| 200 |
+
return False
|
| 201 |
+
|
| 202 |
+
logger.info(f"🚀 OCR 분석 시작: {image_path}")
|
| 203 |
+
|
| 204 |
+
try:
|
| 205 |
+
# 1. 전처리
|
| 206 |
+
ocr_image_path = image_path
|
| 207 |
+
preprocess_result = {'success': False}
|
| 208 |
+
if use_preprocessing:
|
| 209 |
+
logger.info("📸 이미지 전처리 중...")
|
| 210 |
+
base_dir = os.path.dirname(os.path.abspath(image_path))
|
| 211 |
+
base_name = os.path.splitext(os.path.basename(image_path))[0]
|
| 212 |
+
swin_path = os.path.join(base_dir, f"{base_name}_swin_temp.jpg")
|
| 213 |
+
ocr_preprocessed_path = os.path.join(base_dir, f"{base_name}_ocr_temp.png")
|
| 214 |
+
|
| 215 |
+
preprocess_result = preprocess_image_unified(
|
| 216 |
+
input_path=image_path, output_swin_path=swin_path,
|
| 217 |
+
output_ocr_path=ocr_preprocessed_path, use_rubbing=True
|
| 218 |
+
)
|
| 219 |
+
if preprocess_result.get('success'):
|
| 220 |
+
ocr_image_path = ocr_preprocessed_path
|
| 221 |
+
logger.info(f"✅ 전처리 완료: {ocr_preprocessed_path}")
|
| 222 |
+
else:
|
| 223 |
+
logger.warning(f"⚠️ 전처리 실패: {preprocess_result.get('message')}")
|
| 224 |
+
|
| 225 |
+
# 2. 엔진 실행
|
| 226 |
+
engine = get_ocr_engine()
|
| 227 |
+
logger.info("✅ OCR 엔진 로드 완료")
|
| 228 |
+
|
| 229 |
+
try:
|
| 230 |
+
raw_result = engine.run_ocr(ocr_image_path)
|
| 231 |
+
except Exception as e:
|
| 232 |
+
logger.error(f"❌ OCR 실행 예외: {e}")
|
| 233 |
+
return False
|
| 234 |
+
|
| 235 |
+
if not raw_result: return False
|
| 236 |
+
|
| 237 |
+
is_success = raw_result.get('success', False)
|
| 238 |
+
if not is_success and 'results' in raw_result and isinstance(raw_result['results'], list):
|
| 239 |
+
is_success = True
|
| 240 |
+
|
| 241 |
+
if not is_success:
|
| 242 |
+
logger.error(f"❌ OCR 실패: {raw_result.get('error')}")
|
| 243 |
+
return False
|
| 244 |
+
|
| 245 |
+
logger.info("\n" + "="*60)
|
| 246 |
+
logger.info("✅ OCR 분석 완료")
|
| 247 |
+
|
| 248 |
+
# 3. 데이터 포맷팅
|
| 249 |
+
formatted_result = format_ocr_results(raw_result.get('results', []), os.path.basename(image_path))
|
| 250 |
+
results_list = formatted_result.get('results', [])
|
| 251 |
+
|
| 252 |
+
# 4. [열 구분 출력 로직] 좌표 기반으로 열을 계산하여 출력
|
| 253 |
+
logger.info("\n" + "📜 [ 인식된 텍스트 결과 (자동 열 구분) ] " + "-"*25)
|
| 254 |
+
|
| 255 |
+
if not results_list:
|
| 256 |
+
logger.info(" (결과 없음)")
|
| 257 |
+
else:
|
| 258 |
+
columns = []
|
| 259 |
+
current_col_text = []
|
| 260 |
+
|
| 261 |
+
# 첫 번째 글자의 X 중심점 계산
|
| 262 |
+
first_box = results_list[0]['box']
|
| 263 |
+
prev_cx = (first_box[0] + first_box[2]) / 2
|
| 264 |
+
|
| 265 |
+
for item in results_list:
|
| 266 |
+
box = item['box']
|
| 267 |
+
curr_cx = (box[0] + box[2]) / 2
|
| 268 |
+
|
| 269 |
+
# 텍스트 추출 (MASK 처리)
|
| 270 |
+
text = item.get('text', '')
|
| 271 |
+
if item.get('type') in ['MASK1', 'MASK2']:
|
| 272 |
+
text = f"[{item.get('type')}]"
|
| 273 |
+
|
| 274 |
+
# === 열 구분 핵심 로직 ===
|
| 275 |
+
# 이전 글자와 X좌표 중심이 50픽셀 이상 차이나면 새로운 열로 간주
|
| 276 |
+
# (일반적으로 세로쓰기에서 줄바꿈 시 X좌표가 크게 변함)
|
| 277 |
+
if abs(curr_cx - prev_cx) > 50:
|
| 278 |
+
if current_col_text:
|
| 279 |
+
columns.append("".join(current_col_text))
|
| 280 |
+
current_col_text = []
|
| 281 |
+
prev_cx = curr_cx # 새로운 열의 기준으로 갱신
|
| 282 |
+
|
| 283 |
+
current_col_text.append(text)
|
| 284 |
+
# 같은 열 내에서는 미세한 X 흔들림이 있을 수 있으므로 prev_cx를 계속 갱신하지 않고
|
| 285 |
+
# 해당 열의 '대표' X값을 유지하거나, 혹은 글자마다 갱신할 수 있음.
|
| 286 |
+
# 여기서는 글자가 비스듬할 수 있으므로 매번 갱신하는 방식을 씀
|
| 287 |
+
prev_cx = curr_cx
|
| 288 |
+
|
| 289 |
+
# 마지막 열 추가
|
| 290 |
+
if current_col_text:
|
| 291 |
+
columns.append("".join(current_col_text))
|
| 292 |
+
|
| 293 |
+
# 출력
|
| 294 |
+
for idx, col_text in enumerate(columns, 1):
|
| 295 |
+
logger.info(f" [열 {idx:02d}] {col_text}")
|
| 296 |
+
|
| 297 |
+
logger.info("-" * 60 + "\n")
|
| 298 |
+
|
| 299 |
+
# 5. 결과 저장
|
| 300 |
+
json_path = os.path.splitext(image_path)[0] + "_ocr_result.json"
|
| 301 |
+
with open(json_path, 'w', encoding='utf-8') as f:
|
| 302 |
+
json.dump(formatted_result, f, ensure_ascii=False, indent=2)
|
| 303 |
+
logger.info(f"💾 JSON 결과 저장됨: {json_path}")
|
| 304 |
+
|
| 305 |
+
# 6. 시각화 저장
|
| 306 |
+
output_img_path = os.path.splitext(image_path)[0] + "_bbox.jpg"
|
| 307 |
+
bbox_image_path = ocr_image_path if use_preprocessing and preprocess_result.get('success') else image_path
|
| 308 |
+
draw_bboxes(bbox_image_path, results_list, output_img_path)
|
| 309 |
+
|
| 310 |
+
# 7. 통계
|
| 311 |
+
counts = {'Google':0, 'Custom':0, 'MASK1':0, 'MASK2':0, 'TEXT':0}
|
| 312 |
+
for r in results_list:
|
| 313 |
+
if r['source'] in counts: counts[r['source']] += 1
|
| 314 |
+
if r['type'] in counts: counts[r['type']] += 1
|
| 315 |
+
|
| 316 |
+
logger.info("📊 최종 통계")
|
| 317 |
+
logger.info(f" - 🟢 Google: {counts['Google']}개")
|
| 318 |
+
logger.info(f" - 🟣 Custom: {counts['Custom']}개")
|
| 319 |
+
logger.info(f" - 🔵 MASK1: {counts['MASK1']}개")
|
| 320 |
+
logger.info(f" - 🔴 MASK2: {counts['MASK2']}개")
|
| 321 |
+
logger.info(f" - 📝 TEXT: {counts['TEXT']}개")
|
| 322 |
+
logger.info("="*60)
|
| 323 |
+
|
| 324 |
+
return True
|
| 325 |
+
|
| 326 |
+
except Exception as e:
|
| 327 |
+
logger.error(f"❌ 오류 발생: {e}", exc_info=True)
|
| 328 |
+
return False
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def main():
|
| 332 |
+
if len(sys.argv) < 2:
|
| 333 |
+
print("사용법: python dong_ocr.py <이미지>")
|
| 334 |
+
sys.exit(1)
|
| 335 |
+
|
| 336 |
+
if not os.getenv('OCR_WEIGHTS_BASE_PATH') or not os.getenv('GOOGLE_CREDENTIALS_JSON'):
|
| 337 |
+
logger.error("❌ 환경변수 미설정")
|
| 338 |
+
sys.exit(1)
|
| 339 |
+
|
| 340 |
+
if run_ocr(sys.argv[1]):
|
| 341 |
+
logger.info("✅ 작업 완료!")
|
| 342 |
+
sys.exit(0)
|
| 343 |
+
else:
|
| 344 |
+
logger.error("❌ 작업 실패")
|
| 345 |
+
sys.exit(1)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
if __name__ == "__main__":
|
| 349 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
opencv-python
|
| 2 |
+
numpy
|
| 3 |
+
torch
|
| 4 |
+
torchvision
|
| 5 |
+
python-dotenv
|
| 6 |
+
Pillow
|
| 7 |
+
google-cloud-vision
|
| 8 |
+
huggingface-hub>=0.34.0,<1.0
|
weights/best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:31db0eec96515dc820475df31245d7fe51ffcc56c76dc70df4e2bf83ff21d7e6
|
| 3 |
+
size 115004284
|
weights/best_5000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d3180e120cb82a5dc2474448a3b61ea72b53e2507a2dc367bda76dc222a35ec6
|
| 3 |
+
size 62505977
|