leonge commited on
Commit
9f97a7a
·
1 Parent(s): fadb73b

Added first scan logic

Browse files
app.py CHANGED
@@ -1,7 +1,38 @@
1
  import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return "World of Content demo, hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import json
3
+ from PIL import Image
4
+ from utils import run_ocr, image_to_byte_array
5
 
6
+ def extract_all(image):
7
+ image = Image.fromarray(image)
8
 
9
+ image_byte_array = image_to_byte_array(image)
10
+ ocr_text_out, response_json = run_ocr(image_byte_array)
11
+ ocr_full_text = str(ocr_text_out.full_text).replace('\n', ' ')
12
+
13
+ # Extract attributes
14
+ #output_dictionary, product_description = exctract_fields(ocr_full_text)
15
+ output_dictionary = {
16
+ "GPC": 10000045,
17
+ "BRAND": "Kitkat",
18
+ "FUNCTIONAL NAME": "Chocolate Bar",
19
+ "Weight": "41.5",
20
+ "Unit": "Gr",
21
+ "Contact Information": {"Website": "www.kitkat.com", "Adress": " Nestlé Deutschland AG, 60523 \n Frankfurt am Main, Germany"},
22
+ "Allergen_Statement_NL": "Bevat: MELK, TARWE",
23
+ }
24
+
25
+ output_dict_json = str(json.dumps(output_dictionary,indent = 2))
26
+ return str(output_dict_json), ocr_full_text
27
+
28
+
29
+ output_df = [gr.Dataframe(label="Data")]
30
+
31
+
32
+ attributes_tbox = gr.Textbox(label='Attributes')
33
+ #product_description_tbox = gr.Textbox(label='Product Description')
34
+ ocr_output_tbox = gr.Textbox(label='OCR Output')
35
+ gr.Button.style("{color: blue}")
36
+ gr.Interface(fn=extract_all, inputs="image", outputs=[attributes_tbox, ocr_output_tbox],
37
+ title= 'World of Content', description="GS1 Global Extractor",
38
+ css="body {background-color: #F5F7FA} .gr-button.gr-button-primary {background-color: #0080fa; color:white; --tw-gradient-from:0} .gr-button.gr-button-secondary {background-color: #172533; color:white; --tw-gradient-from:0} h1 {background-image: url('file=woc-logo-black.1a4c4e90.svg'); background-size:contain; background-repeat:no-repeat; background-position:center; text-indent:-999999999px} .ou![](Data PoC 3/Testing Data/40052526_0003.png)tput-markdown p{text-align:center; font-size:24px}").launch()
base/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data_models import ( # noqa 401, 403
2
+ Allergen,
3
+ AllergensOut,
4
+ Attribute,
5
+ AttributeAllergen,
6
+ AttributeCommunicationChannel,
7
+ ClassifiedText,
8
+ CommunicationChannels,
9
+ CommunicationChannelsOut,
10
+ ModelOut,
11
+ ModelOutList,
12
+ NetContent,
13
+ NetContentAttribute,
14
+ NutrientTable,
15
+ NutrientTableDailyValueIntake,
16
+ NutrientTableElement,
17
+ NutrientTableQuantity,
18
+ OCROut,
19
+ OCROutList,
20
+ OCRTableOut,
21
+ OCRTextOut,
22
+ OCRWrapperOut,
23
+ PipelineInput,
24
+ PipelineOutput,
25
+ RedirectInput,
26
+ TextWithLanguage,
27
+ TrainModelOut,
28
+ )
29
+
30
+ from base.ocr import * # noqa 401, 403
base/__init__.py.bak ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # isort: skip_file
2
+ from kedro.extras.datasets.json import JSONDataSet # noqa 401
3
+ from kedro.extras.datasets.pandas import CSVDataSet # noqa 401
4
+ from kedro.extras.datasets.pandas import ParquetDataSet # noqa 401
5
+ from kedro.extras.datasets.pandas import ExcelDataSet as excel_dataset # noqa 401
6
+ from kedro.extras.datasets.pandas import JSONDataSet as pandas_json # noqa 401
7
+ from kedro.extras.datasets.pandas.parquet_dataset import ParquetDataSet # noqa 401
8
+ from kedro.extras.datasets.pillow import ImageDataSet # noqa 401
9
+ from kedro.extras.datasets.text import TextDataSet # noqa 401
10
+ from kedro.io.core import AbstractDataSet as DataSet # noqa 401
11
+
12
+ from certifai.base.data_models import ( # noqa 401, 403
13
+ Allergen,
14
+ AllergensOut,
15
+ Attribute,
16
+ AttributeAllergen,
17
+ AttributeCommunicationChannel,
18
+ ClassifiedText,
19
+ CommunicationChannels,
20
+ CommunicationChannelsOut,
21
+ ModelOut,
22
+ ModelOutList,
23
+ NetContent,
24
+ NetContentAttribute,
25
+ NutrientTable,
26
+ NutrientTableDailyValueIntake,
27
+ NutrientTableElement,
28
+ NutrientTableQuantity,
29
+ OCROut,
30
+ OCROutList,
31
+ OCRTableOut,
32
+ OCRTextOut,
33
+ OCRWrapperOut,
34
+ PipelineInput,
35
+ PipelineOutput,
36
+ RedirectInput,
37
+ TextWithLanguage,
38
+ TrainModelOut,
39
+ )
40
+ from certifai.base.abstract import BaseClassifier # noqa 401, 403
41
+ from certifai.base.s3_helper_functions import * # noqa 401, 403
42
+ from certifai.base.custom_datasets import * # noqa 401, 403
43
+ from certifai.base.ocr import * # noqa 401, 403
44
+ from kedro.extras.datasets.yaml import YAMLDataSet # noqa 401
base/data_models.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic models used throughout the codebase.
3
+
4
+ In particular, these are the types that are used as input and output of each step of the pipeline.
5
+ """
6
+ import json
7
+ from typing import Any, Optional, Union
8
+
9
+ from pydantic import BaseModel, Field
10
+
11
+
12
+ class RedirectInput(BaseModel):
13
+ pipeline_arn: str
14
+ job_id: str
15
+
16
+
17
+ class NutrientTableQuantity(BaseModel):
18
+ measurementUnitCode: str
19
+ value: str
20
+ precisionCode: str
21
+
22
+ def __str__(self):
23
+ return f"{self.precisionCode} {self.value} {self.measurementUnitCode}"
24
+
25
+
26
+ class NutrientTableDailyValueIntake(BaseModel):
27
+ value: str
28
+ precisionCode: str
29
+
30
+ def __str__(self):
31
+ return f"{self.precisionCode} {self.value}%"
32
+
33
+
34
+ def s(
35
+ text: Optional[Union[NutrientTableQuantity, NutrientTableDailyValueIntake, str]]
36
+ ) -> str:
37
+ """
38
+ Returns None as "?", otherwise argument as string.
39
+ """
40
+ if text:
41
+ return str(text)
42
+ else:
43
+ return "?"
44
+
45
+
46
+ class NutrientTableElement(BaseModel):
47
+ coordinates: str
48
+ probability: float
49
+ nutrientTypeCode: Optional[str]
50
+ quantityContained: NutrientTableQuantity
51
+ dailyValueIntakePercent: Optional[NutrientTableDailyValueIntake]
52
+ precisionCode: str
53
+
54
+ def __str__(self):
55
+ return " ".join(
56
+ [
57
+ s(self.nutrientTypeCode),
58
+ s(self.quantityContained),
59
+ f"({s(self.dailyValueIntakePercent)})",
60
+ ]
61
+ )
62
+
63
+
64
+ class NutrientTable(BaseModel):
65
+ nutrientBasisQuantityValue: Optional[str]
66
+ nutrientBasisQuantityMeasurementUnitCode: Optional[str]
67
+ preperationStateCode: Optional[str]
68
+ values: list[NutrientTableElement]
69
+
70
+ def __str__(self):
71
+ top = "Nutrients per " + " ".join(
72
+ [
73
+ s(self.nutrientBasisQuantityValue),
74
+ s(self.nutrientBasisQuantityMeasurementUnitCode),
75
+ f"({s(self.preperationStateCode)})",
76
+ ]
77
+ )
78
+ vals = "\n\t".join([str(v) for v in self.values])
79
+ return f"{top}\n\t{vals}"
80
+
81
+
82
+ class Attribute(BaseModel):
83
+ coordinates: str
84
+ entity: str
85
+ probability: float
86
+ value: Union[str, list[NutrientTable]]
87
+ model: str
88
+
89
+
90
+ class AttributeCommunicationChannel(BaseModel):
91
+ coordinates: str
92
+ probability: float
93
+ model: str
94
+ entity: str
95
+ communicationChannelCode: str
96
+ communicationValue: str
97
+
98
+
99
+ class AttributeAllergen(BaseModel):
100
+ coordinates: str
101
+ probability: float
102
+ model: str
103
+ entity: str
104
+ allergenTypeCode: str
105
+ levelOfContainmentCode: str
106
+
107
+
108
+ class NetContentAttribute(BaseModel):
109
+ coordinates: str
110
+ probability: float
111
+ model: str
112
+ entity: str
113
+ measurementUnitCode: str
114
+ value: str
115
+
116
+
117
+ class AllergensOut(BaseModel):
118
+ entity: str
119
+ values: list[AttributeAllergen]
120
+ model: str
121
+
122
+
123
+ class CommunicationChannelsOut(BaseModel):
124
+ entity: str
125
+ values: list[AttributeCommunicationChannel]
126
+ model: str
127
+
128
+
129
+ class PipelineInput(BaseModel):
130
+ image_key: str
131
+
132
+
133
+ class PipelineOutput(BaseModel):
134
+ attributes: list[
135
+ Union[Attribute, CommunicationChannelsOut, AllergensOut, NetContentAttribute]
136
+ ]
137
+ job_id: str = Field(alias="job-id")
138
+ text: str
139
+
140
+ class Config:
141
+ allow_population_by_field_name = True
142
+
143
+
144
+ class TextWithLanguage(BaseModel):
145
+ text: str
146
+ lang_code: str
147
+
148
+
149
+ class OCRTextOut(BaseModel):
150
+ blocks: list[str]
151
+ full_text: str
152
+ sentences: list[TextWithLanguage]
153
+
154
+
155
+ class OCRTableOut(BaseModel):
156
+ tables: list[list[list[str]]]
157
+
158
+
159
+ class OCROut(BaseModel):
160
+ result: Union[OCRTextOut, OCRTableOut]
161
+ job_id: str
162
+
163
+
164
+ class OCROutList(BaseModel):
165
+ __root__: list[OCROut]
166
+
167
+ def __iter__(self):
168
+ return iter(self.__root__)
169
+
170
+ def __getitem__(self, item):
171
+ return self.__root__[item]
172
+
173
+
174
+ class OCRWrapperOut(BaseModel):
175
+ blocks: list[str]
176
+ full_text: str
177
+ job_id: str
178
+ sentences: list[TextWithLanguage]
179
+ tables: list[list[list[str]]]
180
+
181
+
182
+ class ClassifiedText(BaseModel):
183
+ text: str
184
+ attribute: str
185
+ confidence: float
186
+
187
+
188
+ class CommunicationChannels(BaseModel):
189
+ confidence: float
190
+ attribute: str
191
+ communicationChannelCode: str
192
+ communicationValue: str
193
+ text: Optional[str] = ""
194
+
195
+
196
+ class Allergen(BaseModel):
197
+ confidence: float
198
+ attribute: str
199
+ allergenTypeCode: str
200
+ levelOfContainmentCode: str
201
+ text: Optional[str] = ""
202
+
203
+
204
+ class NetContent(BaseModel):
205
+ confidence: float
206
+ attribute: str
207
+ measurementUnitCode: str
208
+ value: str
209
+ text: Optional[str] = ""
210
+
211
+
212
+ class ModelOut(BaseModel):
213
+ blocks: list[Union[NetContent, Allergen, CommunicationChannels, ClassifiedText]]
214
+ tables: Optional[list[NutrientTable]]
215
+ job_id: str
216
+ model: str
217
+ full_text: str
218
+
219
+ def toJSON(self):
220
+ return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True)
221
+
222
+
223
+ class ModelOutList(BaseModel):
224
+ __root__: list[ModelOut]
225
+
226
+ def __iter__(self):
227
+ return iter(self.__root__)
228
+
229
+ def __getitem__(self, item):
230
+ return self.__root__[item]
231
+
232
+
233
+ class TrainModelOut(BaseModel):
234
+ # To be defined later when we have a list of accepted formats
235
+ model: Optional[Any] = None
236
+ artifacts: Optional[Any] = None
base/ocr.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom types for dealing with the Google Vision API JSON output.
3
+ """
4
+ from enum import IntEnum
5
+ from typing import Any, Optional
6
+
7
+ from pydantic import BaseModel
8
+
9
+
10
+ class BreakType(IntEnum):
11
+ UNKNOWN = 0
12
+ SPACE = 1
13
+ SURE_SPACE = 2
14
+ EOL_SURE_SPACE = 3
15
+ LINE_BREAK = 4
16
+ HYPHEN = 5
17
+
18
+
19
+ class BlockType(IntEnum):
20
+ UNKNOWN = 0
21
+ TEXT = 1
22
+ TABLE = 2
23
+ PICTURE = 3
24
+ RULER = 4
25
+ BARCODE = 5
26
+
27
+
28
+ class DetectedBreak(BaseModel):
29
+ type: BreakType
30
+ is_prefix: Optional[bool]
31
+
32
+
33
+ class DetectedLanguage(BaseModel):
34
+ languageCode: str
35
+ confidence: float
36
+
37
+
38
+ class TextProperty(BaseModel):
39
+ detectedLanguages: list[DetectedLanguage]
40
+ detectedBreak: Optional[DetectedBreak]
41
+
42
+
43
+ class Symbol(BaseModel):
44
+ property: Optional[TextProperty]
45
+ boundingBox: Any
46
+ text: str
47
+ confidence: float
48
+
49
+
50
+ class Word(BaseModel):
51
+ property: Optional[TextProperty]
52
+ boundingBox: Any
53
+ symbols: list[Symbol]
54
+ confidence: float
55
+
56
+
57
+ class Paragraph(BaseModel):
58
+ property: Optional[TextProperty]
59
+ boundingBox: Any
60
+ words: list[Word]
61
+ confidence: float
62
+
63
+
64
+ class Block(BaseModel):
65
+ property: Optional[TextProperty]
66
+ boundingBox: Any
67
+ paragraphs: list[Paragraph]
68
+ blockType: BlockType
69
+ confidence: float
70
+
71
+
72
+ class Page(BaseModel):
73
+ property: Optional[TextProperty]
74
+ width: int
75
+ height: int
76
+ blocks: list[Block]
77
+ confidence: float
78
+
79
+
80
+ class TextAnnotation(BaseModel):
81
+ pages: list[Page]
82
+ text: str
83
+
84
+
85
+ class Output(BaseModel):
86
+ fullTextAnnotation: Optional[TextAnnotation] = None
data_models.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic models used throughout the codebase.
3
+
4
+ In particular, these are the types that are used as input and output of each step of the pipeline.
5
+ """
6
+ import json
7
+ from typing import Any, Optional, Union
8
+
9
+ from pydantic import BaseModel, Field
10
+
11
+
12
+ class RedirectInput(BaseModel):
13
+ pipeline_arn: str
14
+ job_id: str
15
+
16
+
17
+ class NutrientTableQuantity(BaseModel):
18
+ measurementUnitCode: str
19
+ value: str
20
+ precisionCode: str
21
+
22
+ def __str__(self):
23
+ return f"{self.precisionCode} {self.value} {self.measurementUnitCode}"
24
+
25
+
26
+ class NutrientTableDailyValueIntake(BaseModel):
27
+ value: str
28
+ precisionCode: str
29
+
30
+ def __str__(self):
31
+ return f"{self.precisionCode} {self.value}%"
32
+
33
+
34
+ def s(
35
+ text: Optional[Union[NutrientTableQuantity, NutrientTableDailyValueIntake, str]]
36
+ ) -> str:
37
+ """
38
+ Returns None as "?", otherwise argument as string.
39
+ """
40
+ if text:
41
+ return str(text)
42
+ else:
43
+ return "?"
44
+
45
+
46
+ class NutrientTableElement(BaseModel):
47
+ coordinates: str
48
+ probability: float
49
+ nutrientTypeCode: Optional[str]
50
+ quantityContained: NutrientTableQuantity
51
+ dailyValueIntakePercent: Optional[NutrientTableDailyValueIntake]
52
+ precisionCode: str
53
+
54
+ def __str__(self):
55
+ return " ".join(
56
+ [
57
+ s(self.nutrientTypeCode),
58
+ s(self.quantityContained),
59
+ f"({s(self.dailyValueIntakePercent)})",
60
+ ]
61
+ )
62
+
63
+
64
+ class NutrientTable(BaseModel):
65
+ nutrientBasisQuantityValue: Optional[str]
66
+ nutrientBasisQuantityMeasurementUnitCode: Optional[str]
67
+ preperationStateCode: Optional[str]
68
+ values: list[NutrientTableElement]
69
+
70
+ def __str__(self):
71
+ top = "Nutrients per " + " ".join(
72
+ [
73
+ s(self.nutrientBasisQuantityValue),
74
+ s(self.nutrientBasisQuantityMeasurementUnitCode),
75
+ f"({s(self.preperationStateCode)})",
76
+ ]
77
+ )
78
+ vals = "\n\t".join([str(v) for v in self.values])
79
+ return f"{top}\n\t{vals}"
80
+
81
+
82
+ class Attribute(BaseModel):
83
+ coordinates: str
84
+ entity: str
85
+ probability: float
86
+ value: Union[str, list[NutrientTable]]
87
+ model: str
88
+
89
+
90
+ class AttributeCommunicationChannel(BaseModel):
91
+ coordinates: str
92
+ probability: float
93
+ model: str
94
+ entity: str
95
+ communicationChannelCode: str
96
+ communicationValue: str
97
+
98
+
99
+ class AttributeAllergen(BaseModel):
100
+ coordinates: str
101
+ probability: float
102
+ model: str
103
+ entity: str
104
+ allergenTypeCode: str
105
+ levelOfContainmentCode: str
106
+
107
+
108
+ class NetContentAttribute(BaseModel):
109
+ coordinates: str
110
+ probability: float
111
+ model: str
112
+ entity: str
113
+ measurementUnitCode: str
114
+ value: str
115
+
116
+
117
+ class AllergensOut(BaseModel):
118
+ entity: str
119
+ values: list[AttributeAllergen]
120
+ model: str
121
+
122
+
123
+ class CommunicationChannelsOut(BaseModel):
124
+ entity: str
125
+ values: list[AttributeCommunicationChannel]
126
+ model: str
127
+
128
+
129
+ class PipelineInput(BaseModel):
130
+ image_key: str
131
+
132
+
133
+ class PipelineOutput(BaseModel):
134
+ attributes: list[
135
+ Union[Attribute, CommunicationChannelsOut, AllergensOut, NetContentAttribute]
136
+ ]
137
+ job_id: str = Field(alias="job-id")
138
+ text: str
139
+
140
+ class Config:
141
+ allow_population_by_field_name = True
142
+
143
+
144
+ class TextWithLanguage(BaseModel):
145
+ text: str
146
+ lang_code: str
147
+
148
+
149
+ class OCRTextOut(BaseModel):
150
+ blocks: list[str]
151
+ full_text: str
152
+ sentences: list[TextWithLanguage]
153
+
154
+
155
+ class OCRTableOut(BaseModel):
156
+ tables: list[list[list[str]]]
157
+
158
+
159
+ class OCROut(BaseModel):
160
+ result: Union[OCRTextOut, OCRTableOut]
161
+ job_id: str
162
+
163
+
164
+ class OCROutList(BaseModel):
165
+ __root__: list[OCROut]
166
+
167
+ def __iter__(self):
168
+ return iter(self.__root__)
169
+
170
+ def __getitem__(self, item):
171
+ return self.__root__[item]
172
+
173
+
174
+ class OCRWrapperOut(BaseModel):
175
+ blocks: list[str]
176
+ full_text: str
177
+ job_id: str
178
+ sentences: list[TextWithLanguage]
179
+ tables: list[list[list[str]]]
180
+
181
+
182
+ class ClassifiedText(BaseModel):
183
+ text: str
184
+ attribute: str
185
+ confidence: float
186
+
187
+
188
+ class CommunicationChannels(BaseModel):
189
+ confidence: float
190
+ attribute: str
191
+ communicationChannelCode: str
192
+ communicationValue: str
193
+ text: Optional[str] = ""
194
+
195
+
196
+ class Allergen(BaseModel):
197
+ confidence: float
198
+ attribute: str
199
+ allergenTypeCode: str
200
+ levelOfContainmentCode: str
201
+ text: Optional[str] = ""
202
+
203
+
204
+ class NetContent(BaseModel):
205
+ confidence: float
206
+ attribute: str
207
+ measurementUnitCode: str
208
+ value: str
209
+ text: Optional[str] = ""
210
+
211
+
212
+ class ModelOut(BaseModel):
213
+ blocks: list[Union[NetContent, Allergen, CommunicationChannels, ClassifiedText]]
214
+ tables: Optional[list[NutrientTable]]
215
+ job_id: str
216
+ model: str
217
+ full_text: str
218
+
219
+ def toJSON(self):
220
+ return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True)
221
+
222
+
223
+ class ModelOutList(BaseModel):
224
+ __root__: list[ModelOut]
225
+
226
+ def __iter__(self):
227
+ return iter(self.__root__)
228
+
229
+ def __getitem__(self, item):
230
+ return self.__root__[item]
231
+
232
+
233
+ class TrainModelOut(BaseModel):
234
+ # To be defined later when we have a list of accepted formats
235
+ model: Optional[Any] = None
236
+ artifacts: Optional[Any] = None
train_classifiers.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
utils.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import json
3
+ from google.cloud import vision
4
+ from base import OCROut, OCRTextOut, TextWithLanguage, ocr
5
+ from nltk.tokenize.punkt import PunktParameters, PunktSentenceTokenizer
6
+ from langdetect import detect
7
+
8
+ from langdetect.lang_detect_exception import LangDetectException
9
+ from typing import Any, Optional, Union
10
+ gcp_credentials: str = '{"type": "service_account", "project_id": "pivotal-pattern-355407", "private_key_id": "92058c1fb443ce736215886b81d83a20b8dce873", "private_key": "-----BEGIN PRIVATE KEY-----\\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC4TVUGgfem6xzd\\nDKtas6PI35kji0/TeVL0KO8kVmVJkGjwNwlBV3jxhsDGHN6j7d4INsC8Ahlm8Yyk\\nwGfoPmxbTikXIzlT7Ze7A+K7vHMNz57Hqu9aMx1mZ2AsYkOmrXuiCcgA12xWE0ut\\nQ7YCkPIcSNRpx6SfmWX5M30wGD/jCY0x+vUb5hqyF81TfX6bi3r/1TNXB0Rbuil0\\nP0hb7G7hiIVBJvYgx1ev6u67j/20kO2hH8o/GRvEdm5PK3DboavBtBrpeFHvnbGL\\nX23+EnmvIoUisM9FMlk9efJwy27/9bDORK26q0LwVQeZ8gxeEwKkaM67iRObkQj5\\ni+l/iF/TAgMBAAECggEAAcAHk6qTncvTwqxGdZ5easo8xl/3Smx3g4az+/taC4sf\\nLqYceg4I7E+O8GgBG+YRWrCdFJah2R69atRs6CwCPr0M74GlH2KSiJ3D1GvjXSPg\\nlR15bQ4uiZ9nrrmMJ5QJfYNsQ8hyk0ABqEvhTuruHE37aQ8Smsq2c4RapvnupArv\\nVKZRi7NMCGGzjwtNisREKdpFft3+aUQ1WnVzmi9aDP4cSYUmfkGz5s2g44vqisf1\\nZeJtxxU1kmXDBuq/n6bBhmeBFEqXqNMH4D19Fr/Px8gfwVsa4t6CFJ9yeRHHM//r\\nauCL+tVsW8jja/DiVPgVAG813HVJfG9t8E+P24lU2QKBgQD/cYEJ4OXG/CZXyQdW\\nd6bcecNVDqdm30JUDev6enk3JviZ6Ah2IJW56kAC66YW80T5cn47t43Cb3yRr+aK\\nMC1dkN2r7F8BCiyZQ52vqGUlq0OFbei2MEUSCvZ9tie8n15z6/TW1GDS5yMUuIU9\\nBVirsuM0tXm2Wj/PldFxuJFB5QKBgQC4tCSNBboYVWJVAuiL7wtQuotZW6Tbu9YN\\nnjZ79v8eW3kK5nH/5uIo4FMnqTB7GsQiC1NmsRi5ELDgl/UDTNLYM/q1RwDt3Msq\\n9D0m267BdzoZSXm0LFlZGx+YKtYLUkesuBgXU2iGTPeWq4yCWLHZrdXU2L/+5RJF\\nJsdsxGtfVwKBgQDbmaOab3p6X4lDDLK/Niv//LndZLSrbqHiCvnkoueUb29nOKAV\\nlZvCcczH9fgaYNbaMACvq/Q1xIxt3q+t+w71EjZPt+xQw4Nl20gzm+wgvyRUEBHw\\nlny6rywLFM8CjGux8pbRHVGD3ms9RAhfVjpNmYNUa/MPvnR/LEgboNVXtQKBgGAD\\npQx58ac8m4U9oc08UlGA53doIGbpWfhySjheyXfqbR2xdYllaCN/mTqAxOb5nwNj\\nh9NWNFffyVK3K3KvuNsTCjy50E3V/PczR6avhESlydnjTBTRCRE8/EuJ2QLBnui2\\nOi/F+Av4bqwwOTZ5DNrvdrzSf9vtEHZKFlkAMiC5AoGAf5awqCDCmV33ztE3mV9E\\nrJiN+fJTAnspbG3Dr6NooU2AVKIQuGuLm6bideM8qjDi7CDlVTOVSlkimD/K29wx\\nMxqNoghUGkP+uOcdWO4zshG+A3Z5qysXKUxc6dphYjzsZS//v/GLYNymygo0+x9O\\nnnytiKiF8KLRDjYiD2OK+/c=\\n-----END PRIVATE KEY-----\\n", "client_email": "certifai-ocr@pivotal-pattern-355407.iam.gserviceaccount.com", "client_id": "100756892217388297600", "auth_uri": "https://accounts.google.com/o/oauth2/auth", "token_uri": "https://oauth2.googleapis.com/token", "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/certifai-ocr%40pivotal-pattern-355407.iam.gserviceaccount.com"}'
11
+
12
+ creds = json.loads(gcp_credentials)
13
+ google_client = vision.ImageAnnotatorClient.from_service_account_info(creds)
14
+
15
+
16
+ def image_to_byte_array(image) -> bytes:
17
+ imgByteArr = io.BytesIO()
18
+ image.save(imgByteArr, format="png")
19
+ image_buffer: bytes = imgByteArr.getvalue()
20
+ return image_buffer
21
+
22
+ def run_image_ocr(image_bytes: bytes, google_client) -> ocr.Output:
23
+ """OCR for a PNG image."""
24
+ # Create a client with the secrets in the local key file
25
+
26
+ # Send image to API
27
+ google_response = google_client.text_detection(
28
+ image=vision.Image(content=image_bytes)
29
+ )
30
+
31
+ # Return response
32
+ google_response_json = json.loads(
33
+ vision.AnnotateImageResponse.to_json(google_response)
34
+ )
35
+ return ocr.Output(**google_response_json), google_response_json
36
+
37
+ def get_block_language(block: ocr.Block) -> str:
38
+ """Returns most confident language in block, or "unk" if there is no language"""
39
+ if not block.property:
40
+ return "unk"
41
+ try:
42
+ prop = max(block.property.detectedLanguages, key=lambda x: x.confidence)
43
+ return prop.languageCode
44
+ except Exception:
45
+ return "unk"
46
+
47
+ def block2text(block: ocr.Block) -> TextWithLanguage:
48
+ """
49
+ Extract the text from a block.
50
+ """
51
+ lang_code = get_block_language(block)
52
+ text = ""
53
+ for paragraph in block.paragraphs:
54
+ for word in paragraph.words:
55
+ for symbol in word.symbols:
56
+ text += symbol.text
57
+ if symbol.property:
58
+ if symbol.property.detectedBreak:
59
+ text += " "
60
+
61
+ return TextWithLanguage(text=text.strip(), lang_code=lang_code)
62
+
63
+ # Takes a list of blocks, and parses the sentences in each block
64
+ def get_sentences(blocks: list[TextWithLanguage]) -> list[TextWithLanguage]:
65
+ """
66
+ Split the sentences of the blocks and return a list of sentences.
67
+ """
68
+ punkt_param = PunktParameters()
69
+ punkt_param.abbrev_types = {
70
+ "dr",
71
+ "vs",
72
+ "mr",
73
+ "mrs",
74
+ "prof",
75
+ "inc",
76
+ "vit",
77
+ "o.a.",
78
+ "o.b.v.",
79
+ "s.p.",
80
+ "m.u.v.",
81
+ "i.v.m.",
82
+ "a.k.a.",
83
+ "e.g.",
84
+ "m.b.v.",
85
+ "max.",
86
+ "min.",
87
+ }
88
+ sentence_splitter = PunktSentenceTokenizer(punkt_param)
89
+
90
+ sentences = []
91
+ for block in blocks:
92
+ for sentence in sentence_splitter.tokenize(block.text):
93
+ try:
94
+ lang_code = detect(str(sentence).lower())
95
+ # Because of lack of context in a sentence
96
+ # afrikaans is often recognized
97
+ if lang_code == "af":
98
+ lang_code = "nl"
99
+ except LangDetectException:
100
+ lang_code = "unk"
101
+ sentences.append(TextWithLanguage(text=str(sentence), lang_code=lang_code))
102
+
103
+ return sentences
104
+
105
+ def run_ocr(image_bytes: bytes) -> Union[OCROut, Any]:
106
+ # API response is of the type AnnotateImageResponse, see
107
+ # https://cloud.google.com/vision/docs/reference/rest/v1/AnnotateImageResponse
108
+ # for more details.
109
+
110
+ ocr_image_annotation, response_json = run_image_ocr(image_bytes, google_client)
111
+
112
+ # We assume we will only process pictures of one page,
113
+ # and no documents of more than one page. Hence, we
114
+ # take the first page here.
115
+
116
+ # Check if the fullTextAnnotations are filled
117
+ if ocr_image_annotation.fullTextAnnotation:
118
+ ocr_blocks = ocr_image_annotation.fullTextAnnotation.pages[0].blocks
119
+ text = ocr_image_annotation.fullTextAnnotation.text
120
+
121
+ blocks = [block2text(block) for block in ocr_blocks]
122
+ block_texts = [block.text for block in blocks]
123
+ sentences = get_sentences(blocks)
124
+ else:
125
+ block_texts = [""]
126
+ text = ""
127
+ sentences = [TextWithLanguage(text="", lang_code="")]
128
+ ocr_text_out = OCRTextOut(blocks=block_texts, full_text=text, sentences=sentences)
129
+ return ocr_text_out, response_json
woc-logo-black.1a4c4e90.svg ADDED