Commit
·
d15887b
1
Parent(s):
88b58c6
Create custom.py
Browse files
custom.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This lobe enables the integration of huggingface pretrained Llama2 Model model plus the expanding embedding layer for additional PAD tokens .
|
| 2 |
+
|
| 3 |
+
Transformer from HuggingFace needs to be installed:
|
| 4 |
+
https://huggingface.co/transformers/installation.html
|
| 5 |
+
|
| 6 |
+
Authors
|
| 7 |
+
* Pooneh Mousavi 2023
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from speechbrain.lobes.models.huggingface_transformers.llama2 import LLAMA2
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class LLAMA2_expanded(LLAMA2):
|
| 21 |
+
"""This lobe enables the integration of HuggingFace pretrained LLAMA2 model.
|
| 22 |
+
Source paper LLAMA2:
|
| 23 |
+
https://arxiv.org/abs/2307.09288
|
| 24 |
+
Transformer from HuggingFace needs to be installed:
|
| 25 |
+
https://huggingface.co/transformers/installation.html
|
| 26 |
+
|
| 27 |
+
The model can be finetuned. It will download automatically the model from
|
| 28 |
+
HuggingFace or use a local path.
|
| 29 |
+
|
| 30 |
+
Arguments
|
| 31 |
+
---------
|
| 32 |
+
source : str
|
| 33 |
+
HuggingFace hub name: e.g "meta-llama/Llama-2-7b-chat-hf"
|
| 34 |
+
save_path : str
|
| 35 |
+
Path (dir) of the downloaded model.
|
| 36 |
+
freeze : bool (default: False)
|
| 37 |
+
If True, the model is frozen. If False, the model will be trained
|
| 38 |
+
alongside with the rest of the pipeline.
|
| 39 |
+
Example
|
| 40 |
+
-------
|
| 41 |
+
>>> model_hub = "meta-llama/Llama-2-7b-chat-hf"
|
| 42 |
+
>>> save_path = "savedir"
|
| 43 |
+
>>> model = LLAMA2(model_hub, save_path)
|
| 44 |
+
>>> tokens = torch.tensor([[1, 1]])
|
| 45 |
+
>>> attention_mask = torch.tensor([[1, 1]])
|
| 46 |
+
>>> outputs = model(tokens, attention_mask)
|
| 47 |
+
"""
|
| 48 |
+
def __init__(
|
| 49 |
+
self, *args, **kwrds
|
| 50 |
+
) -> None:
|
| 51 |
+
super().__init__( *args, **kwrds)
|
| 52 |
+
# Load tokenizer and add special tokens
|
| 53 |
+
# # Add special tokens to the tokenizer and resize model embedding
|
| 54 |
+
# Special tokens
|
| 55 |
+
|
| 56 |
+
self.add_special_tokens_(
|
| 57 |
+
{"pad_token": "<pad>"}
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
def add_special_tokens_(self, attr_to_special_token,) -> None:
|
| 61 |
+
orig_num_tokens = len(self.tokenizer)
|
| 62 |
+
num_added_tokens = self.tokenizer.add_special_tokens(
|
| 63 |
+
attr_to_special_token # type: ignore
|
| 64 |
+
) # doesn't add if they are already there
|
| 65 |
+
if num_added_tokens > 0:
|
| 66 |
+
self.model.resize_token_embeddings(
|
| 67 |
+
new_num_tokens=orig_num_tokens + num_added_tokens
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|