File size: 8,210 Bytes
fed5c73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
from typing import List, Dict, Union, Any, Tuple
import requests
import re
import json
from src.models.base import BaseModel
from src.enum import DucklingDimensionTypes, DucklingLocaleTypes
from src.misc.schemas import PriceExtractionSchema, ProductNamedEntityExtractionSchema


ARABIC_TEXT_PATTERN = r"[\u0600-\u06ff]|[\u0750-\u077f]|[\ufb50-\ufc3f]|[\ufe70-\ufefc]"


class DucklingHTTPModel(BaseModel):
    def __init__(
        self,
        duckling_host: str = "localhost",
        duckling_port: str = "8000",
        dim_types: List[DucklingDimensionTypes] = [
            DucklingDimensionTypes.AMOUNT_OF_MONEY,
            DucklingDimensionTypes.NUMERAL,
            DucklingDimensionTypes.ORDINAL,
        ],
    ) -> None:
        super().__init__()
        self._duckling_url = f"http://{duckling_host}:{duckling_port}/parse"
        self._dim_types = dim_types
        self.currency_pattern_map = {
            "SAR": r"\b(ريال سعودي|ر\.س|ريال|SAR|saudi riyal|riyal|sr)\b",
            "$": r"\b(دولار امريكي|دولار أمريكي|دولار أمريكى|دولار امريكى|دولار|دولارًا|\$|USD|united states dollar|dollar)\b",
            "AED": r"\b(درهم اماراتي|درهم إماراتي|درهم اماراتى|درهم إماراتى|درهم|د\.إ|AED|emirates dirham|emirate dirham|dirham)\b",
            "EGP": r"\b(جنيه مصري|جنيه مصرى|جنيه|ج\.م|£|egyptian pound|pound|le|LE|EGP)\b",
        }

    def predict(
        self, input_query, *args: Any, **kwds: Any
    ) -> ProductNamedEntityExtractionSchema:
        extraction_results = ProductNamedEntityExtractionSchema()

        if re.match(pattern=ARABIC_TEXT_PATTERN, string=input_query):
            locale_type = DucklingLocaleTypes.AR
        else:
            locale_type = DucklingLocaleTypes.EN
        headers = {"Content-Type": "application/x-www-form-urlencoded"}
        parsed_entities_response = requests.post(
            url=self._duckling_url,
            headers=headers,
            data=self.payload(input_query=input_query, locale=locale_type.value),
            timeout=1000,
        ).json()
        if len(parsed_entities_response):
            extraction_results.sub_category_extraction = input_query
            price_extraction_result: PriceExtractionSchema = PriceExtractionSchema()
            extraction_results.price_extraction = price_extraction_result

            for parsed_entity in parsed_entities_response:
                if "from" in parsed_entity["value"].keys():
                    price_extraction_result.lower_range = parsed_entity["value"][
                        "from"
                    ]["value"]

                    if (
                        price_extraction_result.unit == ""
                        or price_extraction_result.unit == "unknown"
                    ):
                        price_extraction_result.unit = parsed_entity["value"]["from"][
                            "unit"
                        ]
                if "to" in parsed_entity["value"].keys():
                    price_extraction_result.upper_range = parsed_entity["value"]["to"][
                        "value"
                    ]
                    if (
                        price_extraction_result.unit == ""
                        or price_extraction_result.unit == "unknown"
                    ):
                        price_extraction_result.unit = parsed_entity["value"]["to"][
                            "unit"
                        ]
                extraction_results.sub_category_extraction = (
                    extraction_results.sub_category_extraction.replace(
                        parsed_entity["body"], ""
                    )
                )
            for parsed_entity in parsed_entities_response:
                if "value" in parsed_entity["value"].keys():
                    if (
                        price_extraction_result.lower_range != -1
                        and price_extraction_result.upper_range != -1
                    ):
                        continue
                    elif (
                        price_extraction_result.lower_range == -1
                        and price_extraction_result.upper_range == -1
                    ):
                        price_extraction_result.lower_range = parsed_entity["value"][
                            "value"
                        ]
                        price_extraction_result.upper_range = parsed_entity["value"][
                            "value"
                        ]
                    elif price_extraction_result.lower_range != -1:
                        lower_range = price_extraction_result.lower_range
                        val = parsed_entity["value"]["value"]
                        price_extraction_result.upper_range = max(lower_range, val)
                        price_extraction_result.lower_range = min(lower_range, val)
                    elif price_extraction_result.upper_range != -1:
                        upper_range = price_extraction_result.upper_range
                        val = parsed_entity["value"]["value"]
                        price_extraction_result.upper_range = max(upper_range, val)
                        price_extraction_result.lower_range = min(upper_range, val)
                    if (
                        price_extraction_result.unit == ""
                        or price_extraction_result.unit == "unknown"
                    ):
                        price_extraction_result.unit = parsed_entity["value"]["unit"]
                extraction_results.sub_category_extraction = (
                    extraction_results.sub_category_extraction.replace(
                        parsed_entity["body"], ""
                    )
                )
            if price_extraction_result.unit == "unknown":
                price_extraction_result.unit = ""
            extraction_results.sub_category_extraction = " ".join(
                extraction_results.sub_category_extraction.split()
            )
            for currency, cur_pattern in self.currency_pattern_map.items():
                currency_matches = re.findall(
                    cur_pattern, price_extraction_result.unit, re.IGNORECASE
                )
                if len(currency_matches):
                    price_extraction_result.unit = currency

        return extraction_results

    def payload(self, input_query: str, locale: str) -> Dict[str, Union[List, str]]:

        return {
            "text": input_query,
            "locale": locale,
            "dims": json.dumps([dim_type.value for dim_type in self._dim_types]),
        }


if __name__ == "__main__":
    model = DucklingHTTPModel()
    # print(model(input_query="360"))

    print(
        json.dumps(
            model(input_query="يساوي 12 ريال و عشرون ريال").model_dump(),
            indent=3,
            ensure_ascii=False,
        )
    )
    print("=====================")
    print(
        json.dumps(
            model(input_query="من 12 ريال").model_dump(), indent=3, ensure_ascii=False
        )
    )
    print("=====================")
    print(
        json.dumps(
            model(input_query="أقل 12 ريال").model_dump(), indent=3, ensure_ascii=False
        )
    )
    print("=====================")
    print(
        json.dumps(
            model(input_query="أكثر من 12 ريال").model_dump(),
            indent=3,
            ensure_ascii=False,
        )
    )
    print("=====================")
    print(
        json.dumps(
            model(input_query="أكثر من 12 إلى 20 ريال").model_dump(),
            indent=3,
            ensure_ascii=False,
        )
    )
    print("=====================")
    print(
        json.dumps(
            model(input_query="عطر ديور ارخص من ٢٠٠").model_dump(),
            indent=3,
            ensure_ascii=False,
        )
    )
    print("=====================")
    print(
        json.dumps(
            model(input_query="عطر ديور ارخص من ٢٠٠").model_dump(),
            indent=3,
            ensure_ascii=False,
        )
    )