retailpxdemo / data_models.py
leonge's picture
Updated pydantic models to support lists
e80c4c1
raw
history blame
5.14 kB
"""
Pydantic models used throughout the codebase.
In particular, these are the types that are used as input and output of each step of the pipeline.
"""
import json
from typing import Any, Optional, Union, List
from pydantic import BaseModel, Field
class RedirectInput(BaseModel):
pipeline_arn: str
job_id: str
class NutrientTableQuantity(BaseModel):
measurementUnitCode: str
value: str
precisionCode: str
def __str__(self):
return f"{self.precisionCode} {self.value} {self.measurementUnitCode}"
class NutrientTableDailyValueIntake(BaseModel):
value: str
precisionCode: str
def __str__(self):
return f"{self.precisionCode} {self.value}%"
def s(
text: Optional[Union[NutrientTableQuantity, NutrientTableDailyValueIntake, str]]
) -> str:
"""
Returns None as "?", otherwise argument as string.
"""
if text:
return str(text)
else:
return "?"
class NutrientTableElement(BaseModel):
coordinates: str
probability: float
nutrientTypeCode: Optional[str]
quantityContained: NutrientTableQuantity
dailyValueIntakePercent: Optional[NutrientTableDailyValueIntake]
precisionCode: str
def __str__(self):
return " ".join(
[
s(self.nutrientTypeCode),
s(self.quantityContained),
f"({s(self.dailyValueIntakePercent)})",
]
)
class NutrientTable(BaseModel):
nutrientBasisQuantityValue: Optional[str]
nutrientBasisQuantityMeasurementUnitCode: Optional[str]
preperationStateCode: Optional[str]
values: List[NutrientTableElement]
def __str__(self):
top = "Nutrients per " + " ".join(
[
s(self.nutrientBasisQuantityValue),
s(self.nutrientBasisQuantityMeasurementUnitCode),
f"({s(self.preperationStateCode)})",
]
)
vals = "\n\t".join([str(v) for v in self.values])
return f"{top}\n\t{vals}"
class Attribute(BaseModel):
coordinates: str
entity: str
probability: float
value: Union[str, List[NutrientTable]]
model: str
class AttributeCommunicationChannel(BaseModel):
coordinates: str
probability: float
model: str
entity: str
communicationChannelCode: str
communicationValue: str
class AttributeAllergen(BaseModel):
coordinates: str
probability: float
model: str
entity: str
allergenTypeCode: str
levelOfContainmentCode: str
class NetContentAttribute(BaseModel):
coordinates: str
probability: float
model: str
entity: str
measurementUnitCode: str
value: str
class AllergensOut(BaseModel):
entity: str
values: List[AttributeAllergen]
model: str
class CommunicationChannelsOut(BaseModel):
entity: str
values: List[AttributeCommunicationChannel]
model: str
class PipelineInput(BaseModel):
image_key: str
class PipelineOutput(BaseModel):
attributes: List[
Union[Attribute, CommunicationChannelsOut, AllergensOut, NetContentAttribute]
]
job_id: str = Field(alias="job-id")
text: str
class Config:
allow_population_by_field_name = True
class TextWithLanguage(BaseModel):
text: str
lang_code: str
class OCRTextOut(BaseModel):
blocks: List[str]
full_text: str
sentences: List[TextWithLanguage]
class OCRTableOut(BaseModel):
tables: List[List[List[str]]]
class OCROut(BaseModel):
result: Union[OCRTextOut, OCRTableOut]
job_id: str
class OCROutList(BaseModel):
__root__: List[OCROut]
def __iter__(self):
return iter(self.__root__)
def __getitem__(self, item):
return self.__root__[item]
class OCRWrapperOut(BaseModel):
blocks: List[str]
full_text: str
job_id: str
sentences: List[TextWithLanguage]
tables: List[List[List[str]]]
class ClassifiedText(BaseModel):
text: str
attribute: str
confidence: float
class CommunicationChannels(BaseModel):
confidence: float
attribute: str
communicationChannelCode: str
communicationValue: str
text: Optional[str] = ""
class Allergen(BaseModel):
confidence: float
attribute: str
allergenTypeCode: str
levelOfContainmentCode: str
text: Optional[str] = ""
class NetContent(BaseModel):
confidence: float
attribute: str
measurementUnitCode: str
value: str
text: Optional[str] = ""
class ModelOut(BaseModel):
blocks: List[Union[NetContent, Allergen, CommunicationChannels, ClassifiedText]]
tables: Optional[List[NutrientTable]]
job_id: str
model: str
full_text: str
def toJSON(self):
return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True)
class ModelOutList(BaseModel):
__root__: List[ModelOut]
def __iter__(self):
return iter(self.__root__)
def __getitem__(self, item):
return self.__root__[item]
class TrainModelOut(BaseModel):
# To be defined later when we have a List of accepted formats
model: Optional[Any] = None
artifacts: Optional[Any] = None