Upload modeling_llama.py
Browse files- modeling_llama.py +128 -72
modeling_llama.py
CHANGED
|
@@ -17,7 +17,10 @@
|
|
| 17 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
# See the License for the specific language governing permissions and
|
| 19 |
# limitations under the License.
|
|
|
|
|
|
|
| 20 |
import bs4
|
|
|
|
| 21 |
import math
|
| 22 |
from typing import List, Optional, Tuple, Union
|
| 23 |
|
|
@@ -32,7 +35,6 @@ from transformers.activations import ACT2FN
|
|
| 32 |
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
| 33 |
from transformers.generation import GenerationMixin
|
| 34 |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 35 |
-
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
| 36 |
from transformers.modeling_outputs import (
|
| 37 |
BaseModelOutputWithPast,
|
| 38 |
CausalLMOutputWithPast,
|
|
@@ -50,6 +52,19 @@ from transformers.utils import (
|
|
| 50 |
logging,
|
| 51 |
replace_return_docstrings,
|
| 52 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
from .configuration_llama import LlamaConfig
|
| 54 |
from collections import defaultdict
|
| 55 |
from typing import List, Tuple
|
|
@@ -97,66 +112,6 @@ class TokenIdNode(Node):
|
|
| 97 |
self.input_ids = kwargs.get('input_ids', [])
|
| 98 |
self.prob = kwargs.get('prob', np.float32(0.0))
|
| 99 |
|
| 100 |
-
|
| 101 |
-
def split_tree(soup: bs4.BeautifulSoup, max_node_words=0) -> List[Tuple[bs4.element.Tag, List[str], bool]]:
|
| 102 |
-
word_count = len(soup.get_text().split())
|
| 103 |
-
if word_count > max_node_words:
|
| 104 |
-
possible_trees = [(soup, [])]
|
| 105 |
-
target_trees = [] # [(tag, path, is_leaf)]
|
| 106 |
-
# split the entire dom tee into subtrees, until the length of the subtree is less than max_node_words words
|
| 107 |
-
# find all possible trees
|
| 108 |
-
while True:
|
| 109 |
-
if len(possible_trees) == 0:
|
| 110 |
-
break
|
| 111 |
-
tree = possible_trees.pop(0)
|
| 112 |
-
tag_children = defaultdict(int)
|
| 113 |
-
bare_word_count = 0
|
| 114 |
-
# count child tags
|
| 115 |
-
for child in tree[0].contents:
|
| 116 |
-
if isinstance(child, bs4.element.Tag):
|
| 117 |
-
tag_children[child.name] += 1
|
| 118 |
-
_tag_children = {k: 0 for k in tag_children.keys()}
|
| 119 |
-
|
| 120 |
-
# check if the tree can be split
|
| 121 |
-
for child in tree[0].contents:
|
| 122 |
-
if isinstance(child, bs4.element.Tag):
|
| 123 |
-
# change child tag with duplicate names
|
| 124 |
-
if tag_children[child.name] > 1:
|
| 125 |
-
new_name = f"{child.name}{_tag_children[child.name]}"
|
| 126 |
-
new_tree = (child, tree[1] + [new_name])
|
| 127 |
-
_tag_children[child.name] += 1
|
| 128 |
-
child.name = new_name
|
| 129 |
-
else:
|
| 130 |
-
new_tree = (child, tree[1] + [child.name])
|
| 131 |
-
word_count = len(child.get_text().split())
|
| 132 |
-
# add node with more than max_node_words words, and recursion depth is less than 64
|
| 133 |
-
if word_count > max_node_words and len(new_tree[1]) < 64:
|
| 134 |
-
possible_trees.append(new_tree)
|
| 135 |
-
else:
|
| 136 |
-
target_trees.append((new_tree[0], new_tree[1], True))
|
| 137 |
-
else:
|
| 138 |
-
bare_word_count += len(str(child).split())
|
| 139 |
-
|
| 140 |
-
# add leaf node
|
| 141 |
-
if len(tag_children) == 0:
|
| 142 |
-
target_trees.append((tree[0], tree[1], True))
|
| 143 |
-
# add node with more than max_node_words bare words
|
| 144 |
-
elif bare_word_count > max_node_words:
|
| 145 |
-
target_trees.append((tree[0], tree[1], False))
|
| 146 |
-
else:
|
| 147 |
-
soup_children = [c for c in soup.contents if isinstance(c, bs4.element.Tag)]
|
| 148 |
-
if len(soup_children) == 1:
|
| 149 |
-
target_trees = [(soup_children[0], [soup_children[0].name], True)]
|
| 150 |
-
else:
|
| 151 |
-
# add an html tag to wrap all children
|
| 152 |
-
new_soup = bs4.BeautifulSoup("", 'html.parser')
|
| 153 |
-
new_tag = new_soup.new_tag("html")
|
| 154 |
-
new_soup.append(new_tag)
|
| 155 |
-
for child in soup_children:
|
| 156 |
-
new_tag.append(child)
|
| 157 |
-
target_trees = [(new_tag, ["html"], True)]
|
| 158 |
-
return target_trees
|
| 159 |
-
|
| 160 |
logger = logging.get_logger(__name__)
|
| 161 |
|
| 162 |
_CONFIG_FOR_DOC = "LlamaConfig"
|
|
@@ -517,6 +472,107 @@ class LlamaFlashAttention2(LlamaAttention):
|
|
| 517 |
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
| 518 |
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
| 519 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
def forward(
|
| 521 |
self,
|
| 522 |
hidden_states: torch.Tensor,
|
|
@@ -600,17 +656,16 @@ class LlamaFlashAttention2(LlamaAttention):
|
|
| 600 |
key_states = key_states.to(target_dtype)
|
| 601 |
value_states = value_states.to(target_dtype)
|
| 602 |
|
| 603 |
-
|
|
|
|
| 604 |
query_states,
|
| 605 |
key_states,
|
| 606 |
value_states,
|
| 607 |
attention_mask,
|
| 608 |
q_len,
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
| 613 |
-
is_causal=self.is_causal,
|
| 614 |
)
|
| 615 |
|
| 616 |
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
|
@@ -1752,6 +1807,7 @@ class LlamaForHTMLTreeGeneration(LlamaPreTrainedModel):
|
|
| 1752 |
tokenizer,
|
| 1753 |
query: List[str],
|
| 1754 |
htmls: List[List[str]],
|
|
|
|
| 1755 |
**kwargs):
|
| 1756 |
max_seq_length = kwargs.pop("max_seq_length", 131072)
|
| 1757 |
def apply_html_tree_template(query, htmls):
|
|
@@ -1787,11 +1843,11 @@ class LlamaForHTMLTreeGeneration(LlamaPreTrainedModel):
|
|
| 1787 |
soup.append(bs4.BeautifulSoup(html, 'html.parser'))
|
| 1788 |
|
| 1789 |
token_id_paths = []
|
| 1790 |
-
|
| 1791 |
-
is_leaf = [p[2] for p in
|
| 1792 |
-
|
| 1793 |
|
| 1794 |
-
for path in
|
| 1795 |
path_str = "<" + "><".join(path) + ">"
|
| 1796 |
token_ids = tokenizer.encode(path_str, add_special_tokens=False)
|
| 1797 |
token_id_paths.append(token_ids)
|
|
@@ -1849,7 +1905,7 @@ class LlamaForHTMLTreeGeneration(LlamaPreTrainedModel):
|
|
| 1849 |
|
| 1850 |
res_html_refs.append({
|
| 1851 |
"html": str(soup),
|
| 1852 |
-
"paths":
|
| 1853 |
"is_leaf": is_leaf,
|
| 1854 |
"path_token_ids": token_id_paths,
|
| 1855 |
"node_tree": list(TokenDotExporter(root, nodenamefunc=nodenamefunc))
|
|
|
|
| 17 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
# See the License for the specific language governing permissions and
|
| 19 |
# limitations under the License.
|
| 20 |
+
import inspect
|
| 21 |
+
|
| 22 |
import bs4
|
| 23 |
+
import loguru
|
| 24 |
import math
|
| 25 |
from typing import List, Optional, Tuple, Union
|
| 26 |
|
|
|
|
| 35 |
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
| 36 |
from transformers.generation import GenerationMixin
|
| 37 |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
|
|
|
| 38 |
from transformers.modeling_outputs import (
|
| 39 |
BaseModelOutputWithPast,
|
| 40 |
CausalLMOutputWithPast,
|
|
|
|
| 52 |
logging,
|
| 53 |
replace_return_docstrings,
|
| 54 |
)
|
| 55 |
+
try:
|
| 56 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 57 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
| 58 |
+
|
| 59 |
+
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
| 60 |
+
except ImportError as error:
|
| 61 |
+
loguru.logger.warning(
|
| 62 |
+
f"`flash-attention` package not found, consider installing for better performance: {error}."
|
| 63 |
+
)
|
| 64 |
+
if not _flash_supports_window_size:
|
| 65 |
+
loguru.logger.warning(
|
| 66 |
+
"Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`."
|
| 67 |
+
)
|
| 68 |
from .configuration_llama import LlamaConfig
|
| 69 |
from collections import defaultdict
|
| 70 |
from typing import List, Tuple
|
|
|
|
| 112 |
self.input_ids = kwargs.get('input_ids', [])
|
| 113 |
self.prob = kwargs.get('prob', np.float32(0.0))
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
logger = logging.get_logger(__name__)
|
| 116 |
|
| 117 |
_CONFIG_FOR_DOC = "LlamaConfig"
|
|
|
|
| 472 |
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
| 473 |
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
| 474 |
|
| 475 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward
|
| 476 |
+
def _flash_attention_forward(
|
| 477 |
+
self,
|
| 478 |
+
query_states,
|
| 479 |
+
key_states,
|
| 480 |
+
value_states,
|
| 481 |
+
attention_mask,
|
| 482 |
+
query_length,
|
| 483 |
+
dropout=0.0,
|
| 484 |
+
softmax_scale=None,
|
| 485 |
+
use_sliding_windows=False,
|
| 486 |
+
):
|
| 487 |
+
"""
|
| 488 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
| 489 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
| 490 |
+
|
| 491 |
+
Args:
|
| 492 |
+
query_states (`torch.Tensor`):
|
| 493 |
+
Input query states to be passed to Flash Attention API
|
| 494 |
+
key_states (`torch.Tensor`):
|
| 495 |
+
Input key states to be passed to Flash Attention API
|
| 496 |
+
value_states (`torch.Tensor`):
|
| 497 |
+
Input value states to be passed to Flash Attention API
|
| 498 |
+
attention_mask (`torch.Tensor`):
|
| 499 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
| 500 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
| 501 |
+
dropout (`float`):
|
| 502 |
+
Attention dropout
|
| 503 |
+
softmax_scale (`float`, *optional*):
|
| 504 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
| 505 |
+
use_sliding_windows (`bool`, *optional*):
|
| 506 |
+
Whether to activate sliding window attention.
|
| 507 |
+
"""
|
| 508 |
+
if not self._flash_attn_uses_top_left_mask:
|
| 509 |
+
causal = self.is_causal
|
| 510 |
+
else:
|
| 511 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
| 512 |
+
causal = self.is_causal and query_length != 1
|
| 513 |
+
|
| 514 |
+
# Contains at least one padding token in the sequence
|
| 515 |
+
if attention_mask is not None:
|
| 516 |
+
batch_size = query_states.shape[0]
|
| 517 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
| 518 |
+
query_states, key_states, value_states, attention_mask, query_length
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
| 522 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
| 523 |
+
|
| 524 |
+
if not use_sliding_windows:
|
| 525 |
+
attn_output_unpad = flash_attn_varlen_func(
|
| 526 |
+
query_states,
|
| 527 |
+
key_states,
|
| 528 |
+
value_states,
|
| 529 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 530 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 531 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
| 532 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
| 533 |
+
dropout_p=dropout,
|
| 534 |
+
softmax_scale=softmax_scale,
|
| 535 |
+
causal=causal,
|
| 536 |
+
)
|
| 537 |
+
else:
|
| 538 |
+
attn_output_unpad = flash_attn_varlen_func(
|
| 539 |
+
query_states,
|
| 540 |
+
key_states,
|
| 541 |
+
value_states,
|
| 542 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 543 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 544 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
| 545 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
| 546 |
+
dropout_p=dropout,
|
| 547 |
+
softmax_scale=softmax_scale,
|
| 548 |
+
causal=causal,
|
| 549 |
+
window_size=(self.config.sliding_window, self.config.sliding_window),
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
| 553 |
+
else:
|
| 554 |
+
if not use_sliding_windows:
|
| 555 |
+
attn_output = flash_attn_func(
|
| 556 |
+
query_states,
|
| 557 |
+
key_states,
|
| 558 |
+
value_states,
|
| 559 |
+
dropout,
|
| 560 |
+
softmax_scale=softmax_scale,
|
| 561 |
+
causal=causal,
|
| 562 |
+
)
|
| 563 |
+
else:
|
| 564 |
+
attn_output = flash_attn_func(
|
| 565 |
+
query_states,
|
| 566 |
+
key_states,
|
| 567 |
+
value_states,
|
| 568 |
+
dropout,
|
| 569 |
+
softmax_scale=softmax_scale,
|
| 570 |
+
causal=causal,
|
| 571 |
+
window_size=(self.config.sliding_window, self.config.sliding_window),
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
return attn_output
|
| 575 |
+
|
| 576 |
def forward(
|
| 577 |
self,
|
| 578 |
hidden_states: torch.Tensor,
|
|
|
|
| 656 |
key_states = key_states.to(target_dtype)
|
| 657 |
value_states = value_states.to(target_dtype)
|
| 658 |
|
| 659 |
+
|
| 660 |
+
attn_output = self._flash_attention_forward(
|
| 661 |
query_states,
|
| 662 |
key_states,
|
| 663 |
value_states,
|
| 664 |
attention_mask,
|
| 665 |
q_len,
|
| 666 |
+
dropout_rate,
|
| 667 |
+
None,
|
| 668 |
+
getattr(self, "sliding_window", None),
|
|
|
|
|
|
|
| 669 |
)
|
| 670 |
|
| 671 |
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
|
|
|
| 1807 |
tokenizer,
|
| 1808 |
query: List[str],
|
| 1809 |
htmls: List[List[str]],
|
| 1810 |
+
block_tree: List[Tuple],
|
| 1811 |
**kwargs):
|
| 1812 |
max_seq_length = kwargs.pop("max_seq_length", 131072)
|
| 1813 |
def apply_html_tree_template(query, htmls):
|
|
|
|
| 1843 |
soup.append(bs4.BeautifulSoup(html, 'html.parser'))
|
| 1844 |
|
| 1845 |
token_id_paths = []
|
| 1846 |
+
_block_tree = block_tree[idx]
|
| 1847 |
+
is_leaf = [p[2] for p in _block_tree]
|
| 1848 |
+
_block_tree = [p[1] for p in _block_tree]
|
| 1849 |
|
| 1850 |
+
for path in _block_tree:
|
| 1851 |
path_str = "<" + "><".join(path) + ">"
|
| 1852 |
token_ids = tokenizer.encode(path_str, add_special_tokens=False)
|
| 1853 |
token_id_paths.append(token_ids)
|
|
|
|
| 1905 |
|
| 1906 |
res_html_refs.append({
|
| 1907 |
"html": str(soup),
|
| 1908 |
+
"paths": _block_tree,
|
| 1909 |
"is_leaf": is_leaf,
|
| 1910 |
"path_token_ids": token_id_paths,
|
| 1911 |
"node_tree": list(TokenDotExporter(root, nodenamefunc=nodenamefunc))
|