z0u commited on
Commit
cae066d
·
unverified ·
1 Parent(s): b479caf

added input validation

Browse files
README.md CHANGED
@@ -44,10 +44,7 @@ To run locally:
44
 
45
  ```bash
46
  uv venv
47
- # source .venv/bin/activate
48
  uv pip install -r requirements-dev.txt
49
- uv pip compile requirements-dev.txt -o uv.lock
50
-
51
  uv run app.py
52
  ```
53
 
 
44
 
45
  ```bash
46
  uv venv
 
47
  uv pip install -r requirements-dev.txt
 
 
48
  uv run app.py
49
  ```
50
 
app.py CHANGED
@@ -1,63 +1,113 @@
1
  import sys
2
  from pathlib import Path
 
 
3
  sys.path.append(str(Path(__file__).parent / "src"))
4
 
 
5
  import gradio as gr
 
6
  from sparky.inference import load_model, calc_token_metrics, visualize_batch
 
 
7
 
8
- # Load model on startup (cached)
9
  model, tokenizer = load_model("gpt2")
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def analyze_text(
13
  text: str,
14
- line_width: int = 80,
15
- surprisal: bool = False,
16
- entropy: bool = False,
17
- s2: bool = True,
18
  ) -> str:
19
- # Default to S2 if nothing selected
20
- if not any([s2, entropy, surprisal]):
21
- s2 = True
22
-
23
- # Build list of metrics to show
24
- metrics_to_show = []
25
  if surprisal:
26
- metrics_to_show.append("surprisal")
27
  if entropy:
28
- metrics_to_show.append("entropy")
29
  if s2:
30
- metrics_to_show.append("s2")
31
 
32
- # Calculate metrics and generate visualization
33
- metrics = calc_token_metrics([text], model, tokenizer)
34
- svgs = visualize_batch(
35
- metrics, metrics_to_show=metrics_to_show, line_width=line_width
36
  )
37
- return svgs[0] # Return first (only) SVG
38
 
39
 
40
- # Create Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  demo = gr.Interface(
42
  fn=analyze_text,
43
- inputs=[
44
- gr.Textbox(
45
- label="Text to analyze",
46
- placeholder="Enter some text to analyze its information content...",
47
- lines=3,
48
- value="The quick brown fox jumps over the lazy dog.",
49
- ),
50
- gr.Slider(
51
- label="Line width (characters)",
52
- minimum=30,
53
- maximum=200,
54
- step=10,
55
- value=30,
56
- ),
57
- gr.Checkbox(label="Show surprisal", value=True),
58
- gr.Checkbox(label="Show entropy", value=True),
59
- gr.Checkbox(label="Show S₂ (surprise-surprise)"),
60
- ],
61
  outputs=gr.HTML(),
62
  title="Token Information Content Visualization",
63
  description="""
@@ -65,27 +115,12 @@ demo = gr.Interface(
65
 
66
  - **Surprisal**: Actual information content (-log probability) of each token
67
  - **Entropy**: Expected information content (uncertainty) at each position
68
- - **S₂** (surprise-surprise): How much more/less surprising a token is than expected
 
 
69
  """,
70
- examples=[
71
- ["The quick brown fox jumps over the lazy dog.", 30, True, True, False],
72
- [
73
- "In a shocking turn of events, the seemingly impossible task",
74
- 30,
75
- False,
76
- False,
77
- True,
78
- ],
79
- [
80
- "In a shocking turn of table, the seemingly impossible task",
81
- 30,
82
- False,
83
- False,
84
- True,
85
- ],
86
- ["A long time ago, in a galaxy far, far away...", 50, True, False, True],
87
- ],
88
  )
89
 
90
  if __name__ == "__main__":
91
- demo.launch(show_error=True)
 
1
  import sys
2
  from pathlib import Path
3
+ from functools import lru_cache, wraps
4
+
5
  sys.path.append(str(Path(__file__).parent / "src"))
6
 
7
+ from pydantic import validate_call, Field, ValidationError
8
  import gradio as gr
9
+
10
  from sparky.inference import load_model, calc_token_metrics, visualize_batch
11
+ from sparky.inference.visualize import MetricType
12
+
13
 
 
14
  model, tokenizer = load_model("gpt2")
15
 
16
 
17
+ def catch_all(fn):
18
+ @wraps(fn)
19
+ def wrapper(*args, **kwargs) -> str:
20
+ try:
21
+ return fn(*args, **kwargs)
22
+ except ValidationError as e:
23
+ return "<br>".join(
24
+ f"{', '.join(err['loc'])}: {err['msg']}"
25
+ for err in e.errors(
26
+ include_url=False,
27
+ include_context=False,
28
+ include_input=True,
29
+ )
30
+ )
31
+ except Exception as e:
32
+ print(f"Error processing text: {str(e)}", file=sys.stderr)
33
+ return "Sorry, there was an error processing your text. Please try a different input."
34
+
35
+ return wrapper
36
+
37
+
38
+ @catch_all
39
+ @validate_call(validate_return=True)
40
+ def _analyze_text(
41
+ text: str = Field(..., min_length=1, max_length=2000),
42
+ line_width: int = Field(..., ge=20, le=200),
43
+ metrics_to_show: tuple[MetricType, ...] = Field(..., min_length=1),
44
+ ) -> str:
45
+ metrics = calc_token_metrics([text], model, tokenizer)
46
+ svgs = visualize_batch(
47
+ metrics, metrics_to_show=metrics_to_show, line_width=line_width
48
+ )
49
+ return svgs[0]
50
+
51
+
52
+ @lru_cache(128)
53
  def analyze_text(
54
  text: str,
55
+ line_width: int,
56
+ surprisal: bool,
57
+ entropy: bool,
58
+ s2: bool,
59
  ) -> str:
