| | from abc import abstractmethod |
| | from typing import List, Dict, Union, Any |
| | import os |
| | import requests |
| | import re |
| | import json |
| | import fasttext |
| | import torch |
| | from fasttext.FastText import _FastText |
| | from transformers import BertTokenizer |
| | from src.models.base import BaseModel |
| | from src.models.bert_cls import BERTClassifier |
| | from src.enum import ( |
| | DucklingDimensionTypes, |
| | DucklingLocaleTypes, |
| | OperatorClassTypes, |
| | ) |
| | from src.misc.schemas import PriceExtractionSchema, ProductNamedEntityExtractionSchema |
| |
|
| | ARABIC_TEXT_PATTERN = r"[\u0600-\u06ff]|[\u0750-\u077f]|[\ufb50-\ufc3f]|[\ufe70-\ufefc]" |
| |
|
| |
|
| | class BaseOperatorModel(BaseModel): |
| | def __init__(self, model_ckpt_path: str) -> None: |
| | super().__init__() |
| | self._model = self.load_model(model_ckpt_path) |
| |
|
| | def __call__(self, input_query: str, *args: Any, **kwds: Any) -> OperatorClassTypes: |
| | return super().__call__(input_query, *args, **kwds) |
| |
|
| | @abstractmethod |
| | def predict(self, input_query, *args: Any, **kwds: Any) -> OperatorClassTypes: |
| | raise NotImplementedError() |
| |
|
| | @abstractmethod |
| | def load_model(self, model_ckpt_path): |
| | raise NotImplementedError() |
| |
|
| |
|
| | class FasttextOperatorModel(BaseOperatorModel): |
| | def __init__( |
| | self, |
| | model_ckpt_path: str = os.path.join( |
| | "ckpt", "fasttext_operator", "fasttext_operator_cls.bin" |
| | ), |
| | opt_thres: float = 0.2, |
| | ) -> None: |
| | super().__init__(model_ckpt_path) |
| | self.opt_thres = opt_thres |
| |
|
| | def load_model(self, model_ckpt_path) -> _FastText: |
| | return fasttext.load_model(model_ckpt_path) |
| |
|
| | def predict(self, input_query, *args: Any, **kwds: Any) -> OperatorClassTypes: |
| | pred_class, _ = [ |
| | res[0] if len(res) else OperatorClassTypes.NONE.value |
| | for res in self._model.predict( |
| | text=input_query, threshold=self.opt_thres, k=1 |
| | ) |
| | ] |
| | for class_type in OperatorClassTypes: |
| | if class_type.value == pred_class: |
| | pred_class = class_type |
| | return pred_class |
| |
|
| |
|
| | class BertOperatorModel(BaseOperatorModel): |
| | def __init__( |
| | self, |
| | model_ckpt_path: str = os.path.join("ckpt", "bert_operator"), |
| | ) -> None: |
| | model_config: Dict = json.load( |
| | open(os.path.join(model_ckpt_path, "model_config.json"), "r") |
| | ) |
| | self._index2class: List[str] = model_config["index2class"] |
| | self._bert_model_name = model_config["bert_model_name"] |
| | self._tokenizer = self.load_tokenizer(bert_model_name=self._bert_model_name) |
| | self._num_classes = len(self._index2class) |
| | self._device = model_config["device"] |
| | self._max_length = model_config["max_length"] |
| | super().__init__(model_ckpt_path) |
| |
|
| | def load_model(self, model_ckpt_path) -> BERTClassifier: |
| | model = BERTClassifier(self._bert_model_name, self._num_classes).to( |
| | self._device |
| | ) |
| | model.eval() |
| | model.load_state_dict( |
| | torch.load( |
| | os.path.join( |
| | model_ckpt_path, |
| | "bert_classifier.pth", |
| | ), |
| | map_location=self._device, |
| | ) |
| | ) |
| | return model |
| |
|
| | def load_tokenizer(self, bert_model_name: str) -> BertTokenizer: |
| | return BertTokenizer.from_pretrained(bert_model_name) |
| |
|
| | def predict(self, input_query: str, *args: Any, **kwds: Any) -> OperatorClassTypes: |
| | encoding = self._tokenizer( |
| | input_query, |
| | return_tensors="pt", |
| | max_length=self._max_length, |
| | padding="max_length", |
| | truncation=True, |
| | ) |
| | input_ids = encoding["input_ids"].to(self._device) |
| | attention_mask = encoding["attention_mask"].to(self._device) |
| |
|
| | with torch.no_grad(): |
| | outputs = self._model(input_ids=input_ids, attention_mask=attention_mask) |
| | _, pred_class = torch.max(outputs, dim=1) |
| | pred_class = self._index2class[pred_class.item()] |
| | for class_type in OperatorClassTypes: |
| | if class_type.value == pred_class: |
| | pred_class = class_type |
| | return pred_class |
| |
|
| |
|
| | class DucklingHTTPOperatorModel(BaseModel): |
| | def __init__( |
| | self, |
| | duckling_host: str = "localhost", |
| | duckling_port: str = "8000", |
| | dim_types: List[DucklingDimensionTypes] = [ |
| | DucklingDimensionTypes.AMOUNT_OF_MONEY, |
| | DucklingDimensionTypes.NUMERAL, |
| | DucklingDimensionTypes.ORDINAL, |
| | ], |
| | operator_model: BaseOperatorModel = None, |
| | ) -> None: |
| | super().__init__() |
| | self._duckling_url = f"http://{duckling_host}:{duckling_port}/parse" |
| | self._dim_types = dim_types |
| | self.operator_model = operator_model |
| | self.currency_pattern_map = { |
| | "SAR": r"\b(ريال سعودي|ر\.س|ريال|SAR|saudi riyal|riyal|sr)\b", |
| | "$": r"\b(دولار امريكي|دولار أمريكي|دولار أمريكى|دولار امريكى|دولار|دولارًا|\$|USD|united states dollar|dollar)\b", |
| | "AED": r"\b(درهم اماراتي|درهم إماراتي|درهم اماراتى|درهم إماراتى|درهم|د\.إ|AED|emirates dirham|emirate dirham|dirham)\b", |
| | "EGP": r"\b(جنيه مصري|جنيه مصرى|جنيه|ج\.م|£|egyptian pound|pound|le|LE|EGP)\b", |
| | } |
| |
|
| | def predict( |
| | self, input_query, *args: Any, **kwds: Any |
| | ) -> ProductNamedEntityExtractionSchema: |
| | input_query = input_query.replace("اغلي", "أغلى").replace("أغلي", "أغلى") |
| | extraction_result = ProductNamedEntityExtractionSchema() |
| | price_extraction_result = PriceExtractionSchema() |
| | extraction_result.sub_category_extraction = input_query |
| | extraction_result.price_extraction = price_extraction_result |
| | operator_class_pred: OperatorClassTypes = self.operator_model( |
| | input_query=input_query |
| | ) |
| |
|
| | locale_type = None |
| | if re.match(pattern=ARABIC_TEXT_PATTERN, string=input_query): |
| | locale_type = DucklingLocaleTypes.AR |
| | else: |
| | locale_type = DucklingLocaleTypes.EN |
| | headers = {"Content-Type": "application/x-www-form-urlencoded"} |
| | if operator_class_pred == OperatorClassTypes.NONE: |
| | return extraction_result |
| |
|
| | parsed_entities_response = requests.post( |
| | url=self._duckling_url, |
| | headers=headers, |
| | data=self.payload(input_query=input_query, locale=locale_type.value), |
| | ).json() |
| | if len(parsed_entities_response): |
| | for parsed_entity in parsed_entities_response: |
| | if "from" in parsed_entity["value"].keys(): |
| | price_extraction_result.lower_range = parsed_entity["value"][ |
| | "from" |
| | ]["value"] |
| |
|
| | if ( |
| | price_extraction_result.unit == "" |
| | or price_extraction_result.unit == "unknown" |
| | ): |
| | price_extraction_result.unit = parsed_entity["value"]["from"][ |
| | "unit" |
| | ] |
| | if "to" in parsed_entity["value"].keys(): |
| | price_extraction_result.upper_range = parsed_entity["value"]["to"][ |
| | "value" |
| | ] |
| | if ( |
| | price_extraction_result.unit == "" |
| | or price_extraction_result.unit == "unknown" |
| | ): |
| | price_extraction_result.unit = parsed_entity["value"]["to"][ |
| | "unit" |
| | ] |
| | extraction_result.sub_category_extraction = ( |
| | extraction_result.sub_category_extraction.replace( |
| | parsed_entity["body"], "" |
| | ) |
| | ) |
| | for parsed_entity in parsed_entities_response: |
| | if "value" in parsed_entity["value"].keys(): |
| | if ( |
| | price_extraction_result.lower_range != -1 |
| | and price_extraction_result.upper_range != -1 |
| | ): |
| | continue |
| | elif ( |
| | price_extraction_result.lower_range == -1 |
| | and price_extraction_result.upper_range == -1 |
| | ): |
| | if operator_class_pred == OperatorClassTypes.LE: |
| | price_extraction_result.upper_range = parsed_entity[ |
| | "value" |
| | ]["value"] |
| | elif operator_class_pred == OperatorClassTypes.EQ: |
| | price_extraction_result.lower_range = parsed_entity[ |
| | "value" |
| | ]["value"] |
| | price_extraction_result.upper_range = parsed_entity[ |
| | "value" |
| | ]["value"] |
| | elif operator_class_pred == OperatorClassTypes.GE: |
| | price_extraction_result.lower_range = parsed_entity[ |
| | "value" |
| | ]["value"] |
| | elif operator_class_pred == OperatorClassTypes.GELE: |
| | numbers_pattern = r"[+-]?\d+\.\d+|\d+" |
| | digits_matches = re.findall(numbers_pattern, input_query) |
| | if len(digits_matches): |
| | digit_matched = None |
| | for digit in digits_matches: |
| | if float(digit) != float( |
| | parsed_entity["value"]["value"] |
| | ): |
| | extraction_result.sub_category_extraction = extraction_result.sub_category_extraction.replace( |
| | digit, "" |
| | ) |
| | digit_matched = float(digit) |
| | break |
| | if digit_matched: |
| | price_extraction_result.lower_range = min( |
| | parsed_entity["value"]["value"], digit_matched |
| | ) |
| | price_extraction_result.upper_range = max( |
| | parsed_entity["value"]["value"], digit_matched |
| | ) |
| | else: |
| | price_extraction_result.lower_range = parsed_entity[ |
| | "value" |
| | ]["value"] |
| | price_extraction_result.upper_range = parsed_entity[ |
| | "value" |
| | ]["value"] |
| | else: |
| | price_extraction_result.lower_range = parsed_entity[ |
| | "value" |
| | ]["value"] |
| | price_extraction_result.upper_range = parsed_entity[ |
| | "value" |
| | ]["value"] |
| |
|
| | elif price_extraction_result.lower_range != -1: |
| | lower_range = price_extraction_result.lower_range |
| | val = parsed_entity["value"]["value"] |
| | if price_extraction_result.upper_range == -1: |
| | if operator_class_pred == OperatorClassTypes.EQ: |
| | price_extraction_result.upper_range = val |
| | elif operator_class_pred == OperatorClassTypes.LE: |
| | price_extraction_result.upper_range = lower_range |
| | price_extraction_result.lower_range = -1 |
| | elif operator_class_pred == OperatorClassTypes.GELE: |
| | price_extraction_result.upper_range = max( |
| | lower_range, val |
| | ) |
| | price_extraction_result.lower_range = min( |
| | lower_range, val |
| | ) |
| | else: |
| | price_extraction_result.upper_range = max(lower_range, val) |
| | price_extraction_result.lower_range = min(lower_range, val) |
| | elif price_extraction_result.upper_range != -1: |
| | upper_range = price_extraction_result.upper_range |
| | val = parsed_entity["value"]["value"] |
| | if price_extraction_result.lower_range == -1: |
| | if operator_class_pred == OperatorClassTypes.EQ: |
| | price_extraction_result.lower_range = val |
| | elif operator_class_pred == OperatorClassTypes.LE: |
| | price_extraction_result.lower_range = upper_range |
| | price_extraction_result.upper_range = -1 |
| | elif operator_class_pred == OperatorClassTypes.GELE: |
| | price_extraction_result.upper_range = max( |
| | upper_range, val |
| | ) |
| | price_extraction_result.lower_range = min( |
| | upper_range, val |
| | ) |
| | else: |
| | price_extraction_result.upper_range = max(upper_range, val) |
| | price_extraction_result.lower_range = min(upper_range, val) |
| | extraction_result.sub_category_extraction = ( |
| | extraction_result.sub_category_extraction.replace( |
| | parsed_entity["body"], "" |
| | ) |
| | ) |
| | if ( |
| | price_extraction_result.unit == "" |
| | or price_extraction_result.unit == "unknown" |
| | ): |
| | price_extraction_result.unit = parsed_entity["value"]["unit"] |
| | if price_extraction_result.unit == "unknown": |
| | price_extraction_result.unit = "" |
| | else: |
| | numbers_pattern = r"[+-]?\d+\.\d+|\d+" |
| | digits_matches = [ |
| | float(digit) for digit in re.findall(numbers_pattern, input_query) |
| | ] |
| | for digit_match in re.findall(numbers_pattern, input_query): |
| | extraction_result.sub_category_extraction = ( |
| | extraction_result.sub_category_extraction.replace(digit_match, "") |
| | ) |
| | if len(digits_matches): |
| | if operator_class_pred == OperatorClassTypes.LE: |
| | price_extraction_result.upper_range = max(digits_matches) |
| | elif operator_class_pred == OperatorClassTypes.EQ: |
| | price_extraction_result.upper_range = digits_matches[0] |
| | price_extraction_result.lower_range = digits_matches[0] |
| | elif operator_class_pred == OperatorClassTypes.GE: |
| | price_extraction_result.lower_range = min(digits_matches) |
| | elif operator_class_pred == OperatorClassTypes.GELE: |
| | price_extraction_result.upper_range = max(digits_matches) |
| | price_extraction_result.lower_range = min(digits_matches) |
| |
|
| | for currency, cur_pattern in self.currency_pattern_map.items(): |
| |
|
| | currency_matches = re.findall(cur_pattern, input_query, re.IGNORECASE) |
| | if len(currency_matches): |
| | price_extraction_result.unit = currency |
| | extraction_result.sub_category_extraction = ( |
| | extraction_result.sub_category_extraction.replace( |
| | currency_matches[0], "" |
| | ) |
| | ) |
| | break |
| | extraction_result.sub_category_extraction = " ".join( |
| | extraction_result.sub_category_extraction.split() |
| | ) |
| | for currency, cur_pattern in self.currency_pattern_map.items(): |
| | currency_matches = re.findall( |
| | cur_pattern, price_extraction_result.unit, re.IGNORECASE |
| | ) |
| | if len(currency_matches): |
| | price_extraction_result.unit = currency |
| |
|
| | break |
| | return extraction_result |
| |
|
| | def payload(self, input_query: str, locale: str) -> Dict[str, Union[List, str]]: |
| |
|
| | return { |
| | "text": input_query, |
| | "locale": locale, |
| | "dims": json.dumps([dim_type.value for dim_type in self._dim_types]), |
| | } |
| |
|
| |
|
| | if __name__ == "__main__": |
| | model = DucklingHTTPOperatorModel() |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | pattern = r"\b(ريال سعودي|ر\.س|جنيه مصري|جنيه مصرى|دولار امريكي|دولار أمريكي|دولار أمريكى|دولار امريكى|درهم اماراتي|درهم إماراتي|درهم اماراتى|درهم إماراتى|ريال|دولار|درهم|دينار|يورو|جنيه|ليرة|د\.إ|ج\.م|\$|€|£|USD|AED|SAR|EUR|GBP|usd|aed|sar|eur|saudi riyal|egyptian pound|riyal|sr|pound|le|LE|emirates dirham|emirate dirham|dirham|united states dollar|dollar)\b" |
| |
|
| | |
| | text = """ |
| | سعر المنتج هو 120 ريال سعودي. |
| | التكلفة الكلية 50 دولار أمريكي. |
| | The laptop costs $500 USD. |
| | سعر الهاتف هو 250 درهم إماراتي. |
| | سعر الساعة 99.99€. |
| | """ |
| |
|
| | |
| | matches = re.findall(pattern, text, re.IGNORECASE) |
| | print(matches) |
| |
|