| |
|
|
| from dataclasses import dataclass |
| from enum import Enum |
| from typing import Iterator, Sequence |
|
|
| from transformers import PreTrainedTokenizerBase as TokenizerLike |
|
|
|
|
| class HarmonyMessageEndType(Enum): |
| INCOMPLETE = 0 |
| END = 1 |
| CALL = 2 |
|
|
|
|
| @dataclass(frozen=True) |
| class HarmonySequence: |
| """A data class representing a sequence of tokens in the Harmony format.""" |
| token_ids: list[int] |
| start: int |
|
|
|
|
| @dataclass(frozen=True) |
| class HarmonyMessage: |
| """A data class representing a message in the Harmony format.""" |
| end: HarmonyMessageEndType |
| role: HarmonySequence | None = None |
| channel: HarmonySequence | None = None |
| constrain: HarmonySequence | None = None |
| content: HarmonySequence | None = None |
|
|
|
|
| class HarmonyMessageParser: |
| """A parser that performs lexical analysis to extract Harmony messages.""" |
|
|
| def __init__(self, tokenizer: TokenizerLike): |
| vocab = tokenizer.get_vocab() |
| self._begin_map = { |
| vocab["<|start|>"]: "role", |
| vocab["<|channel|>"]: "channel", |
| vocab["<|constrain|>"]: "constrain", |
| vocab["<|message|>"]: "content", |
| } |
| self._end_map = { |
| vocab["<|end|>"]: HarmonyMessageEndType.END, |
| vocab["<|return|>"]: HarmonyMessageEndType.END, |
| vocab["<|call|>"]: HarmonyMessageEndType.CALL, |
| } |
|
|
| def iter_messages(self, token_ids: Sequence[int]) -> Iterator[HarmonyMessage]: |
| """ |
| Parse given token ids into messages. |
| |
| Args: |
| token_ids: A sequence of token ids to be parsed. |
| |
| Yields: |
| Detected HarmonyMessages. |
| """ |
|
|
| message_dict: dict[str, HarmonySequence] = {} |
| section: str | None = None |
| text_ids: list[int] = [] |
| text_start: int | None = None |
|
|
| for token_position, token_id in enumerate(token_ids): |
| if token_id in self._begin_map: |
| if section is not None: |
| message_dict[section] = HarmonySequence( |
| token_ids=text_ids, |
| start=text_start, |
| ) |
| section = self._begin_map[token_id] |
| text_ids = [] |
| text_start = token_position + 1 |
|
|
| elif token_id in self._end_map: |
| if section is not None: |
| message_dict[section] = HarmonySequence( |
| token_ids=text_ids, |
| start=text_start, |
| ) |
|
|
| yield HarmonyMessage(**message_dict, end=self._end_map[token_id]) |
|
|
| message_dict = {} |
| section = None |
| text_ids = [] |
| text_start = None |
| |
| else: |
| if section is not None: |
| text_ids.append(token_id) |
| |
| if section is not None: |
| message_dict[section] = HarmonySequence( |
| token_ids=text_ids, |
| start=text_start, |
| ) |
| yield HarmonyMessage(**message_dict, end=HarmonyMessageEndType.INCOMPLETE) |
|
|
| def get_all_messages(self, token_ids: Sequence[int]) -> list[HarmonyMessage]: |
| """ |
| Parse given token ids into messages. |
| |
| Args: |
| token_ids: A sequence of token ids to be parsed. |
| |
| Returns: |
| A list of detected HarmonyMessages. |
| """ |
| return list(self.iter_messages(token_ids)) |
| |
| def reverse_iter_messages(self, token_ids: Sequence[int]) -> Iterator[HarmonyMessage]: |
| """ |
| Parse given token ids into messages in reverse order. |
| |
| Args: |
| token_ids: A sequence of token ids to be parsed. |
| |
| Yields: |
| Detected HarmonyMessages in reverse order. |
| """ |
| end_position = len(token_ids) |
|
|
| for i in range(len(token_ids) - 1, -1, -1): |
| if token_ids[i] == self._start_id: |
| yield next(self.iter_messages(token_ids[i:end_position])) |
| end_position = i |
|
|