Spaces:
Build error
Build error
| """ | |
| Pydantic models used throughout the codebase. | |
| In particular, these are the types that are used as input and output of each step of the pipeline. | |
| """ | |
| import json | |
| from typing import Any, Optional, Union | |
| from pydantic import BaseModel, Field | |
| class RedirectInput(BaseModel): | |
| pipeline_arn: str | |
| job_id: str | |
| class NutrientTableQuantity(BaseModel): | |
| measurementUnitCode: str | |
| value: str | |
| precisionCode: str | |
| def __str__(self): | |
| return f"{self.precisionCode} {self.value} {self.measurementUnitCode}" | |
| class NutrientTableDailyValueIntake(BaseModel): | |
| value: str | |
| precisionCode: str | |
| def __str__(self): | |
| return f"{self.precisionCode} {self.value}%" | |
| def s( | |
| text: Optional[Union[NutrientTableQuantity, NutrientTableDailyValueIntake, str]] | |
| ) -> str: | |
| """ | |
| Returns None as "?", otherwise argument as string. | |
| """ | |
| if text: | |
| return str(text) | |
| else: | |
| return "?" | |
| class NutrientTableElement(BaseModel): | |
| coordinates: str | |
| probability: float | |
| nutrientTypeCode: Optional[str] | |
| quantityContained: NutrientTableQuantity | |
| dailyValueIntakePercent: Optional[NutrientTableDailyValueIntake] | |
| precisionCode: str | |
| def __str__(self): | |
| return " ".join( | |
| [ | |
| s(self.nutrientTypeCode), | |
| s(self.quantityContained), | |
| f"({s(self.dailyValueIntakePercent)})", | |
| ] | |
| ) | |
| class NutrientTable(BaseModel): | |
| nutrientBasisQuantityValue: Optional[str] | |
| nutrientBasisQuantityMeasurementUnitCode: Optional[str] | |
| preperationStateCode: Optional[str] | |
| values: list[NutrientTableElement] | |
| def __str__(self): | |
| top = "Nutrients per " + " ".join( | |
| [ | |
| s(self.nutrientBasisQuantityValue), | |
| s(self.nutrientBasisQuantityMeasurementUnitCode), | |
| f"({s(self.preperationStateCode)})", | |
| ] | |
| ) | |
| vals = "\n\t".join([str(v) for v in self.values]) | |
| return f"{top}\n\t{vals}" | |
| class Attribute(BaseModel): | |
| coordinates: str | |
| entity: str | |
| probability: float | |
| value: Union[str, list[NutrientTable]] | |
| model: str | |
| class AttributeCommunicationChannel(BaseModel): | |
| coordinates: str | |
| probability: float | |
| model: str | |
| entity: str | |
| communicationChannelCode: str | |
| communicationValue: str | |
| class AttributeAllergen(BaseModel): | |
| coordinates: str | |
| probability: float | |
| model: str | |
| entity: str | |
| allergenTypeCode: str | |
| levelOfContainmentCode: str | |
| class NetContentAttribute(BaseModel): | |
| coordinates: str | |
| probability: float | |
| model: str | |
| entity: str | |
| measurementUnitCode: str | |
| value: str | |
| class AllergensOut(BaseModel): | |
| entity: str | |
| values: list[AttributeAllergen] | |
| model: str | |
| class CommunicationChannelsOut(BaseModel): | |
| entity: str | |
| values: list[AttributeCommunicationChannel] | |
| model: str | |
| class PipelineInput(BaseModel): | |
| image_key: str | |
| class PipelineOutput(BaseModel): | |
| attributes: list[ | |
| Union[Attribute, CommunicationChannelsOut, AllergensOut, NetContentAttribute] | |
| ] | |
| job_id: str = Field(alias="job-id") | |
| text: str | |
| class Config: | |
| allow_population_by_field_name = True | |
| class TextWithLanguage(BaseModel): | |
| text: str | |
| lang_code: str | |
| class OCRTextOut(BaseModel): | |
| blocks: list[str] | |
| full_text: str | |
| sentences: list[TextWithLanguage] | |
| class OCRTableOut(BaseModel): | |
| tables: list[list[list[str]]] | |
| class OCROut(BaseModel): | |
| result: Union[OCRTextOut, OCRTableOut] | |
| job_id: str | |
| class OCROutList(BaseModel): | |
| __root__: list[OCROut] | |
| def __iter__(self): | |
| return iter(self.__root__) | |
| def __getitem__(self, item): | |
| return self.__root__[item] | |
| class OCRWrapperOut(BaseModel): | |
| blocks: list[str] | |
| full_text: str | |
| job_id: str | |
| sentences: list[TextWithLanguage] | |
| tables: list[list[list[str]]] | |
| class ClassifiedText(BaseModel): | |
| text: str | |
| attribute: str | |
| confidence: float | |
| class CommunicationChannels(BaseModel): | |
| confidence: float | |
| attribute: str | |
| communicationChannelCode: str | |
| communicationValue: str | |
| text: Optional[str] = "" | |
| class Allergen(BaseModel): | |
| confidence: float | |
| attribute: str | |
| allergenTypeCode: str | |
| levelOfContainmentCode: str | |
| text: Optional[str] = "" | |
| class NetContent(BaseModel): | |
| confidence: float | |
| attribute: str | |
| measurementUnitCode: str | |
| value: str | |
| text: Optional[str] = "" | |
| class ModelOut(BaseModel): | |
| blocks: list[Union[NetContent, Allergen, CommunicationChannels, ClassifiedText]] | |
| tables: Optional[list[NutrientTable]] | |
| job_id: str | |
| model: str | |
| full_text: str | |
| def toJSON(self): | |
| return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True) | |
| class ModelOutList(BaseModel): | |
| __root__: list[ModelOut] | |
| def __iter__(self): | |
| return iter(self.__root__) | |
| def __getitem__(self, item): | |
| return self.__root__[item] | |
| class TrainModelOut(BaseModel): | |
| # To be defined later when we have a list of accepted formats | |
| model: Optional[Any] = None | |
| artifacts: Optional[Any] = None | |