Spaces:
Running
Running
Use caching when possible
Browse files
app.py
CHANGED
|
@@ -90,7 +90,7 @@ generation_mode = st.radio(
|
|
| 90 |
horizontal=True, label_visibility="collapsed"
|
| 91 |
) == "Generation mode"
|
| 92 |
st.caption(
|
| 93 |
-
"In standard mode, we analyze the model's predictions on the input text. "
|
| 94 |
"In generation mode, we generate a continuation of the input text (prompt) "
|
| 95 |
"and visualize the contributions of different contexts to each generated token."
|
| 96 |
)
|
|
@@ -128,7 +128,7 @@ with st.empty():
|
|
| 128 |
with st.expander("Generation options", expanded=False):
|
| 129 |
generate_kwargs["max_new_tokens"] = st.slider(
|
| 130 |
"Max. number of generated tokens",
|
| 131 |
-
min_value=8, max_value=min(1024, max_tokens), value=min(128, max_tokens)
|
| 132 |
)
|
| 133 |
col1, col2, col3, col4 = st.columns(4)
|
| 134 |
with col1:
|
|
@@ -222,8 +222,7 @@ def get_logits_processor(temperature, top_p, typical_p, repetition_penalty) -> L
|
|
| 222 |
def generate(model, inputs, metric, window_len, max_new_tokens, **kwargs):
|
| 223 |
assert metric == "NLL loss"
|
| 224 |
start = max(0, inputs["input_ids"].shape[1] - window_len + 1)
|
| 225 |
-
|
| 226 |
-
del inputs_window["labels"]
|
| 227 |
|
| 228 |
logits_warper = get_logits_processor(**kwargs)
|
| 229 |
|
|
@@ -231,13 +230,16 @@ def generate(model, inputs, metric, window_len, max_new_tokens, **kwargs):
|
|
| 231 |
eos_idx = None
|
| 232 |
pbar = st.progress(0)
|
| 233 |
max_steps = max_new_tokens + window_len - 1
|
|
|
|
| 234 |
for i in range(max_steps):
|
| 235 |
pbar.progress(i / max_steps, f"{i}/{max_steps}")
|
| 236 |
-
|
| 237 |
-
|
|
|
|
|
|
|
| 238 |
logprobs_window = logits_window.log_softmax(dim=-1)
|
| 239 |
if eos_idx is None:
|
| 240 |
-
probs_next = logits_warper(
|
| 241 |
next_token = torch.multinomial(probs_next, num_samples=1).item()
|
| 242 |
if next_token == tokenizer.eos_token_id or i >= max_new_tokens - 1:
|
| 243 |
eos_idx = i
|
|
@@ -245,12 +247,13 @@ def generate(model, inputs, metric, window_len, max_new_tokens, **kwargs):
|
|
| 245 |
next_token = tokenizer.eos_token_id
|
| 246 |
new_ids.append(next_token)
|
| 247 |
|
| 248 |
-
|
| 249 |
-
if
|
| 250 |
-
|
|
|
|
| 251 |
if logprobs_window.shape[0] == window_len:
|
| 252 |
logprobs.append(
|
| 253 |
-
logprobs_window[torch.arange(window_len),
|
| 254 |
)
|
| 255 |
|
| 256 |
if eos_idx is not None and i - eos_idx >= window_len - 1:
|
|
|
|
| 90 |
horizontal=True, label_visibility="collapsed"
|
| 91 |
) == "Generation mode"
|
| 92 |
st.caption(
|
| 93 |
+
"In standard mode, we analyze the model's one-step-ahead predictions on the input text. "
|
| 94 |
"In generation mode, we generate a continuation of the input text (prompt) "
|
| 95 |
"and visualize the contributions of different contexts to each generated token."
|
| 96 |
)
|
|
|
|
| 128 |
with st.expander("Generation options", expanded=False):
|
| 129 |
generate_kwargs["max_new_tokens"] = st.slider(
|
| 130 |
"Max. number of generated tokens",
|
| 131 |
+
min_value=8, max_value=min(1024, max_tokens), step=8, value=min(128, max_tokens)
|
| 132 |
)
|
| 133 |
col1, col2, col3, col4 = st.columns(4)
|
| 134 |
with col1:
|
|
|
|
| 222 |
def generate(model, inputs, metric, window_len, max_new_tokens, **kwargs):
|
| 223 |
assert metric == "NLL loss"
|
| 224 |
start = max(0, inputs["input_ids"].shape[1] - window_len + 1)
|
| 225 |
+
input_ids = inputs["input_ids"][:, start:]
|
|
|
|
| 226 |
|
| 227 |
logits_warper = get_logits_processor(**kwargs)
|
| 228 |
|
|
|
|
| 230 |
eos_idx = None
|
| 231 |
pbar = st.progress(0)
|
| 232 |
max_steps = max_new_tokens + window_len - 1
|
| 233 |
+
model_kwargs = dict(use_cache=True)
|
| 234 |
for i in range(max_steps):
|
| 235 |
pbar.progress(i / max_steps, f"{i}/{max_steps}")
|
| 236 |
+
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
| 237 |
+
model_outputs = model(**model_inputs)
|
| 238 |
+
model_kwargs = model._update_model_kwargs_for_generation(model_outputs, model_kwargs, is_encoder_decoder=False)
|
| 239 |
+
logits_window = model_outputs.logits.squeeze(0)
|
| 240 |
logprobs_window = logits_window.log_softmax(dim=-1)
|
| 241 |
if eos_idx is None:
|
| 242 |
+
probs_next = logits_warper(input_ids, logits_window[[-1]]).softmax(dim=-1)
|
| 243 |
next_token = torch.multinomial(probs_next, num_samples=1).item()
|
| 244 |
if next_token == tokenizer.eos_token_id or i >= max_new_tokens - 1:
|
| 245 |
eos_idx = i
|
|
|
|
| 247 |
next_token = tokenizer.eos_token_id
|
| 248 |
new_ids.append(next_token)
|
| 249 |
|
| 250 |
+
input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=1)
|
| 251 |
+
if input_ids.shape[1] > window_len:
|
| 252 |
+
input_ids = input_ids[:, 1:]
|
| 253 |
+
model_kwargs.update(use_cache=False, past_key_values=None)
|
| 254 |
if logprobs_window.shape[0] == window_len:
|
| 255 |
logprobs.append(
|
| 256 |
+
logprobs_window[torch.arange(window_len), input_ids.squeeze(0)]
|
| 257 |
)
|
| 258 |
|
| 259 |
if eos_idx is not None and i - eos_idx >= window_len - 1:
|