v1
Browse files- app.py +2 -2
- trol/arch_internlm2/modeling_internlm2.py +5 -3
- trol/arch_phi3/modeling_phi3.py +5 -3
- trol/load_trol.py +7 -1
app.py
CHANGED
|
@@ -18,8 +18,8 @@ from transformers import TextIteratorStreamer
|
|
| 18 |
from torchvision.transforms.functional import pil_to_tensor
|
| 19 |
|
| 20 |
# flash attention
|
| 21 |
-
|
| 22 |
-
|
| 23 |
|
| 24 |
# accel
|
| 25 |
accel = Accelerator()
|
|
|
|
| 18 |
from torchvision.transforms.functional import pil_to_tensor
|
| 19 |
|
| 20 |
# flash attention
|
| 21 |
+
import subprocess
|
| 22 |
+
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
| 23 |
|
| 24 |
# accel
|
| 25 |
accel = Accelerator()
|
trol/arch_internlm2/modeling_internlm2.py
CHANGED
|
@@ -867,13 +867,15 @@ class InternLM2Model(InternLM2PreTrainedModel):
|
|
| 867 |
self.norm = InternLM2RMSNorm(
|
| 868 |
config.hidden_size, eps=config.rms_norm_eps)
|
| 869 |
|
| 870 |
-
self.trol_gating = nn.ModuleList([nn.Linear(self.config.hidden_size, 1)]*self.config.num_hidden_layers)
|
| 871 |
-
self.trol_function = lambda x, idx: 0.5*F.tanh(self.trol_gating[idx](x))+0.5
|
| 872 |
-
|
| 873 |
self.gradient_checkpointing = False
|
| 874 |
# Initialize weights and apply final processing
|
| 875 |
self.post_init()
|
| 876 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 877 |
def get_input_embeddings(self):
|
| 878 |
return self.tok_embeddings
|
| 879 |
|
|
|
|
| 867 |
self.norm = InternLM2RMSNorm(
|
| 868 |
config.hidden_size, eps=config.rms_norm_eps)
|
| 869 |
|
|
|
|
|
|
|
|
|
|
| 870 |
self.gradient_checkpointing = False
|
| 871 |
# Initialize weights and apply final processing
|
| 872 |
self.post_init()
|
| 873 |
|
| 874 |
+
def initialize_trol_gating(self):
|
| 875 |
+
self.trol_gating = nn.ModuleList([nn.Linear(self.config.hidden_size, 1).cuda()]*self.config.num_hidden_layers)
|
| 876 |
+
self.trol_function = lambda x, idx: 0.5*F.tanh(self.trol_gating[idx](x))+0.5
|
| 877 |
+
|
| 878 |
+
|
| 879 |
def get_input_embeddings(self):
|
| 880 |
return self.tok_embeddings
|
| 881 |
|
trol/arch_phi3/modeling_phi3.py
CHANGED
|
@@ -1031,13 +1031,15 @@ class Phi3Model(Phi3PreTrainedModel):
|
|
| 1031 |
self._attn_implementation = "flash_attention_2"
|
| 1032 |
self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1033 |
|
| 1034 |
-
self.trol_gating = nn.ModuleList([nn.Linear(self.config.hidden_size, 1)]*self.config.num_hidden_layers)
|
| 1035 |
-
self.trol_function = lambda x, idx: 0.5*F.tanh(self.trol_gating[idx](x))+0.5
|
| 1036 |
-
|
| 1037 |
self.gradient_checkpointing = False
|
| 1038 |
# Initialize weights and apply final processing
|
| 1039 |
self.post_init()
|
| 1040 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1041 |
def get_input_embeddings(self):
|
| 1042 |
return self.embed_tokens
|
| 1043 |
|
|
|
|
| 1031 |
self._attn_implementation = "flash_attention_2"
|
| 1032 |
self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1033 |
|
|
|
|
|
|
|
|
|
|
| 1034 |
self.gradient_checkpointing = False
|
| 1035 |
# Initialize weights and apply final processing
|
| 1036 |
self.post_init()
|
| 1037 |
|
| 1038 |
+
def initialize_trol_gating(self):
|
| 1039 |
+
self.trol_gating = nn.ModuleList([nn.Linear(self.config.hidden_size, 1).cuda()]*self.config.num_hidden_layers)
|
| 1040 |
+
self.trol_function = lambda x, idx: 0.5*F.tanh(self.trol_gating[idx](x))+0.5
|
| 1041 |
+
|
| 1042 |
+
|
| 1043 |
def get_input_embeddings(self):
|
| 1044 |
return self.embed_tokens
|
| 1045 |
|
trol/load_trol.py
CHANGED
|
@@ -81,11 +81,17 @@ def load_trol(link):
|
|
| 81 |
# setting config
|
| 82 |
setting_trol_config(trol, tok_trol, image_special_token)
|
| 83 |
|
| 84 |
-
|
| 85 |
# trol gating load
|
| 86 |
from huggingface_hub import hf_hub_download
|
| 87 |
try:
|
|
|
|
| 88 |
trol.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
|
| 89 |
except:
|
|
|
|
| 90 |
trol.language_model.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
return trol, tok_trol
|
|
|
|
| 81 |
# setting config
|
| 82 |
setting_trol_config(trol, tok_trol, image_special_token)
|
| 83 |
|
|
|
|
| 84 |
# trol gating load
|
| 85 |
from huggingface_hub import hf_hub_download
|
| 86 |
try:
|
| 87 |
+
trol.model.initialize_trol_gating()
|
| 88 |
trol.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
|
| 89 |
except:
|
| 90 |
+
trol.language_model.model.initialize_trol_gating()
|
| 91 |
trol.language_model.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
|
| 92 |
+
|
| 93 |
+
# X -> float16 conversion
|
| 94 |
+
for param in trol.parameters():
|
| 95 |
+
if 'float32' in str(param.dtype).lower() or 'float16' in str(param.dtype).lower():
|
| 96 |
+
param.data = param.data.to(torch.float16)
|
| 97 |
return trol, tok_trol
|