x54-729 commited on
Commit ·
daa886c
1
Parent(s): 7bee5c5
update modeling file to newest
Browse files- modeling_internlm2.py +11 -3
modeling_internlm2.py
CHANGED
|
@@ -13,7 +13,7 @@
|
|
| 13 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
# See the License for the specific language governing permissions and
|
| 15 |
# limitations under the License.
|
| 16 |
-
"""PyTorch InternLM2
|
| 17 |
import math
|
| 18 |
import queue
|
| 19 |
import threading
|
|
@@ -59,6 +59,10 @@ try:
|
|
| 59 |
except:
|
| 60 |
pass
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
logger = logging.get_logger(__name__)
|
| 64 |
|
|
@@ -1093,7 +1097,11 @@ class InternLM2Model(InternLM2PreTrainedModel):
|
|
| 1093 |
else:
|
| 1094 |
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
| 1095 |
if sequence_length != 1:
|
| 1096 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1097 |
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
| 1098 |
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
| 1099 |
if attention_mask is not None:
|
|
@@ -1797,4 +1805,4 @@ class InternLM2ForTokenClassification(InternLM2PreTrainedModel):
|
|
| 1797 |
logits=logits,
|
| 1798 |
hidden_states=outputs.hidden_states,
|
| 1799 |
attentions=outputs.attentions,
|
| 1800 |
-
)
|
|
|
|
| 13 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
# See the License for the specific language governing permissions and
|
| 15 |
# limitations under the License.
|
| 16 |
+
"""PyTorch InternLM2 model."""
|
| 17 |
import math
|
| 18 |
import queue
|
| 19 |
import threading
|
|
|
|
| 59 |
except:
|
| 60 |
pass
|
| 61 |
|
| 62 |
+
try:
|
| 63 |
+
support_bf16_triu = torch.__version__ >= "2.1.0"
|
| 64 |
+
except Exception:
|
| 65 |
+
support_bf16_triu = False
|
| 66 |
|
| 67 |
logger = logging.get_logger(__name__)
|
| 68 |
|
|
|
|
| 1097 |
else:
|
| 1098 |
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
| 1099 |
if sequence_length != 1:
|
| 1100 |
+
if support_bf16_triu or dtype == torch.float32:
|
| 1101 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 1102 |
+
else:
|
| 1103 |
+
triu_mask = torch.triu(torch.ones(causal_mask.size(), device=device), diagonal=1).bool()
|
| 1104 |
+
causal_mask.masked_fill_(~triu_mask, 0)
|
| 1105 |
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
| 1106 |
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
| 1107 |
if attention_mask is not None:
|
|
|
|
| 1805 |
logits=logits,
|
| 1806 |
hidden_states=outputs.hidden_states,
|
| 1807 |
attentions=outputs.attentions,
|
| 1808 |
+
)
|