v1
Browse files- app.py +9 -9
- trol/arch_internlm2/modeling_internlm2.py +1 -1
- trol/arch_internlm2/modeling_trol.py +1 -1
- trol/load_trol.py +21 -7
app.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
# A100 Zero GPU
|
| 2 |
-
|
| 3 |
|
| 4 |
# TroL Package
|
| 5 |
import torch
|
|
@@ -33,10 +33,10 @@ question="What is the troll doing? Provide the detail in the image and imagine w
|
|
| 33 |
model_1_8, tokenizer_1_8 = load_trol(link='TroL-1.8B')
|
| 34 |
|
| 35 |
# loading model
|
| 36 |
-
|
| 37 |
|
| 38 |
# loading model
|
| 39 |
-
|
| 40 |
|
| 41 |
def threading_function(inputs, image_token_number, streamer, device, model, tokenizer, temperature, new_max_token, top_p):
|
| 42 |
|
|
@@ -55,7 +55,7 @@ def threading_function(inputs, image_token_number, streamer, device, model, toke
|
|
| 55 |
generation_kwargs.update({'use_cache': True})
|
| 56 |
return model.generate(**generation_kwargs)
|
| 57 |
|
| 58 |
-
|
| 59 |
def bot_streaming(message, history, link, temperature, new_max_token, top_p):
|
| 60 |
|
| 61 |
# model selection
|
|
@@ -70,9 +70,9 @@ def bot_streaming(message, history, link, temperature, new_max_token, top_p):
|
|
| 70 |
tokenizer = tokenizer_7
|
| 71 |
|
| 72 |
# cpu -> gpu
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
|
| 77 |
# prompt type -> input prompt
|
| 78 |
image_token_number = None
|
|
@@ -131,11 +131,11 @@ def bot_streaming(message, history, link, temperature, new_max_token, top_p):
|
|
| 131 |
buffer = ""
|
| 132 |
for character in response:
|
| 133 |
buffer += character
|
| 134 |
-
time.sleep(0.
|
| 135 |
yield buffer
|
| 136 |
|
| 137 |
demo = gr.ChatInterface(fn=bot_streaming,
|
| 138 |
-
additional_inputs = [gr.Radio(["1.8B"], label="Size", info="Select one model size", value="
|
| 139 |
additional_inputs_accordion="Generation Hyperparameters",
|
| 140 |
theme=gr.themes.Soft(),
|
| 141 |
title="TroL",
|
|
|
|
| 1 |
# A100 Zero GPU
|
| 2 |
+
import spaces
|
| 3 |
|
| 4 |
# TroL Package
|
| 5 |
import torch
|
|
|
|
| 33 |
model_1_8, tokenizer_1_8 = load_trol(link='TroL-1.8B')
|
| 34 |
|
| 35 |
# loading model
|
| 36 |
+
model_3_8, tokenizer_3_8 = load_trol(link='TroL-3.8B')
|
| 37 |
|
| 38 |
# loading model
|
| 39 |
+
model_7, tokenizer_7 = load_trol(link='TroL-7B')
|
| 40 |
|
| 41 |
def threading_function(inputs, image_token_number, streamer, device, model, tokenizer, temperature, new_max_token, top_p):
|
| 42 |
|
|
|
|
| 55 |
generation_kwargs.update({'use_cache': True})
|
| 56 |
return model.generate(**generation_kwargs)
|
| 57 |
|
| 58 |
+
@spaces.GPU
|
| 59 |
def bot_streaming(message, history, link, temperature, new_max_token, top_p):
|
| 60 |
|
| 61 |
# model selection
|
|
|
|
| 70 |
tokenizer = tokenizer_7
|
| 71 |
|
| 72 |
# cpu -> gpu
|
| 73 |
+
for param in model.parameters():
|
| 74 |
+
if not param.is_cuda:
|
| 75 |
+
param.data = param.to(accel.device)
|
| 76 |
|
| 77 |
# prompt type -> input prompt
|
| 78 |
image_token_number = None
|
|
|
|
| 131 |
buffer = ""
|
| 132 |
for character in response:
|
| 133 |
buffer += character
|
| 134 |
+
time.sleep(0.012)
|
| 135 |
yield buffer
|
| 136 |
|
| 137 |
demo = gr.ChatInterface(fn=bot_streaming,
|
| 138 |
+
additional_inputs = [gr.Radio(["1.8B", "3.8B", "7B"], label="Size", info="Select one model size", value="7B"), gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")],
|
| 139 |
additional_inputs_accordion="Generation Hyperparameters",
|
| 140 |
theme=gr.themes.Soft(),
|
| 141 |
title="TroL",
|
trol/arch_internlm2/modeling_internlm2.py
CHANGED
|
@@ -857,7 +857,7 @@ class InternLM2Model(InternLM2PreTrainedModel):
|
|
| 857 |
self.vocab_size = config.vocab_size
|
| 858 |
self.config = config
|
| 859 |
|
| 860 |
-
self.tok_embeddings = nn.Embedding(config.vocab_size,
|
| 861 |
config.hidden_size,
|
| 862 |
self.padding_idx)
|
| 863 |
self.layers = nn.ModuleList([
|
|
|
|
| 857 |
self.vocab_size = config.vocab_size
|
| 858 |
self.config = config
|
| 859 |
|
| 860 |
+
self.tok_embeddings = nn.Embedding(config.vocab_size+1,
|
| 861 |
config.hidden_size,
|
| 862 |
self.padding_idx)
|
| 863 |
self.layers = nn.ModuleList([
|
trol/arch_internlm2/modeling_trol.py
CHANGED
|
@@ -30,7 +30,7 @@ class TroLForCausalLM(InternLM2PreTrainedModel):
|
|
| 30 |
# Model
|
| 31 |
self.model = InternLM2Model(config)
|
| 32 |
self.vocab_size = config.vocab_size
|
| 33 |
-
self.output = nn.Linear(config.hidden_size, config.vocab_size
|
| 34 |
self.max_length = config.max_length
|
| 35 |
|
| 36 |
# Initialize weights and apply final processing
|
|
|
|
| 30 |
# Model
|
| 31 |
self.model = InternLM2Model(config)
|
| 32 |
self.vocab_size = config.vocab_size
|
| 33 |
+
self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 34 |
self.max_length = config.max_length
|
| 35 |
|
| 36 |
# Initialize weights and apply final processing
|
trol/load_trol.py
CHANGED
|
@@ -1,11 +1,16 @@
|
|
| 1 |
import torch
|
| 2 |
import warnings
|
| 3 |
from config import *
|
| 4 |
-
from peft import LoraConfig
|
| 5 |
from transformers import BitsAndBytesConfig
|
| 6 |
|
| 7 |
warnings.filterwarnings(action='ignore')
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
def load_trol(link):
|
| 10 |
|
| 11 |
"""
|
|
@@ -16,21 +21,24 @@ def load_trol(link):
|
|
| 16 |
from .arch_internlm2.tokenization_internlm2 import InternLM2Tokenizer as TroLTokenizer
|
| 17 |
bits = 4
|
| 18 |
path = TROL_1_8B
|
| 19 |
-
|
|
|
|
| 20 |
|
| 21 |
elif link == 'TroL-3.8B':
|
| 22 |
from trol.arch_phi3.modeling_trol import TroLForCausalLM
|
| 23 |
from transformers import LlamaTokenizerFast as TroLTokenizer
|
| 24 |
bits = 8
|
| 25 |
path = TROL_3_8B
|
| 26 |
-
|
|
|
|
| 27 |
|
| 28 |
elif link == 'TroL-7B':
|
| 29 |
from .arch_internlm2.modeling_trol import TroLForCausalLM
|
| 30 |
from .arch_internlm2.tokenization_internlm2 import InternLM2Tokenizer as TroLTokenizer
|
| 31 |
bits = 4
|
| 32 |
path = TROL_7B
|
| 33 |
-
|
|
|
|
| 34 |
else:
|
| 35 |
raise Exception("Unsupported Link")
|
| 36 |
|
|
@@ -68,10 +76,16 @@ def load_trol(link):
|
|
| 68 |
except:
|
| 69 |
del huggingface_config["attn_implementation"]
|
| 70 |
trol = TroLForCausalLM.from_pretrained(path, **huggingface_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
-
#
|
|
|
|
| 73 |
try:
|
| 74 |
-
trol =
|
| 75 |
except:
|
| 76 |
-
|
| 77 |
return trol, tok_trol
|
|
|
|
| 1 |
import torch
|
| 2 |
import warnings
|
| 3 |
from config import *
|
|
|
|
| 4 |
from transformers import BitsAndBytesConfig
|
| 5 |
|
| 6 |
warnings.filterwarnings(action='ignore')
|
| 7 |
|
| 8 |
+
def setting_trol_config(trol, tok_trol, image_special_token):
|
| 9 |
+
trol.config.image_token_index = tok_trol.convert_tokens_to_ids(image_special_token)
|
| 10 |
+
trol.config.ignore_index = -100
|
| 11 |
+
trol.config.pad_token_id = tok_trol.eos_token_id
|
| 12 |
+
trol.config.eos_token_id = tok_trol.eos_token_id
|
| 13 |
+
|
| 14 |
def load_trol(link):
|
| 15 |
|
| 16 |
"""
|
|
|
|
| 21 |
from .arch_internlm2.tokenization_internlm2 import InternLM2Tokenizer as TroLTokenizer
|
| 22 |
bits = 4
|
| 23 |
path = TROL_1_8B
|
| 24 |
+
image_special_token = "<image>"
|
| 25 |
+
bit_quant_skip = ["vit", "vision_proj", "ffn", "output", "trol_gating"]
|
| 26 |
|
| 27 |
elif link == 'TroL-3.8B':
|
| 28 |
from trol.arch_phi3.modeling_trol import TroLForCausalLM
|
| 29 |
from transformers import LlamaTokenizerFast as TroLTokenizer
|
| 30 |
bits = 8
|
| 31 |
path = TROL_3_8B
|
| 32 |
+
image_special_token = "<IMG_CONTEXT>"
|
| 33 |
+
bit_quant_skip = ["vision_model", "vision_proj", "lm_head", "trol_gating"]
|
| 34 |
|
| 35 |
elif link == 'TroL-7B':
|
| 36 |
from .arch_internlm2.modeling_trol import TroLForCausalLM
|
| 37 |
from .arch_internlm2.tokenization_internlm2 import InternLM2Tokenizer as TroLTokenizer
|
| 38 |
bits = 4
|
| 39 |
path = TROL_7B
|
| 40 |
+
image_special_token = "<image>"
|
| 41 |
+
bit_quant_skip = ["vit", "vision_proj", "ffn", "output", "trol_gating"]
|
| 42 |
else:
|
| 43 |
raise Exception("Unsupported Link")
|
| 44 |
|
|
|
|
| 76 |
except:
|
| 77 |
del huggingface_config["attn_implementation"]
|
| 78 |
trol = TroLForCausalLM.from_pretrained(path, **huggingface_config)
|
| 79 |
+
trol.config.llm_config.use_cache = False
|
| 80 |
+
|
| 81 |
+
# setting config
|
| 82 |
+
setting_trol_config(trol, tok_trol, image_special_token)
|
| 83 |
+
|
| 84 |
|
| 85 |
+
# trol gating load
|
| 86 |
+
from huggingface_hub import hf_hub_download
|
| 87 |
try:
|
| 88 |
+
trol.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
|
| 89 |
except:
|
| 90 |
+
trol.language_model.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
|
| 91 |
return trol, tok_trol
|