60
+ metrics_to_show = tuple()
 
 
 
 
 
61
  if surprisal:
62
+ metrics_to_show += ("surprisal",)
63
  if entropy:
64
+ metrics_to_show += ("entropy",)
65
  if s2:
66
+ metrics_to_show += ("s2",)
67
 
68
+ return _analyze_text(
69
+ text=text, line_width=line_width, metrics_to_show=metrics_to_show
 
 
70
  )
 
71
 
72
 
73
+ # Define UI components
74
+ text_input = gr.Textbox(
75
+ label="text",
76
+ placeholder="Enter some text to analyze...",
77
+ lines=3,
78
+ value="The quick brown fox jumps over the lazy dog.",
79
+ )
80
+
81
+ width_slider = gr.Slider(
82
+ label="line_width",
83
+ minimum=20,
84
+ maximum=120,
85
+ step=10,
86
+ value=30,
87
+ )
88
+
89
+ metric_toggles = [
90
+ gr.Checkbox(label="Surprisal", value=True),
91
+ gr.Checkbox(label="Entropy", value=True),
92
+ gr.Checkbox(label="S₂", value=False),
93
+ ]
94
+
95
+ inputs = [text_input, width_slider, *metric_toggles]
96
+
97
+ # Example inputs showing different linguistic patterns
98
+ examples = [
99
+ ["The quick brown fox jumps over the lazy dog."],
100
+ ["In a shocking turn of events, the seemingly impossible task"],
101
+ ["In a shocking turn of table, the seemingly impossible task"],
102
+ ["A long time ago, in a galaxy far, far away..."],
103
+ ]
104
+ empty_sample = [None] * len(inputs)
105
+ examples = [sample + empty_sample[len(sample) :] for sample in examples]
106
+
107
+ # Create interface
108
  demo = gr.Interface(
109
  fn=analyze_text,
110
+ inputs=inputs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  outputs=gr.HTML(),
112
  title="Token Information Content Visualization",
113
  description="""
 
115
 
116
  - **Surprisal**: Actual information content (-log probability) of each token
117
  - **Entropy**: Expected information content (uncertainty) at each position
118
+ - **S₂** (surprise-surprise): How much more/less surprising a token is than expected (surprisal - entropy)
119
+
120
+ Read the paper for more details about this visualization and S₂: [Detecting out of distribution text with surprisal and entropy](https://www.lesswrong.com/posts/Kjo64rSWkFfc3sre5/detecting-out-of-distribution-text-with-surprisal-and#)
121
  """,
122
+ examples=examples,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  )
124
 
125
  if __name__ == "__main__":
126
+ demo.launch()
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  numpy~=2.2.2
2
  torch~=2.5.1
3
  transformers~=4.48.1
 
 
1
  numpy~=2.2.2
2
  torch~=2.5.1
3
  transformers~=4.48.1
4
+ pydantic~=2.10.0
src/sparky/inference/calc_metrics.py CHANGED
@@ -16,7 +16,6 @@ def calc_token_metrics(
16
  texts: List[str],
17
  model: PreTrainedModel,
18
  tokenizer: PreTrainedTokenizer,
19
- truncation=False,
20
  ) -> TokenMetrics:
21
  """Calculate per-token metrics for a batch of text sequences using a language model."""
22
  if tokenizer.pad_token is None:
@@ -30,7 +29,7 @@ def calc_token_metrics(
30
  texts,
31
  return_tensors="pt",
32
  padding=True,
33
- truncation=truncation,
34
  return_length=True,
35
  ).to(device)
36
 
 
16
  texts: List[str],
17
  model: PreTrainedModel,
18
  tokenizer: PreTrainedTokenizer,
 
19
  ) -> TokenMetrics:
20
  """Calculate per-token metrics for a batch of text sequences using a language model."""
21
  if tokenizer.pad_token is None:
 
29
  texts,
30
  return_tensors="pt",
31
  padding=True,
32
+ truncation=True,
33
  return_length=True,
34
  ).to(device)
35
 
src/sparky/inference/model.py CHANGED
@@ -1,9 +1,26 @@
 
 
1
  import torch
 
2
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
 
4
 
5
  def load_model(name="gpt2"):
 
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
7
  model = GPT2LMHeadModel.from_pretrained(name).to(device)
 
 
 
 
8
  tokenizer = GPT2Tokenizer.from_pretrained(name, clean_up_tokenization_spaces=True)
 
 
 
 
 
 
9
  return model, tokenizer
 
1
+ import time
2
+
3
  import torch
4
+ import transformers
5
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
6
 
7
 
8
  def load_model(name="gpt2"):
9
+ print(f"Loading model {name}...")
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ print("Loading model from pretrained...")
13
+ t0 = time.perf_counter()
14
  model = GPT2LMHeadModel.from_pretrained(name).to(device)
15
+ t1 = time.perf_counter()
16
+ print(f"Model loaded in {t1 - t0:.1f}s")
17
+
18
+ print("Loading tokenizer...")
19
  tokenizer = GPT2Tokenizer.from_pretrained(name, clean_up_tokenization_spaces=True)
20
+ t2 = time.perf_counter()
21
+ print(f"Tokenizer loaded in {t2 - t1:.1f}s")
22
+
23
+ # Truncate from start to preserve adversarial suffixes.
24
+ tokenizer.truncation_side = "left"
25
+
26
  return model, tokenizer