donghyun commited on
Commit
8672bad
ยท
1 Parent(s): 1a7b7d2

Add OCR code, modules, and weights

Browse files
.gitignore ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+
6
+ # Environment & Secrets
7
+ .env
8
+ *.json
9
+ weights/*.json
10
+
11
+ # Logs & Temp
12
+ *.log
13
+ *.tmp
14
+ *.temp
15
+
16
+ # Output files
17
+ *_bbox.*
18
+ *_ocr_result.json
19
+
20
+ # OS files
21
+ .DS_Store
22
+ Thumbs.db
README.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+
ai_modules/__init__.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ ================================================================================
4
+ Epitext AI Unified Preprocessing Module
5
+ ================================================================================
6
+
7
+ ํ†ตํ•ฉ ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ ํŒจํ‚ค์ง€ (Swin Gray + OCR ๋™์‹œ ์ƒ์„ฑ)
8
+
9
+ ํ•œ ๋ฒˆ์˜ ํ•จ์ˆ˜ ํ˜ธ์ถœ๋กœ ๋‘ ๊ฐ€์ง€ ์ „์ฒ˜๋ฆฌ ์™„๋ฃŒ:
10
+
11
+ 1๏ธโƒฃ Swin Gray: ๊ทธ๋ ˆ์ด ๋น„์ด์ง„ํ™” (์ •๋ณด ์†์‹ค ์ตœ์†Œ) โ†’ JPG 3์ฑ„๋„
12
+
13
+ 2๏ธโƒฃ OCR: ์ด์ง„ํ™” (๋ช…ํ™•ํ•œ ํ‘๋ฐฑ) โ†’ PNG 1์ฑ„๋„
14
+
15
+ ๋ฒ„์ „: 1.0.0
16
+ ์ƒํƒœ: โœ… Production Ready
17
+
18
+ ์ฃผ์š” ํŠน์ง•:
19
+
20
+ โœ… ํšจ์œจ์„ฑ: ์˜์—ญ ๊ฒ€์ถœ 1ํšŒ (๋‘ ๊ฐ€์ง€ ๋ชจ๋‘ ์‚ฌ์šฉ)
21
+
22
+ โœ… ๋ฐฐ๊ฒฝ ๋ณด์žฅ: Swin (๋ฐ์Œ) + OCR (ํ•˜์–€์ƒ‰)
23
+
24
+ โœ… ํƒ๋ณธ ์ง€์›: ์ž๋™ ๊ฒ€์ถœ ์˜ต์…˜
25
+
26
+ โœ… ์„ค์ • ๊ฐ€๋Šฅ: JSON ๊ธฐ๋ฐ˜ ์ปค์Šคํ„ฐ๋งˆ์ด์ง•
27
+
28
+ """
29
+
30
+ from .preprocessor_unified import (
31
+ UnifiedImagePreprocessor,
32
+ get_preprocessor,
33
+ preprocess_image_unified
34
+ )
35
+ from .ocr_engine import (
36
+ get_ocr_engine,
37
+ OCREngine,
38
+ ocr_and_detect
39
+ )
40
+ from .nlp_engine import (
41
+ get_nlp_engine,
42
+ NLPEngine,
43
+ process_text_with_nlp
44
+ )
45
+
46
+ __version__ = "1.0.0"
47
+ __author__ = "Epitext Team"
48
+
49
+ __all__ = [
50
+ "UnifiedImagePreprocessor",
51
+ "get_preprocessor",
52
+ "preprocess_image_unified",
53
+ "get_ocr_engine",
54
+ "OCREngine",
55
+ "ocr_and_detect",
56
+ "get_nlp_engine",
57
+ "NLPEngine",
58
+ "process_text_with_nlp"
59
+ ]
60
+
ai_modules/models/HRCenterNet.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from ai_modules.models.modules import BasicBlock, Bottleneck
4
+
5
+
6
+ class StageModule(nn.Module):
7
+ def __init__(self, stage, output_branches, c, bn_momentum):
8
+ super(StageModule, self).__init__()
9
+ self.stage = stage
10
+ self.output_branches = output_branches
11
+
12
+ self.branches = nn.ModuleList()
13
+ for i in range(self.stage):
14
+ w = c * (2 ** i)
15
+ branch = nn.Sequential(
16
+ BasicBlock(w, w, bn_momentum=bn_momentum),
17
+ BasicBlock(w, w, bn_momentum=bn_momentum),
18
+ BasicBlock(w, w, bn_momentum=bn_momentum),
19
+ BasicBlock(w, w, bn_momentum=bn_momentum),
20
+ )
21
+ self.branches.append(branch)
22
+
23
+ self.fuse_layers = nn.ModuleList()
24
+ # for each output_branches (i.e. each branch in all cases but the very last one)
25
+ for i in range(self.output_branches):
26
+ self.fuse_layers.append(nn.ModuleList())
27
+ for j in range(self.stage): # for each branch
28
+ if i == j:
29
+ self.fuse_layers[-1].append(nn.Sequential()) # Used in place of "None" because it is callable
30
+ elif i < j:
31
+ self.fuse_layers[-1].append(nn.Sequential(
32
+ nn.Conv2d(c * (2 ** j), c * (2 ** i), kernel_size=(1, 1), stride=(1, 1), bias=False),
33
+ nn.BatchNorm2d(c * (2 ** i), eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
34
+ nn.Upsample(scale_factor=(2.0 ** (j - i)), mode='nearest'),
35
+ ))
36
+ elif i > j:
37
+ ops = []
38
+ for k in range(i - j - 1):
39
+ ops.append(nn.Sequential(
40
+ nn.Conv2d(c * (2 ** j), c * (2 ** j), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1),
41
+ bias=False),
42
+ nn.BatchNorm2d(c * (2 ** j), eps=1e-05, momentum=0.1, affine=True,
43
+ track_running_stats=True),
44
+ nn.ReLU(inplace=True),
45
+ ))
46
+ ops.append(nn.Sequential(
47
+ nn.Conv2d(c * (2 ** j), c * (2 ** i), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1),
48
+ bias=False),
49
+ nn.BatchNorm2d(c * (2 ** i), eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
50
+ ))
51
+ self.fuse_layers[-1].append(nn.Sequential(*ops))
52
+
53
+ self.relu = nn.ReLU(inplace=True)
54
+
55
+ def forward(self, x):
56
+ assert len(self.branches) == len(x)
57
+
58
+ x = [branch(b) for branch, b in zip(self.branches, x)]
59
+
60
+ x_fused = []
61
+ for i in range(len(self.fuse_layers)):
62
+ for j in range(0, len(self.branches)):
63
+ if j == 0:
64
+ x_fused.append(self.fuse_layers[i][0](x[0]))
65
+ else:
66
+ x_fused[i] = x_fused[i] + self.fuse_layers[i][j](x[j])
67
+
68
+ for i in range(len(x_fused)):
69
+ x_fused[i] = self.relu(x_fused[i])
70
+
71
+ return x_fused
72
+
73
+
74
+ class _HRCenterNet(nn.Module):
75
+ def __init__(self, c=48, nof_joints=17, bn_momentum=0.1):
76
+ super(_HRCenterNet, self).__init__()
77
+
78
+ # Input (stem net)
79
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
80
+ self.bn1 = nn.BatchNorm2d(64, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True)
81
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
82
+ self.bn2 = nn.BatchNorm2d(64, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True)
83
+ self.relu = nn.ReLU(inplace=True)
84
+
85
+ # Stage 1 (layer1) - First group of bottleneck (resnet) modules
86
+ downsample = nn.Sequential(
87
+ nn.Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False),
88
+ nn.BatchNorm2d(256, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
89
+ )
90
+ self.layer1 = nn.Sequential(
91
+ Bottleneck(64, 64, downsample=downsample),
92
+ Bottleneck(256, 64),
93
+ Bottleneck(256, 64),
94
+ Bottleneck(256, 64),
95
+ )
96
+
97
+ # Fusion layer 1 (transition1) - Creation of the first two branches (one full and one half resolution)
98
+ self.transition1 = nn.ModuleList([
99
+ nn.Sequential(
100
+ nn.Conv2d(256, c, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
101
+ nn.BatchNorm2d(c, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
102
+ nn.ReLU(inplace=True),
103
+ ),
104
+ nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights
105
+ nn.Conv2d(256, c * (2 ** 1), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
106
+ nn.BatchNorm2d(c * (2 ** 1), eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
107
+ nn.ReLU(inplace=True),
108
+ )),
109
+ ])
110
+
111
+ # Stage 2 (stage2) - Second module with 1 group of bottleneck (resnet) modules. This has 2 branches
112
+ self.stage2 = nn.Sequential(
113
+ StageModule(stage=2, output_branches=2, c=c, bn_momentum=bn_momentum),
114
+ )
115
+
116
+ # Fusion layer 2 (transition2) - Creation of the third branch (1/4 resolution)
117
+ self.transition2 = nn.ModuleList([
118
+ nn.Sequential(), # None, - Used in place of "None" because it is callable
119
+ nn.Sequential(), # None, - Used in place of "None" because it is callable
120
+ nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights
121
+ nn.Conv2d(c * (2 ** 1), c * (2 ** 2), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
122
+ nn.BatchNorm2d(c * (2 ** 2), eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
123
+ nn.ReLU(inplace=True),
124
+ )), # ToDo Why the new branch derives from the "upper" branch only?
125
+ ])
126
+
127
+ # Stage 3 (stage3) - Third module with 4 groups of bottleneck (resnet) modules. This has 3 branches
128
+ self.stage3 = nn.Sequential(
129
+ StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
130
+ StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
131
+ StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
132
+ StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
133
+ )
134
+
135
+ # Fusion layer 3 (transition3) - Creation of the fourth branch (1/8 resolution)
136
+ self.transition3 = nn.ModuleList([
137
+ nn.Sequential(), # None, - Used in place of "None" because it is callable
138
+ nn.Sequential(), # None, - Used in place of "None" because it is callable
139
+ nn.Sequential(), # None, - Used in place of "None" because it is callable
140
+ nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights
141
+ nn.Conv2d(c * (2 ** 2), c * (2 ** 3), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
142
+ nn.BatchNorm2d(c * (2 ** 3), eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
143
+ nn.ReLU(inplace=True),
144
+ )), # ToDo Why the new branch derives from the "upper" branch only?
145
+ ])
146
+
147
+ # Stage 4 (stage4) - Fourth module with 3 groups of bottleneck (resnet) modules. This has 4 branches
148
+ self.stage4 = nn.Sequential(
149
+ StageModule(stage=4, output_branches=4, c=c, bn_momentum=bn_momentum),
150
+ StageModule(stage=4, output_branches=4, c=c, bn_momentum=bn_momentum),
151
+ StageModule(stage=4, output_branches=1, c=c, bn_momentum=bn_momentum),
152
+ )
153
+
154
+ # Final layer (final_layer)
155
+ self.final_layer = nn.Sequential(
156
+ nn.Conv2d(c, 32, kernel_size=(1, 1), stride=(1, 1)),
157
+ nn.BatchNorm2d(32, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
158
+ nn.ReLU(inplace=True),
159
+ nn.Conv2d(32, nof_joints, kernel_size=(1, 1), stride=(1, 1)),
160
+ nn.Sigmoid()
161
+ )
162
+
163
+ def forward(self, x):
164
+ x = self.conv1(x)
165
+ x = self.bn1(x)
166
+ x = self.relu(x)
167
+ x = self.conv2(x)
168
+ x = self.bn2(x)
169
+ x = self.relu(x)
170
+
171
+ x = self.layer1(x)
172
+ x = [trans(x) for trans in self.transition1] # Since now, x is a list (# == nof branches)
173
+
174
+ x = self.stage2(x)
175
+ # x = [trans(x[-1]) for trans in self.transition2] # New branch derives from the "upper" branch only
176
+ x = [
177
+ self.transition2[0](x[0]),
178
+ self.transition2[1](x[1]),
179
+ self.transition2[2](x[-1])
180
+ ] # New branch derives from the "upper" branch only
181
+
182
+ x = self.stage3(x)
183
+ # x = [trans(x) for trans in self.transition3] # New branch derives from the "upper" branch only
184
+ x = [
185
+ self.transition3[0](x[0]),
186
+ self.transition3[1](x[1]),
187
+ self.transition3[2](x[2]),
188
+ self.transition3[3](x[-1])
189
+ ] # New branch derives from the "upper" branch only
190
+
191
+ x = self.stage4(x)
192
+
193
+ x = self.final_layer(x[0])
194
+
195
+ return x
196
+
197
+ def HRCenterNet(args):
198
+
199
+ model = _HRCenterNet(32, 5, 0.1)
200
+
201
+ if not (args.log_dir == None):
202
+ model.load_state_dict(torch.load(args.log_dir))
203
+
204
+ return model
ai_modules/models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ OCR ๋ชจ๋ธ ๋ชจ๋“ˆ ํŒจํ‚ค์ง€
4
+ """
5
+
ai_modules/models/modules.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class Bottleneck(nn.Module):
6
+ expansion = 4
7
+
8
+ def __init__(self, inplanes, planes, stride=1, downsample=None, bn_momentum=0.1):
9
+ super(Bottleneck, self).__init__()
10
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
11
+ self.bn1 = nn.BatchNorm2d(planes, momentum=bn_momentum)
12
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
13
+ self.bn2 = nn.BatchNorm2d(planes, momentum=bn_momentum)
14
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
15
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=bn_momentum)
16
+ self.relu = nn.ReLU(inplace=True)
17
+ self.downsample = downsample
18
+ self.stride = stride
19
+
20
+ def forward(self, x):
21
+ residual = x
22
+
23
+ out = self.conv1(x)
24
+ out = self.bn1(out)
25
+ out = self.relu(out)
26
+
27
+ out = self.conv2(out)
28
+ out = self.bn2(out)
29
+ out = self.relu(out)
30
+
31
+ out = self.conv3(out)
32
+ out = self.bn3(out)
33
+
34
+ if self.downsample is not None:
35
+ residual = self.downsample(x)
36
+
37
+ out += residual
38
+ out = self.relu(out)
39
+
40
+ return out
41
+
42
+
43
+ # class Bottleneck_Tranpose(nn.Module):
44
+ # expansion = 4
45
+
46
+ # def __init__(self, inplanes, planes, stride=1, downsample=None, bn_momentum=0.1):
47
+ # super(Bottleneck, self).__init__()
48
+ # nn.ConvTranspose2d(c, 64, (3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1)),
49
+
50
+ # self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
51
+ # self.bn1 = nn.BatchNorm2d(planes, momentum=bn_momentum)
52
+ # self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
53
+ # self.bn2 = nn.BatchNorm2d(planes, momentum=bn_momentum)
54
+ # self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
55
+ # self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=bn_momentum)
56
+ # self.relu = nn.ReLU(inplace=True)
57
+ # self.downsample = downsample
58
+ # self.stride = stride
59
+
60
+ # def forward(self, x):
61
+ # residual = x
62
+
63
+ # out = self.conv1(x)
64
+ # out = self.bn1(out)
65
+ # out = self.relu(out)
66
+
67
+ # out = self.conv2(out)
68
+ # out = self.bn2(out)
69
+ # out = self.relu(out)
70
+
71
+ # out = self.conv3(out)
72
+ # out = self.bn3(out)
73
+
74
+ # if self.downsample is not None:
75
+ # residual = self.downsample(x)
76
+
77
+ # out += residual
78
+ # out = self.relu(out)
79
+
80
+ # return out
81
+
82
+ class BasicBlock(nn.Module):
83
+ expansion = 1
84
+
85
+ def __init__(self, inplanes, planes, stride=1, downsample=None, bn_momentum=0.1):
86
+ super(BasicBlock, self).__init__()
87
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
88
+ self.bn1 = nn.BatchNorm2d(planes, momentum=bn_momentum)
89
+ self.relu = nn.ReLU(inplace=True)
90
+ self.conv2 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1, bias=False)
91
+ self.bn2 = nn.BatchNorm2d(planes, momentum=bn_momentum)
92
+ self.downsample = downsample
93
+ self.stride = stride
94
+
95
+ def forward(self, x):
96
+ residual = x
97
+
98
+ out = self.conv1(x)
99
+ out = self.bn1(out)
100
+ out = self.relu(out)
101
+
102
+ out = self.conv2(out)
103
+ out = self.bn2(out)
104
+
105
+ if self.downsample is not None:
106
+ residual = self.downsample(x)
107
+
108
+ out += residual
109
+ out = self.relu(out)
110
+
111
+ return out
ai_modules/models/resnet.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import PIL
3
+
4
+ from torch import nn
5
+ from torchvision import transforms
6
+
7
+ class BasicBlock(nn.Module):
8
+ expansion = 1
9
+
10
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
11
+ base_width=64, dilation=1, norm_layer=None):
12
+ super(BasicBlock, self).__init__()
13
+ if norm_layer is None:
14
+ norm_layer = nn.BatchNorm2d
15
+ if groups != 1 or base_width != 64:
16
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
17
+ if dilation > 1:
18
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
19
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
20
+ self.conv1 = conv3x3(inplanes, planes, stride)
21
+ self.bn1 = norm_layer(planes)
22
+ self.relu = nn.ReLU(inplace=True)
23
+ self.conv2 = conv3x3(planes, planes)
24
+ self.bn2 = norm_layer(planes)
25
+ self.downsample = downsample
26
+ self.stride = stride
27
+
28
+ def forward(self, x):
29
+ identity = x
30
+
31
+ out = self.conv1(x)
32
+ out = self.bn1(out)
33
+ out = self.relu(out)
34
+
35
+ out = self.conv2(out)
36
+ out = self.bn2(out)
37
+
38
+ if self.downsample is not None:
39
+ identity = self.downsample(x)
40
+
41
+ out += identity
42
+ out = self.relu(out)
43
+
44
+ return out
45
+
46
+ class ResNet(nn.Module):
47
+ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
48
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
49
+ norm_layer=None):
50
+ super(ResNet, self).__init__()
51
+ if norm_layer is None:
52
+ norm_layer = nn.BatchNorm2d
53
+ self._norm_layer = norm_layer
54
+
55
+ self.inplanes = 64
56
+ self.dilation = 1
57
+ if replace_stride_with_dilation is None:
58
+ # each element in the tuple indicates if we should replace
59
+ # the 2x2 stride with a dilated convolution instead
60
+ replace_stride_with_dilation = [False, False, False]
61
+ if len(replace_stride_with_dilation) != 3:
62
+ raise ValueError("replace_stride_with_dilation should be None "
63
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
64
+ self.groups = groups
65
+ self.base_width = width_per_group
66
+ self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,
67
+ bias=False)
68
+ self.bn1 = norm_layer(self.inplanes)
69
+ self.relu = nn.ReLU(inplace=True)
70
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
71
+ self.layer1 = self._make_layer(block, 64, layers[0])
72
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
73
+ dilate=replace_stride_with_dilation[0])
74
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
75
+ dilate=replace_stride_with_dilation[1])
76
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
77
+ dilate=replace_stride_with_dilation[2])
78
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
79
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
80
+
81
+ for m in self.modules():
82
+ if isinstance(m, nn.Conv2d):
83
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
84
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
85
+ nn.init.constant_(m.weight, 1)
86
+ nn.init.constant_(m.bias, 0)
87
+
88
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
89
+ norm_layer = self._norm_layer
90
+ downsample = None
91
+ previous_dilation = self.dilation
92
+ if dilate:
93
+ self.dilation *= stride
94
+ stride = 1
95
+ if stride != 1 or self.inplanes != planes * block.expansion:
96
+ downsample = nn.Sequential(
97
+ conv1x1(self.inplanes, planes * block.expansion, stride),
98
+ norm_layer(planes * block.expansion),
99
+ )
100
+
101
+ layers = []
102
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
103
+ self.base_width, previous_dilation, norm_layer))
104
+ self.inplanes = planes * block.expansion
105
+ for _ in range(1, blocks):
106
+ layers.append(block(self.inplanes, planes, groups=self.groups,
107
+ base_width=self.base_width, dilation=self.dilation,
108
+ norm_layer=norm_layer))
109
+
110
+ return nn.Sequential(*layers)
111
+
112
+ def _forward_impl(self, x):
113
+ # See note [TorchScript super()]
114
+ x = self.conv1(x)
115
+ x = self.bn1(x)
116
+ x = self.relu(x)
117
+ x = self.maxpool(x)
118
+
119
+ x = self.layer1(x)
120
+ x = self.layer2(x)
121
+ x = self.layer3(x)
122
+ x = self.layer4(x)
123
+
124
+ x = self.avgpool(x)
125
+ x = torch.flatten(x, 1)
126
+ x = self.fc(x)
127
+
128
+ return x
129
+
130
+ def forward(self, x):
131
+ return self._forward_impl(x)
132
+
133
+ def conv1x1(in_planes, out_planes, stride=1):
134
+ """1x1 convolution"""
135
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
136
+
137
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
138
+ """3x3 convolution with padding"""
139
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
140
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
141
+
142
+ class ResnetCustom(torch.nn.Module):
143
+ def __init__(self, weight_fn):
144
+ super().__init__()
145
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
146
+ weight = torch.load(weight_fn, map_location=self.device)
147
+ self.id2charDict = weight['vocab']['id2char']
148
+ num_classes = len(self.id2charDict)
149
+ self.id2charDict[-1] = "โ– " # unrecognized token
150
+ self.transform = transforms.Compose([transforms.Grayscale(),
151
+ transforms.Resize((64,64)),
152
+ transforms.ToTensor()])
153
+
154
+ self.net = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)
155
+ self.net.load_state_dict(weight['model'])
156
+ self.net = self.net.to(self.device)
157
+ self.net.eval()
158
+ #self.net(torch.rand((64,1,64,64)))
159
+ print(f'{weight_fn} loaded!')
160
+
161
+ def forward(self, images:PIL.Image, bs=256, conf_thres=0.5):
162
+ '''
163
+ input
164
+ images: list of PIL images
165
+ return
166
+ chars: list of recognized chars
167
+ '''
168
+ chars = []
169
+ for i in range(0, len(images), bs):
170
+ inp = []
171
+ for image in images[i: i+bs]:
172
+ inp.append(self.transform(image))
173
+ inp = torch.stack(inp, dim=0).to(self.device)
174
+ out = self.net(inp)
175
+ out = torch.nn.functional.softmax(out, dim=1)
176
+ conf, indice = torch.max(out, dim=1)
177
+ indice[conf<conf_thres] = -1
178
+ chars += [self.id2charDict[x] for x in indice.tolist()]
179
+
180
+ return chars
181
+
182
+ if __name__ == "__main__":
183
+ net = ResnetCustom(weight_fn="best_5000.pt")
184
+ inp = [PIL.Image.open('0.jpg'), PIL.Image.open('1.png')]
185
+ print(net(inp))
186
+
ai_modules/nlp/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Korean Historical Text Processor NLP Module
3
+ ํ•œ๊ตญ์–ด ๊ณ ์ „ ํ…์ŠคํŠธ์˜ ๊ตฌ๋‘์  ๋ณต์› ๋ฐ MLM ์˜ˆ์ธก์„ ์œ„ํ•œ ๋ชจ๋“ˆ์ž…๋‹ˆ๋‹ค.
4
+ """
5
+
6
+ __version__ = "1.0.0"
7
+ __author__ = "EPITEXT"
8
+
9
+ from .punctuation_restorer import PunctuationRestorer
10
+ from .mlm_predictor import MLMPredictor
11
+ from .utils import (
12
+ remove_punctuation,
13
+ extract_mask_info,
14
+ replace_mask_with_symbol,
15
+ normalize_mask_tokens,
16
+ )
17
+
18
+ __all__ = [
19
+ "PunctuationRestorer",
20
+ "MLMPredictor",
21
+ "remove_punctuation",
22
+ "extract_mask_info",
23
+ "replace_mask_with_symbol",
24
+ "normalize_mask_tokens",
25
+ ]
26
+
ai_modules/nlp/mlm_predictor.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MLM(Masked Language Model) ์˜ˆ์ธก ๋ชจ๋“ˆ
3
+ BERT ๊ธฐ๋ฐ˜ MLM์„ ์‚ฌ์šฉํ•˜์—ฌ ๋งˆ์Šคํ‚น๋œ ํ† ํฐ์„ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค.
4
+ """
5
+
6
+ import torch
7
+ from typing import List, Dict
8
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
9
+ from .utils import normalize_mask_tokens
10
+
11
+
12
+ class MLMPredictor:
13
+ """MLM ์˜ˆ์ธก์„ ๋‹ด๋‹นํ•˜๋Š” ํด๋ž˜์Šค"""
14
+
15
+ def __init__(self, config: Dict, device: str = "cpu"):
16
+ """
17
+ MLM ์˜ˆ์ธก๊ธฐ๋ฅผ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
18
+
19
+ Args:
20
+ config: ์„ค์ • ๋”•์…”๋„ˆ๋ฆฌ (nlp_config.json์—์„œ ๋กœ๋“œ)
21
+ device: ์—ฐ์‚ฐ ๋””๋ฐ”์ด์Šค ('cpu' ๋˜๋Š” 'cuda')
22
+ """
23
+ mlm_cfg = config['mlm_model']
24
+ self.model_name = mlm_cfg['model_name']
25
+ self.top_k = mlm_cfg['top_k']
26
+ self.max_length = mlm_cfg['max_length']
27
+ self.device = device
28
+ self.tokenizer = None
29
+ self.model = None
30
+
31
+ def load_model(self) -> None:
32
+ """๋ชจ๋ธ์„ ๋ฉ”๋ชจ๋ฆฌ์— ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค."""
33
+ print(f"[MLM] ๋ชจ๋ธ ๋กœ๋“œ ์ค‘: {self.model_name}")
34
+
35
+ self.tokenizer = AutoTokenizer.from_pretrained(
36
+ self.model_name,
37
+ use_fast=False
38
+ )
39
+ self.model = AutoModelForMaskedLM.from_pretrained(self.model_name)
40
+ self.model.to(self.device)
41
+ self.model.eval()
42
+
43
+ print(f"[MLM] โœ“ MLM ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ")
44
+
45
+ def predict_masks(
46
+ self,
47
+ text: str
48
+ ) -> List[List[Dict[str, any]]]:
49
+ """
50
+ ํ…์ŠคํŠธ ๋‚ด์˜ [MASK] ํ† ํฐ์„ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค.
51
+
52
+ Args:
53
+ text: ๋งˆ์Šคํฌ๊ฐ€ ํฌํ•จ๋œ ํ…์ŠคํŠธ
54
+
55
+ Returns:
56
+ ๊ฐ ๋งˆ์Šคํฌ ์œ„์น˜๋ณ„ top-k ์˜ˆ์ธก ๊ฒฐ๊ณผ ๋ฆฌ์ŠคํŠธ
57
+ """
58
+ # [MASK1], [MASK2] -> [MASK] ์ •๊ทœํ™”
59
+ text_normalized = normalize_mask_tokens(text)
60
+
61
+ print(f"[MLM] ์ž…๋ ฅ ํ…์ŠคํŠธ ์ƒ˜ํ”Œ: {text_normalized[:100]}...")
62
+ print(f"[MLM] [MASK] ํ† ํฐ ๊ฐœ์ˆ˜: {text_normalized.count('[MASK]')}")
63
+
64
+ # ํ† ํฌ๋‚˜์ด์ฆˆ
65
+ inputs = self.tokenizer(
66
+ text_normalized,
67
+ return_tensors="pt",
68
+ truncation=True,
69
+ max_length=self.max_length
70
+ )
71
+
72
+ # ๋””๋ฐ”์ด์Šค๋กœ ์ด๋™
73
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
74
+
75
+ # [MASK] ์œ„์น˜ ์ฐพ๊ธฐ
76
+ mask_indices = torch.where(
77
+ inputs["input_ids"] == self.tokenizer.mask_token_id
78
+ )[1]
79
+
80
+ print(f"[MLM] ํ† ํฌ๋‚˜์ด์ €๊ฐ€ ์ฐพ์€ [MASK] ์œ„์น˜ ๊ฐœ์ˆ˜: {len(mask_indices)}")
81
+
82
+ if len(mask_indices) == 0:
83
+ print("[MLM] โš ๏ธ ๊ฒฝ๊ณ : [MASK] ํ† ํฐ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค!")
84
+ sample_tokens = self.tokenizer.convert_ids_to_tokens(
85
+ inputs['input_ids'][0][:50]
86
+ )
87
+ print(f"[MLM] ํ† ํฐํ™”๋œ ์ž…๋ ฅ ์ƒ˜ํ”Œ: {sample_tokens}")
88
+ return []
89
+
90
+ # ์˜ˆ์ธก ์ˆ˜ํ–‰
91
+ with torch.no_grad():
92
+ outputs = self.model(**inputs)
93
+ logits = outputs.logits
94
+
95
+ # ๊ฐ ๋งˆ์Šคํฌ ์œ„์น˜๋ณ„๋กœ top-k ์˜ˆ์ธก
96
+ all_predictions = []
97
+ for mask_idx in mask_indices:
98
+ mask_logits = logits[0, mask_idx, :]
99
+
100
+ # ์ „์ฒด ์–ดํœ˜์— ๋Œ€ํ•ด softmax ๊ณ„์‚ฐ ํ›„ top-k ์„ ํƒ
101
+ all_probs = torch.nn.functional.softmax(mask_logits, dim=-1)
102
+ top_k_probs, top_k_indices = torch.topk(all_probs, self.top_k)
103
+
104
+ top_k_tokens = self.tokenizer.convert_ids_to_tokens(
105
+ top_k_indices.tolist()
106
+ )
107
+
108
+ predictions = [
109
+ {
110
+ "token": token,
111
+ "probability": float(prob)
112
+ }
113
+ for token, prob in zip(top_k_tokens, top_k_probs.tolist())
114
+ ]
115
+ all_predictions.append(predictions)
116
+
117
+ return all_predictions
118
+
ai_modules/nlp/punctuation_restorer.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ๊ตฌ๋‘์  ๋ณต์› ๋ชจ๋“ˆ
3
+ Hugging Face ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ํ•œ๊ตญ์–ด ๊ณ ์ „ ํ…์ŠคํŠธ์˜ ๊ตฌ๋‘์ ์„ ๋ณต์›ํ•ฉ๋‹ˆ๋‹ค.
4
+ """
5
+
6
+ import json
7
+ import torch
8
+ from pathlib import Path
9
+ from typing import Dict, List, Tuple
10
+ from collections import Counter
11
+ from huggingface_hub import snapshot_download
12
+ from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
13
+
14
+
15
+ class PunctuationRestorer:
16
+ """๊ตฌ๋‘์  ๋ณต์›์„ ๋‹ด๋‹นํ•˜๋Š” ํด๋ž˜์Šค"""
17
+
18
+ def __init__(self, config: Dict, cache_dir: str, device: str = "cpu"):
19
+ """
20
+ ๊ตฌ๋‘์  ๋ณต์›๊ธฐ๋ฅผ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
21
+
22
+ Args:
23
+ config: ์„ค์ • ๋”•์…”๋„ˆ๋ฆฌ (nlp_config.json์—์„œ ๋กœ๋“œ)
24
+ cache_dir: ๋ชจ๋ธ ์บ์‹œ ๋””๋ ‰ํ† ๋ฆฌ (๊ธฐ๋ณธ ๊ฒฝ๋กœ)
25
+ device: ์—ฐ์‚ฐ ๋””๋ฐ”์ด์Šค ('cpu' ๋˜๋Š” 'cuda')
26
+ """
27
+ punc_cfg = config['punc_model']
28
+ self.model_tag = punc_cfg['model_tag']
29
+ self.max_length = punc_cfg['max_length']
30
+ self.window_size = punc_cfg['window_size']
31
+ self.overlap = punc_cfg['overlap']
32
+
33
+ self.cache_dir = Path(cache_dir) / "punc"
34
+ self.device = device
35
+ self.model_info = None
36
+
37
+ def download_model(self) -> None:
38
+ """Hugging Face์—์„œ ๋ชจ๋ธ์„ ๋‹ค์šด๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค."""
39
+ self.cache_dir.parent.mkdir(parents=True, exist_ok=True)
40
+
41
+ if not self.cache_dir.exists() or not any(self.cache_dir.iterdir()):
42
+ print(f"[PUNC] ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ์ค‘: {self.model_tag}")
43
+ snapshot_download(
44
+ repo_id=self.model_tag,
45
+ repo_type="model",
46
+ local_dir=str(self.cache_dir),
47
+ local_dir_use_symlinks=False,
48
+ )
49
+ else:
50
+ print(f"[PUNC] ์บ์‹œ๋œ ๋ชจ๋ธ ์‚ฌ์šฉ: {self.cache_dir}")
51
+
52
+ def load_model(self) -> None:
53
+ """๋ชจ๋ธ์„ ๋ฉ”๋ชจ๋ฆฌ์— ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค."""
54
+ torch_dtype = torch.float16 if "cuda" in self.device else torch.float32
55
+
56
+ # ๋ชจ๋ธ ํŒŒ์ผ ์ฐพ๊ธฐ
57
+ fnames = sorted(self.cache_dir.rglob("*.safetensors"))
58
+ if len(fnames) == 0:
59
+ # safetensors๊ฐ€ ์—†์œผ๋ฉด ๋‹ค๋ฅธ ํ˜•์‹ ์‹œ๋„
60
+ fnames = sorted(self.cache_dir.rglob("*.bin"))
61
+
62
+ if len(fnames) == 0:
63
+ raise FileNotFoundError(f"๋ชจ๋ธ ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {self.cache_dir}")
64
+
65
+ hface_path = fnames[0].parent
66
+
67
+ # ํ† ํฌ๋‚˜์ด์ € ๋ฐ ๋ชจ๋ธ ๋กœ๋“œ
68
+ tokenizer = AutoTokenizer.from_pretrained(
69
+ str(hface_path),
70
+ model_max_length=self.max_length
71
+ )
72
+ model = AutoModelForTokenClassification.from_pretrained(
73
+ str(hface_path),
74
+ device_map=self.device if "cuda" in self.device else None,
75
+ torch_dtype=torch_dtype
76
+ )
77
+ if "cuda" not in self.device:
78
+ model = model.to(self.device)
79
+ model.eval()
80
+
81
+ # NER ํŒŒ์ดํ”„๋ผ์ธ ์ƒ์„ฑ
82
+ ner_pipeline = pipeline(
83
+ task="ner",
84
+ model=model,
85
+ tokenizer=tokenizer,
86
+ device=0 if "cuda" in self.device else -1
87
+ )
88
+
89
+ # ๋ ˆ์ด๋ธ” ๋งคํ•‘ ๋กœ๋“œ
90
+ label2id_path = hface_path / "label2id.json"
91
+ if not label2id_path.is_file():
92
+ label2id_path = hface_path.parent / "label2id.json"
93
+ if not label2id_path.is_file():
94
+ raise FileNotFoundError(f"label2id.json์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {hface_path}")
95
+
96
+ label2id = json.loads(label2id_path.read_text(encoding="utf-8"))
97
+
98
+ self.model_info = {
99
+ "model": model,
100
+ "tokenizer": tokenizer,
101
+ "pipe": ner_pipeline,
102
+ "label2id": label2id
103
+ }
104
+
105
+ print(f"[PUNC] โœ“ ๊ตฌ๋‘์  ๋ณต์› ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ")
106
+
107
+ def restore_punctuation(
108
+ self,
109
+ text: str,
110
+ add_space: bool = True,
111
+ reduce: bool = True,
112
+ ) -> str:
113
+ """
114
+ ์Šฌ๋ผ์ด๋”ฉ ์œˆ๋„์šฐ ๋ฐฉ์‹์œผ๋กœ ๊ตฌ๋‘์ ์„ ๋ณต์›ํ•ฉ๋‹ˆ๋‹ค.
115
+
116
+ Args:
117
+ text: ์ž…๋ ฅ ํ…์ŠคํŠธ
118
+ add_space: ๊ตฌ๋‘์  ๋’ค ๊ณต๋ฐฑ ์ถ”๊ฐ€ ์—ฌ๋ถ€
119
+ reduce: ๊ตฌ๋‘์  ๋‹จ์ˆœํ™” ์—ฌ๋ถ€
120
+
121
+ Returns:
122
+ ๊ตฌ๋‘์ ์ด ๋ณต์›๋œ ํ…์ŠคํŠธ
123
+ """
124
+ if not text.strip():
125
+ return ""
126
+
127
+ # ๋ ˆ์ด๋ธ” -> ๊ตฌ๋‘์  ๋งคํ•‘ ์ƒ์„ฑ
128
+ label2punc = self._build_label2punc(add_space, reduce)
129
+
130
+ # ์Šฌ๋ผ์ด๋”ฉ ์œˆ๋„์šฐ๋กœ ๋ ˆ์ด๋ธ” ์˜ˆ์ธก
131
+ labels = self._predict_labels_sliding(text, self.window_size, self.overlap)
132
+
133
+ # ๊ธธ์ด ์กฐ์ •
134
+ if len(labels) < len(text):
135
+ labels += ["O"] * (len(text) - len(labels))
136
+ elif len(labels) > len(text):
137
+ labels = labels[:len(text)]
138
+
139
+ # ๊ตฌ๋‘์  ์‚ฝ์ž…
140
+ result = ""
141
+ for ch, label in zip(text, labels):
142
+ result += ch
143
+ punc = label2punc.get(label, "")
144
+ result += punc
145
+
146
+ return result.strip()
147
+
148
+ def _predict_labels_sliding(
149
+ self,
150
+ text: str,
151
+ window_size: int,
152
+ overlap: int
153
+ ) -> List[str]:
154
+ """
155
+ ์Šฌ๋ผ์ด๋”ฉ ์œˆ๋„์šฐ๋กœ ๊ฐ ๋ฌธ์ž์˜ ๋ ˆ์ด๋ธ”์„ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค.
156
+
157
+ Args:
158
+ text: ์ž…๋ ฅ ํ…์ŠคํŠธ
159
+ window_size: ์œˆ๋„์šฐ ํฌ๊ธฐ
160
+ overlap: ์ค‘์ฒฉ ํฌ๊ธฐ
161
+
162
+ Returns:
163
+ ๊ฐ ๋ฌธ์ž์— ๋Œ€ํ•œ ๋ ˆ์ด๋ธ” ๋ฆฌ์ŠคํŠธ
164
+ """
165
+ n = len(text)
166
+ if n == 0:
167
+ return []
168
+
169
+ # ๊ฐ ์œ„์น˜๋ณ„ ํ›„๋ณด ๋ ˆ์ด๋ธ” ์ €์žฅ
170
+ labels_per_pos = [[] for _ in range(n)]
171
+ stride = max(1, window_size - overlap)
172
+ start = 0
173
+
174
+ while start < n:
175
+ end = min(start + window_size, n)
176
+ sub_text = text[start:end]
177
+
178
+ try:
179
+ # NER ์˜ˆ์ธก ์ˆ˜ํ–‰
180
+ sub_preds = self.model_info["pipe"](sub_text)
181
+ _, sub_labels = self._align_predictions(sub_text, sub_preds)
182
+ except Exception as e:
183
+ # ์˜ค๋ฅ˜ ๋ฐœ์ƒ ์‹œ ๋ชจ๋‘ 'O' ๋ ˆ์ด๋ธ”
184
+ print(f"[PUNC] ์˜ˆ์ธก ์˜ค๋ฅ˜ (start={start}): {e}")
185
+ sub_labels = ["O"] * len(sub_text)
186
+
187
+ # ์ „์—ญ ์œ„์น˜์— ๋ ˆ์ด๋ธ” ์ €์žฅ
188
+ for i, label in enumerate(sub_labels):
189
+ gidx = start + i
190
+ if gidx >= n:
191
+ break
192
+ if label != "O":
193
+ labels_per_pos[gidx].append(label)
194
+
195
+ if end == n:
196
+ break
197
+ start += stride
198
+
199
+ # ๋‹ค์ˆ˜๊ฒฐ ํˆฌํ‘œ๋กœ ์ตœ์ข… ๋ ˆ์ด๋ธ” ๊ฒฐ์ •
200
+ final_labels = []
201
+ for cand_list in labels_per_pos:
202
+ if not cand_list:
203
+ final_labels.append("O")
204
+ else:
205
+ c = Counter(cand_list)
206
+ label, _ = c.most_common(1)[0]
207
+ final_labels.append(label)
208
+
209
+ return final_labels
210
+
211
+ @staticmethod
212
+ def _align_predictions(text: str, predictions: List[dict]) -> Tuple[List[str], List[str]]:
213
+ """
214
+ NER ์˜ˆ์ธก ๊ฒฐ๊ณผ๋ฅผ ๋ฌธ์ž ๋‹จ์œ„ ๋ ˆ์ด๋ธ”๋กœ ์ •๋ ฌํ•ฉ๋‹ˆ๋‹ค.
215
+
216
+ Args:
217
+ text: ์›๋ณธ ํ…์ŠคํŠธ
218
+ predictions: NER ์˜ˆ์ธก ๊ฒฐ๊ณผ
219
+
220
+ Returns:
221
+ (๋ฌธ์ž ๋ฆฌ์ŠคํŠธ, ๋ ˆ์ด๋ธ” ๋ฆฌ์ŠคํŠธ) ํŠœํ”Œ
222
+ """
223
+ words = list(text)
224
+ labels = ["O" for _ in range(len(words))]
225
+
226
+ for pred in predictions:
227
+ idx = pred["end"] - 1
228
+ if 0 <= idx < len(labels):
229
+ labels[idx] = pred["entity"]
230
+
231
+ return words, labels
232
+
233
+ def _build_label2punc(self, add_space: bool, reduce: bool) -> Dict[str, str]:
234
+ """
235
+ ๋ ˆ์ด๋ธ”์„ ๊ตฌ๋‘์ ์œผ๋กœ ๋งคํ•‘ํ•˜๋Š” ๋”•์…”๋„ˆ๋ฆฌ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
236
+
237
+ Args:
238
+ add_space: ๊ตฌ๋‘์  ๋’ค ๊ณต๋ฐฑ ์ถ”๊ฐ€ ์—ฌ๋ถ€
239
+ reduce: ๊ตฌ๋‘์  ๋‹จ์ˆœํ™” ์—ฌ๋ถ€
240
+
241
+ Returns:
242
+ ๋ ˆ์ด๋ธ” -> ๊ตฌ๋‘์  ๋งคํ•‘ ๋”•์…”๋„ˆ๋ฆฌ
243
+ """
244
+ label2id = self.model_info["label2id"]
245
+ label2punc = {f"B-{v}": k for k, v in label2id.items()}
246
+ label2punc["O"] = ""
247
+
248
+ # ๊ตฌ๋‘์  ๋‹จ์ˆœํ™”
249
+ if reduce:
250
+ new_label2punc = {}
251
+ for label, punc in label2punc.items():
252
+ if label == "O":
253
+ new_label2punc[label] = ""
254
+ else:
255
+ reduced = self._reduce_punc(punc)
256
+ new_label2punc[label] = reduced
257
+ label2punc = new_label2punc
258
+
259
+ # ๊ณต๋ฐฑ ์ถ”๊ฐ€
260
+ if add_space:
261
+ special_puncs = "!,:;?ใ€‚"
262
+ label2punc = {
263
+ k: self._insert_space(v, special_puncs)
264
+ for k, v in label2punc.items()
265
+ }
266
+ label2punc["O"] = ""
267
+
268
+ return label2punc
269
+
270
+ @staticmethod
271
+ def _reduce_punc(text: str) -> str:
272
+ """
273
+ ๊ตฌ๋‘์ ์„ ๋‹จ์ˆœํ™”ํ•ฉ๋‹ˆ๋‹ค (?, ใ€‚, , ์ค‘ ํ•˜๋‚˜๋กœ ๋ณ€ํ™˜).
274
+
275
+ Args:
276
+ text: ๊ตฌ๋‘์  ๋ฌธ์ž์—ด
277
+
278
+ Returns:
279
+ ๋‹จ์ˆœํ™”๋œ ๊ตฌ๋‘์ 
280
+ """
281
+ reduce_map = {
282
+ ",": ",", "-": ",", "/": ",", ":": ",", "|": ",",
283
+ "ยท": ",", "ใ€": ",",
284
+ "?": "?", "!": "ใ€‚", ".": "ใ€‚", ";": "ใ€‚", "ใ€‚": "ใ€‚",
285
+ }
286
+
287
+ text = "".join([reduce_map.get(c, "") for c in text])
288
+ punc_order = "?ใ€‚,,"
289
+
290
+ if len(set(text).intersection(punc_order)) == 0:
291
+ return ""
292
+
293
+ # ๊ฐ€์žฅ ๋งŽ์ด ๋“ฑ์žฅํ•œ ๊ตฌ๋‘์  ์„ ํƒ
294
+ counts = {c: text.count(c) for c in punc_order}
295
+ max_count = max(counts.values())
296
+ max_keys = {k for k, v in counts.items() if v == max_count}
297
+
298
+ if len(max_keys) == 1:
299
+ return max_keys.pop()
300
+
301
+ # ๋™๋ฅ ์ผ ๊ฒฝ์šฐ ์šฐ์„ ์ˆœ์œ„์— ๋”ฐ๋ผ ์„ ํƒ
302
+ for c in punc_order:
303
+ if c in max_keys:
304
+ return c
305
+
306
+ return ""
307
+
308
+ @staticmethod
309
+ def _insert_space(text: str, chars: str) -> str:
310
+ """
311
+ ํŠน์ • ๋ฌธ์ž ๋’ค์— ๊ณต๋ฐฑ์„ ์‚ฝ์ž…ํ•ฉ๋‹ˆ๋‹ค.
312
+
313
+ Args:
314
+ text: ์›๋ณธ ํ…์ŠคํŠธ
315
+ chars: ๊ณต๋ฐฑ์„ ์ถ”๊ฐ€ํ•  ๋ฌธ์ž๋“ค
316
+
317
+ Returns:
318
+ ๊ณต๋ฐฑ์ด ์‚ฝ์ž…๋œ ํ…์ŠคํŠธ
319
+ """
320
+ result = ""
321
+ for c in text:
322
+ result += c
323
+ if c in chars:
324
+ result += " "
325
+ return result
326
+
ai_modules/nlp/utils.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ์œ ํ‹ธ๋ฆฌํ‹ฐ ํ•จ์ˆ˜ ๋ชจ๋“ˆ
3
+ ํŒŒ์ผ ์ž…์ถœ๋ ฅ, ํ…์ŠคํŠธ ์ „์ฒ˜๋ฆฌ ๋“ฑ ๊ณตํ†ต ๊ธฐ๋Šฅ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
4
+ """
5
+
6
+ import re
7
+ import unicodedata
8
+ from typing import Dict, Any
9
+
10
+
11
+ def remove_punctuation(text: str) -> str:
12
+ """
13
+ ํ…์ŠคํŠธ์—์„œ ๊ตฌ๋‘์ ๊ณผ ๊ณต๋ฐฑ์„ ์ œ๊ฑฐํ•ฉ๋‹ˆ๋‹ค. [MASK] ํ† ํฐ์€ ๋ณด์กดํ•ฉ๋‹ˆ๋‹ค.
14
+
15
+ Args:
16
+ text: ์›๋ณธ ํ…์ŠคํŠธ
17
+
18
+ Returns:
19
+ ๊ตฌ๋‘์ ์ด ์ œ๊ฑฐ๋œ ํ…์ŠคํŠธ
20
+ """
21
+ result = []
22
+ i = 0
23
+
24
+ while i < len(text):
25
+ # [MASK...] ํ˜•ํƒœ์˜ ํ† ํฐ ๋ณด์กด
26
+ if text[i:i+1] == '[' and 'MASK' in text[i:i+10]:
27
+ end = text.find(']', i)
28
+ if end != -1:
29
+ result.append(text[i:end+1])
30
+ i = end + 1
31
+ continue
32
+
33
+ # ์ผ๋ฐ˜ ๋ฌธ์ž ์ฒ˜๋ฆฌ (๊ตฌ๋‘์ ๊ณผ ๊ณต๋ฐฑ ์ œ์™ธ)
34
+ if unicodedata.category(text[i])[0] not in "PZ":
35
+ result.append(text[i])
36
+ i += 1
37
+
38
+ return "".join(result)
39
+
40
+
41
+ def replace_mask_with_symbol(text: str, symbol: str = "โ–ก") -> str:
42
+ """
43
+ [MASK1], [MASK2] ๋“ฑ์˜ ๋งˆ์Šคํฌ ํ† ํฐ์„ ์ง€์ •๋œ ๊ธฐํ˜ธ๋กœ ์น˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
44
+
45
+ Args:
46
+ text: ์›๋ณธ ํ…์ŠคํŠธ
47
+ symbol: ์น˜ํ™˜ํ•  ๊ธฐํ˜ธ
48
+
49
+ Returns:
50
+ ๋งˆ์Šคํฌ๊ฐ€ ์น˜ํ™˜๋œ ํ…์ŠคํŠธ
51
+ """
52
+ return re.sub(r'\[MASK\d+\]', symbol, text)
53
+
54
+
55
+ def normalize_mask_tokens(text: str) -> str:
56
+ """
57
+ [MASK1], [MASK2] ๋“ฑ์„ [MASK]๋กœ ์ •๊ทœํ™”ํ•ฉ๋‹ˆ๋‹ค.
58
+
59
+ Args:
60
+ text: ์›๋ณธ ํ…์ŠคํŠธ
61
+
62
+ Returns:
63
+ ์ •๊ทœํ™”๋œ ํ…์ŠคํŠธ
64
+ """
65
+ return re.sub(r'\[MASK\d+\]', '[MASK]', text)
66
+
67
+
68
+ def extract_mask_info(json_data: Dict[str, Any]) -> list:
69
+ """
70
+ JSON ๋ฐ์ดํ„ฐ์—์„œ ๋งˆ์Šคํฌ ์ •๋ณด๋ฅผ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค.
71
+
72
+ Args:
73
+ json_data: ์ž…๋ ฅ JSON ๋ฐ์ดํ„ฐ
74
+
75
+ Returns:
76
+ ๋งˆ์Šคํฌ ์ •๋ณด ๋ฆฌ์ŠคํŠธ (order์™€ type ํฌํ•จ)
77
+ """
78
+ mask_info = []
79
+ for item in json_data.get('results', []):
80
+ if 'MASK' in item.get('type', ''):
81
+ mask_info.append({
82
+ 'order': item['order'],
83
+ 'type': item['type']
84
+ })
85
+ mask_info.sort(key=lambda x: x['order'])
86
+ return mask_info
87
+
ai_modules/nlp_engine.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NLP ํ†ตํ•ฉ ์—”์ง„
3
+ ๊ตฌ๋‘์  ๋ณต์› ๋ฐ MLM ์˜ˆ์ธก์„ ํ†ตํ•ฉ ๊ด€๋ฆฌํ•˜๋Š” ์—”์ง„์ž…๋‹ˆ๋‹ค.
4
+ """
5
+
6
+ import os
7
+ import json
8
+ import torch
9
+ import logging
10
+ from pathlib import Path
11
+ from typing import Dict, Any, Optional, List
12
+
13
+ from .nlp.punctuation_restorer import PunctuationRestorer
14
+ from .nlp.mlm_predictor import MLMPredictor
15
+ from .nlp.utils import remove_punctuation, replace_mask_with_symbol
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def load_nlp_config(config_path: Optional[str] = None) -> Dict[str, Any]:
21
+ """
22
+ NLP ์„ค์ • ํŒŒ์ผ์„ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
23
+
24
+ Args:
25
+ config_path: ์„ค์ • ํŒŒ์ผ ๊ฒฝ๋กœ (None์ด๋ฉด ๊ธฐ๋ณธ ๊ฒฝ๋กœ ์‚ฌ์šฉ)
26
+
27
+ Returns:
28
+ ์„ค์ • ๋”•์…”๋„ˆ๋ฆฌ
29
+ """
30
+ if config_path is None:
31
+ config_path = Path(__file__).parent / "config" / "nlp_config.json"
32
+ else:
33
+ config_path = Path(config_path)
34
+
35
+ if not config_path.exists():
36
+ raise FileNotFoundError(f"NLP ์„ค์ • ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {config_path}")
37
+
38
+ with open(config_path, 'r', encoding='utf-8') as f:
39
+ return json.load(f)
40
+
41
+
42
+ class NLPEngine:
43
+ """NLP ์ฒ˜๋ฆฌ ํ†ตํ•ฉ ์—”์ง„ ํด๋ž˜์Šค"""
44
+
45
+ def __init__(self, config_path: Optional[str] = None):
46
+ """
47
+ NLP ์—”์ง„์„ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
48
+
49
+ Args:
50
+ config_path: ์„ค์ • ํŒŒ์ผ ๊ฒฝ๋กœ (None์ด๋ฉด ๊ธฐ๋ณธ ๊ฒฝ๋กœ ์‚ฌ์šฉ)
51
+ """
52
+ self.config = load_nlp_config(config_path)
53
+
54
+ # ๋””๋ฐ”์ด์Šค ์„ค์ •
55
+ dev_cfg = self.config.get('device', 'auto')
56
+ if dev_cfg == 'auto':
57
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
58
+ else:
59
+ self.device = dev_cfg
60
+
61
+ logger.info(f"[NLP] Device: {self.device}")
62
+
63
+ # ๋ชจ๋ธ ์บ์‹œ ๊ฒฝ๋กœ (ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ๋˜๋Š” ๊ธฐ๋ณธ๊ฐ’)
64
+ self.base_model_dir = os.getenv(
65
+ 'AI_MODEL_DIR',
66
+ str(Path(__file__).parent.parent / "models")
67
+ )
68
+
69
+ # ์„œ๋ธŒ ๋ชจ๋“ˆ ์ดˆ๊ธฐํ™” (์ง€์—ฐ ๋กœ๋”ฉ)
70
+ self.punc_restorer = None
71
+ self.mlm_predictor = None
72
+
73
+ def _load_models(self):
74
+ """ํ•„์š”ํ•  ๋•Œ ๋ชจ๋ธ์„ ๋ฉ”๋ชจ๋ฆฌ์— ๋กœ๋“œ"""
75
+ if self.punc_restorer is None:
76
+ logger.info("[NLP] ๊ตฌ๋‘์  ๋ณต์› ๋ชจ๋ธ ๋กœ๋“œ ์ค‘...")
77
+ self.punc_restorer = PunctuationRestorer(
78
+ self.config,
79
+ self.base_model_dir,
80
+ self.device
81
+ )
82
+ self.punc_restorer.download_model()
83
+ self.punc_restorer.load_model()
84
+
85
+ if self.mlm_predictor is None:
86
+ logger.info("[NLP] MLM ๋ชจ๋ธ ๋กœ๋“œ ์ค‘...")
87
+ self.mlm_predictor = MLMPredictor(self.config, self.device)
88
+ self.mlm_predictor.load_model()
89
+
90
+ def process_text(
91
+ self,
92
+ raw_text: str,
93
+ ocr_results: Optional[List[Dict]] = None,
94
+ add_space: bool = True,
95
+ reduce_punc: bool = True
96
+ ) -> Dict[str, Any]:
97
+ """
98
+ ํ…์ŠคํŠธ ์ฒ˜๋ฆฌ ํŒŒ์ดํ”„๋ผ์ธ:
99
+ 1. ๊ตฌ๋‘์  ์ œ๊ฑฐ (์ „์ฒ˜๋ฆฌ)
100
+ 2. ๊ตฌ๋‘์  ๋ณต์›
101
+ 3. [MASK] ์˜ˆ์ธก
102
+
103
+ Args:
104
+ raw_text: ์›๋ณธ ํ…์ŠคํŠธ (๊ตฌ๋‘์  ํฌํ•จ ๊ฐ€๋Šฅ)
105
+ add_space: ๊ตฌ๋‘์  ๋’ค ๊ณต๋ฐฑ ์ถ”๊ฐ€ ์—ฌ๋ถ€
106
+ reduce_punc: ๊ตฌ๋‘์  ๋‹จ์ˆœํ™” ์—ฌ๋ถ€
107
+
108
+ Returns:
109
+ ์ฒ˜๋ฆฌ ๊ฒฐ๊ณผ ๋”•์…”๋„ˆ๋ฆฌ
110
+ """
111
+ self._load_models()
112
+
113
+ try:
114
+ # 1. ์ „์ฒ˜๋ฆฌ (๊ตฌ๋‘์  ์ œ๊ฑฐ, [MASK] ๋ณด์กด)
115
+ clean_text = remove_punctuation(raw_text)
116
+ logger.info(f"[NLP] ๊ตฌ๋‘์  ์ œ๊ฑฐ ์™„๋ฃŒ: {len(clean_text)} ๊ธ€์ž")
117
+
118
+ # 2. ๊ตฌ๋‘์  ๋ณต์›
119
+ punctuated_text = self.punc_restorer.restore_punctuation(
120
+ clean_text,
121
+ add_space=add_space,
122
+ reduce=reduce_punc
123
+ )
124
+ logger.info(f"[NLP] ๊ตฌ๋‘์  ๋ณต์› ์™„๋ฃŒ: {len(punctuated_text)} ๊ธ€์ž")
125
+
126
+ # 3. MLM ์˜ˆ์ธก
127
+ mask_predictions = self.mlm_predictor.predict_masks(punctuated_text)
128
+ logger.info(f"[NLP] MLM ์˜ˆ์ธก ์™„๋ฃŒ: {len(mask_predictions)}๊ฐœ ๋งˆ์Šคํฌ")
129
+
130
+ # 4. ์ถœ๋ ฅ์šฉ ํ…์ŠคํŠธ ์ƒ์„ฑ ([MASK] -> โ–ก)
131
+ mask_replacement = self.config['tokens']['mask_replacement']
132
+ final_text = replace_mask_with_symbol(
133
+ punctuated_text,
134
+ mask_replacement
135
+ )
136
+
137
+ # Extract mask info from OCR results or original text
138
+ mask_info_list = []
139
+ if ocr_results:
140
+ # Use OCR results to get order and type
141
+ for item in ocr_results:
142
+ if 'MASK' in item.get('type', ''):
143
+ mask_info_list.append({
144
+ 'order': item.get('order', 0),
145
+ 'type': item.get('type', 'MASK2'),
146
+ 'text': item.get('text', '')
147
+ })
148
+ else:
149
+ # Fallback: extract from text
150
+ i = 0
151
+ while i < len(raw_text):
152
+ if raw_text[i] == '[' and 'MASK' in raw_text[i:i+10]:
153
+ end = raw_text.find(']', i)
154
+ if end != -1:
155
+ mask_text = raw_text[i:end+1]
156
+ mask_type = 'MASK1' if 'MASK1' in mask_text else 'MASK2'
157
+ mask_info_list.append({
158
+ 'order': len(mask_info_list), # Sequential order
159
+ 'type': mask_type,
160
+ 'text': mask_text
161
+ })
162
+ i = end + 1
163
+ continue
164
+ i += 1
165
+
166
+ # Format results according to specification
167
+ formatted_results = []
168
+ for idx, pred_list in enumerate(mask_predictions):
169
+ if idx < len(mask_info_list):
170
+ mask_info = mask_info_list[idx]
171
+ formatted_results.append({
172
+ "order": mask_info['order'],
173
+ "type": mask_info['type'],
174
+ "top_10": pred_list[:10] # Top-10 predictions
175
+ })
176
+ else:
177
+ # Fallback if mask_info_list is shorter
178
+ formatted_results.append({
179
+ "order": idx,
180
+ "type": "MASK2",
181
+ "top_10": pred_list[:10]
182
+ })
183
+
184
+ # Calculate statistics
185
+ top1_probs = [preds[0]['probability'] for preds in mask_predictions if preds]
186
+ statistics = {
187
+ "top1_probability_avg": float(sum(top1_probs) / len(top1_probs)) if top1_probs else 0.0,
188
+ "top1_probability_min": float(min(top1_probs)) if top1_probs else 0.0,
189
+ "top1_probability_max": float(max(top1_probs)) if top1_probs else 0.0,
190
+ "total_masks": len(mask_predictions)
191
+ }
192
+
193
+ return {
194
+ "punctuated_text_with_masks": final_text,
195
+ "results": formatted_results,
196
+ "statistics": statistics
197
+ }
198
+
199
+ except Exception as e:
200
+ logger.error(f"[NLP] ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜: {e}", exc_info=True)
201
+ return {
202
+ "success": False,
203
+ "error": str(e)
204
+ }
205
+
206
+ def restore_punctuation_only(
207
+ self,
208
+ text: str,
209
+ add_space: bool = True,
210
+ reduce_punc: bool = True
211
+ ) -> Dict[str, Any]:
212
+ """
213
+ ๊ตฌ๋‘์  ๋ณต์›๋งŒ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค (MLM ์˜ˆ์ธก ์ œ์™ธ).
214
+
215
+ Args:
216
+ text: ์ž…๋ ฅ ํ…์ŠคํŠธ
217
+ add_space: ๊ตฌ๋‘์  ๋’ค ๊ณต๋ฐฑ ์ถ”๊ฐ€ ์—ฌ๋ถ€
218
+ reduce_punc: ๊ตฌ๋‘์  ๋‹จ์ˆœํ™” ์—ฌ๋ถ€
219
+
220
+ Returns:
221
+ ๊ตฌ๋‘์  ๋ณต์› ๊ฒฐ๊ณผ
222
+ """
223
+ self._load_models()
224
+
225
+ try:
226
+ clean_text = remove_punctuation(text)
227
+ punctuated_text = self.punc_restorer.restore_punctuation(
228
+ clean_text,
229
+ add_space=add_space,
230
+ reduce=reduce_punc
231
+ )
232
+
233
+ return {
234
+ "success": True,
235
+ "original_text": text,
236
+ "clean_text": clean_text,
237
+ "punctuated_text": punctuated_text
238
+ }
239
+ except Exception as e:
240
+ logger.error(f"[NLP] ๊ตฌ๋‘์  ๋ณต์› ์ค‘ ์˜ค๋ฅ˜: {e}", exc_info=True)
241
+ return {
242
+ "success": False,
243
+ "error": str(e)
244
+ }
245
+
246
+ def predict_masks_only(
247
+ self,
248
+ text: str
249
+ ) -> Dict[str, Any]:
250
+ """
251
+ MLM ์˜ˆ์ธก๋งŒ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค (๊ตฌ๋‘์  ๋ณต์› ์ œ์™ธ).
252
+
253
+ Args:
254
+ text: ๋งˆ์Šคํฌ๊ฐ€ ํฌํ•จ๋œ ํ…์ŠคํŠธ
255
+
256
+ Returns:
257
+ MLM ์˜ˆ์ธก ๊ฒฐ๊ณผ
258
+ """
259
+ self._load_models()
260
+
261
+ try:
262
+ mask_predictions = self.mlm_predictor.predict_masks(text)
263
+
264
+ return {
265
+ "success": True,
266
+ "predictions": mask_predictions,
267
+ "mask_count": len(mask_predictions)
268
+ }
269
+ except Exception as e:
270
+ logger.error(f"[NLP] MLM ์˜ˆ์ธก ์ค‘ ์˜ค๋ฅ˜: {e}", exc_info=True)
271
+ return {
272
+ "success": False,
273
+ "error": str(e)
274
+ }
275
+
276
+
277
+ # ================================================================================
278
+ # Global Accessor
279
+ # ================================================================================
280
+ _nlp_engine = None
281
+
282
+
283
+ def get_nlp_engine(config_path: Optional[str] = None) -> NLPEngine:
284
+ """
285
+ ์ „์—ญ NLP ์—”์ง„ ์ธ์Šคํ„ด์Šค๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค (์‹ฑ๊ธ€ํ†ค ํŒจํ„ด).
286
+
287
+ Args:
288
+ config_path: ์„ค์ • ํŒŒ์ผ ๊ฒฝ๏ฟฝ๏ฟฝ๏ฟฝ (None์ด๋ฉด ๊ธฐ๋ณธ ๊ฒฝ๋กœ ์‚ฌ์šฉ)
289
+
290
+ Returns:
291
+ NLPEngine ์ธ์Šคํ„ด์Šค
292
+ """
293
+ global _nlp_engine
294
+ if _nlp_engine is None:
295
+ _nlp_engine = NLPEngine(config_path)
296
+ return _nlp_engine
297
+
298
+
299
+ def process_text_with_nlp(
300
+ text: str,
301
+ ocr_results: Optional[List[Dict]] = None,
302
+ config_path: Optional[str] = None,
303
+ add_space: bool = True,
304
+ reduce_punc: bool = True
305
+ ) -> Dict[str, Any]:
306
+ """
307
+ ํŽธ์˜ ํ•จ์ˆ˜: ํ…์ŠคํŠธ๋ฅผ NLP ํŒŒ์ดํ”„๋ผ์ธ์œผ๋กœ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.
308
+
309
+ Args:
310
+ text: ์ž…๋ ฅ ํ…์ŠคํŠธ
311
+ ocr_results: OCR ๊ฒฐ๊ณผ ๋ฆฌ์ŠคํŠธ (order, type ์ •๋ณด ํฌํ•จ)
312
+ config_path: ์„ค์ • ํŒŒ์ผ ๊ฒฝ๋กœ
313
+ add_space: ๊ตฌ๋‘์  ๋’ค ๊ณต๋ฐฑ ์ถ”๊ฐ€ ์—ฌ๋ถ€
314
+ reduce_punc: ๊ตฌ๋‘์  ๋‹จ์ˆœํ™” ์—ฌ๋ถ€
315
+
316
+ Returns:
317
+ ์ฒ˜๋ฆฌ ๊ฒฐ๊ณผ ๋”•์…”๋„ˆ๋ฆฌ
318
+ """
319
+ engine = get_nlp_engine(config_path)
320
+ return engine.process_text(text, ocr_results=ocr_results, add_space=add_space, reduce_punc=reduce_punc)
321
+
ai_modules/ocr_engine.py ADDED
@@ -0,0 +1,767 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ ================================================================================
4
+ OCR Ensemble Module for Epitext AI Project
5
+ ================================================================================
6
+ ๋ชจ๋“ˆ๋ช…: ocr_engine.py (v12.0.0 - Production Ready)
7
+ ์ž‘์„ฑ์ผ: 2025-12-03
8
+ ๋ชฉ์ : Google Vision API + HRCenterNet ์•™์ƒ๋ธ” ๊ธฐ๋ฐ˜ ํ•œ์ž OCR ๋ฐ ์†์ƒ ์˜์—ญ ํƒ์ง€
9
+ ์ƒํƒœ: Production Ready
10
+ ================================================================================
11
+ """
12
+ import os
13
+ import sys
14
+ import io
15
+ import cv2
16
+ import json
17
+ import numpy as np
18
+ import torch
19
+ import torchvision
20
+ import re
21
+ import logging
22
+ from torch.autograd import Variable
23
+ from pathlib import Path
24
+ from PIL import Image
25
+ from typing import Dict, List, Optional, Tuple, Any
26
+
27
+ # ================================================================================
28
+ # Logging Configuration
29
+ # ================================================================================
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # ================================================================================
33
+ # External Model Imports
34
+ # ================================================================================
35
+ try:
36
+ from ai_modules.models.resnet import ResnetCustom
37
+ from ai_modules.models.HRCenterNet import _HRCenterNet
38
+ logger.info("[INIT] ์™ธ๋ถ€ ๋ชจ๋ธ ์ž„ํฌํŠธ ์™„๋ฃŒ: ResnetCustom, HRCenterNet")
39
+ except ImportError as e:
40
+ logger.error(f"[INIT] ๋ชจ๋ธ ์ž„ํฌํŠธ ์‹คํŒจ: {e}")
41
+ raise
42
+
43
+ # ================================================================================
44
+ # Google Vision API Import
45
+ # ================================================================================
46
+ try:
47
+ from google.cloud import vision
48
+ HAS_GOOGLE_VISION = True
49
+ except ImportError:
50
+ HAS_GOOGLE_VISION = False
51
+ logger.warning("[INIT] google-cloud-vision ํŒจํ‚ค์ง€๊ฐ€ ์„ค์น˜๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
52
+
53
+ # ================================================================================
54
+ # Utility Functions
55
+ # ================================================================================
56
+ def is_hanja(text: str) -> bool:
57
+ if not text: return False
58
+ return re.match(r'[\u4e00-\u9fff]', text) is not None
59
+
60
+ def calculate_pixel_density(binary_img: np.ndarray, box: Dict) -> float:
61
+ x1, y1 = int(box['min_x']), int(box['min_y'])
62
+ x2, y2 = int(box['max_x']), int(box['max_y'])
63
+ h, w = binary_img.shape
64
+ x1, y1 = max(0, x1), max(0, y1)
65
+ x2, y2 = min(w, x2), min(h, y2)
66
+ if x2 <= x1 or y2 <= y1: return 0.0
67
+ roi = binary_img[y1:y2, x1:x2]
68
+ return cv2.countNonZero(roi) / ((x2 - x1) * (y2 - y1))
69
+
70
+ def load_ocr_config(config_path: Optional[str] = None) -> Dict:
71
+ """์„ค์ • ํŒŒ์ผ ๋กœ๋“œ"""
72
+ if config_path is None:
73
+ config_path = str(Path(__file__).parent / "config" / "ocr_config.json")
74
+
75
+ with open(config_path, 'r', encoding='utf-8') as f:
76
+ return json.load(f)
77
+
78
+ # ================================================================================
79
+ # Text Detection Class
80
+ # ================================================================================
81
+ class TextDetector:
82
+ def __init__(self, device: torch.device, det_ckpt: str, config: Dict):
83
+ self.device = device
84
+ self.config = config
85
+ self.input_size = config['model_config']['input_size']
86
+ self.output_size = config['model_config']['output_size']
87
+
88
+ self.model = _HRCenterNet(32, 5, 0.1)
89
+ if not os.path.exists(det_ckpt):
90
+ raise FileNotFoundError(f"์ฒดํฌํฌ์ธํŠธ ํŒŒ์ผ ์—†์Œ: {det_ckpt}")
91
+
92
+ state = torch.load(det_ckpt, map_location=self.device)
93
+ self.model.load_state_dict(state)
94
+ self.model = self.model.to(self.device)
95
+ self.model.eval()
96
+
97
+ self.transform = torchvision.transforms.Compose([
98
+ torchvision.transforms.Resize((self.input_size, self.input_size)),
99
+ torchvision.transforms.ToTensor()
100
+ ])
101
+
102
+ @torch.no_grad()
103
+ def detect(self, image) -> Tuple[List, List]:
104
+ if isinstance(image, str): img = Image.open(image).convert("RGB")
105
+ elif isinstance(image, np.ndarray): img = Image.fromarray(image).convert("RGB")
106
+ else: img = image.convert("RGB")
107
+
108
+ image_tensor = self.transform(img).unsqueeze_(0)
109
+ inp = Variable(image_tensor).to(self.device, dtype=torch.float)
110
+
111
+ predict = self.model(inp)
112
+ predict_np = predict.data.cpu().numpy()
113
+ heatmap, offset_y, offset_x, width_map, height_map = predict_np[0]
114
+
115
+ bbox, score_list = [], []
116
+ Hc, Wc = img.size[1] / self.output_size, img.size[0] / self.output_size
117
+
118
+ # Config์—์„œ NMS ์ž„๊ณ„๊ฐ’ ๋กœ๋“œ
119
+ nms_cfg = self.config.get('nms_config', {})
120
+ nms_score = nms_cfg.get('primary_threshold', 0.12)
121
+
122
+ idxs = np.where(heatmap.reshape(-1, 1) >= nms_score)[0]
123
+ if len(idxs) == 0:
124
+ nms_score = nms_cfg.get('fallback_threshold', 0.08)
125
+ idxs = np.where(heatmap.reshape(-1, 1) >= nms_score)[0]
126
+
127
+ for j in idxs:
128
+ row = j // self.output_size
129
+ col = j - row * self.output_size
130
+ bias_x = offset_x[row, col] * Hc
131
+ bias_y = offset_y[row, col] * Wc
132
+ width = width_map[row, col] * self.output_size * Hc
133
+ height = height_map[row, col] * self.output_size * Wc
134
+
135
+ score_list.append(float(heatmap[row, col]))
136
+ row = row * Hc + bias_y
137
+ col = col * Wc + bias_x
138
+
139
+ top = row - width / 2.0
140
+ left = col - height / 2.0
141
+ bottom = row + width / 2.0
142
+ right = col + height / 2.0
143
+ bbox.append([left, top, max(0.0, right - left), max(0.0, bottom - top)])
144
+
145
+ if not bbox: return [], []
146
+
147
+ xyxy = [[x, y, x+w, y+h] for x, y, w, h in bbox]
148
+ keep = torchvision.ops.nms(
149
+ torch.tensor(xyxy, dtype=torch.float32),
150
+ scores=torch.tensor(score_list, dtype=torch.float32),
151
+ iou_threshold=nms_cfg.get('iou_threshold', 0.05)
152
+ ).cpu().numpy().tolist()
153
+
154
+ res_boxes, res_scores = [], []
155
+ W, H = img.size
156
+ for k in keep:
157
+ idx = int(k)
158
+ x, y, w, h = bbox[idx]
159
+ x = max(0.0, min(x, W - 1.0))
160
+ y = max(0.0, min(y, H - 1.0))
161
+ w = max(0.0, min(w, W - x))
162
+ h = max(0.0, min(h, H - y))
163
+ if w > 1 and h > 1:
164
+ res_boxes.append([x, y, w, h])
165
+ res_scores.append(score_list[idx])
166
+
167
+ return res_boxes, res_scores
168
+
169
+ # ================================================================================
170
+ # Merging Logics (Config ์ ์šฉ)
171
+ # ================================================================================
172
+ def merge_vertical_fragments(boxes, scores, config):
173
+ if not boxes: return [], []
174
+ rects = [{'x': b[0], 'y': b[1], 'w': b[2], 'h': b[3],
175
+ 'x2': b[0]+b[2], 'y2': b[1]+b[3],
176
+ 'cx': b[0]+b[2]/2, 'cy': b[1]+b[3]/2, 'score': s}
177
+ for b, s in zip(boxes, scores)]
178
+
179
+ cfg = config['merge_config']['vertical_fragments']
180
+
181
+ while True:
182
+ rects.sort(key=lambda r: r['y'])
183
+ merged = False
184
+ new_rects, skip_indices = [], set()
185
+
186
+ for i in range(len(rects)):
187
+ if i in skip_indices: continue
188
+ current = rects[i]
189
+ best_cand_idx = -1
190
+
191
+ for j in range(i + 1, min(i + 5, len(rects))):
192
+ if j in skip_indices: continue
193
+ candidate = rects[j]
194
+
195
+ avg_w = (current['w'] + candidate['w']) / 2
196
+ if abs(current['cx'] - candidate['cx']) > avg_w * cfg['horizontal_center_ratio']: continue
197
+ if (candidate['y'] - current['y2']) > avg_w * cfg['vertical_gap_ratio']: continue
198
+
199
+ new_h = max(current['y2'], candidate['y2']) - min(current['y'], candidate['y'])
200
+ new_w = max(current['x2'], candidate['x2']) - min(current['x'], candidate['x'])
201
+
202
+ is_safe_ratio = (new_h / new_w) < cfg['aspect_ratio_limit']
203
+ cur_square = (current['h'] / current['w']) > 0.85
204
+ cand_square = (candidate['h'] / candidate['w']) > 0.85
205
+ is_overlapped = (candidate['y'] - current['y2']) < -avg_w * 0.2
206
+
207
+ if is_safe_ratio and (not (cur_square and cand_square) or is_overlapped):
208
+ best_cand_idx = j
209
+ break
210
+
211
+ if best_cand_idx != -1:
212
+ cand = rects[best_cand_idx]
213
+ nx, ny = min(current['x'], cand['x']), min(current['y'], cand['y'])
214
+ nx2, ny2 = max(current['x2'], cand['x2']), max(current['y2'], cand['y2'])
215
+ new_rects.append({
216
+ 'x': nx, 'y': ny, 'w': nx2-nx, 'h': ny2-ny,
217
+ 'x2': nx2, 'y2': ny2, 'cx': (nx+nx2)/2, 'cy': (ny+ny2)/2,
218
+ 'score': max(current['score'], cand['score'])
219
+ })
220
+ skip_indices.add(best_cand_idx)
221
+ merged = True
222
+ else:
223
+ new_rects.append(current)
224
+ rects = new_rects
225
+ if not merged: break
226
+
227
+ return [[r['x'], r['y'], r['w'], r['h']] for r in rects], [r['score'] for r in rects]
228
+
229
+ def merge_google_symbols(symbols, config):
230
+ if not symbols: return []
231
+ cfg = config['merge_config']['google_symbols']
232
+
233
+ while True:
234
+ symbols.sort(key=lambda s: s['min_y'])
235
+ merged = False
236
+ new_symbols, skip_indices = [], set()
237
+
238
+ for i in range(len(symbols)):
239
+ if i in skip_indices: continue
240
+ curr = symbols[i]
241
+ best_cand_idx = -1
242
+
243
+ for j in range(i + 1, min(i + 5, len(symbols))):
244
+ if j in skip_indices: continue
245
+ cand = symbols[j]
246
+
247
+ avg_w = (curr['width'] + cand['width']) / 2
248
+ if abs(curr['center_x'] - cand['center_x']) > avg_w * cfg['horizontal_center_ratio']: continue
249
+
250
+ gap = cand['min_y'] - curr['max_y']
251
+ is_touching = gap < (avg_w * cfg['vertical_gap_ratio'])
252
+
253
+ new_h = max(curr['max_y'], cand['max_y']) - min(curr['min_y'], cand['min_y'])
254
+ new_w = max(curr['max_x'], cand['max_x']) - min(curr['min_x'], cand['min_x'])
255
+
256
+ is_both_square = (curr['height']/curr['width'] > 0.85) and (cand['height']/cand['width'] > 0.85)
257
+ is_safe_ratio = (new_h / new_w) < cfg['aspect_ratio_limit']
258
+ is_duplicate = (curr['text'] == cand['text'])
259
+
260
+ if (is_touching and is_safe_ratio and not is_both_square) or is_duplicate:
261
+ best_cand_idx = j
262
+ break
263
+
264
+ if best_cand_idx != -1:
265
+ cand = symbols[best_cand_idx]
266
+ merged_sym = {
267
+ 'text': curr['text'],
268
+ 'min_x': min(curr['min_x'], cand['min_x']), 'min_y': min(curr['min_y'], cand['min_y']),
269
+ 'max_x': max(curr['max_x'], cand['max_x']), 'max_y': max(curr['max_y'], cand['max_y']),
270
+ 'confidence': max(curr['confidence'], cand['confidence']),
271
+ 'source': 'Google'
272
+ }
273
+ merged_sym['width'] = merged_sym['max_x'] - merged_sym['min_x']
274
+ merged_sym['height'] = merged_sym['max_y'] - merged_sym['min_y']
275
+ merged_sym['center_x'] = (merged_sym['min_x'] + merged_sym['max_x']) / 2
276
+ merged_sym['center_y'] = (merged_sym['min_y'] + merged_sym['max_y']) / 2
277
+ new_symbols.append(merged_sym)
278
+ skip_indices.add(best_cand_idx)
279
+ merged = True
280
+ else:
281
+ new_symbols.append(curr)
282
+ symbols = new_symbols
283
+ if not merged: break
284
+ return symbols
285
+
286
+ # ================================================================================
287
+ # Models Execution
288
+ # ================================================================================
289
+ def get_google_ocr(content: bytes, config: Dict, google_json_path: Optional[str] = None) -> List[Dict]:
290
+ if not HAS_GOOGLE_VISION: return []
291
+ if google_json_path and os.path.exists(google_json_path):
292
+ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = google_json_path
293
+
294
+ try:
295
+ client = vision.ImageAnnotatorClient()
296
+ image = vision.Image(content=content)
297
+ context = vision.ImageContext(language_hints=["zh-Hant"])
298
+ response = client.document_text_detection(image=image, image_context=context)
299
+
300
+ if not response.full_text_annotation: return []
301
+
302
+ symbols = []
303
+ for page in response.full_text_annotation.pages:
304
+ for block in page.blocks:
305
+ for paragraph in block.paragraphs:
306
+ for word in paragraph.words:
307
+ for s in word.symbols:
308
+ if not is_hanja(s.text): continue
309
+ v = s.bounding_box.vertices
310
+ x, y = [p.x for p in v], [p.y for p in v]
311
+ symbols.append({
312
+ 'text': s.text,
313
+ 'center_x': (min(x)+max(x))/2, 'center_y': (min(y)+max(y))/2,
314
+ 'min_x': min(x), 'max_x': max(x), 'min_y': min(y), 'max_y': max(y),
315
+ 'width': max(x)-min(x), 'height': max(y)-min(y),
316
+ 'confidence': s.confidence, 'source': 'Google'
317
+ })
318
+
319
+ original_count = len(symbols)
320
+ symbols = merge_google_symbols(symbols, config)
321
+ if len(symbols) < original_count:
322
+ logger.info(f"[OCR] Google ๋ณ‘ํ•ฉ: {original_count} -> {len(symbols)}๊ฐœ")
323
+ return symbols
324
+ except Exception as e:
325
+ logger.error(f"[OCR] Google Vision Error: {e}")
326
+ return []
327
+
328
+ def get_custom_model_ocr(image_path, binary_img, detector, recognizer, config):
329
+ try:
330
+ pil_img = Image.open(image_path).convert("RGB")
331
+ boxes, scores = detector.detect(pil_img)
332
+ if not boxes: return []
333
+
334
+ # Merge
335
+ original_count = len(boxes)
336
+ boxes, scores = merge_vertical_fragments(boxes, scores, config)
337
+ if len(boxes) < original_count:
338
+ logger.info(f"[OCR] Custom ๋ณ‘ํ•ฉ: {original_count} -> {len(boxes)}๊ฐœ")
339
+
340
+ # Stats
341
+ all_heights = [b[3] for b in boxes]
342
+ all_widths = [b[2] for b in boxes]
343
+ median_h = np.median(all_heights) if all_heights else 0
344
+ median_w = np.median(all_widths) if all_widths else 0
345
+
346
+ # Recognize
347
+ crops = [pil_img.crop((int(b[0]), int(b[1]), int(b[0]+b[2]), int(b[1]+b[3]))) for b in boxes]
348
+ chars = recognizer(crops) if crops else []
349
+
350
+ # Filter & Mask (Config values)
351
+ symbols = []
352
+ img_h, _ = binary_img.shape
353
+ ft = config['filtering_thresholds']
354
+ it = config['ink_detection_thresholds']
355
+
356
+ for char, (x, y, w, h), score in zip(chars, boxes, scores):
357
+ if not char or char == "โ– ": continue
358
+
359
+ box_dict = {'min_x': x, 'min_y': y, 'max_x': x+w, 'max_y': y+h}
360
+ density = calculate_pixel_density(binary_img, box_dict)
361
+
362
+ # Hard Filters
363
+ if score < ft['min_score_hard'] or density < ft['density_min_hard']: continue
364
+ # Smart Filters
365
+ if score < ft['smart_score_threshold'] and density < ft['smart_density_threshold']: continue
366
+
367
+ # Title Removal
368
+ is_huge = (h > median_h * 3.5) if median_h > 0 else False
369
+ is_top = (y < img_h * 0.15) and (h > median_h * 2.5 or w > median_w * 2.5) if median_h > 0 else False
370
+ if median_h > 0 and (is_huge or is_top): continue
371
+
372
+ # Masking
373
+ final_text, final_type = char, 'TEXT'
374
+ if density >= it['density_ink_heavy']:
375
+ final_text, final_type = '[MASK1]', 'MASK1'
376
+ elif density >= it['density_ink_partial']:
377
+ final_text, final_type = '[MASK2]', 'MASK2'
378
+ else:
379
+ if not is_hanja(char): continue
380
+
381
+ symbols.append({
382
+ 'text': final_text, 'type': final_type,
383
+ 'center_x': x+w/2, 'center_y': y+h/2,
384
+ 'min_x': x, 'max_x': x+w, 'min_y': y, 'max_y': y+h,
385
+ 'width': w, 'height': h,
386
+ 'confidence': float(score), 'source': 'Custom', 'density': density
387
+ })
388
+
389
+ logger.info(f"[OCR] Custom Model ์™„๋ฃŒ: {len(symbols)}๊ฐœ")
390
+ return symbols
391
+ except Exception as e:
392
+ logger.error(f"[OCR] Custom Model Error: {e}")
393
+ return []
394
+
395
+ # ================================================================================
396
+ # Ensemble Reconstruction (Full Logic from Script)
397
+ # ================================================================================
398
+ def ensemble_reconstruction(google_syms, custom_syms, binary_img, config):
399
+ logger.info("[ENSEMBLE] ์•™์ƒ๋ธ” ์žฌ๊ตฌ์„ฑ ์‹œ์ž‘...")
400
+ img_h, img_w = binary_img.shape
401
+ ec = config['ensemble_config']
402
+ ft = config['filtering_thresholds']
403
+ it = config['ink_detection_thresholds']
404
+
405
+ # --- Helper Functions ---
406
+ def filter_excessive_masks(nodes):
407
+ filtered, buffer = [], []
408
+ threshold = ec['excessive_mask_threshold']
409
+ for node in nodes:
410
+ if 'MASK' in node.get('type', 'TEXT'): buffer.append(node)
411
+ else:
412
+ if buffer:
413
+ if len(buffer) < threshold: filtered.extend(buffer)
414
+ buffer = []
415
+ filtered.append(node)
416
+ if buffer and len(buffer) < threshold: filtered.extend(buffer)
417
+ return filtered
418
+
419
+ def merge_split_masks(nodes, avg_h):
420
+ if not nodes: return []
421
+ merged, skip = [], False
422
+ for i in range(len(nodes)):
423
+ if skip: skip = False; continue
424
+ curr = nodes[i]
425
+ if i == len(nodes)-1: merged.append(curr); break
426
+
427
+ next_node = nodes[i+1]
428
+ if 'MASK' in curr.get('type','TEXT') and 'MASK' in next_node.get('type','TEXT'):
429
+ combined_h = next_node['max_y'] - curr['min_y']
430
+ if combined_h < avg_h * 1.8:
431
+ new_node = curr.copy()
432
+ new_node.update({'max_y': next_node['max_y'], 'height': next_node['max_y'] - curr['min_y']})
433
+ density = calculate_pixel_density(binary_img, new_node)
434
+ new_node['density'] = density
435
+
436
+ if density < ft['density_min_hard']:
437
+ skip = True; continue
438
+
439
+ m_type = 'MASK1' if density >= it['density_ink_heavy'] else 'MASK2'
440
+ new_node.update({'type': m_type, 'text': f'[{m_type}]'})
441
+ merged.append(new_node)
442
+ skip = True
443
+ continue
444
+ merged.append(curr)
445
+ return merged
446
+
447
+ def resolve_overlaps(boxes):
448
+ if not boxes: return []
449
+ boxes.sort(key=lambda x: x['min_y'])
450
+ for i in range(len(boxes)-1):
451
+ curr, next_box = boxes[i], boxes[i+1]
452
+ if min(curr['max_x'], next_box['max_x']) - max(curr['min_x'], next_box['min_x']) <= 0: continue
453
+
454
+ if curr['max_y'] > next_box['min_y']:
455
+ mid_y = (curr['max_y'] + next_box['min_y']) / 2
456
+ curr['max_y'], curr['height'] = mid_y, mid_y - curr['min_y']
457
+ next_box['min_y'], next_box['height'] = mid_y, next_box['max_y'] - mid_y
458
+ return boxes
459
+
460
+ def filter_google_overlaps(g_boxes, c_boxes):
461
+ if not g_boxes: return c_boxes
462
+ filtered = []
463
+ for c in c_boxes:
464
+ is_dup = False
465
+ for g in g_boxes:
466
+ dx = abs(c['center_x'] - g['center_x'])
467
+ dy = abs(c['center_y'] - g['center_y'])
468
+ # MASK is preserved even if overlapping
469
+ if 'MASK' in c.get('type', 'TEXT'): pass
470
+ elif (min(c['max_x'], g['max_x']) > max(c['min_x'], g['min_x']) and
471
+ min(c['max_y'], g['max_y']) > max(c['min_y'], g['min_y'])) or \
472
+ (dx < g['width']*0.4 and dy < g['height']*0.4):
473
+ is_dup = True; break
474
+ if not is_dup: filtered.append(c)
475
+ return filtered
476
+
477
+ def infer_gaps(col, step_y, avg_w):
478
+ if not col: return []
479
+ col.sort(key=lambda s: s['center_y'])
480
+ filled = []
481
+ for i, curr in enumerate(col):
482
+ if i > 0:
483
+ prev = col[i-1]
484
+ gap = curr['center_y'] - prev['center_y']
485
+ if gap > step_y * ec['gap_inference_ratio']:
486
+ missing = int(round(gap/step_y)) - 1
487
+ if missing > 0:
488
+ step = gap / (missing + 1)
489
+ for k in range(1, missing + 1):
490
+ ny = prev['center_y'] + k*step
491
+ nb = {'min_x': curr['center_x'] - avg_w/2, 'max_x': curr['center_x'] + avg_w/2,
492
+ 'min_y': max(0, ny - step_y*0.4), 'max_y': min(img_h, ny + step_y*0.4)}
493
+ nb.update({'height': nb['max_y']-nb['min_y'], 'width': nb['max_x']-nb['min_x'],
494
+ 'center_x': (nb['min_x']+nb['max_x'])/2, 'center_y': (nb['min_y']+nb['max_y'])/2})
495
+
496
+ d = calculate_pixel_density(binary_img, nb)
497
+ if d < ft['density_min_hard']: continue
498
+
499
+ mt = 'MASK1' if d >= it['density_ink_heavy'] else 'MASK2'
500
+ nb.update({'text': f'[{mt}]', 'type': mt, 'density': d, 'confidence': 0.0, 'source': 'Inferred'})
501
+ filled.append(nb)
502
+ filled.append(curr)
503
+ return filled
504
+
505
+ def check_ink_on_google(g_syms):
506
+ filtered = []
507
+ for s in g_syms:
508
+ d = calculate_pixel_density(binary_img, s)
509
+ s['density'] = d
510
+ if d >= it['density_ink_heavy']: s.update({'type': 'MASK1', 'text': '[MASK1]'})
511
+ elif d >= it['density_ink_partial']: s.update({'type': 'MASK2', 'text': '[MASK2]'})
512
+ elif d < ft['density_min_hard']: continue # Hallucination check
513
+ else: s['type'] = 'TEXT'
514
+ filtered.append(s)
515
+ return filtered
516
+
517
+ # --- Preprocessing ---
518
+ all_h = ([s['height'] for s in google_syms] + [s['height'] for s in custom_syms])
519
+ median_h = np.median(all_h) if all_h else 30.0
520
+
521
+ # Filter Height & Check Ink
522
+ def global_remove_tall_and_top(boxes, median_h, threshold=2.0):
523
+ if not boxes: return []
524
+ filtered = []
525
+ for b in boxes:
526
+ if b['height'] > median_h * threshold: continue
527
+ if b['min_y'] < img_h * 0.15 and b['height'] > median_h * 2.5: continue
528
+ filtered.append(b)
529
+ return filtered
530
+
531
+ if google_syms:
532
+ google_syms = global_remove_tall_and_top(google_syms, median_h, threshold=2.0)
533
+ google_syms = check_ink_on_google(google_syms)
534
+ if custom_syms:
535
+ custom_syms = global_remove_tall_and_top(custom_syms, median_h, threshold=3.5)
536
+
537
+ # Resize & Filter Custom
538
+ avg_w = np.mean([s['width'] for s in google_syms]) if google_syms else 0
539
+ median_w = np.median([s['width'] for s in google_syms]) if google_syms else 0
540
+
541
+ processed_custom = []
542
+ for s in custom_syms:
543
+ if 'MASK' in s.get('type', 'TEXT'):
544
+ processed_custom.append(s); continue
545
+
546
+ if (s['width']*s['height'] > (median_w*median_h)*0.2 and
547
+ s['width'] > median_w*0.3 and s['height'] > median_h*0.3):
548
+
549
+ # Resize logic
550
+ if s['width'] < median_w*0.8 or s['height'] < median_h*0.8:
551
+ tw = max(s['width'], median_w*0.9)
552
+ th = max(s['height'], median_h*0.9)
553
+ cx, cy = s['center_x'], s['center_y']
554
+ s.update({'min_x': max(0, cx-tw/2), 'max_x': min(img_w, cx+tw/2),
555
+ 'min_y': max(0, cy-th/2), 'max_y': min(img_h, cy+th/2)})
556
+ s.update({'width': s['max_x']-s['min_x'], 'height': s['max_y']-s['min_y']})
557
+ processed_custom.append(s)
558
+
559
+ custom_syms = filter_google_overlaps(google_syms, processed_custom)
560
+
561
+ if not google_syms and not custom_syms: return [], []
562
+
563
+ # --- Column Grouping ---
564
+ all_syms = google_syms + custom_syms
565
+ columns = []
566
+ if all_syms:
567
+ for s in sorted(all_syms, key=lambda x: -x['center_x']):
568
+ found = False
569
+ for col in columns:
570
+ cx = sum(c['center_x'] for c in col) / len(col)
571
+ if abs(s['center_x'] - cx) < (avg_w if avg_w else s['width']) * ec['column_grouping_ratio']:
572
+ col.append(s); found = True; break
573
+ if not found: columns.append([s])
574
+
575
+ # Vertical Step Calculation
576
+ global_steps = []
577
+ for col in columns:
578
+ col.sort(key=lambda s: s['center_y'])
579
+ for k in range(len(col)-1):
580
+ step = col[k+1]['center_y'] - col[k]['center_y']
581
+ if median_h * 0.8 < step < median_h * 1.5: global_steps.append(step)
582
+ global_step = np.median(global_steps) if global_steps else median_h * 1.1
583
+
584
+ # --- Reconstruction ---
585
+ final_boxes, lines = [], []
586
+ for col in columns:
587
+ col.sort(key=lambda s: s['center_y'])
588
+ local_steps = [col[k+1]['center_y'] - col[k]['center_y'] for k in range(len(col)-1)
589
+ if median_h*0.8 < (col[k+1]['center_y'] - col[k]['center_y']) < median_h*1.5]
590
+ step_y = np.median(local_steps) if local_steps else global_step
591
+
592
+ # Deduplication in column
593
+ unique_col = []
594
+ if col:
595
+ prev = col[0]
596
+ unique_col.append(prev)
597
+ for k in range(1, len(col)):
598
+ curr = col[k]
599
+ dist_y = abs(curr['center_y'] - prev['center_y'])
600
+ is_same_text = (curr.get('text') == prev.get('text'))
601
+ is_close = (dist_y < median_h * 0.6)
602
+
603
+ if is_close:
604
+ prev_is_mask = 'MASK' in prev.get('type', 'TEXT')
605
+ curr_is_mask = 'MASK' in curr.get('type', 'TEXT')
606
+
607
+ if prev_is_mask and curr_is_mask:
608
+ if prev['density'] < curr['density']:
609
+ unique_col.pop()
610
+ unique_col.append(curr)
611
+ prev = curr
612
+ continue
613
+ elif prev_is_mask and not curr_is_mask:
614
+ continue
615
+ elif not prev_is_mask and curr_is_mask:
616
+ unique_col.pop()
617
+ unique_col.append(curr)
618
+ prev = curr
619
+ continue
620
+
621
+ if is_same_text and is_close:
622
+ if prev.get('source') == 'Google':
623
+ continue
624
+ elif curr.get('source') == 'Google':
625
+ unique_col.pop()
626
+ unique_col.append(curr)
627
+ prev = curr
628
+ else:
629
+ continue
630
+ else:
631
+ unique_col.append(curr)
632
+ prev = curr
633
+
634
+ col = infer_gaps(unique_col, step_y, avg_w if avg_w else median_h)
635
+
636
+ # Gap Filling with Masks
637
+ filled_col, cy = [], col[0]['min_y'] if col else 0
638
+ for item in col:
639
+ gap = item['min_y'] - cy
640
+ if gap > step_y * 1.2:
641
+ mb = {'min_x': item['center_x'] - (avg_w if avg_w else median_h)/2,
642
+ 'max_x': item['center_x'] + (avg_w if avg_w else median_h)/2,
643
+ 'min_y': max(0, cy + gap*0.1), 'max_y': min(img_h, item['min_y'] - gap*0.1)}
644
+ d = calculate_pixel_density(binary_img, mb)
645
+ if d >= ft['density_min_hard']:
646
+ mt = 'MASK1' if d >= it['density_ink_heavy'] else 'MASK2'
647
+ if d >= it['density_ink_partial']:
648
+ filled_col.append({'text': f'[{mt}]', 'type': mt, 'density': d,
649
+ 'min_x': mb['min_x'], 'max_x': mb['max_x'],
650
+ 'min_y': mb['min_y'], 'max_y': mb['max_y'],
651
+ 'confidence': 0.0, 'source': 'GapFill'})
652
+
653
+ if item.get('density', 0) < ft['density_min_hard'] and 'MASK' not in item.get('type','TEXT'):
654
+ cy = item['max_y']; continue
655
+
656
+ filled_col.append(item)
657
+ cy = item['max_y']
658
+
659
+ filled_col = merge_split_masks(filled_col, median_h)
660
+ filled_col = filter_excessive_masks(filled_col)
661
+ filled_col = resolve_overlaps(filled_col)
662
+
663
+ final_boxes.extend(filled_col)
664
+ lines.append("".join([s['text'] for s in filled_col]))
665
+
666
+ logger.info(f"[ENSEMBLE] ์™„๋ฃŒ: {len(final_boxes)}๊ฐœ ๋ฐ•์Šค, {len(lines)}๊ฐœ ์—ด")
667
+ return final_boxes, lines
668
+
669
+ # ================================================================================
670
+ # OCREngine Class
671
+ # ================================================================================
672
+ class OCREngine:
673
+ def __init__(self, config_path: Optional[str] = None):
674
+ self.config = load_ocr_config(config_path)
675
+
676
+ # Load paths from env
677
+ base_path = os.getenv('OCR_WEIGHTS_BASE_PATH')
678
+ if not base_path:
679
+ raise ValueError("OCR_WEIGHTS_BASE_PATH environment variable is required. Please set it in your .env file.")
680
+
681
+ self.det_ckpt = os.path.join(base_path, os.getenv('OCR_DETECTION_MODEL', 'best.pth'))
682
+ self.rec_ckpt = os.path.join(base_path, os.getenv('OCR_RECOGNITION_MODEL', 'best_5000.pt'))
683
+ self.google_json = os.path.join(base_path, os.getenv('GOOGLE_CREDENTIALS_JSON'))
684
+
685
+ if not self.google_json or not os.path.exists(self.google_json):
686
+ raise ValueError(f"GOOGLE_CREDENTIALS_JSON environment variable is required and file must exist. Please set it in your .env file.")
687
+
688
+ if os.path.exists(self.google_json):
689
+ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.google_json
690
+
691
+ # Device
692
+ dev_cfg = self.config['model_config']['device']
693
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if dev_cfg == 'auto' else torch.device(dev_cfg)
694
+ self.detector = None
695
+ self.recognizer = None
696
+
697
+ def _load_models(self):
698
+ if not self.detector:
699
+ self.detector = TextDetector(self.device, self.det_ckpt, self.config)
700
+ if not self.recognizer:
701
+ self.recognizer = ResnetCustom(weight_fn=self.rec_ckpt)
702
+ self.recognizer.to(self.device)
703
+
704
+ def run_ocr(self, image_path: str) -> Dict:
705
+ try:
706
+ self._load_models()
707
+
708
+ # 1. Preprocessing (Exact Match to v12 Script)
709
+ img_bgr = cv2.imread(image_path)
710
+ if img_bgr is None: raise ValueError(f"Image not found: {image_path}")
711
+
712
+ img_gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
713
+ img_blur = cv2.medianBlur(img_gray, 3)
714
+ _, img_binary = cv2.threshold(img_blur, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
715
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
716
+ img_binary = cv2.morphologyEx(img_binary, cv2.MORPH_CLOSE, kernel)
717
+
718
+ # 2. Google Vision
719
+ with io.open(image_path, 'rb') as f: content = f.read()
720
+ google_syms = get_google_ocr(content, self.config, self.google_json)
721
+
722
+ # 3. Custom Model
723
+ custom_syms = get_custom_model_ocr(image_path, img_binary, self.detector, self.recognizer, self.config)
724
+
725
+ # 4. Ensemble
726
+ final_boxes, result_lines = ensemble_reconstruction(google_syms, custom_syms, img_binary, self.config)
727
+
728
+ # Format results according to specification
729
+ formatted_results = []
730
+ for order, box in enumerate(final_boxes):
731
+ formatted_results.append({
732
+ "order": order,
733
+ "text": box.get('text', ''),
734
+ "type": box.get('type', 'TEXT'),
735
+ "box": [
736
+ float(box.get('min_x', 0)),
737
+ float(box.get('min_y', 0)),
738
+ float(box.get('max_x', 0)),
739
+ float(box.get('max_y', 0))
740
+ ],
741
+ "confidence": float(box.get('confidence', 0.0)),
742
+ "source": box.get('source', 'Unknown')
743
+ })
744
+
745
+ # Extract image filename
746
+ image_filename = os.path.basename(image_path)
747
+
748
+ return {
749
+ "image": image_filename,
750
+ "results": formatted_results
751
+ }
752
+ except Exception as e:
753
+ logger.error(f"[OCR] Execution Failed: {e}", exc_info=True)
754
+ return {"success": False, "error": str(e)}
755
+
756
+ # ================================================================================
757
+ # Global Accessor
758
+ # ================================================================================
759
+ _engine = None
760
+
761
+ def get_ocr_engine(config_path: Optional[str] = None) -> OCREngine:
762
+ global _engine
763
+ if _engine is None: _engine = OCREngine(config_path)
764
+ return _engine
765
+
766
+ def ocr_and_detect(image_path: str, config_path: Optional[str] = None, bbox: Optional[Tuple[int, int, int, int]] = None, device: str = "cuda") -> Dict:
767
+ return get_ocr_engine(config_path).run_ocr(image_path)
ai_modules/preprocessor_unified.py ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Epitext_Back/ai_modules/preprocessor_unified.py
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ ================================================================================
5
+ Unified Image Preprocessing Module for Epitext AI Project
6
+ ================================================================================
7
+
8
+ ๋ชจ๋“ˆ๋ช…: preprocessor_unified.py (v1.0.0 - Production Ready)
9
+ ์ž‘์„ฑ์ผ: 2025-12-02
10
+ ๋ชฉ์ : ํ•œ์ž ์ด๋ฏธ์ง€๋ฅผ Swin Gray์™€ OCR์šฉ์œผ๋กœ ๋™์‹œ์— ์ „์ฒ˜๋ฆฌ
11
+ ์ƒํƒœ: Production Ready
12
+
13
+ ํ•ต์‹ฌ ๊ธฐ๋Šฅ:
14
+ ํ•œ ๋ฒˆ์— ๋‘ ๊ฐ€์ง€ ์ „์ฒ˜๋ฆฌ ์™„๋ฃŒ:
15
+ 1. Swin Gray: ๊ทธ๋ ˆ์ด ๋น„์ด์ง„ํ™” -> 3์ฑ„๋„ (์ •๋ณด ์†์‹ค ์ตœ์†Œ)
16
+ 2. OCR: ์ด์ง„ํ™” -> 1์ฑ„๋„ (๋ช…ํ™•ํ•œ ํ‘๋ฐฑ)
17
+
18
+ ์ž๋™ ๋ฐฐ๊ฒฝ ๋ณด์žฅ:
19
+ - Swin: ๋ฐ์€๋ฐฐ๊ฒฝ (>=127)
20
+ - OCR: ํฐ๋ฐฐ๊ฒฝ + ๊ฒ€์ •๊ธ€์ž (255/0)
21
+
22
+ ํƒ๋ณธ ์ž๋™ ๊ฒ€์ถœ: ํฐ ์–ด๋‘์šด ์˜์—ญ ์‹๋ณ„
23
+ ์˜์—ญ ๊ฒ€์ถœ 1ํšŒ: ํšจ์œจ์„ฑ
24
+ ์„ค์ • ํŒŒ์ผ ์ง€์›: JSON ๊ธฐ๋ฐ˜ ์ปค์Šคํ„ฐ๋งˆ์ด์ง•
25
+ ๋กœ๊น… ์ง€์›: DEBUG, INFO, WARNING, ERROR
26
+
27
+ ์˜์กด์„ฑ:
28
+ - opencv-python >= 4.8.0
29
+ - numpy >= 1.24.0
30
+
31
+ ๋‹จ์ผ ํ•จ์ˆ˜:
32
+ preprocess_image_unified(input_path, output_swin_path, output_ocr_path, ...)
33
+
34
+ ์‚ฌ์šฉ ์˜ˆ์‹œ:
35
+ >>> from ai_modules.preprocessor_unified import preprocess_image_unified
36
+ >>> result = preprocess_image_unified(
37
+ ... "input.jpg",
38
+ ... "swin.jpg",
39
+ ... "ocr.png"
40
+ ... )
41
+
42
+ ================================================================================
43
+ """
44
+
45
+
46
+ import cv2
47
+ import numpy as np
48
+ from pathlib import Path
49
+ import json
50
+ import logging
51
+ from typing import Dict, Optional, Tuple
52
+
53
+
54
+ # ================================================================================
55
+ # Logging Configuration
56
+ # ================================================================================
57
+
58
+
59
+ logging.basicConfig(
60
+ level=logging.INFO,
61
+ format='%(asctime)s - [%(levelname)s] %(message)s'
62
+ )
63
+ logger = logging.getLogger(__name__)
64
+
65
+
66
+ # ================================================================================
67
+ # Constants
68
+ # ================================================================================
69
+
70
+
71
+ # ๊ธฐ๋ณธ ์„ค์ •๊ฐ’
72
+ DEFAULT_MARGIN = 10
73
+ DEFAULT_BRIGHTNESS_THRESHOLD = 127
74
+ DEFAULT_RUBBING_MIN_AREA_RATIO = 0.1
75
+ DEFAULT_TEXT_MIN_AREA = 16
76
+ DEFAULT_TEXT_AREA_RATIO = 0.00005
77
+ DEFAULT_MORPHOLOGY_KERNEL_SIZE = (2, 2)
78
+ DEFAULT_MORPHOLOGY_CLOSE_ITERATIONS = 3
79
+ DEFAULT_MORPHOLOGY_OPEN_ITERATIONS = 2
80
+ DEFAULT_RUBBING_KERNEL_SIZE = (5, 5)
81
+ DEFAULT_RUBBING_CLOSE_ITERATIONS = 10
82
+ DEFAULT_RUBBING_OPEN_ITERATIONS = 5
83
+
84
+
85
+ # ================================================================================
86
+ # Main Preprocessing Class
87
+ # ================================================================================
88
+
89
+
90
+ class UnifiedImagePreprocessor:
91
+ """
92
+ ํ†ตํ•ฉ ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ ํด๋ž˜์Šค (Swin + OCR)
93
+
94
+ ํ•œ ๋ฒˆ์˜ ์ฒ˜๋ฆฌ๋กœ Swin Gray์™€ OCR์šฉ ์ด๋ฏธ์ง€๋ฅผ ๋ชจ๋‘ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
95
+
96
+ Attributes:
97
+ config (dict): ์ „์ฒ˜๋ฆฌ ์„ค์ • ํŒŒ๋ผ๋ฏธํ„ฐ
98
+
99
+ Example:
100
+ >>> prep = UnifiedImagePreprocessor()
101
+ >>> result = prep.preprocess_unified("input.jpg", "swin.jpg", "ocr.png")
102
+ """
103
+
104
+ def __init__(self, config_path: Optional[str] = None) -> None:
105
+ """
106
+ UnifiedImagePreprocessor ์ดˆ๊ธฐํ™”
107
+
108
+ Args:
109
+ config_path (str, optional): ์„ค์ • ํŒŒ์ผ ๊ฒฝ๋กœ (JSON)
110
+ """
111
+ self.config = self._load_config(config_path)
112
+ logger.info("[INIT] UnifiedImagePreprocessor v1.0.0 ์ดˆ๊ธฐํ™” ์™„๋ฃŒ")
113
+
114
+ def _load_config(self, config_path: Optional[str]) -> Dict:
115
+ """์„ค์ • ํŒŒ์ผ ๋กœ๋“œ"""
116
+ default_config = {
117
+ "margin": DEFAULT_MARGIN,
118
+ "brightness_threshold": DEFAULT_BRIGHTNESS_THRESHOLD,
119
+ "rubbing_min_area_ratio": DEFAULT_RUBBING_MIN_AREA_RATIO,
120
+ "text_min_area": DEFAULT_TEXT_MIN_AREA,
121
+ "text_area_ratio": DEFAULT_TEXT_AREA_RATIO,
122
+ "morphology_kernel_size": DEFAULT_MORPHOLOGY_KERNEL_SIZE,
123
+ "morphology_close_iterations": DEFAULT_MORPHOLOGY_CLOSE_ITERATIONS,
124
+ "morphology_open_iterations": DEFAULT_MORPHOLOGY_OPEN_ITERATIONS,
125
+ "rubbing_kernel_size": DEFAULT_RUBBING_KERNEL_SIZE,
126
+ "rubbing_close_iterations": DEFAULT_RUBBING_CLOSE_ITERATIONS,
127
+ "rubbing_open_iterations": DEFAULT_RUBBING_OPEN_ITERATIONS,
128
+ }
129
+
130
+ # ๊ธฐ๋ณธ ์„ค์ • ํŒŒ์ผ ๊ฒฝ๋กœ (config_path๊ฐ€ ์—†์„ ๋•Œ)
131
+ if config_path is None:
132
+ default_config_path = Path(__file__).parent / "config" / "preprocess_config.json"
133
+ if default_config_path.exists():
134
+ config_path = str(default_config_path)
135
+
136
+ if config_path and Path(config_path).exists():
137
+ try:
138
+ with open(config_path, 'r', encoding='utf-8') as f:
139
+ user_config = json.load(f)
140
+ # _description ํ•„๋“œ๋Š” ์ œ์™ธํ•˜๊ณ  ์—…๋ฐ์ดํŠธ
141
+ user_config_clean = {k: v for k, v in user_config.items() if not k.startswith('_')}
142
+ default_config.update(user_config_clean)
143
+ logger.info(f"[CONFIG] ์„ค์ • ํŒŒ์ผ ๋กœ๋“œ: {config_path}")
144
+ except Exception as e:
145
+ logger.warning(f"[CONFIG] ์„ค์ • ํŒŒ์ผ ๋กœ๋“œ ์‹คํŒจ: {e} - ๊ธฐ๋ณธ ์„ค์ • ์‚ฌ์šฉ")
146
+
147
+ return default_config
148
+
149
+ def _find_rubbing_bbox(self, gray_image: np.ndarray) -> Optional[Tuple[int, int, int, int]]:
150
+ """
151
+ ํƒ๋ณธ ์˜์—ญ ๊ฒ€์ถœ (ํฐ ์–ด๋‘์šด ์‚ฌ๊ฐํ˜• ์ฐพ๊ธฐ)
152
+
153
+ Args:
154
+ gray_image (np.ndarray): ๊ทธ๋ ˆ์ด์Šค์ผ€์ผ ์ด๋ฏธ์ง€
155
+
156
+ Returns:
157
+ tuple: (x, y, w, h) ๋˜๋Š” None
158
+ """
159
+ H_img, W_img = gray_image.shape
160
+
161
+ # Step 1: ์–ด๋‘์šด ์˜์—ญ ์ถ”์ถœ
162
+ _, dark_mask = cv2.threshold(gray_image, 127, 255, cv2.THRESH_BINARY_INV)
163
+
164
+ # Step 2: ๋ชจํด๋กœ์ง€ ์—ฐ์‚ฐ
165
+ kernel_rub = np.ones(self.config["rubbing_kernel_size"], np.uint8)
166
+ dark_mask = cv2.morphologyEx(
167
+ dark_mask, cv2.MORPH_CLOSE, kernel_rub,
168
+ iterations=self.config["rubbing_close_iterations"]
169
+ )
170
+ dark_mask = cv2.morphologyEx(
171
+ dark_mask, cv2.MORPH_OPEN, kernel_rub,
172
+ iterations=self.config["rubbing_open_iterations"]
173
+ )
174
+
175
+ # Step 3: ์ปจํˆฌ์–ด ๊ฒ€์ถœ
176
+ contours, _ = cv2.findContours(dark_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
177
+
178
+ if not contours:
179
+ return None
180
+
181
+ # Step 4: ๊ฐ€์žฅ ํฐ ์ปจํˆฌ์–ด
182
+ largest = max(contours, key=cv2.contourArea)
183
+ area = cv2.contourArea(largest)
184
+
185
+ # Step 5: ๋ฉด์  ๊ฒ€์ฆ
186
+ min_area = (H_img * W_img) * self.config["rubbing_min_area_ratio"]
187
+ if area < min_area:
188
+ return None
189
+
190
+ return cv2.boundingRect(largest)
191
+
192
+ def _find_text_bbox(self, gray_image: np.ndarray) -> Tuple[int, int, int, int]:
193
+ """
194
+ ํ…์ŠคํŠธ ์˜์—ญ ๊ฒ€์ถœ
195
+
196
+ Args:
197
+ gray_image (np.ndarray): ๊ทธ๋ ˆ์ด์Šค์ผ€์ผ ์ด๋ฏธ์ง€
198
+
199
+ Returns:
200
+ tuple: (x, y, w, h)
201
+ """
202
+ H_img, W_img = gray_image.shape
203
+
204
+ # Step 1: Otsu ์ด์ง„ํ™”
205
+ _, binary = cv2.threshold(
206
+ gray_image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU
207
+ )
208
+
209
+ # Step 2: ๋ชจํด๋กœ์ง€ ์—ฐ์‚ฐ
210
+ kernel_morph = np.ones(self.config["morphology_kernel_size"], np.uint8)
211
+ binary = cv2.morphologyEx(
212
+ binary, cv2.MORPH_CLOSE, kernel_morph,
213
+ iterations=self.config["morphology_close_iterations"]
214
+ )
215
+ binary = cv2.morphologyEx(
216
+ binary, cv2.MORPH_OPEN, kernel_morph,
217
+ iterations=self.config["morphology_open_iterations"]
218
+ )
219
+
220
+ # Step 3: ์ปจํˆฌ์–ด ๊ฒ€์ถœ
221
+ contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
222
+
223
+ # Step 4: ์ตœ์†Œ ๋ฉด์  ์„ค์ •
224
+ min_area = max(
225
+ self.config["text_min_area"],
226
+ int((H_img * W_img) * self.config["text_area_ratio"])
227
+ )
228
+
229
+ # Step 5: ์œ ํšจํ•œ ์ปจํˆฌ์–ด ํ•„ํ„ฐ๋ง
230
+ valid_contours = [
231
+ cnt for cnt in contours
232
+ if cv2.contourArea(cv2.boundingRect(cnt)) >= min_area
233
+ ]
234
+
235
+ # Step 6: ๊ฒฝ๊ณ„๋ฐ•์Šค ๊ณ„์‚ฐ
236
+ if valid_contours:
237
+ all_points = np.vstack(valid_contours)
238
+ return cv2.boundingRect(all_points)
239
+ else:
240
+ return (0, 0, W_img, H_img)
241
+
242
+ def _apply_margin(
243
+ self,
244
+ bbox: Tuple[int, int, int, int],
245
+ gray_image: np.ndarray,
246
+ margin_val: int
247
+ ) -> Tuple[int, int, int, int]:
248
+ """์—ฌ๋ฐฑ ์ถ”๊ฐ€"""
249
+ x, y, w, h = bbox
250
+ H_img, W_img = gray_image.shape
251
+
252
+ x_new = max(0, x - margin_val)
253
+ y_new = max(0, y - margin_val)
254
+ w_new = min(W_img - x_new, w + 2 * margin_val)
255
+ h_new = min(H_img - y_new, h + 2 * margin_val)
256
+
257
+ return (x_new, y_new, w_new, h_new)
258
+
259
+ def _ensure_bright_background(
260
+ self,
261
+ gray_cropped: np.ndarray
262
+ ) -> Tuple[np.ndarray, Dict]:
263
+ """
264
+ ๋ฐ์€๋ฐฐ๊ฒฝ ๋ณด์žฅ (Swin์šฉ)
265
+
266
+ Returns:
267
+ tuple: (์ฒ˜๋ฆฌ๋œ ๊ทธ๋ ˆ์ด ์ด๋ฏธ์ง€, ์ฒ˜๋ฆฌ ์ •๋ณด)
268
+ """
269
+ mean_brightness = np.mean(gray_cropped)
270
+ is_inverted = False
271
+
272
+ if mean_brightness < self.config["brightness_threshold"]:
273
+ gray_bright = cv2.bitwise_not(gray_cropped)
274
+ is_inverted = True
275
+ else:
276
+ gray_bright = gray_cropped.copy()
277
+
278
+ # ์žฌํ™•์ธ
279
+ final_brightness = np.mean(gray_bright)
280
+ if final_brightness < self.config["brightness_threshold"]:
281
+ gray_bright = cv2.bitwise_not(gray_bright)
282
+ is_inverted = not is_inverted
283
+ final_brightness = np.mean(gray_bright)
284
+
285
+ return gray_bright, {
286
+ "mean_brightness_before": float(mean_brightness),
287
+ "mean_brightness_after": float(final_brightness),
288
+ "is_inverted": is_inverted,
289
+ "is_bright_bg": final_brightness >= self.config["brightness_threshold"]
290
+ }
291
+
292
+ def _ensure_white_background(
293
+ self,
294
+ gray_cropped: np.ndarray
295
+ ) -> Tuple[np.ndarray, Dict]:
296
+ """
297
+ ํฐ๋ฐฐ๊ฒฝ ๋ณด์žฅ (OCR์šฉ)
298
+
299
+ Returns:
300
+ tuple: (์ฒ˜๋ฆฌ๋œ ์ด์ง„ ์ด๋ฏธ์ง€, ์ฒ˜๋ฆฌ ์ •๋ณด)
301
+ """
302
+ # Step 1: ์ด์ง„ํ™”
303
+ _, binary = cv2.threshold(
304
+ gray_cropped, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU
305
+ )
306
+
307
+ # Step 2: ํด๋ผ๋ฆฌํ‹ฐ ํŒ๋‹จ
308
+ mean_brightness = np.mean(binary)
309
+
310
+ # Step 3: ํ•„์š”์‹œ ๋ฐ˜์ „
311
+ if mean_brightness < self.config["brightness_threshold"]:
312
+ binary_final = cv2.bitwise_not(binary)
313
+ polarity = "inverted"
314
+ else:
315
+ binary_final = binary
316
+ polarity = "normal"
317
+
318
+ final_brightness = np.mean(binary_final)
319
+
320
+ return binary_final, {
321
+ "mean_brightness_before": float(mean_brightness),
322
+ "mean_brightness_after": float(final_brightness),
323
+ "polarity": polarity,
324
+ "is_white_bg": final_brightness > self.config["brightness_threshold"]
325
+ }
326
+
327
+ def preprocess_unified(
328
+ self,
329
+ input_image_path: str,
330
+ output_swin_path: str,
331
+ output_ocr_path: str,
332
+ margin: Optional[int] = None,
333
+ use_rubbing: bool = False
334
+ ) -> Dict:
335
+ """
336
+ ํ†ตํ•ฉ ์ „์ฒ˜๋ฆฌ (Swin Gray + OCR ๋™์‹œ ์ƒ์„ฑ)
337
+
338
+ ํ•œ ๋ฒˆ์˜ ํ•จ์ˆ˜ ํ˜ธ์ถœ๋กœ Swin Gray์™€ OCR์šฉ ์ด๋ฏธ์ง€๋ฅผ ๋ชจ๋‘ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
339
+ ํƒ๋ณธ ๋ฐ ํ…์ŠคํŠธ ์˜์—ญ ๊ฒ€์ถœ์€ 1ํšŒ๋งŒ ์ˆ˜ํ–‰๋˜์–ด ํšจ์œจ์„ฑ์„ ๋ณด์žฅํ•ฉ๋‹ˆ๋‹ค.
340
+
341
+ Args:
342
+ input_image_path (str): ์ž…๋ ฅ ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ
343
+ output_swin_path (str): Swin Gray ์ถœ๋ ฅ ๊ฒฝ๋กœ (JPG)
344
+ output_ocr_path (str): OCR ์ถœ๋ ฅ ๊ฒฝ๋กœ (PNG)
345
+ margin (int, optional): ํฌ๋กญ ์—ฌ๋ฐฑ (ํ”ฝ์…€)
346
+ use_rubbing (bool): ํƒ๋ณธ ๊ฒ€์ถœ ์—ฌ๋ถ€ (๊ธฐ๋ณธ: False)
347
+
348
+ Returns:
349
+ dict: ์ฒ˜๋ฆฌ ๊ฒฐ๊ณผ
350
+ ์„ฑ๊ณต ์‹œ: {
351
+ "success": True,
352
+ "original_shape": (H, W, C),
353
+ "bbox": (x, y, w, h),
354
+ "region_type": "text" or "rubbing",
355
+ "region_detected": bool,
356
+
357
+ "swin": {
358
+ "output_path": str,
359
+ "output_shape": (H, W, 3),
360
+ "is_bright_bg": bool,
361
+ ...
362
+ },
363
+
364
+ "ocr": {
365
+ "output_path": str,
366
+ "output_shape": (H, W),
367
+ "is_white_bg": bool,
368
+ ...
369
+ }
370
+ }
371
+
372
+ ์‹คํŒจ ์‹œ: {
373
+ "success": False,
374
+ "message": str
375
+ }
376
+
377
+ Processing Steps:
378
+ 1. ์ด๋ฏธ์ง€ ๋กœ๋“œ
379
+ 2. ๊ทธ๋ ˆ์ด์Šค์ผ€์ผ ๋ณ€ํ™˜
380
+ 3. ์˜์—ญ ๊ฒ€์ถœ (ํƒ๋ณธ ๋˜๋Š” ํ…์ŠคํŠธ, 1ํšŒ๋งŒ)
381
+ 4. ํฌ๋กญ + ์—ฌ๋ฐฑ
382
+ 5. Swin Gray ์ฒ˜๋ฆฌ (๋ฐ์€๋ฐฐ๊ฒฝ ๋ณด์žฅ)
383
+ 6. OCR ์ฒ˜๋ฆฌ (์ด์ง„ํ™” + ํฐ๋ฐฐ๊ฒฝ ๋ณด์žฅ)
384
+ 7. ๋™์‹œ ์ €์žฅ
385
+
386
+ Output:
387
+ - Swin: JPG 3์ฑ„๋„ (๋น„์ด์ง„ํ™” 256๋‹จ๊ณ„)
388
+ - OCR: PNG 1์ฑ„๋„ (์ด์ง„ํ™”)
389
+
390
+ Example:
391
+ >>> prep = UnifiedImagePreprocessor()
392
+ >>> result = prep.preprocess_unified(
393
+ ... "input.jpg",
394
+ ... "swin.jpg",
395
+ ... "ocr.png"
396
+ ... )
397
+ >>> if result["success"]:
398
+ ... swin_output = result["swin"]["output_path"]
399
+ ... ocr_output = result["ocr"]["output_path"]
400
+ """
401
+ margin_val = margin or self.config["margin"]
402
+
403
+ try:
404
+ # ====================================================================
405
+ # Step 1: ์ด๋ฏธ์ง€ ๋กœ๋“œ
406
+ # ====================================================================
407
+ img_bgr = cv2.imread(str(input_image_path), cv2.IMREAD_COLOR)
408
+ if img_bgr is None:
409
+ raise ValueError(f"์ด๋ฏธ์ง€ ๋กœ๋“œ ์‹คํŒจ: {input_image_path}")
410
+
411
+ original_shape = img_bgr.shape
412
+ logger.info(f"[LOAD] ์ด๋ฏธ์ง€ ๋กœ๋“œ: {input_image_path} {original_shape}")
413
+
414
+ # ====================================================================
415
+ # Step 2: ๊ทธ๋ ˆ์ด์Šค์ผ€์ผ ๋ณ€ํ™˜
416
+ # ====================================================================
417
+ gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
418
+
419
+ # ====================================================================
420
+ # Step 3: ์˜์—ญ ๊ฒ€์ถœ (ํƒ๋ณธ ๋˜๋Š” ํ…์ŠคํŠธ)
421
+ # ====================================================================
422
+ if use_rubbing:
423
+ detected_bbox = self._find_rubbing_bbox(gray)
424
+ region_type = "rubbing"
425
+ logger.info("[DETECT] ํƒ๋ณธ ์˜์—ญ ๊ฒ€์ถœ ๋ชจ๋“œ")
426
+ else:
427
+ detected_bbox = None
428
+ region_type = "text"
429
+ logger.info("[DETECT] ํ…์ŠคํŠธ ์˜์—ญ ๊ฒ€์ถœ ๋ชจ๋“œ")
430
+
431
+ H_img, W_img = gray.shape
432
+
433
+ # ====================================================================
434
+ # Step 4: ํฌ๋กญ + ์—ฌ๋ฐฑ
435
+ # ====================================================================
436
+ if detected_bbox is not None:
437
+ bbox_final = self._apply_margin(detected_bbox, gray, margin_val)
438
+ logger.info(f"[DETECT] {region_type} ์˜์—ญ ๊ฒ€์ถœ: {bbox_final}")
439
+ else:
440
+ # ํƒ๋ณธ ๋ฏธ๊ฒ€์ถœ ๋˜๋Š” ํ…์ŠคํŠธ ๋ชจ๋“œ -> ํ…์ŠคํŠธ ๊ฒ€์ถœ
441
+ if use_rubbing:
442
+ bbox_final = (0, 0, W_img, H_img)
443
+ logger.warning("[DETECT] ํƒ๋ณธ ๋ฏธ๊ฒ€์ถœ - ์ „์ฒด ์ด๋ฏธ์ง€ ์‚ฌ์šฉ")
444
+ else:
445
+ bbox_text = self._find_text_bbox(gray)
446
+ bbox_final = self._apply_margin(bbox_text, gray, margin_val)
447
+ logger.info(f"[DETECT] ํ…์ŠคํŠธ ์˜์—ญ ๊ฒ€์ถœ: {bbox_final}")
448
+
449
+ x, y, w, h = bbox_final
450
+ gray_cropped = gray[y:y+h, x:x+w]
451
+
452
+ logger.info(f"[CROP] ํฌ๋กญ ์™„๋ฃŒ: {gray_cropped.shape}")
453
+
454
+ # ====================================================================
455
+ # Step 5: Swin Gray ์ฒ˜๋ฆฌ
456
+ # ====================================================================
457
+ gray_bright, info_swin = self._ensure_bright_background(gray_cropped)
458
+ swin_output_3ch = cv2.cvtColor(gray_bright, cv2.COLOR_GRAY2BGR)
459
+
460
+ # ====================================================================
461
+ # Step 6: OCR ์ฒ˜๋ฆฌ
462
+ # ====================================================================
463
+ binary_final, info_ocr = self._ensure_white_background(gray_cropped)
464
+
465
+ # ====================================================================
466
+ # Step 7: ๋™์‹œ ์ €์žฅ
467
+ # ====================================================================
468
+ output_swin_path_obj = Path(output_swin_path)
469
+ output_swin_path_obj.parent.mkdir(parents=True, exist_ok=True)
470
+ swin_success = cv2.imwrite(str(output_swin_path_obj), swin_output_3ch)
471
+
472
+ output_ocr_path_obj = Path(output_ocr_path)
473
+ output_ocr_path_obj.parent.mkdir(parents=True, exist_ok=True)
474
+ ocr_success = cv2.imwrite(str(output_ocr_path_obj), binary_final)
475
+
476
+ if not swin_success or not ocr_success:
477
+ raise ValueError("์ด๋ฏธ์ง€ ์ €์žฅ ์‹คํŒจ")
478
+
479
+ logger.info(f"[SAVE] Swin ์ €์žฅ: {output_swin_path_obj}")
480
+ logger.info(f"[SAVE] OCR ์ €์žฅ: {output_ocr_path_obj}")
481
+
482
+ # ====================================================================
483
+ # ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜
484
+ # ====================================================================
485
+ return {
486
+ "success": True,
487
+ "version": "Unified Swin Gray + OCR (v1.0.0)",
488
+ "original_shape": original_shape,
489
+ "bbox": bbox_final,
490
+ "region_type": region_type,
491
+ "region_detected": detected_bbox is not None,
492
+
493
+ # Swin ๋ถ€๋ถ„
494
+ "swin": {
495
+ "output_path": str(output_swin_path_obj).replace("\\", "/"),
496
+ "output_shape": swin_output_3ch.shape,
497
+ "color_type": "Grayscale 3์ฑ„๋„ (B=G=R, ๋น„์ด์ง„ํ™” 256๋‹จ๊ณ„)",
498
+ "is_inverted": info_swin["is_inverted"],
499
+ "mean_brightness_before": info_swin["mean_brightness_before"],
500
+ "mean_brightness_after": info_swin["mean_brightness_after"],
501
+ "is_bright_bg": info_swin["is_bright_bg"]
502
+ },
503
+
504
+ # OCR ๋ถ€๋ถ„
505
+ "ocr": {
506
+ "output_path": str(output_ocr_path_obj).replace("\\", "/"),
507
+ "output_shape": binary_final.shape,
508
+ "polarity": info_ocr["polarity"],
509
+ "mean_brightness_before": info_ocr["mean_brightness_before"],
510
+ "mean_brightness_after": info_ocr["mean_brightness_after"],
511
+ "is_white_bg": info_ocr["is_white_bg"]
512
+ },
513
+
514
+ "message": "[DONE] ํ†ตํ•ฉ ์ „์ฒ˜๋ฆฌ ์™„๋ฃŒ (Swin + OCR)"
515
+ }
516
+
517
+ except Exception as e:
518
+ logger.error(f"[ERROR] ํ†ตํ•ฉ ์ „์ฒ˜๋ฆฌ ์‹คํŒจ: {e}")
519
+ return {
520
+ "success": False,
521
+ "message": str(e)
522
+ }
523
+
524
+
525
+ # ================================================================================
526
+ # Global Instance & Convenience Functions
527
+ # ================================================================================
528
+
529
+
530
+ _global_preprocessor = None
531
+
532
+
533
+ def get_preprocessor(config_path: Optional[str] = None) -> UnifiedImagePreprocessor:
534
+ """์ „์—ญ ์ „์ฒ˜๋ฆฌ๊ธฐ ์ธ์Šคํ„ด์Šค ๋ฐ˜ํ™˜"""
535
+ global _global_preprocessor
536
+ if _global_preprocessor is None:
537
+ _global_preprocessor = UnifiedImagePreprocessor(config_path)
538
+ return _global_preprocessor
539
+
540
+
541
+ def preprocess_image_unified(
542
+ input_path: str,
543
+ output_swin_path: str,
544
+ output_ocr_path: str,
545
+ margin: Optional[int] = None,
546
+ use_rubbing: bool = False
547
+ ) -> Dict:
548
+ """
549
+ ํŽธ์˜ ํ•จ์ˆ˜: ํ†ตํ•ฉ ์ „์ฒ˜๋ฆฌ
550
+
551
+ Args:
552
+ input_path (str): ์ž…๋ ฅ ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ
553
+ output_swin_path (str): Swin ์ถœ๋ ฅ ๊ฒฝ๋กœ
554
+ output_ocr_path (str): OCR ์ถœ๋ ฅ ๊ฒฝ๋กœ
555
+ margin (int, optional): ์—ฌ๋ฐฑ
556
+ use_rubbing (bool): ํƒ๋ณธ ๋ชจ๋“œ
557
+
558
+ Returns:
559
+ dict: ์ฒ˜๋ฆฌ ๊ฒฐ๊ณผ
560
+ """
561
+ prep = get_preprocessor()
562
+ return prep.preprocess_unified(
563
+ input_path,
564
+ output_swin_path,
565
+ output_ocr_path,
566
+ margin,
567
+ use_rubbing
568
+ )
569
+
570
+
571
+ # ================================================================================
572
+ # Usage Example
573
+ # ================================================================================
574
+
575
+
576
+ if __name__ == "__main__":
577
+ """
578
+ ํ…Œ์ŠคํŠธ ์˜ˆ์‹œ
579
+ """
580
+ logger.info("=" * 80)
581
+ logger.info("[TEST] Unified Image Preprocessor v1.0.0 - ํ…Œ์ŠคํŠธ ์‹œ์ž‘")
582
+ logger.info("=" * 80)
583
+
584
+ try:
585
+ prep = UnifiedImagePreprocessor()
586
+
587
+ result = prep.preprocess_unified(
588
+ "test_input.jpg",
589
+ "test_swin.jpg",
590
+ "test_ocr.png"
591
+ )
592
+
593
+ if result["success"]:
594
+ logger.info("[TEST] ํ†ตํ•ฉ ์ „์ฒ˜๋ฆฌ ์„ฑ๊ณต!")
595
+ logger.info(f"[TEST] Swin: {result['swin']['output_path']}")
596
+ logger.info(f"[TEST] OCR: {result['ocr']['output_path']}")
597
+ logger.info(f"[TEST] Swin ๋ฐ์€๋ฐฐ๊ฒฝ: {'Yes' if result['swin']['is_bright_bg'] else 'No'}")
598
+ logger.info(f"[TEST] OCR ํฐ๋ฐฐ๊ฒฝ: {'Yes' if result['ocr']['is_white_bg'] else 'No'}")
599
+ else:
600
+ logger.error(f"[TEST] ์‹คํŒจ: {result['message']}")
601
+
602
+ except Exception as e:
603
+ logger.error(f"[TEST] ์˜ˆ์™ธ: {e}")
604
+
605
+ logger.info("=" * 80)
dong_ocr.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ ๋…๋ฆฝ ์‹คํ–‰ ๊ฐ€๋Šฅํ•œ OCR ์Šคํฌ๋ฆฝํŠธ
4
+ Google Vision API + HRCenterNet ์•™์ƒ๋ธ” ๊ธฐ๋ฐ˜ ํ•œ์ž OCR ๋ฐ ์†์ƒ ์˜์—ญ ํƒ์ง€
5
+
6
+ ์ˆ˜์ •์‚ฌํ•ญ:
7
+ 1. ์ขŒํ‘œ(X๊ฐ’) ๋ณ€ํ™”๋ฅผ ๊ฐ์ง€ํ•˜์—ฌ ์ž๋™์œผ๋กœ ์—ด(Column)์„ ๊ตฌ๋ถ„ํ•˜์—ฌ ์ถœ๋ ฅํ•˜๋Š” ๋กœ์ง ์ถ”๊ฐ€
8
+ 2. [MASK] ์ขŒํ‘œ ๋“ฑ ์†Œ์ˆ˜์  ์˜์—ญ ์†์‹ค ๋ฐฉ์ง€๋ฅผ ์œ„ํ•œ Safe Crop(๋‚ด๋ฆผ/์˜ฌ๋ฆผ) ์ ์šฉ
9
+ -> ์‹œ๊ฐํ™” ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ JSON ๊ฒฐ๊ณผ ๋ฐ์ดํ„ฐ ์ž์ฒด์—๋„ ์ ์šฉํ•˜์—ฌ ์†Œ์ˆ˜์  ์ œ๊ฑฐ
10
+ """
11
+
12
+ import os
13
+ import sys
14
+ import json
15
+ import logging
16
+ import cv2
17
+ import math
18
+ import numpy as np
19
+ from pathlib import Path
20
+ from dotenv import load_dotenv
21
+
22
+ # ํ˜„์žฌ ์Šคํฌ๋ฆฝํŠธ์˜ ๋””๋ ‰ํ† ๋ฆฌ๋ฅผ Python ๊ฒฝ๋กœ์— ์ถ”๊ฐ€
23
+ current_dir = os.path.dirname(os.path.abspath(__file__))
24
+ if current_dir not in sys.path:
25
+ sys.path.insert(0, current_dir)
26
+
27
+ # ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ๋กœ๋“œ
28
+ load_dotenv()
29
+
30
+ # ๋กœ๊น… ์„ค์ •
31
+ logging.basicConfig(
32
+ level=logging.INFO,
33
+ format='%(asctime)s - [%(levelname)s] %(message)s'
34
+ )
35
+ logger = logging.getLogger("DONG_OCR")
36
+
37
+ # OCR ์—”์ง„ ๋ฐ ์ „์ฒ˜๋ฆฌ ๋ชจ๋“ˆ import
38
+ try:
39
+ from ai_modules.ocr_engine import get_ocr_engine
40
+ from ai_modules.preprocessor_unified import preprocess_image_unified
41
+ except ImportError as e:
42
+ logger.error(f"โŒ ๋ชจ๋“ˆ import ์‹คํŒจ: {e}")
43
+ sys.exit(1)
44
+
45
+
46
+ def format_ocr_results(raw_results, image_filename):
47
+ """
48
+ OCR ๊ฒฐ๊ณผ๋ฅผ ์š”์ฒญํ•˜์‹  JSON ํฌ๋งท์œผ๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜
49
+ ์ˆ˜์ •: JSON ์ €์žฅ ์‹œ์—๋„ Safe Crop(๋‚ด๋ฆผ/์˜ฌ๋ฆผ)์„ ์ ์šฉํ•˜์—ฌ ์ •์ˆ˜๋กœ ๋ณ€ํ™˜
50
+ """
51
+ formatted_list = []
52
+
53
+ if raw_results is None:
54
+ raw_results = []
55
+
56
+ if not raw_results:
57
+ return {"image": image_filename, "results": []}
58
+
59
+ order_counter = 0
60
+ for idx, item in enumerate(raw_results):
61
+ if not isinstance(item, dict): continue
62
+
63
+ min_x, min_y, max_x, max_y = 0.0, 0.0, 0.0, 0.0
64
+
65
+ # 1. ์ด๋ฏธ 'box' ๋ฆฌ์ŠคํŠธ๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ
66
+ if 'box' in item and isinstance(item['box'], list) and len(item['box']) == 4:
67
+ try:
68
+ min_x, min_y, max_x, max_y = map(float, item['box'])
69
+ except: pass
70
+
71
+ # 2. 'box'๊ฐ€ ์—†์œผ๋ฉด ๊ฐœ๋ณ„ ์ขŒํ‘œ ํ‚ค ์‚ฌ์šฉ
72
+ if min_x == 0 and max_x == 0:
73
+ mx = item.get('min_x')
74
+ my = item.get('min_y')
75
+ Mx = item.get('max_x')
76
+ My = item.get('max_y')
77
+
78
+ if mx is None: mx = item.get('x', 0)
79
+ if my is None: my = item.get('y', 0)
80
+ if Mx is None:
81
+ Mx = item.get('x2')
82
+ if Mx is None:
83
+ width = item.get('width', 0)
84
+ Mx = mx + width if width > 0 else 0
85
+ if My is None:
86
+ My = item.get('y2')
87
+ if My is None:
88
+ height = item.get('height', 0)
89
+ My = my + height if height > 0 else 0
90
+
91
+ try:
92
+ min_x, min_y, max_x, max_y = float(mx), float(my), float(Mx), float(My)
93
+ except: continue
94
+
95
+ if min_x == 0 and min_y == 0 and max_x == 0 and max_y == 0:
96
+ width = item.get('width', 0)
97
+ height = item.get('height', 0)
98
+ if width > 0 and height > 0:
99
+ cx, cy = item.get('center_x', width/2), item.get('center_y', height/2)
100
+ min_x, min_y = cx - width/2, cy - height/2
101
+ max_x, max_y = cx + width/2, cy + height/2
102
+ else: continue
103
+
104
+ if max_x <= min_x or max_y <= min_y: continue
105
+
106
+ # === [์ถ”๊ฐ€๋จ] JSON ๋ฐ์ดํ„ฐ ์ž์ฒด์— Safe Crop ์ ์šฉ (์†Œ์ˆ˜์  ์ œ๊ฑฐ) ===
107
+ # min ์ขŒํ‘œ๋Š” ๋‚ด๋ฆผ(floor), max ์ขŒํ‘œ๋Š” ์˜ฌ๋ฆผ(ceil)ํ•˜์—ฌ ์˜์—ญ ํ™•๋ณด ํ›„ ์ •์ˆ˜ ๋ณ€ํ™˜
108
+ min_x = int(math.floor(min_x))
109
+ min_y = int(math.floor(min_y))
110
+ max_x = int(math.ceil(max_x))
111
+ max_y = int(math.ceil(max_y))
112
+
113
+ # ์Œ์ˆ˜ ์ขŒํ‘œ ๋ฐฉ์ง€ (์ตœ์†Œ 0)
114
+ min_x = max(0, min_x)
115
+ min_y = max(0, min_y)
116
+ # ==========================================================
117
+
118
+ new_item = {
119
+ "order": order_counter,
120
+ "text": item.get('text', ''),
121
+ "type": item.get('type', 'TEXT'),
122
+ "box": [min_x, min_y, max_x, max_y],
123
+ "confidence": float(item.get('confidence', 0.0)),
124
+ "source": item.get('source', 'Unknown')
125
+ }
126
+ formatted_list.append(new_item)
127
+ order_counter += 1
128
+
129
+ return {"image": image_filename, "results": formatted_list}
130
+
131
+
132
+ def draw_bboxes(image_path, results, output_path):
133
+ """์ด๋ฏธ์ง€์— Bounding Box ๊ทธ๋ฆฌ๊ธฐ (Safe Crop ์ ์šฉ)"""
134
+ try:
135
+ img_array = np.fromfile(image_path, np.uint8)
136
+ img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
137
+ if img is None:
138
+ img = cv2.imread(image_path)
139
+ if img is None: return
140
+
141
+ box_count = 0
142
+ colors = {
143
+ 'Google': (0, 255, 0), 'Custom': (255, 0, 255),
144
+ 'MASK1': (255, 0, 0), 'MASK2': (0, 0, 255), 'Default': (0, 255, 255)
145
+ }
146
+
147
+ for item in results:
148
+ box = item.get('box', [])
149
+ if len(box) != 4: continue
150
+ try:
151
+ # format_ocr_results์—์„œ ์ด๋ฏธ ์ •์ˆ˜๋กœ ๋ณ€ํ™˜๋˜์–ด ์˜ค์ง€๋งŒ,
152
+ # ์•ˆ์ „์„ ์œ„ํ•ด ํ•œ ๋ฒˆ ๋” ์ฒ˜๋ฆฌ (float๋กœ ๋“ค์–ด์™€๋„ ์ฒ˜๋ฆฌ ๊ฐ€๋Šฅํ•˜๋„๋ก ์œ ์ง€)
153
+ x1 = int(math.floor(float(box[0])))
154
+ y1 = int(math.floor(float(box[1])))
155
+ x2 = int(math.ceil(float(box[2])))
156
+ y2 = int(math.ceil(float(box[3])))
157
+ except: continue
158
+
159
+ h, w = img.shape[:2]
160
+ # ์ด๋ฏธ์ง€ ๋ฒ”์œ„ ๋ฒ—์–ด๋‚˜์ง€ ์•Š๊ฒŒ ํด๋ฆฌํ•‘
161
+ x1 = max(0, min(x1, w-1))
162
+ y1 = max(0, min(y1, h-1))
163
+ x2 = max(x1+1, min(x2, w))
164
+ y2 = max(y1+1, min(y2, h))
165
+
166
+ text = item.get('text', '')
167
+ source = item.get('source', '')
168
+ itype = item.get('type', 'TEXT')
169
+
170
+ if 'MASK1' in itype or '[MASK1]' in text: color = colors['MASK1']
171
+ elif 'MASK2' in itype or '[MASK2]' in text: color = colors['MASK2']
172
+ elif source in colors: color = colors[source]
173
+ else: color = colors['Default']
174
+
175
+ cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
176
+
177
+ if itype == 'TEXT' and len(text) <= 2:
178
+ cv2.putText(img, text, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
179
+ elif 'MASK' in itype:
180
+ label = '[M1]' if itype == 'MASK1' else '[M2]'
181
+ cv2.putText(img, label, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1)
182
+
183
+ box_count += 1
184
+
185
+ ext = os.path.splitext(output_path)[1].lower()
186
+ params = [int(cv2.IMWRITE_JPEG_QUALITY), 95] if ext in ['.jpg', '.jpeg'] else [int(cv2.IMWRITE_PNG_COMPRESSION), 3]
187
+ result, encoded_img = cv2.imencode(ext, img, params)
188
+ if result:
189
+ with open(output_path, mode='wb') as f: encoded_img.tofile(f)
190
+ logger.info(f"๐Ÿ–ผ๏ธ B-Box ์ด๋ฏธ์ง€ ์ €์žฅ๋จ: {output_path} ({box_count}๊ฐœ ๋ฐ•์Šค)")
191
+
192
+ except Exception as e:
193
+ logger.error(f"โŒ ์‹œ๊ฐํ™” ์ค‘ ์˜ค๋ฅ˜: {e}")
194
+
195
+
196
+ def run_ocr(image_path, use_preprocessing=True):
197
+ """OCR ์‹คํ–‰, ๊ฒฐ๊ณผ ์ถœ๋ ฅ ๋ฐ ์ €์žฅ"""
198
+ if not os.path.exists(image_path):
199
+ logger.error(f"โŒ ์ด๋ฏธ์ง€ ์—†์Œ: {image_path}")
200
+ return False
201
+
202
+ logger.info(f"๐Ÿš€ OCR ๋ถ„์„ ์‹œ์ž‘: {image_path}")
203
+
204
+ try:
205
+ # 1. ์ „์ฒ˜๋ฆฌ
206
+ ocr_image_path = image_path
207
+ preprocess_result = {'success': False}
208
+ if use_preprocessing:
209
+ logger.info("๐Ÿ“ธ ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ ์ค‘...")
210
+ base_dir = os.path.dirname(os.path.abspath(image_path))
211
+ base_name = os.path.splitext(os.path.basename(image_path))[0]
212
+ swin_path = os.path.join(base_dir, f"{base_name}_swin_temp.jpg")
213
+ ocr_preprocessed_path = os.path.join(base_dir, f"{base_name}_ocr_temp.png")
214
+
215
+ preprocess_result = preprocess_image_unified(
216
+ input_path=image_path, output_swin_path=swin_path,
217
+ output_ocr_path=ocr_preprocessed_path, use_rubbing=True
218
+ )
219
+ if preprocess_result.get('success'):
220
+ ocr_image_path = ocr_preprocessed_path
221
+ logger.info(f"โœ… ์ „์ฒ˜๋ฆฌ ์™„๋ฃŒ: {ocr_preprocessed_path}")
222
+ else:
223
+ logger.warning(f"โš ๏ธ ์ „์ฒ˜๋ฆฌ ์‹คํŒจ: {preprocess_result.get('message')}")
224
+
225
+ # 2. ์—”์ง„ ์‹คํ–‰
226
+ engine = get_ocr_engine()
227
+ logger.info("โœ… OCR ์—”์ง„ ๋กœ๋“œ ์™„๋ฃŒ")
228
+
229
+ try:
230
+ raw_result = engine.run_ocr(ocr_image_path)
231
+ except Exception as e:
232
+ logger.error(f"โŒ OCR ์‹คํ–‰ ์˜ˆ์™ธ: {e}")
233
+ return False
234
+
235
+ if not raw_result: return False
236
+
237
+ is_success = raw_result.get('success', False)
238
+ if not is_success and 'results' in raw_result and isinstance(raw_result['results'], list):
239
+ is_success = True
240
+
241
+ if not is_success:
242
+ logger.error(f"โŒ OCR ์‹คํŒจ: {raw_result.get('error')}")
243
+ return False
244
+
245
+ logger.info("\n" + "="*60)
246
+ logger.info("โœ… OCR ๋ถ„์„ ์™„๋ฃŒ")
247
+
248
+ # 3. ๋ฐ์ดํ„ฐ ํฌ๋งทํŒ…
249
+ formatted_result = format_ocr_results(raw_result.get('results', []), os.path.basename(image_path))
250
+ results_list = formatted_result.get('results', [])
251
+
252
+ # 4. [์—ด ๊ตฌ๋ถ„ ์ถœ๋ ฅ ๋กœ์ง] ์ขŒํ‘œ ๊ธฐ๋ฐ˜์œผ๋กœ ์—ด์„ ๊ณ„์‚ฐํ•˜์—ฌ ์ถœ๋ ฅ
253
+ logger.info("\n" + "๐Ÿ“œ [ ์ธ์‹๋œ ํ…์ŠคํŠธ ๊ฒฐ๊ณผ (์ž๋™ ์—ด ๊ตฌ๋ถ„) ] " + "-"*25)
254
+
255
+ if not results_list:
256
+ logger.info(" (๊ฒฐ๊ณผ ์—†์Œ)")
257
+ else:
258
+ columns = []
259
+ current_col_text = []
260
+
261
+ # ์ฒซ ๋ฒˆ์งธ ๊ธ€์ž์˜ X ์ค‘์‹ฌ์  ๊ณ„์‚ฐ
262
+ first_box = results_list[0]['box']
263
+ prev_cx = (first_box[0] + first_box[2]) / 2
264
+
265
+ for item in results_list:
266
+ box = item['box']
267
+ curr_cx = (box[0] + box[2]) / 2
268
+
269
+ # ํ…์ŠคํŠธ ์ถ”์ถœ (MASK ์ฒ˜๋ฆฌ)
270
+ text = item.get('text', '')
271
+ if item.get('type') in ['MASK1', 'MASK2']:
272
+ text = f"[{item.get('type')}]"
273
+
274
+ # === ์—ด ๊ตฌ๋ถ„ ํ•ต์‹ฌ ๋กœ์ง ===
275
+ # ์ด์ „ ๊ธ€์ž์™€ X์ขŒํ‘œ ์ค‘์‹ฌ์ด 50ํ”ฝ์…€ ์ด์ƒ ์ฐจ์ด๋‚˜๋ฉด ์ƒˆ๋กœ์šด ์—ด๋กœ ๊ฐ„์ฃผ
276
+ # (์ผ๋ฐ˜์ ์œผ๋กœ ์„ธ๋กœ์“ฐ๊ธฐ์—์„œ ์ค„๋ฐ”๊ฟˆ ์‹œ X์ขŒํ‘œ๊ฐ€ ํฌ๊ฒŒ ๋ณ€ํ•จ)
277
+ if abs(curr_cx - prev_cx) > 50:
278
+ if current_col_text:
279
+ columns.append("".join(current_col_text))
280
+ current_col_text = []
281
+ prev_cx = curr_cx # ์ƒˆ๋กœ์šด ์—ด์˜ ๊ธฐ์ค€์œผ๋กœ ๊ฐฑ์‹ 
282
+
283
+ current_col_text.append(text)
284
+ # ๊ฐ™์€ ์—ด ๋‚ด์—์„œ๋Š” ๋ฏธ์„ธํ•œ X ํ”๋“ค๋ฆผ์ด ์žˆ์„ ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ prev_cx๋ฅผ ๊ณ„์† ๊ฐฑ์‹ ํ•˜์ง€ ์•Š๊ณ 
285
+ # ํ•ด๋‹น ์—ด์˜ '๋Œ€ํ‘œ' X๊ฐ’์„ ์œ ์ง€ํ•˜๊ฑฐ๋‚˜, ํ˜น์€ ๊ธ€์ž๋งˆ๋‹ค ๊ฐฑ์‹ ํ•  ์ˆ˜ ์žˆ์Œ.
286
+ # ์—ฌ๊ธฐ์„œ๋Š” ๊ธ€์ž๊ฐ€ ๋น„์Šค๋“ฌํ•  ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ ๋งค๋ฒˆ ๊ฐฑ์‹ ํ•˜๋Š” ๋ฐฉ์‹์„ ์”€
287
+ prev_cx = curr_cx
288
+
289
+ # ๋งˆ์ง€๋ง‰ ์—ด ์ถ”๊ฐ€
290
+ if current_col_text:
291
+ columns.append("".join(current_col_text))
292
+
293
+ # ์ถœ๋ ฅ
294
+ for idx, col_text in enumerate(columns, 1):
295
+ logger.info(f" [์—ด {idx:02d}] {col_text}")
296
+
297
+ logger.info("-" * 60 + "\n")
298
+
299
+ # 5. ๊ฒฐ๊ณผ ์ €์žฅ
300
+ json_path = os.path.splitext(image_path)[0] + "_ocr_result.json"
301
+ with open(json_path, 'w', encoding='utf-8') as f:
302
+ json.dump(formatted_result, f, ensure_ascii=False, indent=2)
303
+ logger.info(f"๐Ÿ’พ JSON ๊ฒฐ๊ณผ ์ €์žฅ๋จ: {json_path}")
304
+
305
+ # 6. ์‹œ๊ฐํ™” ์ €์žฅ
306
+ output_img_path = os.path.splitext(image_path)[0] + "_bbox.jpg"
307
+ bbox_image_path = ocr_image_path if use_preprocessing and preprocess_result.get('success') else image_path
308
+ draw_bboxes(bbox_image_path, results_list, output_img_path)
309
+
310
+ # 7. ํ†ต๊ณ„
311
+ counts = {'Google':0, 'Custom':0, 'MASK1':0, 'MASK2':0, 'TEXT':0}
312
+ for r in results_list:
313
+ if r['source'] in counts: counts[r['source']] += 1
314
+ if r['type'] in counts: counts[r['type']] += 1
315
+
316
+ logger.info("๐Ÿ“Š ์ตœ์ข… ํ†ต๊ณ„")
317
+ logger.info(f" - ๐ŸŸข Google: {counts['Google']}๊ฐœ")
318
+ logger.info(f" - ๐ŸŸฃ Custom: {counts['Custom']}๊ฐœ")
319
+ logger.info(f" - ๐Ÿ”ต MASK1: {counts['MASK1']}๊ฐœ")
320
+ logger.info(f" - ๐Ÿ”ด MASK2: {counts['MASK2']}๊ฐœ")
321
+ logger.info(f" - ๐Ÿ“ TEXT: {counts['TEXT']}๊ฐœ")
322
+ logger.info("="*60)
323
+
324
+ return True
325
+
326
+ except Exception as e:
327
+ logger.error(f"โŒ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}", exc_info=True)
328
+ return False
329
+
330
+
331
+ def main():
332
+ if len(sys.argv) < 2:
333
+ print("์‚ฌ์šฉ๋ฒ•: python dong_ocr.py <์ด๋ฏธ์ง€>")
334
+ sys.exit(1)
335
+
336
+ if not os.getenv('OCR_WEIGHTS_BASE_PATH') or not os.getenv('GOOGLE_CREDENTIALS_JSON'):
337
+ logger.error("โŒ ํ™˜๊ฒฝ๋ณ€์ˆ˜ ๋ฏธ์„ค์ •")
338
+ sys.exit(1)
339
+
340
+ if run_ocr(sys.argv[1]):
341
+ logger.info("โœ… ์ž‘์—… ์™„๋ฃŒ!")
342
+ sys.exit(0)
343
+ else:
344
+ logger.error("โŒ ์ž‘์—… ์‹คํŒจ")
345
+ sys.exit(1)
346
+
347
+
348
+ if __name__ == "__main__":
349
+ main()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ opencv-python
2
+ numpy
3
+ torch
4
+ torchvision
5
+ python-dotenv
6
+ Pillow
7
+ google-cloud-vision
8
+ huggingface-hub>=0.34.0,<1.0
weights/best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31db0eec96515dc820475df31245d7fe51ffcc56c76dc70df4e2bf83ff21d7e6
3
+ size 115004284
weights/best_5000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3180e120cb82a5dc2474448a3b61ea72b53e2507a2dc367bda76dc222a35ec6
3
+ size 62505977