| import sys |
| from pathlib import Path |
| from functools import lru_cache, wraps |
|
|
| sys.path.append(str(Path(__file__).parent / "src")) |
|
|
| from pydantic import validate_call, Field, ValidationError |
| import gradio as gr |
|
|
| from sparky.inference import load_model, calc_token_metrics, visualize_batch |
| from sparky.inference.visualize import MetricType |
|
|
|
|
| model, tokenizer = load_model("gpt2") |
|
|
|
|
| |
| examples = [ |
| ["A long time ago, in a galaxy somewhat far away..."], |
| ["In a shocking turn of table, the seemingly impossible task"], |
| ["The quick brown fox jumps over the lazy dog."], |
| ] |
|
|
|
|
| def catch_all(fn): |
| @wraps(fn) |
| def wrapper(*args, **kwargs) -> str: |
| try: |
| return fn(*args, **kwargs) |
| except ValidationError as e: |
| return "<br>".join( |
| f"{', '.join(err['loc'])}: {err['msg']}" |
| for err in e.errors( |
| include_url=False, |
| include_context=False, |
| include_input=True, |
| ) |
| ) |
| except Exception as e: |
| print(f"Error processing text: {str(e)}", file=sys.stderr) |
| return "Sorry, there was an error processing your text. Please try a different input." |
|
|
| return wrapper |
|
|
|
|
| @catch_all |
| @validate_call(validate_return=True) |
| def _analyze_text( |
| text: str = Field(..., min_length=1, max_length=2000), |
| line_width: int = Field(..., ge=20, le=200), |
| metrics_to_show: tuple[MetricType, ...] = Field(..., min_length=1), |
| ) -> str: |
| metrics = calc_token_metrics([text], model, tokenizer) |
| svgs = visualize_batch( |
| metrics, metrics_to_show=metrics_to_show, line_width=line_width |
| ) |
| return svgs[0] |
|
|
|
|
| @lru_cache(128) |
| def analyze_text( |
| text: str, |
| line_width: int, |
| surprisal: bool, |
| entropy: bool, |
| s2: bool, |
| ) -> str: |
| metrics_to_show = tuple() |
| if surprisal: |
| metrics_to_show += ("surprisal",) |
| if entropy: |
| metrics_to_show += ("entropy",) |
| if s2: |
| metrics_to_show += ("s2",) |
|
|
| return _analyze_text( |
| text=text, line_width=line_width, metrics_to_show=metrics_to_show |
| ) |
|
|
|
|
| |
| text_input = gr.Textbox( |
| label="text", |
| placeholder="Enter some text to analyze...", |
| lines=3, |
| value=examples[0][0], |
| ) |
|
|
| width_slider = gr.Slider( |
| label="line_width", |
| minimum=20, |
| maximum=120, |
| step=5, |
| value=30, |
| ) |
|
|
| metric_toggles = [ |
| gr.Checkbox(label="Surprisal", value=False), |
| gr.Checkbox(label="Entropy", value=False), |
| gr.Checkbox(label="S₂", value=True), |
| ] |
|
|
| inputs = [text_input, width_slider, *metric_toggles] |
|
|
| empty_sample = [None] * len(inputs) |
| examples = [sample + empty_sample[len(sample) :] for sample in examples] |
|
|
| |
| demo = gr.Interface( |
| fn=analyze_text, |
| inputs=inputs, |
| outputs=gr.HTML(), |
| title="Sparky: Token Information Content Visualization", |
| description=""" |
| Visualize how predictable each token is according to GPT-2. The metrics shown are: |
| |
| - **Surprisal**: Actual information content (-log probability) of each token |
| - **Entropy**: Expected information content (uncertainty) at each position |
| - **S₂** (surprise-surprise): How much more/less surprising a token is than expected (surprisal - entropy) |
| |
| Read the paper for more details about this visualization and S₂: [Detecting out of distribution text with surprisal and entropy](https://www.lesswrong.com/posts/Kjo64rSWkFfc3sre5/detecting-out-of-distribution-text-with-surprisal-and#) |
| """, |
| examples=examples, |
| cache_examples=False, |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|