added input validation
Browse files- README.md +0 -3
- app.py +93 -58
- requirements.txt +1 -0
- src/sparky/inference/calc_metrics.py +1 -2
- src/sparky/inference/model.py +17 -0
README.md
CHANGED
|
@@ -44,10 +44,7 @@ To run locally:
|
|
| 44 |
|
| 45 |
```bash
|
| 46 |
uv venv
|
| 47 |
-
# source .venv/bin/activate
|
| 48 |
uv pip install -r requirements-dev.txt
|
| 49 |
-
uv pip compile requirements-dev.txt -o uv.lock
|
| 50 |
-
|
| 51 |
uv run app.py
|
| 52 |
```
|
| 53 |
|
|
|
|
| 44 |
|
| 45 |
```bash
|
| 46 |
uv venv
|
|
|
|
| 47 |
uv pip install -r requirements-dev.txt
|
|
|
|
|
|
|
| 48 |
uv run app.py
|
| 49 |
```
|
| 50 |
|
app.py
CHANGED
|
@@ -1,63 +1,113 @@
|
|
| 1 |
import sys
|
| 2 |
from pathlib import Path
|
|
|
|
|
|
|
| 3 |
sys.path.append(str(Path(__file__).parent / "src"))
|
| 4 |
|
|
|
|
| 5 |
import gradio as gr
|
|
|
|
| 6 |
from sparky.inference import load_model, calc_token_metrics, visualize_batch
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
# Load model on startup (cached)
|
| 9 |
model, tokenizer = load_model("gpt2")
|
| 10 |
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
def analyze_text(
|
| 13 |
text: str,
|
| 14 |
-
line_width: int
|
| 15 |
-
surprisal: bool
|
| 16 |
-
entropy: bool
|
| 17 |
-
s2: bool
|
| 18 |
) -> str:
|
| 19 |
-
|
| 20 |
-
if not any([s2, entropy, surprisal]):
|
| 21 |
-
s2 = True
|
| 22 |
-
|
| 23 |
-
# Build list of metrics to show
|
| 24 |
-
metrics_to_show = []
|
| 25 |
if surprisal:
|
| 26 |
-
metrics_to_show
|
| 27 |
if entropy:
|
| 28 |
-
metrics_to_show
|
| 29 |
if s2:
|
| 30 |
-
metrics_to_show
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
svgs = visualize_batch(
|
| 35 |
-
metrics, metrics_to_show=metrics_to_show, line_width=line_width
|
| 36 |
)
|
| 37 |
-
return svgs[0] # Return first (only) SVG
|
| 38 |
|
| 39 |
|
| 40 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
demo = gr.Interface(
|
| 42 |
fn=analyze_text,
|
| 43 |
-
inputs=
|
| 44 |
-
gr.Textbox(
|
| 45 |
-
label="Text to analyze",
|
| 46 |
-
placeholder="Enter some text to analyze its information content...",
|
| 47 |
-
lines=3,
|
| 48 |
-
value="The quick brown fox jumps over the lazy dog.",
|
| 49 |
-
),
|
| 50 |
-
gr.Slider(
|
| 51 |
-
label="Line width (characters)",
|
| 52 |
-
minimum=30,
|
| 53 |
-
maximum=200,
|
| 54 |
-
step=10,
|
| 55 |
-
value=30,
|
| 56 |
-
),
|
| 57 |
-
gr.Checkbox(label="Show surprisal", value=True),
|
| 58 |
-
gr.Checkbox(label="Show entropy", value=True),
|
| 59 |
-
gr.Checkbox(label="Show S₂ (surprise-surprise)"),
|
| 60 |
-
],
|
| 61 |
outputs=gr.HTML(),
|
| 62 |
title="Token Information Content Visualization",
|
| 63 |
description="""
|
|
@@ -65,27 +115,12 @@ demo = gr.Interface(
|
|
| 65 |
|
| 66 |
- **Surprisal**: Actual information content (-log probability) of each token
|
| 67 |
- **Entropy**: Expected information content (uncertainty) at each position
|
| 68 |
-
- **S₂** (surprise-surprise): How much more/less surprising a token is than expected
|
|
|
|
|
|
|
| 69 |
""",
|
| 70 |
-
examples=
|
| 71 |
-
["The quick brown fox jumps over the lazy dog.", 30, True, True, False],
|
| 72 |
-
[
|
| 73 |
-
"In a shocking turn of events, the seemingly impossible task",
|
| 74 |
-
30,
|
| 75 |
-
False,
|
| 76 |
-
False,
|
| 77 |
-
True,
|
| 78 |
-
],
|
| 79 |
-
[
|
| 80 |
-
"In a shocking turn of table, the seemingly impossible task",
|
| 81 |
-
30,
|
| 82 |
-
False,
|
| 83 |
-
False,
|
| 84 |
-
True,
|
| 85 |
-
],
|
| 86 |
-
["A long time ago, in a galaxy far, far away...", 50, True, False, True],
|
| 87 |
-
],
|
| 88 |
)
|
| 89 |
|
| 90 |
if __name__ == "__main__":
|
| 91 |
-
demo.launch(
|
|
|
|
| 1 |
import sys
|
| 2 |
from pathlib import Path
|
| 3 |
+
from functools import lru_cache, wraps
|
| 4 |
+
|
| 5 |
sys.path.append(str(Path(__file__).parent / "src"))
|
| 6 |
|
| 7 |
+
from pydantic import validate_call, Field, ValidationError
|
| 8 |
import gradio as gr
|
| 9 |
+
|
| 10 |
from sparky.inference import load_model, calc_token_metrics, visualize_batch
|
| 11 |
+
from sparky.inference.visualize import MetricType
|
| 12 |
+
|
| 13 |
|
|
|
|
| 14 |
model, tokenizer = load_model("gpt2")
|
| 15 |
|
| 16 |
|
| 17 |
+
def catch_all(fn):
|
| 18 |
+
@wraps(fn)
|
| 19 |
+
def wrapper(*args, **kwargs) -> str:
|
| 20 |
+
try:
|
| 21 |
+
return fn(*args, **kwargs)
|
| 22 |
+
except ValidationError as e:
|
| 23 |
+
return "<br>".join(
|
| 24 |
+
f"{', '.join(err['loc'])}: {err['msg']}"
|
| 25 |
+
for err in e.errors(
|
| 26 |
+
include_url=False,
|
| 27 |
+
include_context=False,
|
| 28 |
+
include_input=True,
|
| 29 |
+
)
|
| 30 |
+
)
|
| 31 |
+
except Exception as e:
|
| 32 |
+
print(f"Error processing text: {str(e)}", file=sys.stderr)
|
| 33 |
+
return "Sorry, there was an error processing your text. Please try a different input."
|
| 34 |
+
|
| 35 |
+
return wrapper
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@catch_all
|
| 39 |
+
@validate_call(validate_return=True)
|
| 40 |
+
def _analyze_text(
|
| 41 |
+
text: str = Field(..., min_length=1, max_length=2000),
|
| 42 |
+
line_width: int = Field(..., ge=20, le=200),
|
| 43 |
+
metrics_to_show: tuple[MetricType, ...] = Field(..., min_length=1),
|
| 44 |
+
) -> str:
|
| 45 |
+
metrics = calc_token_metrics([text], model, tokenizer)
|
| 46 |
+
svgs = visualize_batch(
|
| 47 |
+
metrics, metrics_to_show=metrics_to_show, line_width=line_width
|
| 48 |
+
)
|
| 49 |
+
return svgs[0]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@lru_cache(128)
|
| 53 |
def analyze_text(
|
| 54 |
text: str,
|
| 55 |
+
line_width: int,
|
| 56 |
+
surprisal: bool,
|
| 57 |
+
entropy: bool,
|
| 58 |
+
s2: bool,
|
| 59 |
) -> str:
|
| 60 |
+
metrics_to_show = tuple()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
if surprisal:
|
| 62 |
+
metrics_to_show += ("surprisal",)
|
| 63 |
if entropy:
|
| 64 |
+
metrics_to_show += ("entropy",)
|
| 65 |
if s2:
|
| 66 |
+
metrics_to_show += ("s2",)
|
| 67 |
|
| 68 |
+
return _analyze_text(
|
| 69 |
+
text=text, line_width=line_width, metrics_to_show=metrics_to_show
|
|
|
|
|
|
|
| 70 |
)
|
|
|
|
| 71 |
|
| 72 |
|
| 73 |
+
# Define UI components
|
| 74 |
+
text_input = gr.Textbox(
|
| 75 |
+
label="text",
|
| 76 |
+
placeholder="Enter some text to analyze...",
|
| 77 |
+
lines=3,
|
| 78 |
+
value="The quick brown fox jumps over the lazy dog.",
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
width_slider = gr.Slider(
|
| 82 |
+
label="line_width",
|
| 83 |
+
minimum=20,
|
| 84 |
+
maximum=120,
|
| 85 |
+
step=10,
|
| 86 |
+
value=30,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
metric_toggles = [
|
| 90 |
+
gr.Checkbox(label="Surprisal", value=True),
|
| 91 |
+
gr.Checkbox(label="Entropy", value=True),
|
| 92 |
+
gr.Checkbox(label="S₂", value=False),
|
| 93 |
+
]
|
| 94 |
+
|
| 95 |
+
inputs = [text_input, width_slider, *metric_toggles]
|
| 96 |
+
|
| 97 |
+
# Example inputs showing different linguistic patterns
|
| 98 |
+
examples = [
|
| 99 |
+
["The quick brown fox jumps over the lazy dog."],
|
| 100 |
+
["In a shocking turn of events, the seemingly impossible task"],
|
| 101 |
+
["In a shocking turn of table, the seemingly impossible task"],
|
| 102 |
+
["A long time ago, in a galaxy far, far away..."],
|
| 103 |
+
]
|
| 104 |
+
empty_sample = [None] * len(inputs)
|
| 105 |
+
examples = [sample + empty_sample[len(sample) :] for sample in examples]
|
| 106 |
+
|
| 107 |
+
# Create interface
|
| 108 |
demo = gr.Interface(
|
| 109 |
fn=analyze_text,
|
| 110 |
+
inputs=inputs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
outputs=gr.HTML(),
|
| 112 |
title="Token Information Content Visualization",
|
| 113 |
description="""
|
|
|
|
| 115 |
|
| 116 |
- **Surprisal**: Actual information content (-log probability) of each token
|
| 117 |
- **Entropy**: Expected information content (uncertainty) at each position
|
| 118 |
+
- **S₂** (surprise-surprise): How much more/less surprising a token is than expected (surprisal - entropy)
|
| 119 |
+
|
| 120 |
+
Read the paper for more details about this visualization and S₂: [Detecting out of distribution text with surprisal and entropy](https://www.lesswrong.com/posts/Kjo64rSWkFfc3sre5/detecting-out-of-distribution-text-with-surprisal-and#)
|
| 121 |
""",
|
| 122 |
+
examples=examples,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
)
|
| 124 |
|
| 125 |
if __name__ == "__main__":
|
| 126 |
+
demo.launch()
|
requirements.txt
CHANGED
|
@@ -1,3 +1,4 @@
|
|
| 1 |
numpy~=2.2.2
|
| 2 |
torch~=2.5.1
|
| 3 |
transformers~=4.48.1
|
|
|
|
|
|
| 1 |
numpy~=2.2.2
|
| 2 |
torch~=2.5.1
|
| 3 |
transformers~=4.48.1
|
| 4 |
+
pydantic~=2.10.0
|
src/sparky/inference/calc_metrics.py
CHANGED
|
@@ -16,7 +16,6 @@ def calc_token_metrics(
|
|
| 16 |
texts: List[str],
|
| 17 |
model: PreTrainedModel,
|
| 18 |
tokenizer: PreTrainedTokenizer,
|
| 19 |
-
truncation=False,
|
| 20 |
) -> TokenMetrics:
|
| 21 |
"""Calculate per-token metrics for a batch of text sequences using a language model."""
|
| 22 |
if tokenizer.pad_token is None:
|
|
@@ -30,7 +29,7 @@ def calc_token_metrics(
|
|
| 30 |
texts,
|
| 31 |
return_tensors="pt",
|
| 32 |
padding=True,
|
| 33 |
-
truncation=
|
| 34 |
return_length=True,
|
| 35 |
).to(device)
|
| 36 |
|
|
|
|
| 16 |
texts: List[str],
|
| 17 |
model: PreTrainedModel,
|
| 18 |
tokenizer: PreTrainedTokenizer,
|
|
|
|
| 19 |
) -> TokenMetrics:
|
| 20 |
"""Calculate per-token metrics for a batch of text sequences using a language model."""
|
| 21 |
if tokenizer.pad_token is None:
|
|
|
|
| 29 |
texts,
|
| 30 |
return_tensors="pt",
|
| 31 |
padding=True,
|
| 32 |
+
truncation=True,
|
| 33 |
return_length=True,
|
| 34 |
).to(device)
|
| 35 |
|
src/sparky/inference/model.py
CHANGED
|
@@ -1,9 +1,26 @@
|
|
|
|
|
|
|
|
| 1 |
import torch
|
|
|
|
| 2 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
| 3 |
|
| 4 |
|
| 5 |
def load_model(name="gpt2"):
|
|
|
|
| 6 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
|
| 7 |
model = GPT2LMHeadModel.from_pretrained(name).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
tokenizer = GPT2Tokenizer.from_pretrained(name, clean_up_tokenization_spaces=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
return model, tokenizer
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
|
| 3 |
import torch
|
| 4 |
+
import transformers
|
| 5 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
| 6 |
|
| 7 |
|
| 8 |
def load_model(name="gpt2"):
|
| 9 |
+
print(f"Loading model {name}...")
|
| 10 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 11 |
+
|
| 12 |
+
print("Loading model from pretrained...")
|
| 13 |
+
t0 = time.perf_counter()
|
| 14 |
model = GPT2LMHeadModel.from_pretrained(name).to(device)
|
| 15 |
+
t1 = time.perf_counter()
|
| 16 |
+
print(f"Model loaded in {t1 - t0:.1f}s")
|
| 17 |
+
|
| 18 |
+
print("Loading tokenizer...")
|
| 19 |
tokenizer = GPT2Tokenizer.from_pretrained(name, clean_up_tokenization_spaces=True)
|
| 20 |
+
t2 = time.perf_counter()
|
| 21 |
+
print(f"Tokenizer loaded in {t2 - t1:.1f}s")
|
| 22 |
+
|
| 23 |
+
# Truncate from start to preserve adversarial suffixes.
|
| 24 |
+
tokenizer.truncation_side = "left"
|
| 25 |
+
|
| 26 |
return model, tokenizer
|