Spaces:
Running
Running
Cache individual batches
Browse files
app.py
CHANGED
|
@@ -70,14 +70,19 @@ if len(input_ids) < 2:
|
|
| 70 |
|
| 71 |
@st.cache_data(show_spinner=False)
|
| 72 |
@torch.inference_mode()
|
| 73 |
-
def
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
inputs_sliding = get_windows_batched(
|
| 78 |
-
|
| 79 |
window_len=window_len,
|
| 80 |
-
pad_id=
|
| 81 |
).convert_to_tensors("pt")
|
| 82 |
|
| 83 |
logits = []
|
|
@@ -88,7 +93,13 @@ def run_context_length_probing(model_name, text, window_len):
|
|
| 88 |
for i in range(0, num_items, batch_size):
|
| 89 |
pbar.progress(i / num_items, f"{i}/{num_items}")
|
| 90 |
batch = {k: v[i:i + batch_size] for k, v in inputs_sliding.items()}
|
| 91 |
-
logits.append(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
logits = torch.cat(logits, dim=0)
|
| 93 |
pbar.empty()
|
| 94 |
|
|
@@ -108,9 +119,11 @@ def run_context_length_probing(model_name, text, window_len):
|
|
| 108 |
return scores
|
| 109 |
|
| 110 |
scores = run_context_length_probing(
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
| 114 |
)
|
| 115 |
tokens = ids_to_readable_tokens(tokenizer, input_ids)
|
| 116 |
|
|
|
|
| 70 |
|
| 71 |
@st.cache_data(show_spinner=False)
|
| 72 |
@torch.inference_mode()
|
| 73 |
+
def get_logits(_model, _inputs, cache_key):
|
| 74 |
+
del cache_key
|
| 75 |
+
return _model(**_inputs).logits.to(torch.float16)
|
| 76 |
+
|
| 77 |
+
@st.cache_data(show_spinner=False)
|
| 78 |
+
@torch.inference_mode()
|
| 79 |
+
def run_context_length_probing(_model, _tokenizer, _inputs, window_len, cache_key):
|
| 80 |
+
del cache_key
|
| 81 |
|
| 82 |
inputs_sliding = get_windows_batched(
|
| 83 |
+
_inputs,
|
| 84 |
window_len=window_len,
|
| 85 |
+
pad_id=_tokenizer.eos_token_id
|
| 86 |
).convert_to_tensors("pt")
|
| 87 |
|
| 88 |
logits = []
|
|
|
|
| 93 |
for i in range(0, num_items, batch_size):
|
| 94 |
pbar.progress(i / num_items, f"{i}/{num_items}")
|
| 95 |
batch = {k: v[i:i + batch_size] for k, v in inputs_sliding.items()}
|
| 96 |
+
logits.append(
|
| 97 |
+
get_logits(
|
| 98 |
+
_model,
|
| 99 |
+
batch,
|
| 100 |
+
cache_key=(model_name, batch["input_ids"].cpu().numpy().tobytes())
|
| 101 |
+
)
|
| 102 |
+
)
|
| 103 |
logits = torch.cat(logits, dim=0)
|
| 104 |
pbar.empty()
|
| 105 |
|
|
|
|
| 119 |
return scores
|
| 120 |
|
| 121 |
scores = run_context_length_probing(
|
| 122 |
+
_model=model,
|
| 123 |
+
_tokenizer=tokenizer,
|
| 124 |
+
_inputs=inputs,
|
| 125 |
+
window_len=window_len,
|
| 126 |
+
cache_key=(model_name, text),
|
| 127 |
)
|
| 128 |
tokens = ids_to_readable_tokens(tokenizer, input_ids)
|
| 129 |
|