Spaces:
Running
Running
Better use of memory; limit window size and number of tokens
Browse files
app.py
CHANGED
|
@@ -54,11 +54,25 @@ if not compact_layout:
|
|
| 54 |
|
| 55 |
model_name = st.selectbox("Model", ["distilgpt2", "gpt2", "EleutherAI/gpt-neo-125m"])
|
| 56 |
metric_name = st.selectbox("Metric", ["KL divergence", "Cross entropy"], index=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
window_len = st.select_slider(
|
| 58 |
r"Window size ($c_\text{max}$)",
|
| 59 |
-
options=
|
| 60 |
-
value=
|
| 61 |
)
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
DEFAULT_TEXT = """
|
| 64 |
We present context length probing, a novel explanation technique for causal
|
|
@@ -71,31 +85,38 @@ dependencies.
|
|
| 71 |
""".replace("\n", " ").strip()
|
| 72 |
|
| 73 |
text = st.text_area(
|
| 74 |
-
"Input text",
|
| 75 |
DEFAULT_TEXT,
|
| 76 |
)
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
if metric_name == "KL divergence":
|
| 79 |
st.error("KL divergence is not supported yet. Stay tuned!", icon="😭")
|
| 80 |
st.stop()
|
| 81 |
|
| 82 |
with st.spinner("Loading model…"):
|
| 83 |
-
tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name)
|
| 84 |
model = st.cache_resource(AutoModelForCausalLM.from_pretrained, show_spinner=False)(model_name)
|
| 85 |
|
| 86 |
-
inputs = tokenizer([text])
|
| 87 |
-
[input_ids] = inputs["input_ids"]
|
| 88 |
window_len = min(window_len, len(input_ids))
|
| 89 |
|
| 90 |
-
if len(input_ids) < 2:
|
| 91 |
-
st.error("Please enter at least 2 tokens.", icon="🚨")
|
| 92 |
-
st.stop()
|
| 93 |
-
|
| 94 |
@st.cache_data(show_spinner=False)
|
| 95 |
@torch.inference_mode()
|
| 96 |
-
def
|
| 97 |
del cache_key
|
| 98 |
-
return _model(**_inputs).logits.to(torch.float16)
|
| 99 |
|
| 100 |
@st.cache_data(show_spinner=False)
|
| 101 |
@torch.inference_mode()
|
|
@@ -108,7 +129,7 @@ def run_context_length_probing(_model, _tokenizer, _inputs, window_len, cache_ke
|
|
| 108 |
pad_id=_tokenizer.eos_token_id
|
| 109 |
).convert_to_tensors("pt")
|
| 110 |
|
| 111 |
-
|
| 112 |
with st.spinner("Running model…"):
|
| 113 |
batch_size = 8
|
| 114 |
num_items = len(inputs_sliding["input_ids"])
|
|
@@ -116,27 +137,26 @@ def run_context_length_probing(_model, _tokenizer, _inputs, window_len, cache_ke
|
|
| 116 |
for i in range(0, num_items, batch_size):
|
| 117 |
pbar.progress(i / num_items, f"{i}/{num_items}")
|
| 118 |
batch = {k: v[i:i + batch_size] for k, v in inputs_sliding.items()}
|
| 119 |
-
|
| 120 |
-
|
| 121 |
_model,
|
| 122 |
batch,
|
| 123 |
cache_key=(model_name, batch["input_ids"].cpu().numpy().tobytes())
|
| 124 |
)
|
| 125 |
)
|
| 126 |
-
|
| 127 |
pbar.empty()
|
| 128 |
|
| 129 |
with st.spinner("Computing scores…"):
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
|
| 135 |
-
scores =
|
| 136 |
-
scores = scores[:, torch.arange(len(input_ids[1:])), input_ids[1:]]
|
| 137 |
scores = scores.diff(dim=0).transpose(0, 1)
|
| 138 |
scores = scores.nan_to_num()
|
| 139 |
-
scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-
|
| 140 |
scores = scores.to(torch.float16)
|
| 141 |
|
| 142 |
return scores
|
|
|
|
| 54 |
|
| 55 |
model_name = st.selectbox("Model", ["distilgpt2", "gpt2", "EleutherAI/gpt-neo-125m"])
|
| 56 |
metric_name = st.selectbox("Metric", ["KL divergence", "Cross entropy"], index=1)
|
| 57 |
+
|
| 58 |
+
tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name, use_fast=False)
|
| 59 |
+
|
| 60 |
+
# Make sure the logprobs do not use up more than ~6 GB of memory
|
| 61 |
+
MAX_MEM = 6e9 / (torch.finfo(torch.float16).bits / 8)
|
| 62 |
+
# Select window lengths such that we are allowed to fill the whole window without running out of memory
|
| 63 |
+
# (otherwise the window length is irrelevant)
|
| 64 |
+
window_len_options = [
|
| 65 |
+
w for w in [8, 16, 32, 64, 128, 256, 512, 1024]
|
| 66 |
+
if w == 8 or w * (2 * w) * tokenizer.vocab_size <= MAX_MEM
|
| 67 |
+
]
|
| 68 |
window_len = st.select_slider(
|
| 69 |
r"Window size ($c_\text{max}$)",
|
| 70 |
+
options=window_len_options,
|
| 71 |
+
value=min(128, window_len_options[-1])
|
| 72 |
)
|
| 73 |
+
# Now figure out how many tokens we are allowed to use:
|
| 74 |
+
# window_len * (num_tokens + window_len) * vocab_size <= MAX_MEM
|
| 75 |
+
max_tokens = int(MAX_MEM / (tokenizer.vocab_size * window_len) - window_len)
|
| 76 |
|
| 77 |
DEFAULT_TEXT = """
|
| 78 |
We present context length probing, a novel explanation technique for causal
|
|
|
|
| 85 |
""".replace("\n", " ").strip()
|
| 86 |
|
| 87 |
text = st.text_area(
|
| 88 |
+
f"Input text (≤{max_tokens} tokens)",
|
| 89 |
DEFAULT_TEXT,
|
| 90 |
)
|
| 91 |
|
| 92 |
+
inputs = tokenizer([text])
|
| 93 |
+
[input_ids] = inputs["input_ids"]
|
| 94 |
+
|
| 95 |
+
if len(input_ids) < 2:
|
| 96 |
+
st.error("Please enter at least 2 tokens.", icon="🚨")
|
| 97 |
+
st.stop()
|
| 98 |
+
if len(input_ids) > max_tokens:
|
| 99 |
+
st.error(
|
| 100 |
+
f"Your input has {len(input_ids)} tokens. Please enter at most {max_tokens} tokens "
|
| 101 |
+
f"or try reducing the window size.",
|
| 102 |
+
icon="🚨"
|
| 103 |
+
)
|
| 104 |
+
st.stop()
|
| 105 |
+
|
| 106 |
if metric_name == "KL divergence":
|
| 107 |
st.error("KL divergence is not supported yet. Stay tuned!", icon="😭")
|
| 108 |
st.stop()
|
| 109 |
|
| 110 |
with st.spinner("Loading model…"):
|
|
|
|
| 111 |
model = st.cache_resource(AutoModelForCausalLM.from_pretrained, show_spinner=False)(model_name)
|
| 112 |
|
|
|
|
|
|
|
| 113 |
window_len = min(window_len, len(input_ids))
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
@st.cache_data(show_spinner=False)
|
| 116 |
@torch.inference_mode()
|
| 117 |
+
def get_logprobs(_model, _inputs, cache_key):
|
| 118 |
del cache_key
|
| 119 |
+
return _model(**_inputs).logits.log_softmax(dim=-1).to(torch.float16)
|
| 120 |
|
| 121 |
@st.cache_data(show_spinner=False)
|
| 122 |
@torch.inference_mode()
|
|
|
|
| 129 |
pad_id=_tokenizer.eos_token_id
|
| 130 |
).convert_to_tensors("pt")
|
| 131 |
|
| 132 |
+
logprobs = []
|
| 133 |
with st.spinner("Running model…"):
|
| 134 |
batch_size = 8
|
| 135 |
num_items = len(inputs_sliding["input_ids"])
|
|
|
|
| 137 |
for i in range(0, num_items, batch_size):
|
| 138 |
pbar.progress(i / num_items, f"{i}/{num_items}")
|
| 139 |
batch = {k: v[i:i + batch_size] for k, v in inputs_sliding.items()}
|
| 140 |
+
logprobs.append(
|
| 141 |
+
get_logprobs(
|
| 142 |
_model,
|
| 143 |
batch,
|
| 144 |
cache_key=(model_name, batch["input_ids"].cpu().numpy().tobytes())
|
| 145 |
)
|
| 146 |
)
|
| 147 |
+
logprobs = torch.cat(logprobs, dim=0)
|
| 148 |
pbar.empty()
|
| 149 |
|
| 150 |
with st.spinner("Computing scores…"):
|
| 151 |
+
logprobs = logprobs.permute(1, 0, 2)
|
| 152 |
+
logprobs = F.pad(logprobs, (0, 0, 0, window_len, 0, 0), value=torch.nan)
|
| 153 |
+
logprobs = logprobs.view(-1, logprobs.shape[-1])[:-window_len]
|
| 154 |
+
logprobs = logprobs.view(window_len, len(input_ids) + window_len - 2, logprobs.shape[-1])
|
| 155 |
|
| 156 |
+
scores = logprobs[:, torch.arange(len(input_ids[1:])), input_ids[1:]]
|
|
|
|
| 157 |
scores = scores.diff(dim=0).transpose(0, 1)
|
| 158 |
scores = scores.nan_to_num()
|
| 159 |
+
scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-6
|
| 160 |
scores = scores.to(torch.float16)
|
| 161 |
|
| 162 |
return scores
|