Spaces:
Runtime error
Runtime error
| import regex as re | |
| PROGRAM_SPECIAL_TOKEN="<extra_id_124>" | |
| UTTERANCES_SPECIAL_TOKEN="<extra_id_123>" | |
| GT_PROGRAM_SPECIAL_TOKEN="<extra_id_122>" | |
| def consistent(rx, spec): | |
| # spec is in the form of (string, '+'/'-') pairs | |
| for s, label in spec: | |
| if not label in ['+', '-']: | |
| return None | |
| try: | |
| if re.fullmatch(rx, s, timeout=1): | |
| if label == '-': | |
| return False | |
| else: | |
| if label == '+': | |
| return False | |
| except re.error: | |
| return None | |
| except TimeoutError: | |
| return None | |
| return True | |
| def decode(c): | |
| if c < 3: | |
| return f"<{c}>" | |
| elif c < 258: | |
| return chr(c - 3) | |
| else: | |
| return f"<extra_id_{c - 259}>" | |
| def byt5_decode_batch(outputs, skip_special_tokens=True, skip_position_token=False): | |
| skipped_tokens = outputs | |
| if skip_special_tokens: | |
| skipped_tokens = [ | |
| [[t for t in x if t >= 3] for x in beam] | |
| for beam in skipped_tokens | |
| ] | |
| if skip_position_token: | |
| skipped_tokens = [ | |
| [[t for t in x if t <= 258] for x in beam] | |
| for beam in skipped_tokens | |
| ] | |
| return [ | |
| [''.join([decode(t) for t in x]) for x in beam] | |
| for beam in skipped_tokens | |
| ] | |
| def get_preprocess_function(tokenizer): | |
| def preprocess_function(examples): | |
| model_inputs = tokenizer( | |
| [' ' if x is None else x for x in examples["context"]], | |
| text_target=examples["target"], | |
| truncation=True | |
| ) | |
| return model_inputs | |
| return preprocess_function | |
| def get_utterance_processing_functions(label_pos, idx, separator=' '): | |
| if label_pos == "suffix": | |
| if idx: | |
| def utterances_to_string(spec): | |
| return ''.join([f"<extra_id_{i}>{s}{label}" for i, (s, label) in enumerate(spec)]) | |
| else: | |
| def utterances_to_string(spec): | |
| return separator.join([f"{s}{label}" for s, label in spec]) | |
| else: | |
| if idx: | |
| def utterances_to_string(spec): | |
| return ''.join([f"<extra_id_{i}>{label}{s}" for i, (s, label) in enumerate(spec)]) | |
| else: | |
| def utterances_to_string(spec): | |
| return separator.join([f"{label}{s}" for s, label in spec]) | |
| if label_pos == "suffix": | |
| if idx: | |
| def string_to_utterances(string): | |
| string = re.sub(r'<extra_id_\d+>', ' ', string) | |
| return [(s[:-1], s[-1]) for s in string.split(' ') if len(s) > 0] | |
| else: | |
| def string_to_utterances(string): | |
| return [(s[:-1], s[-1]) for s in string.split(separator) if len(s) > 0] | |
| else: | |
| if idx: | |
| def string_to_utterances(string): | |
| string = re.sub(r'<extra_id_\d+>', '', string) | |
| return [(s[1:], s[0]) for s in string.split(separator) if len(s) > 0] | |
| else: | |
| def string_to_utterances(string): | |
| return [(s[1:], s[0]) for s in string.split(separator) if len(s) > 0] | |
| return utterances_to_string, string_to_utterances |