Spaces:
Runtime error
Runtime error
| from abc import ABC | |
| from modules.module_rankSents import RankSents | |
| from modules.module_crowsPairs import CrowsPairs | |
| from typing import List, Tuple | |
| class Connector(ABC): | |
| def parse_word( | |
| self, | |
| word: str | |
| ) -> str: | |
| return word.lower().strip() | |
| def parse_words( | |
| self, | |
| array_in_string: str | |
| ) -> List[str]: | |
| words = array_in_string.strip() | |
| if not words: | |
| return [] | |
| words = [ | |
| self.parse_word(word) | |
| for word in words.split(',') if word.strip() != '' | |
| ] | |
| return words | |
| def process_error( | |
| self, | |
| err: str | |
| ) -> str: | |
| if err: | |
| err = "<center><h3>" + err + "</h3></center>" | |
| return err | |
| class PhraseBiasExplorerConnector(Connector): | |
| def __init__( | |
| self, | |
| **kwargs | |
| ) -> None: | |
| language_model = kwargs.get('language_model', None) | |
| lang = kwargs.get('lang', None) | |
| if language_model is None or lang is None: | |
| raise KeyError | |
| self.phrase_bias_explorer = RankSents( | |
| language_model=language_model, | |
| lang=lang | |
| ) | |
| def rank_sentence_options( | |
| self, | |
| sent: str, | |
| word_list: str, | |
| banned_word_list: str, | |
| useArticles: bool, | |
| usePrepositions: bool, | |
| useConjunctions: bool | |
| ) -> Tuple: | |
| sent = " ".join(sent.strip().replace("*"," * ").split()) | |
| err = self.phrase_bias_explorer.errorChecking(sent) | |
| if err: | |
| return self.process_error(err), "", "" | |
| word_list = self.parse_words(word_list) | |
| banned_word_list = self.parse_words(banned_word_list) | |
| all_plls_scores = self.phrase_bias_explorer.rank( | |
| sent, | |
| word_list, | |
| banned_word_list, | |
| useArticles, | |
| usePrepositions, | |
| useConjunctions | |
| ) | |
| all_plls_scores = self.phrase_bias_explorer.Label.compute(all_plls_scores) | |
| return self.process_error(err), all_plls_scores, "" | |
| class CrowsPairsExplorerConnector(Connector): | |
| def __init__( | |
| self, | |
| **kwargs | |
| ) -> None: | |
| language_model = kwargs.get('language_model', None) | |
| if language_model is None: | |
| raise KeyError | |
| self.crows_pairs_explorer = CrowsPairs( | |
| language_model=language_model | |
| ) | |
| def compare_sentences( | |
| self, | |
| sent0: str, | |
| sent1: str, | |
| sent2: str, | |
| sent3: str, | |
| sent4: str, | |
| sent5: str | |
| ) -> Tuple: | |
| sent_list = [sent0, sent1, sent2, sent3, sent4, sent5] | |
| err = self.crows_pairs_explorer.errorChecking( | |
| sent_list | |
| ) | |
| if err: | |
| return self.process_error(err), "", "" | |
| all_plls_scores = self.crows_pairs_explorer.rank( | |
| sent_list | |
| ) | |
| all_plls_scores = self.crows_pairs_explorer.Label.compute(all_plls_scores) | |
| return self.process_error(err), all_plls_scores, "" |