File size: 3,717 Bytes
3cfa391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#!/usr/bin/env python3
"""
Gradio app: Clean text (remove non-ASCII, lowercase), tokenize with BERT,
compute embeddings, and display tokens + per-token vectors.

Run locally:
  pip install -r requirements.txt
  python app.py
"""

import re
import numpy as np
import pandas as pd
import torch
import gradio as gr
from transformers import BertTokenizer, BertModel

# ---- Preprocessing helpers ---------------------------------------------------
_ascii_re = re.compile(r"[^\x00-\x7F]+")

def clean_text(s: str) -> str:
    """Remove non-ASCII chars and lowercase."""
    if s is None:
        return ""
    s = _ascii_re.sub("", s)         # drop non-ASCII
    s = s.lower()
    s = re.sub(r"\s+", " ", s).strip()
    return s

# ---- Load model/tokenizer once ----------------------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TOKENIZER = BertTokenizer.from_pretrained("bert-base-uncased")
MODEL = BertModel.from_pretrained("bert-base-uncased")
MODEL.to(DEVICE)
MODEL.eval()

# ---- Core function -----------------------------------------------------------
def bert_embed(text: str, max_tokens: int = 48):
    """
    Return:
      - cleaned text
      - list of wordpiece tokens
      - DataFrame of embeddings (one row per token, 768-d columns)
    """
    cleaned = clean_text(text)
    if not cleaned:
        return "", [], pd.DataFrame()

    # Tokenize (truncate to keep UI snappy)
    enc = TOKENIZER(
        cleaned,
        return_tensors="pt",
        truncation=True,
        max_length=max_tokens,
        add_special_tokens=True,
    )
    input_ids = enc["input_ids"].to(DEVICE)
    attention_mask = enc["attention_mask"].to(DEVICE)

    with torch.no_grad():
        outputs = MODEL(input_ids=input_ids, attention_mask=attention_mask)
        # last_hidden_state shape: [batch=1, seq_len, hidden=768]
        last_hidden_state = outputs.last_hidden_state.squeeze(0).cpu().numpy()

    tokens = TOKENIZER.convert_ids_to_tokens(input_ids.squeeze(0).tolist())

    # Build a DataFrame: rows = tokens, columns = dim_0..dim_767
    cols = [f"dim_{i}" for i in range(last_hidden_state.shape[1])]
    df = pd.DataFrame(last_hidden_state, index=tokens, columns=cols)

    return cleaned, tokens, df


# ---- Gradio UI ---------------------------------------------------------------
with gr.Blocks(title="BERT Tokenizer & Embeddings") as demo:
    gr.Markdown(
        """
        # BERT Tokenizer & Embeddings
        Paste text below. The app will **remove non-ASCII characters**, **lowercase** the text, then use
        **BERT (bert-base-uncased)** to produce tokens and embeddings (last hidden state).
        """
    )

    with gr.Row():
        inp = gr.Textbox(label="Input text", lines=6, placeholder="Type or paste text...")
        max_tok = gr.Slider(8, 256, value=48, step=1, label="Max tokens (truncate)")

    with gr.Row():
        cleaned_out = gr.Textbox(label="Cleaned text (ASCII-only, lowercased)")
    tokens_out = gr.JSON(label="WordPiece tokens")
    df_out = gr.Dataframe(label="Per-token embeddings (last_hidden_state)", wrap=True)

    run_btn = gr.Button("Transform with BERT", variant="primary")
    run_btn.click(bert_embed, inputs=[inp, max_tok], outputs=[cleaned_out, tokens_out, df_out])

    gr.Markdown(
        """
        **Notes**
        - Embeddings are 768-dim vectors from the last hidden state (one row per token).
        - Special tokens like `[CLS]` and `[SEP]` are included.
        - Truncation keeps the UI responsive; increase *Max tokens* if needed.
        """
    )

if __name__ == "__main__":
    # Do not force share=True (some hosts disallow it)
    demo.launch(server_name="0.0.0.0", server_port=7860)