ajayinsac commited on
Commit
3cfa391
·
verified ·
1 Parent(s): 9711432

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio app: Clean text (remove non-ASCII, lowercase), tokenize with BERT,
4
+ compute embeddings, and display tokens + per-token vectors.
5
+
6
+ Run locally:
7
+ pip install -r requirements.txt
8
+ python app.py
9
+ """
10
+
11
+ import re
12
+ import numpy as np
13
+ import pandas as pd
14
+ import torch
15
+ import gradio as gr
16
+ from transformers import BertTokenizer, BertModel
17
+
18
+ # ---- Preprocessing helpers ---------------------------------------------------
19
+ _ascii_re = re.compile(r"[^\x00-\x7F]+")
20
+
21
+ def clean_text(s: str) -> str:
22
+ """Remove non-ASCII chars and lowercase."""
23
+ if s is None:
24
+ return ""
25
+ s = _ascii_re.sub("", s) # drop non-ASCII
26
+ s = s.lower()
27
+ s = re.sub(r"\s+", " ", s).strip()
28
+ return s
29
+
30
+ # ---- Load model/tokenizer once ----------------------------------------------
31
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ TOKENIZER = BertTokenizer.from_pretrained("bert-base-uncased")
33
+ MODEL = BertModel.from_pretrained("bert-base-uncased")
34
+ MODEL.to(DEVICE)
35
+ MODEL.eval()
36
+
37
+ # ---- Core function -----------------------------------------------------------
38
+ def bert_embed(text: str, max_tokens: int = 48):
39
+ """
40
+ Return:
41
+ - cleaned text
42
+ - list of wordpiece tokens
43
+ - DataFrame of embeddings (one row per token, 768-d columns)
44
+ """
45
+ cleaned = clean_text(text)
46
+ if not cleaned:
47
+ return "", [], pd.DataFrame()
48
+
49
+ # Tokenize (truncate to keep UI snappy)
50
+ enc = TOKENIZER(
51
+ cleaned,
52
+ return_tensors="pt",
53
+ truncation=True,
54
+ max_length=max_tokens,
55
+ add_special_tokens=True,
56
+ )
57
+ input_ids = enc["input_ids"].to(DEVICE)
58
+ attention_mask = enc["attention_mask"].to(DEVICE)
59
+
60
+ with torch.no_grad():
61
+ outputs = MODEL(input_ids=input_ids, attention_mask=attention_mask)
62
+ # last_hidden_state shape: [batch=1, seq_len, hidden=768]
63
+ last_hidden_state = outputs.last_hidden_state.squeeze(0).cpu().numpy()
64
+
65
+ tokens = TOKENIZER.convert_ids_to_tokens(input_ids.squeeze(0).tolist())
66
+
67
+ # Build a DataFrame: rows = tokens, columns = dim_0..dim_767
68
+ cols = [f"dim_{i}" for i in range(last_hidden_state.shape[1])]
69
+ df = pd.DataFrame(last_hidden_state, index=tokens, columns=cols)
70
+
71
+ return cleaned, tokens, df
72
+
73
+
74
+ # ---- Gradio UI ---------------------------------------------------------------
75
+ with gr.Blocks(title="BERT Tokenizer & Embeddings") as demo:
76
+ gr.Markdown(
77
+ """
78
+ # BERT Tokenizer & Embeddings
79
+ Paste text below. The app will **remove non-ASCII characters**, **lowercase** the text, then use
80
+ **BERT (bert-base-uncased)** to produce tokens and embeddings (last hidden state).
81
+ """
82
+ )
83
+
84
+ with gr.Row():
85
+ inp = gr.Textbox(label="Input text", lines=6, placeholder="Type or paste text...")
86
+ max_tok = gr.Slider(8, 256, value=48, step=1, label="Max tokens (truncate)")
87
+
88
+ with gr.Row():
89
+ cleaned_out = gr.Textbox(label="Cleaned text (ASCII-only, lowercased)")
90
+ tokens_out = gr.JSON(label="WordPiece tokens")
91
+ df_out = gr.Dataframe(label="Per-token embeddings (last_hidden_state)", wrap=True)
92
+
93
+ run_btn = gr.Button("Transform with BERT", variant="primary")
94
+ run_btn.click(bert_embed, inputs=[inp, max_tok], outputs=[cleaned_out, tokens_out, df_out])
95
+
96
+ gr.Markdown(
97
+ """
98
+ **Notes**
99
+ - Embeddings are 768-dim vectors from the last hidden state (one row per token).
100
+ - Special tokens like `[CLS]` and `[SEP]` are included.
101
+ - Truncation keeps the UI responsive; increase *Max tokens* if needed.
102
+ """
103
+ )
104
+
105
+ if __name__ == "__main__":
106
+ # Do not force share=True (some hosts disallow it)
107
+ demo.launch(server_name="0.0.0.0", server_port=7860)