Update modeling_zhinao.py
#2
by
neofung
- opened
- modeling_zhinao.py +12 -13
modeling_zhinao.py
CHANGED
|
@@ -748,6 +748,17 @@ class ZhinaoForCausalLM(ZhinaoPreTrainedModel):
|
|
| 748 |
|
| 749 |
def __init__(self, config):
|
| 750 |
super().__init__(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 751 |
self.model = ZhinaoModel(config)
|
| 752 |
self.vocab_size = config.vocab_size
|
| 753 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
@@ -761,19 +772,7 @@ class ZhinaoForCausalLM(ZhinaoPreTrainedModel):
|
|
| 761 |
if config.fp16:
|
| 762 |
self.model.half()
|
| 763 |
self.lm_head.half()
|
| 764 |
-
self.linear.half()
|
| 765 |
-
|
| 766 |
-
if config.use_flash_attn == "auto":
|
| 767 |
-
if flash_attn_varlen_func:
|
| 768 |
-
if config.bf16 or config.fp16:
|
| 769 |
-
logger.warn("Try importing flash-attention.")
|
| 770 |
-
config.use_flash_attn = True
|
| 771 |
-
else:
|
| 772 |
-
config.use_flash_attn = False
|
| 773 |
-
logger.warn("Flash attention will be disabled because it does NOT support fp32.")
|
| 774 |
-
else:
|
| 775 |
-
config.use_flash_attn = False
|
| 776 |
-
logger.warn("Please install FlashAttention first, " "e.g., with pip install flash-attn")
|
| 777 |
|
| 778 |
self.post_init()
|
| 779 |
|
|
|
|
| 748 |
|
| 749 |
def __init__(self, config):
|
| 750 |
super().__init__(config)
|
| 751 |
+
if config.use_flash_attn == "auto":
|
| 752 |
+
if flash_attn_varlen_func:
|
| 753 |
+
if config.bf16 or config.fp16:
|
| 754 |
+
logger.warn("Try importing flash-attention.")
|
| 755 |
+
config.use_flash_attn = True
|
| 756 |
+
else:
|
| 757 |
+
config.use_flash_attn = False
|
| 758 |
+
logger.warn("Flash attention will be disabled because it does NOT support fp32.")
|
| 759 |
+
else:
|
| 760 |
+
config.use_flash_attn = False
|
| 761 |
+
logger.warn("Please install FlashAttention first, " "e.g., with pip install flash-attn")
|
| 762 |
self.model = ZhinaoModel(config)
|
| 763 |
self.vocab_size = config.vocab_size
|
| 764 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
| 772 |
if config.fp16:
|
| 773 |
self.model.half()
|
| 774 |
self.lm_head.half()
|
| 775 |
+
self.linear.half()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 776 |
|
| 777 |
self.post_init()
|
| 778 |
|