|
|
|
|
|
""" |
|
|
Gradio app: Clean text (remove non-ASCII, lowercase), tokenize with BERT, |
|
|
compute embeddings, and display tokens + per-token vectors. |
|
|
|
|
|
Run locally: |
|
|
pip install -r requirements.txt |
|
|
python app.py |
|
|
""" |
|
|
|
|
|
import re |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
import gradio as gr |
|
|
from transformers import BertTokenizer, BertModel |
|
|
|
|
|
|
|
|
_ascii_re = re.compile(r"[^\x00-\x7F]+") |
|
|
|
|
|
def clean_text(s: str) -> str: |
|
|
"""Remove non-ASCII chars and lowercase.""" |
|
|
if s is None: |
|
|
return "" |
|
|
s = _ascii_re.sub("", s) |
|
|
s = s.lower() |
|
|
s = re.sub(r"\s+", " ", s).strip() |
|
|
return s |
|
|
|
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
TOKENIZER = BertTokenizer.from_pretrained("bert-base-uncased") |
|
|
MODEL = BertModel.from_pretrained("bert-base-uncased") |
|
|
MODEL.to(DEVICE) |
|
|
MODEL.eval() |
|
|
|
|
|
|
|
|
def bert_embed(text: str, max_tokens: int = 48): |
|
|
""" |
|
|
Return: |
|
|
- cleaned text |
|
|
- list of wordpiece tokens |
|
|
- DataFrame of embeddings (one row per token, 768-d columns) |
|
|
""" |
|
|
cleaned = clean_text(text) |
|
|
if not cleaned: |
|
|
return "", [], pd.DataFrame() |
|
|
|
|
|
|
|
|
enc = TOKENIZER( |
|
|
cleaned, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=max_tokens, |
|
|
add_special_tokens=True, |
|
|
) |
|
|
input_ids = enc["input_ids"].to(DEVICE) |
|
|
attention_mask = enc["attention_mask"].to(DEVICE) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = MODEL(input_ids=input_ids, attention_mask=attention_mask) |
|
|
|
|
|
last_hidden_state = outputs.last_hidden_state.squeeze(0).cpu().numpy() |
|
|
|
|
|
tokens = TOKENIZER.convert_ids_to_tokens(input_ids.squeeze(0).tolist()) |
|
|
|
|
|
|
|
|
cols = [f"dim_{i}" for i in range(last_hidden_state.shape[1])] |
|
|
df = pd.DataFrame(last_hidden_state, index=tokens, columns=cols) |
|
|
|
|
|
return cleaned, tokens, df |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="BERT Tokenizer & Embeddings") as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# BERT Tokenizer & Embeddings |
|
|
Paste text below. The app will **remove non-ASCII characters**, **lowercase** the text, then use |
|
|
**BERT (bert-base-uncased)** to produce tokens and embeddings (last hidden state). |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
inp = gr.Textbox(label="Input text", lines=6, placeholder="Type or paste text...") |
|
|
max_tok = gr.Slider(8, 256, value=48, step=1, label="Max tokens (truncate)") |
|
|
|
|
|
with gr.Row(): |
|
|
cleaned_out = gr.Textbox(label="Cleaned text (ASCII-only, lowercased)") |
|
|
tokens_out = gr.JSON(label="WordPiece tokens") |
|
|
df_out = gr.Dataframe(label="Per-token embeddings (last_hidden_state)", wrap=True) |
|
|
|
|
|
run_btn = gr.Button("Transform with BERT", variant="primary") |
|
|
run_btn.click(bert_embed, inputs=[inp, max_tok], outputs=[cleaned_out, tokens_out, df_out]) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
**Notes** |
|
|
- Embeddings are 768-dim vectors from the last hidden state (one row per token). |
|
|
- Special tokens like `[CLS]` and `[SEP]` are included. |
|
|
- Truncation keeps the UI responsive; increase *Max tokens* if needed. |
|
|
""" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|