Spaces:
Build error
Build error
Added first scan logic
Browse files- app.py +35 -4
- base/__init__.py +30 -0
- base/__init__.py.bak +44 -0
- base/data_models.py +236 -0
- base/ocr.py +86 -0
- data_models.py +236 -0
- train_classifiers.ipynb +0 -0
- utils.py +129 -0
- woc-logo-black.1a4c4e90.svg +38 -0
app.py
CHANGED
|
@@ -1,7 +1,38 @@
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
def
|
| 4 |
-
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import json
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from utils import run_ocr, image_to_byte_array
|
| 5 |
|
| 6 |
+
def extract_all(image):
|
| 7 |
+
image = Image.fromarray(image)
|
| 8 |
|
| 9 |
+
image_byte_array = image_to_byte_array(image)
|
| 10 |
+
ocr_text_out, response_json = run_ocr(image_byte_array)
|
| 11 |
+
ocr_full_text = str(ocr_text_out.full_text).replace('\n', ' ')
|
| 12 |
+
|
| 13 |
+
# Extract attributes
|
| 14 |
+
#output_dictionary, product_description = exctract_fields(ocr_full_text)
|
| 15 |
+
output_dictionary = {
|
| 16 |
+
"GPC": 10000045,
|
| 17 |
+
"BRAND": "Kitkat",
|
| 18 |
+
"FUNCTIONAL NAME": "Chocolate Bar",
|
| 19 |
+
"Weight": "41.5",
|
| 20 |
+
"Unit": "Gr",
|
| 21 |
+
"Contact Information": {"Website": "www.kitkat.com", "Adress": " Nestlé Deutschland AG, 60523 \n Frankfurt am Main, Germany"},
|
| 22 |
+
"Allergen_Statement_NL": "Bevat: MELK, TARWE",
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
output_dict_json = str(json.dumps(output_dictionary,indent = 2))
|
| 26 |
+
return str(output_dict_json), ocr_full_text
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
output_df = [gr.Dataframe(label="Data")]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
attributes_tbox = gr.Textbox(label='Attributes')
|
| 33 |
+
#product_description_tbox = gr.Textbox(label='Product Description')
|
| 34 |
+
ocr_output_tbox = gr.Textbox(label='OCR Output')
|
| 35 |
+
gr.Button.style("{color: blue}")
|
| 36 |
+
gr.Interface(fn=extract_all, inputs="image", outputs=[attributes_tbox, ocr_output_tbox],
|
| 37 |
+
title= 'World of Content', description="GS1 Global Extractor",
|
| 38 |
+
css="body {background-color: #F5F7FA} .gr-button.gr-button-primary {background-color: #0080fa; color:white; --tw-gradient-from:0} .gr-button.gr-button-secondary {background-color: #172533; color:white; --tw-gradient-from:0} h1 {background-image: url('file=woc-logo-black.1a4c4e90.svg'); background-size:contain; background-repeat:no-repeat; background-position:center; text-indent:-999999999px} .output-markdown p{text-align:center; font-size:24px}").launch()
|
base/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from data_models import ( # noqa 401, 403
|
| 2 |
+
Allergen,
|
| 3 |
+
AllergensOut,
|
| 4 |
+
Attribute,
|
| 5 |
+
AttributeAllergen,
|
| 6 |
+
AttributeCommunicationChannel,
|
| 7 |
+
ClassifiedText,
|
| 8 |
+
CommunicationChannels,
|
| 9 |
+
CommunicationChannelsOut,
|
| 10 |
+
ModelOut,
|
| 11 |
+
ModelOutList,
|
| 12 |
+
NetContent,
|
| 13 |
+
NetContentAttribute,
|
| 14 |
+
NutrientTable,
|
| 15 |
+
NutrientTableDailyValueIntake,
|
| 16 |
+
NutrientTableElement,
|
| 17 |
+
NutrientTableQuantity,
|
| 18 |
+
OCROut,
|
| 19 |
+
OCROutList,
|
| 20 |
+
OCRTableOut,
|
| 21 |
+
OCRTextOut,
|
| 22 |
+
OCRWrapperOut,
|
| 23 |
+
PipelineInput,
|
| 24 |
+
PipelineOutput,
|
| 25 |
+
RedirectInput,
|
| 26 |
+
TextWithLanguage,
|
| 27 |
+
TrainModelOut,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
from base.ocr import * # noqa 401, 403
|
base/__init__.py.bak
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# isort: skip_file
|
| 2 |
+
from kedro.extras.datasets.json import JSONDataSet # noqa 401
|
| 3 |
+
from kedro.extras.datasets.pandas import CSVDataSet # noqa 401
|
| 4 |
+
from kedro.extras.datasets.pandas import ParquetDataSet # noqa 401
|
| 5 |
+
from kedro.extras.datasets.pandas import ExcelDataSet as excel_dataset # noqa 401
|
| 6 |
+
from kedro.extras.datasets.pandas import JSONDataSet as pandas_json # noqa 401
|
| 7 |
+
from kedro.extras.datasets.pandas.parquet_dataset import ParquetDataSet # noqa 401
|
| 8 |
+
from kedro.extras.datasets.pillow import ImageDataSet # noqa 401
|
| 9 |
+
from kedro.extras.datasets.text import TextDataSet # noqa 401
|
| 10 |
+
from kedro.io.core import AbstractDataSet as DataSet # noqa 401
|
| 11 |
+
|
| 12 |
+
from certifai.base.data_models import ( # noqa 401, 403
|
| 13 |
+
Allergen,
|
| 14 |
+
AllergensOut,
|
| 15 |
+
Attribute,
|
| 16 |
+
AttributeAllergen,
|
| 17 |
+
AttributeCommunicationChannel,
|
| 18 |
+
ClassifiedText,
|
| 19 |
+
CommunicationChannels,
|
| 20 |
+
CommunicationChannelsOut,
|
| 21 |
+
ModelOut,
|
| 22 |
+
ModelOutList,
|
| 23 |
+
NetContent,
|
| 24 |
+
NetContentAttribute,
|
| 25 |
+
NutrientTable,
|
| 26 |
+
NutrientTableDailyValueIntake,
|
| 27 |
+
NutrientTableElement,
|
| 28 |
+
NutrientTableQuantity,
|
| 29 |
+
OCROut,
|
| 30 |
+
OCROutList,
|
| 31 |
+
OCRTableOut,
|
| 32 |
+
OCRTextOut,
|
| 33 |
+
OCRWrapperOut,
|
| 34 |
+
PipelineInput,
|
| 35 |
+
PipelineOutput,
|
| 36 |
+
RedirectInput,
|
| 37 |
+
TextWithLanguage,
|
| 38 |
+
TrainModelOut,
|
| 39 |
+
)
|
| 40 |
+
from certifai.base.abstract import BaseClassifier # noqa 401, 403
|
| 41 |
+
from certifai.base.s3_helper_functions import * # noqa 401, 403
|
| 42 |
+
from certifai.base.custom_datasets import * # noqa 401, 403
|
| 43 |
+
from certifai.base.ocr import * # noqa 401, 403
|
| 44 |
+
from kedro.extras.datasets.yaml import YAMLDataSet # noqa 401
|
base/data_models.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic models used throughout the codebase.
|
| 3 |
+
|
| 4 |
+
In particular, these are the types that are used as input and output of each step of the pipeline.
|
| 5 |
+
"""
|
| 6 |
+
import json
|
| 7 |
+
from typing import Any, Optional, Union
|
| 8 |
+
|
| 9 |
+
from pydantic import BaseModel, Field
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class RedirectInput(BaseModel):
|
| 13 |
+
pipeline_arn: str
|
| 14 |
+
job_id: str
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class NutrientTableQuantity(BaseModel):
|
| 18 |
+
measurementUnitCode: str
|
| 19 |
+
value: str
|
| 20 |
+
precisionCode: str
|
| 21 |
+
|
| 22 |
+
def __str__(self):
|
| 23 |
+
return f"{self.precisionCode} {self.value} {self.measurementUnitCode}"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class NutrientTableDailyValueIntake(BaseModel):
|
| 27 |
+
value: str
|
| 28 |
+
precisionCode: str
|
| 29 |
+
|
| 30 |
+
def __str__(self):
|
| 31 |
+
return f"{self.precisionCode} {self.value}%"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def s(
|
| 35 |
+
text: Optional[Union[NutrientTableQuantity, NutrientTableDailyValueIntake, str]]
|
| 36 |
+
) -> str:
|
| 37 |
+
"""
|
| 38 |
+
Returns None as "?", otherwise argument as string.
|
| 39 |
+
"""
|
| 40 |
+
if text:
|
| 41 |
+
return str(text)
|
| 42 |
+
else:
|
| 43 |
+
return "?"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class NutrientTableElement(BaseModel):
|
| 47 |
+
coordinates: str
|
| 48 |
+
probability: float
|
| 49 |
+
nutrientTypeCode: Optional[str]
|
| 50 |
+
quantityContained: NutrientTableQuantity
|
| 51 |
+
dailyValueIntakePercent: Optional[NutrientTableDailyValueIntake]
|
| 52 |
+
precisionCode: str
|
| 53 |
+
|
| 54 |
+
def __str__(self):
|
| 55 |
+
return " ".join(
|
| 56 |
+
[
|
| 57 |
+
s(self.nutrientTypeCode),
|
| 58 |
+
s(self.quantityContained),
|
| 59 |
+
f"({s(self.dailyValueIntakePercent)})",
|
| 60 |
+
]
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class NutrientTable(BaseModel):
|
| 65 |
+
nutrientBasisQuantityValue: Optional[str]
|
| 66 |
+
nutrientBasisQuantityMeasurementUnitCode: Optional[str]
|
| 67 |
+
preperationStateCode: Optional[str]
|
| 68 |
+
values: list[NutrientTableElement]
|
| 69 |
+
|
| 70 |
+
def __str__(self):
|
| 71 |
+
top = "Nutrients per " + " ".join(
|
| 72 |
+
[
|
| 73 |
+
s(self.nutrientBasisQuantityValue),
|
| 74 |
+
s(self.nutrientBasisQuantityMeasurementUnitCode),
|
| 75 |
+
f"({s(self.preperationStateCode)})",
|
| 76 |
+
]
|
| 77 |
+
)
|
| 78 |
+
vals = "\n\t".join([str(v) for v in self.values])
|
| 79 |
+
return f"{top}\n\t{vals}"
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class Attribute(BaseModel):
|
| 83 |
+
coordinates: str
|
| 84 |
+
entity: str
|
| 85 |
+
probability: float
|
| 86 |
+
value: Union[str, list[NutrientTable]]
|
| 87 |
+
model: str
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class AttributeCommunicationChannel(BaseModel):
|
| 91 |
+
coordinates: str
|
| 92 |
+
probability: float
|
| 93 |
+
model: str
|
| 94 |
+
entity: str
|
| 95 |
+
communicationChannelCode: str
|
| 96 |
+
communicationValue: str
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class AttributeAllergen(BaseModel):
|
| 100 |
+
coordinates: str
|
| 101 |
+
probability: float
|
| 102 |
+
model: str
|
| 103 |
+
entity: str
|
| 104 |
+
allergenTypeCode: str
|
| 105 |
+
levelOfContainmentCode: str
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class NetContentAttribute(BaseModel):
|
| 109 |
+
coordinates: str
|
| 110 |
+
probability: float
|
| 111 |
+
model: str
|
| 112 |
+
entity: str
|
| 113 |
+
measurementUnitCode: str
|
| 114 |
+
value: str
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class AllergensOut(BaseModel):
|
| 118 |
+
entity: str
|
| 119 |
+
values: list[AttributeAllergen]
|
| 120 |
+
model: str
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class CommunicationChannelsOut(BaseModel):
|
| 124 |
+
entity: str
|
| 125 |
+
values: list[AttributeCommunicationChannel]
|
| 126 |
+
model: str
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class PipelineInput(BaseModel):
|
| 130 |
+
image_key: str
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class PipelineOutput(BaseModel):
|
| 134 |
+
attributes: list[
|
| 135 |
+
Union[Attribute, CommunicationChannelsOut, AllergensOut, NetContentAttribute]
|
| 136 |
+
]
|
| 137 |
+
job_id: str = Field(alias="job-id")
|
| 138 |
+
text: str
|
| 139 |
+
|
| 140 |
+
class Config:
|
| 141 |
+
allow_population_by_field_name = True
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class TextWithLanguage(BaseModel):
|
| 145 |
+
text: str
|
| 146 |
+
lang_code: str
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class OCRTextOut(BaseModel):
|
| 150 |
+
blocks: list[str]
|
| 151 |
+
full_text: str
|
| 152 |
+
sentences: list[TextWithLanguage]
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class OCRTableOut(BaseModel):
|
| 156 |
+
tables: list[list[list[str]]]
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class OCROut(BaseModel):
|
| 160 |
+
result: Union[OCRTextOut, OCRTableOut]
|
| 161 |
+
job_id: str
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class OCROutList(BaseModel):
|
| 165 |
+
__root__: list[OCROut]
|
| 166 |
+
|
| 167 |
+
def __iter__(self):
|
| 168 |
+
return iter(self.__root__)
|
| 169 |
+
|
| 170 |
+
def __getitem__(self, item):
|
| 171 |
+
return self.__root__[item]
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class OCRWrapperOut(BaseModel):
|
| 175 |
+
blocks: list[str]
|
| 176 |
+
full_text: str
|
| 177 |
+
job_id: str
|
| 178 |
+
sentences: list[TextWithLanguage]
|
| 179 |
+
tables: list[list[list[str]]]
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class ClassifiedText(BaseModel):
|
| 183 |
+
text: str
|
| 184 |
+
attribute: str
|
| 185 |
+
confidence: float
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class CommunicationChannels(BaseModel):
|
| 189 |
+
confidence: float
|
| 190 |
+
attribute: str
|
| 191 |
+
communicationChannelCode: str
|
| 192 |
+
communicationValue: str
|
| 193 |
+
text: Optional[str] = ""
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class Allergen(BaseModel):
|
| 197 |
+
confidence: float
|
| 198 |
+
attribute: str
|
| 199 |
+
allergenTypeCode: str
|
| 200 |
+
levelOfContainmentCode: str
|
| 201 |
+
text: Optional[str] = ""
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class NetContent(BaseModel):
|
| 205 |
+
confidence: float
|
| 206 |
+
attribute: str
|
| 207 |
+
measurementUnitCode: str
|
| 208 |
+
value: str
|
| 209 |
+
text: Optional[str] = ""
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class ModelOut(BaseModel):
|
| 213 |
+
blocks: list[Union[NetContent, Allergen, CommunicationChannels, ClassifiedText]]
|
| 214 |
+
tables: Optional[list[NutrientTable]]
|
| 215 |
+
job_id: str
|
| 216 |
+
model: str
|
| 217 |
+
full_text: str
|
| 218 |
+
|
| 219 |
+
def toJSON(self):
|
| 220 |
+
return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class ModelOutList(BaseModel):
|
| 224 |
+
__root__: list[ModelOut]
|
| 225 |
+
|
| 226 |
+
def __iter__(self):
|
| 227 |
+
return iter(self.__root__)
|
| 228 |
+
|
| 229 |
+
def __getitem__(self, item):
|
| 230 |
+
return self.__root__[item]
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class TrainModelOut(BaseModel):
|
| 234 |
+
# To be defined later when we have a list of accepted formats
|
| 235 |
+
model: Optional[Any] = None
|
| 236 |
+
artifacts: Optional[Any] = None
|
base/ocr.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom types for dealing with the Google Vision API JSON output.
|
| 3 |
+
"""
|
| 4 |
+
from enum import IntEnum
|
| 5 |
+
from typing import Any, Optional
|
| 6 |
+
|
| 7 |
+
from pydantic import BaseModel
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BreakType(IntEnum):
|
| 11 |
+
UNKNOWN = 0
|
| 12 |
+
SPACE = 1
|
| 13 |
+
SURE_SPACE = 2
|
| 14 |
+
EOL_SURE_SPACE = 3
|
| 15 |
+
LINE_BREAK = 4
|
| 16 |
+
HYPHEN = 5
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class BlockType(IntEnum):
|
| 20 |
+
UNKNOWN = 0
|
| 21 |
+
TEXT = 1
|
| 22 |
+
TABLE = 2
|
| 23 |
+
PICTURE = 3
|
| 24 |
+
RULER = 4
|
| 25 |
+
BARCODE = 5
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DetectedBreak(BaseModel):
|
| 29 |
+
type: BreakType
|
| 30 |
+
is_prefix: Optional[bool]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class DetectedLanguage(BaseModel):
|
| 34 |
+
languageCode: str
|
| 35 |
+
confidence: float
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class TextProperty(BaseModel):
|
| 39 |
+
detectedLanguages: list[DetectedLanguage]
|
| 40 |
+
detectedBreak: Optional[DetectedBreak]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class Symbol(BaseModel):
|
| 44 |
+
property: Optional[TextProperty]
|
| 45 |
+
boundingBox: Any
|
| 46 |
+
text: str
|
| 47 |
+
confidence: float
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class Word(BaseModel):
|
| 51 |
+
property: Optional[TextProperty]
|
| 52 |
+
boundingBox: Any
|
| 53 |
+
symbols: list[Symbol]
|
| 54 |
+
confidence: float
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class Paragraph(BaseModel):
|
| 58 |
+
property: Optional[TextProperty]
|
| 59 |
+
boundingBox: Any
|
| 60 |
+
words: list[Word]
|
| 61 |
+
confidence: float
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class Block(BaseModel):
|
| 65 |
+
property: Optional[TextProperty]
|
| 66 |
+
boundingBox: Any
|
| 67 |
+
paragraphs: list[Paragraph]
|
| 68 |
+
blockType: BlockType
|
| 69 |
+
confidence: float
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class Page(BaseModel):
|
| 73 |
+
property: Optional[TextProperty]
|
| 74 |
+
width: int
|
| 75 |
+
height: int
|
| 76 |
+
blocks: list[Block]
|
| 77 |
+
confidence: float
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class TextAnnotation(BaseModel):
|
| 81 |
+
pages: list[Page]
|
| 82 |
+
text: str
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class Output(BaseModel):
|
| 86 |
+
fullTextAnnotation: Optional[TextAnnotation] = None
|
data_models.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic models used throughout the codebase.
|
| 3 |
+
|
| 4 |
+
In particular, these are the types that are used as input and output of each step of the pipeline.
|
| 5 |
+
"""
|
| 6 |
+
import json
|
| 7 |
+
from typing import Any, Optional, Union
|
| 8 |
+
|
| 9 |
+
from pydantic import BaseModel, Field
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class RedirectInput(BaseModel):
|
| 13 |
+
pipeline_arn: str
|
| 14 |
+
job_id: str
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class NutrientTableQuantity(BaseModel):
|
| 18 |
+
measurementUnitCode: str
|
| 19 |
+
value: str
|
| 20 |
+
precisionCode: str
|
| 21 |
+
|
| 22 |
+
def __str__(self):
|
| 23 |
+
return f"{self.precisionCode} {self.value} {self.measurementUnitCode}"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class NutrientTableDailyValueIntake(BaseModel):
|
| 27 |
+
value: str
|
| 28 |
+
precisionCode: str
|
| 29 |
+
|
| 30 |
+
def __str__(self):
|
| 31 |
+
return f"{self.precisionCode} {self.value}%"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def s(
|
| 35 |
+
text: Optional[Union[NutrientTableQuantity, NutrientTableDailyValueIntake, str]]
|
| 36 |
+
) -> str:
|
| 37 |
+
"""
|
| 38 |
+
Returns None as "?", otherwise argument as string.
|
| 39 |
+
"""
|
| 40 |
+
if text:
|
| 41 |
+
return str(text)
|
| 42 |
+
else:
|
| 43 |
+
return "?"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class NutrientTableElement(BaseModel):
|
| 47 |
+
coordinates: str
|
| 48 |
+
probability: float
|
| 49 |
+
nutrientTypeCode: Optional[str]
|
| 50 |
+
quantityContained: NutrientTableQuantity
|
| 51 |
+
dailyValueIntakePercent: Optional[NutrientTableDailyValueIntake]
|
| 52 |
+
precisionCode: str
|
| 53 |
+
|
| 54 |
+
def __str__(self):
|
| 55 |
+
return " ".join(
|
| 56 |
+
[
|
| 57 |
+
s(self.nutrientTypeCode),
|
| 58 |
+
s(self.quantityContained),
|
| 59 |
+
f"({s(self.dailyValueIntakePercent)})",
|
| 60 |
+
]
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class NutrientTable(BaseModel):
|
| 65 |
+
nutrientBasisQuantityValue: Optional[str]
|
| 66 |
+
nutrientBasisQuantityMeasurementUnitCode: Optional[str]
|
| 67 |
+
preperationStateCode: Optional[str]
|
| 68 |
+
values: list[NutrientTableElement]
|
| 69 |
+
|
| 70 |
+
def __str__(self):
|
| 71 |
+
top = "Nutrients per " + " ".join(
|
| 72 |
+
[
|
| 73 |
+
s(self.nutrientBasisQuantityValue),
|
| 74 |
+
s(self.nutrientBasisQuantityMeasurementUnitCode),
|
| 75 |
+
f"({s(self.preperationStateCode)})",
|
| 76 |
+
]
|
| 77 |
+
)
|
| 78 |
+
vals = "\n\t".join([str(v) for v in self.values])
|
| 79 |
+
return f"{top}\n\t{vals}"
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class Attribute(BaseModel):
|
| 83 |
+
coordinates: str
|
| 84 |
+
entity: str
|
| 85 |
+
probability: float
|
| 86 |
+
value: Union[str, list[NutrientTable]]
|
| 87 |
+
model: str
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class AttributeCommunicationChannel(BaseModel):
|
| 91 |
+
coordinates: str
|
| 92 |
+
probability: float
|
| 93 |
+
model: str
|
| 94 |
+
entity: str
|
| 95 |
+
communicationChannelCode: str
|
| 96 |
+
communicationValue: str
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class AttributeAllergen(BaseModel):
|
| 100 |
+
coordinates: str
|
| 101 |
+
probability: float
|
| 102 |
+
model: str
|
| 103 |
+
entity: str
|
| 104 |
+
allergenTypeCode: str
|
| 105 |
+
levelOfContainmentCode: str
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class NetContentAttribute(BaseModel):
|
| 109 |
+
coordinates: str
|
| 110 |
+
probability: float
|
| 111 |
+
model: str
|
| 112 |
+
entity: str
|
| 113 |
+
measurementUnitCode: str
|
| 114 |
+
value: str
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class AllergensOut(BaseModel):
|
| 118 |
+
entity: str
|
| 119 |
+
values: list[AttributeAllergen]
|
| 120 |
+
model: str
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class CommunicationChannelsOut(BaseModel):
|
| 124 |
+
entity: str
|
| 125 |
+
values: list[AttributeCommunicationChannel]
|
| 126 |
+
model: str
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class PipelineInput(BaseModel):
|
| 130 |
+
image_key: str
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class PipelineOutput(BaseModel):
|
| 134 |
+
attributes: list[
|
| 135 |
+
Union[Attribute, CommunicationChannelsOut, AllergensOut, NetContentAttribute]
|
| 136 |
+
]
|
| 137 |
+
job_id: str = Field(alias="job-id")
|
| 138 |
+
text: str
|
| 139 |
+
|
| 140 |
+
class Config:
|
| 141 |
+
allow_population_by_field_name = True
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class TextWithLanguage(BaseModel):
|
| 145 |
+
text: str
|
| 146 |
+
lang_code: str
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class OCRTextOut(BaseModel):
|
| 150 |
+
blocks: list[str]
|
| 151 |
+
full_text: str
|
| 152 |
+
sentences: list[TextWithLanguage]
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class OCRTableOut(BaseModel):
|
| 156 |
+
tables: list[list[list[str]]]
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class OCROut(BaseModel):
|
| 160 |
+
result: Union[OCRTextOut, OCRTableOut]
|
| 161 |
+
job_id: str
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class OCROutList(BaseModel):
|
| 165 |
+
__root__: list[OCROut]
|
| 166 |
+
|
| 167 |
+
def __iter__(self):
|
| 168 |
+
return iter(self.__root__)
|
| 169 |
+
|
| 170 |
+
def __getitem__(self, item):
|
| 171 |
+
return self.__root__[item]
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class OCRWrapperOut(BaseModel):
|
| 175 |
+
blocks: list[str]
|
| 176 |
+
full_text: str
|
| 177 |
+
job_id: str
|
| 178 |
+
sentences: list[TextWithLanguage]
|
| 179 |
+
tables: list[list[list[str]]]
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class ClassifiedText(BaseModel):
|
| 183 |
+
text: str
|
| 184 |
+
attribute: str
|
| 185 |
+
confidence: float
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class CommunicationChannels(BaseModel):
|
| 189 |
+
confidence: float
|
| 190 |
+
attribute: str
|
| 191 |
+
communicationChannelCode: str
|
| 192 |
+
communicationValue: str
|
| 193 |
+
text: Optional[str] = ""
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class Allergen(BaseModel):
|
| 197 |
+
confidence: float
|
| 198 |
+
attribute: str
|
| 199 |
+
allergenTypeCode: str
|
| 200 |
+
levelOfContainmentCode: str
|
| 201 |
+
text: Optional[str] = ""
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class NetContent(BaseModel):
|
| 205 |
+
confidence: float
|
| 206 |
+
attribute: str
|
| 207 |
+
measurementUnitCode: str
|
| 208 |
+
value: str
|
| 209 |
+
text: Optional[str] = ""
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class ModelOut(BaseModel):
|
| 213 |
+
blocks: list[Union[NetContent, Allergen, CommunicationChannels, ClassifiedText]]
|
| 214 |
+
tables: Optional[list[NutrientTable]]
|
| 215 |
+
job_id: str
|
| 216 |
+
model: str
|
| 217 |
+
full_text: str
|
| 218 |
+
|
| 219 |
+
def toJSON(self):
|
| 220 |
+
return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class ModelOutList(BaseModel):
|
| 224 |
+
__root__: list[ModelOut]
|
| 225 |
+
|
| 226 |
+
def __iter__(self):
|
| 227 |
+
return iter(self.__root__)
|
| 228 |
+
|
| 229 |
+
def __getitem__(self, item):
|
| 230 |
+
return self.__root__[item]
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class TrainModelOut(BaseModel):
|
| 234 |
+
# To be defined later when we have a list of accepted formats
|
| 235 |
+
model: Optional[Any] = None
|
| 236 |
+
artifacts: Optional[Any] = None
|
train_classifiers.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
utils.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import json
|
| 3 |
+
from google.cloud import vision
|
| 4 |
+
from base import OCROut, OCRTextOut, TextWithLanguage, ocr
|
| 5 |
+
from nltk.tokenize.punkt import PunktParameters, PunktSentenceTokenizer
|
| 6 |
+
from langdetect import detect
|
| 7 |
+
|
| 8 |
+
from langdetect.lang_detect_exception import LangDetectException
|
| 9 |
+
from typing import Any, Optional, Union
|
| 10 |
+
gcp_credentials: str = '{"type": "service_account", "project_id": "pivotal-pattern-355407", "private_key_id": "92058c1fb443ce736215886b81d83a20b8dce873", "private_key": "-----BEGIN PRIVATE KEY-----\\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC4TVUGgfem6xzd\\nDKtas6PI35kji0/TeVL0KO8kVmVJkGjwNwlBV3jxhsDGHN6j7d4INsC8Ahlm8Yyk\\nwGfoPmxbTikXIzlT7Ze7A+K7vHMNz57Hqu9aMx1mZ2AsYkOmrXuiCcgA12xWE0ut\\nQ7YCkPIcSNRpx6SfmWX5M30wGD/jCY0x+vUb5hqyF81TfX6bi3r/1TNXB0Rbuil0\\nP0hb7G7hiIVBJvYgx1ev6u67j/20kO2hH8o/GRvEdm5PK3DboavBtBrpeFHvnbGL\\nX23+EnmvIoUisM9FMlk9efJwy27/9bDORK26q0LwVQeZ8gxeEwKkaM67iRObkQj5\\ni+l/iF/TAgMBAAECggEAAcAHk6qTncvTwqxGdZ5easo8xl/3Smx3g4az+/taC4sf\\nLqYceg4I7E+O8GgBG+YRWrCdFJah2R69atRs6CwCPr0M74GlH2KSiJ3D1GvjXSPg\\nlR15bQ4uiZ9nrrmMJ5QJfYNsQ8hyk0ABqEvhTuruHE37aQ8Smsq2c4RapvnupArv\\nVKZRi7NMCGGzjwtNisREKdpFft3+aUQ1WnVzmi9aDP4cSYUmfkGz5s2g44vqisf1\\nZeJtxxU1kmXDBuq/n6bBhmeBFEqXqNMH4D19Fr/Px8gfwVsa4t6CFJ9yeRHHM//r\\nauCL+tVsW8jja/DiVPgVAG813HVJfG9t8E+P24lU2QKBgQD/cYEJ4OXG/CZXyQdW\\nd6bcecNVDqdm30JUDev6enk3JviZ6Ah2IJW56kAC66YW80T5cn47t43Cb3yRr+aK\\nMC1dkN2r7F8BCiyZQ52vqGUlq0OFbei2MEUSCvZ9tie8n15z6/TW1GDS5yMUuIU9\\nBVirsuM0tXm2Wj/PldFxuJFB5QKBgQC4tCSNBboYVWJVAuiL7wtQuotZW6Tbu9YN\\nnjZ79v8eW3kK5nH/5uIo4FMnqTB7GsQiC1NmsRi5ELDgl/UDTNLYM/q1RwDt3Msq\\n9D0m267BdzoZSXm0LFlZGx+YKtYLUkesuBgXU2iGTPeWq4yCWLHZrdXU2L/+5RJF\\nJsdsxGtfVwKBgQDbmaOab3p6X4lDDLK/Niv//LndZLSrbqHiCvnkoueUb29nOKAV\\nlZvCcczH9fgaYNbaMACvq/Q1xIxt3q+t+w71EjZPt+xQw4Nl20gzm+wgvyRUEBHw\\nlny6rywLFM8CjGux8pbRHVGD3ms9RAhfVjpNmYNUa/MPvnR/LEgboNVXtQKBgGAD\\npQx58ac8m4U9oc08UlGA53doIGbpWfhySjheyXfqbR2xdYllaCN/mTqAxOb5nwNj\\nh9NWNFffyVK3K3KvuNsTCjy50E3V/PczR6avhESlydnjTBTRCRE8/EuJ2QLBnui2\\nOi/F+Av4bqwwOTZ5DNrvdrzSf9vtEHZKFlkAMiC5AoGAf5awqCDCmV33ztE3mV9E\\nrJiN+fJTAnspbG3Dr6NooU2AVKIQuGuLm6bideM8qjDi7CDlVTOVSlkimD/K29wx\\nMxqNoghUGkP+uOcdWO4zshG+A3Z5qysXKUxc6dphYjzsZS//v/GLYNymygo0+x9O\\nnnytiKiF8KLRDjYiD2OK+/c=\\n-----END PRIVATE KEY-----\\n", "client_email": "certifai-ocr@pivotal-pattern-355407.iam.gserviceaccount.com", "client_id": "100756892217388297600", "auth_uri": "https://accounts.google.com/o/oauth2/auth", "token_uri": "https://oauth2.googleapis.com/token", "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/certifai-ocr%40pivotal-pattern-355407.iam.gserviceaccount.com"}'
|
| 11 |
+
|
| 12 |
+
creds = json.loads(gcp_credentials)
|
| 13 |
+
google_client = vision.ImageAnnotatorClient.from_service_account_info(creds)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def image_to_byte_array(image) -> bytes:
|
| 17 |
+
imgByteArr = io.BytesIO()
|
| 18 |
+
image.save(imgByteArr, format="png")
|
| 19 |
+
image_buffer: bytes = imgByteArr.getvalue()
|
| 20 |
+
return image_buffer
|
| 21 |
+
|
| 22 |
+
def run_image_ocr(image_bytes: bytes, google_client) -> ocr.Output:
|
| 23 |
+
"""OCR for a PNG image."""
|
| 24 |
+
# Create a client with the secrets in the local key file
|
| 25 |
+
|
| 26 |
+
# Send image to API
|
| 27 |
+
google_response = google_client.text_detection(
|
| 28 |
+
image=vision.Image(content=image_bytes)
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# Return response
|
| 32 |
+
google_response_json = json.loads(
|
| 33 |
+
vision.AnnotateImageResponse.to_json(google_response)
|
| 34 |
+
)
|
| 35 |
+
return ocr.Output(**google_response_json), google_response_json
|
| 36 |
+
|
| 37 |
+
def get_block_language(block: ocr.Block) -> str:
|
| 38 |
+
"""Returns most confident language in block, or "unk" if there is no language"""
|
| 39 |
+
if not block.property:
|
| 40 |
+
return "unk"
|
| 41 |
+
try:
|
| 42 |
+
prop = max(block.property.detectedLanguages, key=lambda x: x.confidence)
|
| 43 |
+
return prop.languageCode
|
| 44 |
+
except Exception:
|
| 45 |
+
return "unk"
|
| 46 |
+
|
| 47 |
+
def block2text(block: ocr.Block) -> TextWithLanguage:
|
| 48 |
+
"""
|
| 49 |
+
Extract the text from a block.
|
| 50 |
+
"""
|
| 51 |
+
lang_code = get_block_language(block)
|
| 52 |
+
text = ""
|
| 53 |
+
for paragraph in block.paragraphs:
|
| 54 |
+
for word in paragraph.words:
|
| 55 |
+
for symbol in word.symbols:
|
| 56 |
+
text += symbol.text
|
| 57 |
+
if symbol.property:
|
| 58 |
+
if symbol.property.detectedBreak:
|
| 59 |
+
text += " "
|
| 60 |
+
|
| 61 |
+
return TextWithLanguage(text=text.strip(), lang_code=lang_code)
|
| 62 |
+
|
| 63 |
+
# Takes a list of blocks, and parses the sentences in each block
|
| 64 |
+
def get_sentences(blocks: list[TextWithLanguage]) -> list[TextWithLanguage]:
|
| 65 |
+
"""
|
| 66 |
+
Split the sentences of the blocks and return a list of sentences.
|
| 67 |
+
"""
|
| 68 |
+
punkt_param = PunktParameters()
|
| 69 |
+
punkt_param.abbrev_types = {
|
| 70 |
+
"dr",
|
| 71 |
+
"vs",
|
| 72 |
+
"mr",
|
| 73 |
+
"mrs",
|
| 74 |
+
"prof",
|
| 75 |
+
"inc",
|
| 76 |
+
"vit",
|
| 77 |
+
"o.a.",
|
| 78 |
+
"o.b.v.",
|
| 79 |
+
"s.p.",
|
| 80 |
+
"m.u.v.",
|
| 81 |
+
"i.v.m.",
|
| 82 |
+
"a.k.a.",
|
| 83 |
+
"e.g.",
|
| 84 |
+
"m.b.v.",
|
| 85 |
+
"max.",
|
| 86 |
+
"min.",
|
| 87 |
+
}
|
| 88 |
+
sentence_splitter = PunktSentenceTokenizer(punkt_param)
|
| 89 |
+
|
| 90 |
+
sentences = []
|
| 91 |
+
for block in blocks:
|
| 92 |
+
for sentence in sentence_splitter.tokenize(block.text):
|
| 93 |
+
try:
|
| 94 |
+
lang_code = detect(str(sentence).lower())
|
| 95 |
+
# Because of lack of context in a sentence
|
| 96 |
+
# afrikaans is often recognized
|
| 97 |
+
if lang_code == "af":
|
| 98 |
+
lang_code = "nl"
|
| 99 |
+
except LangDetectException:
|
| 100 |
+
lang_code = "unk"
|
| 101 |
+
sentences.append(TextWithLanguage(text=str(sentence), lang_code=lang_code))
|
| 102 |
+
|
| 103 |
+
return sentences
|
| 104 |
+
|
| 105 |
+
def run_ocr(image_bytes: bytes) -> Union[OCROut, Any]:
|
| 106 |
+
# API response is of the type AnnotateImageResponse, see
|
| 107 |
+
# https://cloud.google.com/vision/docs/reference/rest/v1/AnnotateImageResponse
|
| 108 |
+
# for more details.
|
| 109 |
+
|
| 110 |
+
ocr_image_annotation, response_json = run_image_ocr(image_bytes, google_client)
|
| 111 |
+
|
| 112 |
+
# We assume we will only process pictures of one page,
|
| 113 |
+
# and no documents of more than one page. Hence, we
|
| 114 |
+
# take the first page here.
|
| 115 |
+
|
| 116 |
+
# Check if the fullTextAnnotations are filled
|
| 117 |
+
if ocr_image_annotation.fullTextAnnotation:
|
| 118 |
+
ocr_blocks = ocr_image_annotation.fullTextAnnotation.pages[0].blocks
|
| 119 |
+
text = ocr_image_annotation.fullTextAnnotation.text
|
| 120 |
+
|
| 121 |
+
blocks = [block2text(block) for block in ocr_blocks]
|
| 122 |
+
block_texts = [block.text for block in blocks]
|
| 123 |
+
sentences = get_sentences(blocks)
|
| 124 |
+
else:
|
| 125 |
+
block_texts = [""]
|
| 126 |
+
text = ""
|
| 127 |
+
sentences = [TextWithLanguage(text="", lang_code="")]
|
| 128 |
+
ocr_text_out = OCRTextOut(blocks=block_texts, full_text=text, sentences=sentences)
|
| 129 |
+
return ocr_text_out, response_json
|
woc-logo-black.1a4c4e90.svg
ADDED
|
|