rajendrr commited on
Commit
b0f93cb
·
verified ·
1 Parent(s): 762b017

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -27
app.py CHANGED
@@ -21,17 +21,19 @@ model.to(device)
21
 
22
  # ---- Text cleaning helpers ----
23
  def remove_non_ascii_and_lowercase(text: str) -> str:
24
- # Keep ASCII only, then lowercase
25
  text_ascii = re.sub(r"[^\x00-\x7F]+", "", text or "")
26
  return text_ascii.lower()
27
 
28
  # ---- Embedding helpers ----
29
  def get_embeddings(clean_text: str):
30
  """
 
 
31
  Returns:
32
- tokens_with_special: list[str] tokens including [CLS]/[SEP]
33
- embeddings: np.ndarray shape (seq_len, hidden_size)
34
- sent_embedding: np.ndarray shape (hidden_size,)
35
  """
36
  if not clean_text.strip():
37
  return [], np.zeros((0, 768), dtype=np.float32), np.zeros((768,), dtype=np.float32)
@@ -47,19 +49,16 @@ def get_embeddings(clean_text: str):
47
  enc = {k: v.to(device) for k, v in enc.items()}
48
 
49
  with torch.no_grad():
50
- outputs = model(**enc) # BaseModelOutputWithPoolingAndCrossAttentions
51
  last_hidden = outputs.last_hidden_state # (1, seq_len, hidden)
52
 
53
- # Convert to CPU numpy
54
- last_hidden_np = last_hidden.squeeze(0).detach().cpu().numpy() # (seq_len, hidden)
55
  tokens_with_special = tokenizer.convert_ids_to_tokens(enc["input_ids"][0])
56
 
57
- # Sentence embedding
58
  if POOLING == "cls":
59
  sent_embedding = last_hidden_np[0] # [CLS]
60
  else:
61
- # mean pooling over tokens with attention mask
62
- mask = enc["attention_mask"].squeeze(0).detach().cpu().numpy().astype(bool) # (seq_len,)
63
  if mask.any():
64
  sent_embedding = last_hidden_np[mask].mean(axis=0)
65
  else:
@@ -68,10 +67,7 @@ def get_embeddings(clean_text: str):
68
  return tokens_with_special, last_hidden_np, sent_embedding
69
 
70
  def build_token_df(tokens, embeddings, dims_to_show=DEFAULT_DIMS_TO_SHOW) -> pd.DataFrame:
71
- """
72
- tokens: list[str], embeddings: (seq_len, hidden)
73
- returns a DataFrame with columns: token, dim_0..dim_{dims_to_show-1}
74
- """
75
  if len(tokens) == 0:
76
  return pd.DataFrame(columns=["token"] + [f"dim_{i}" for i in range(dims_to_show)])
77
 
@@ -84,19 +80,14 @@ def build_token_df(tokens, embeddings, dims_to_show=DEFAULT_DIMS_TO_SHOW) -> pd.
84
  return pd.DataFrame(data, columns=cols)
85
 
86
  def save_full_token_csv(tokens, embeddings) -> str:
87
- """
88
- Save full 768-dim token embeddings to a CSV and return file path.
89
- Columns: token, dim_0..dim_767
90
- """
91
  if len(tokens) == 0:
92
  fd, empty_path = tempfile.mkstemp(suffix=".csv")
93
  os.close(fd)
94
  return empty_path
95
 
96
  cols = ["token"] + [f"dim_{i}" for i in range(embeddings.shape[1])]
97
- rows = []
98
- for tok, vec in zip(tokens, embeddings):
99
- rows.append([tok] + list(vec))
100
  df = pd.DataFrame(rows, columns=cols)
101
 
102
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
@@ -104,10 +95,7 @@ def save_full_token_csv(tokens, embeddings) -> str:
104
  return tmp.name
105
 
106
  def save_sentence_csv(sent_embedding) -> str:
107
- """
108
- Save 768-dim sentence embedding to CSV and return file path.
109
- Columns: dim_0..dim_767
110
- """
111
  cols = [f"dim_{i}" for i in range(sent_embedding.shape[0])]
112
  df = pd.DataFrame([sent_embedding], columns=cols)
113
 
@@ -118,6 +106,86 @@ def save_sentence_csv(sent_embedding) -> str:
118
  # ---- Gradio pipeline ----
119
  def run_pipeline(raw_text: str, dims_to_show: int):
120
  """
 
 
121
  Returns:
122
- cleaned_text (str)
123
- sh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # ---- Text cleaning helpers ----
23
  def remove_non_ascii_and_lowercase(text: str) -> str:
24
+ """Remove non-ASCII characters and lowercase the text."""
25
  text_ascii = re.sub(r"[^\x00-\x7F]+", "", text or "")
26
  return text_ascii.lower()
27
 
28
  # ---- Embedding helpers ----
29
  def get_embeddings(clean_text: str):
30
  """
31
+ Generate token and sentence embeddings using BERT.
32
+
33
  Returns:
34
+ tokens_with_special (list[str]): tokens including [CLS]/[SEP]
35
+ embeddings (np.ndarray): shape (seq_len, hidden_size)
36
+ sent_embedding (np.ndarray): shape (hidden_size,)
37
  """
38
  if not clean_text.strip():
39
  return [], np.zeros((0, 768), dtype=np.float32), np.zeros((768,), dtype=np.float32)
 
49
  enc = {k: v.to(device) for k, v in enc.items()}
50
 
51
  with torch.no_grad():
52
+ outputs = model(**enc)
53
  last_hidden = outputs.last_hidden_state # (1, seq_len, hidden)
54
 
55
+ last_hidden_np = last_hidden.squeeze(0).detach().cpu().numpy()
 
56
  tokens_with_special = tokenizer.convert_ids_to_tokens(enc["input_ids"][0])
57
 
 
58
  if POOLING == "cls":
59
  sent_embedding = last_hidden_np[0] # [CLS]
60
  else:
61
+ mask = enc["attention_mask"].squeeze(0).detach().cpu().numpy().astype(bool)
 
62
  if mask.any():
63
  sent_embedding = last_hidden_np[mask].mean(axis=0)
64
  else:
 
67
  return tokens_with_special, last_hidden_np, sent_embedding
68
 
69
  def build_token_df(tokens, embeddings, dims_to_show=DEFAULT_DIMS_TO_SHOW) -> pd.DataFrame:
70
+ """Create a DataFrame of tokens with the first N embedding dimensions."""
 
 
 
71
  if len(tokens) == 0:
72
  return pd.DataFrame(columns=["token"] + [f"dim_{i}" for i in range(dims_to_show)])
73
 
 
80
  return pd.DataFrame(data, columns=cols)
81
 
82
  def save_full_token_csv(tokens, embeddings) -> str:
83
+ """Save full 768-dim token embeddings to a CSV and return file path."""
 
 
 
84
  if len(tokens) == 0:
85
  fd, empty_path = tempfile.mkstemp(suffix=".csv")
86
  os.close(fd)
87
  return empty_path
88
 
89
  cols = ["token"] + [f"dim_{i}" for i in range(embeddings.shape[1])]
90
+ rows = [[tok] + list(vec) for tok, vec in zip(tokens, embeddings)]
 
 
91
  df = pd.DataFrame(rows, columns=cols)
92
 
93
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
 
95
  return tmp.name
96
 
97
  def save_sentence_csv(sent_embedding) -> str:
98
+ """Save 768-dim sentence embedding to CSV and return file path."""
 
 
 
99
  cols = [f"dim_{i}" for i in range(sent_embedding.shape[0])]
100
  df = pd.DataFrame([sent_embedding], columns=cols)
101
 
 
106
  # ---- Gradio pipeline ----
107
  def run_pipeline(raw_text: str, dims_to_show: int):
108
  """
109
+ Process input text, generate embeddings, and prepare preview/CSV outputs.
110
+
111
  Returns:
112
+ cleaned_text (str)
113
+ shape_info (str)
114
+ token_df (DataFrame with first N dims)
115
+ token_csv_path (File)
116
+ sent_df (DataFrame with first N dims)
117
+ sent_csv_path (File)
118
+ """
119
+ cleaned_text = remove_non_ascii_and_lowercase(raw_text or "")
120
+ tokens, token_embeds, sent_embed = get_embeddings(cleaned_text)
121
+
122
+ seq_len = token_embeds.shape[0]
123
+ hidden = token_embeds.shape[1] if seq_len > 0 else 768
124
+ shape_info = (
125
+ f"Tokens (including [CLS]/[SEP]): {seq_len}\n"
126
+ f"Embedding size: {hidden}\n"
127
+ f"Sentence embedding size: {sent_embed.shape[0]}"
128
+ )
129
+
130
+ token_df = build_token_df(tokens, token_embeds, dims_to_show=dims_to_show)
131
+
132
+ dims_to_show = max(1, min(dims_to_show, sent_embed.shape[0]))
133
+ sent_df = pd.DataFrame([list(sent_embed[:dims_to_show])],
134
+ columns=[f"dim_{i}" for i in range(dims_to_show)])
135
+
136
+ token_csv_path = save_full_token_csv(tokens, token_embeds)
137
+ sent_csv_path = save_sentence_csv(sent_embed)
138
+
139
+ return cleaned_text, shape_info, token_df, token_csv_path, sent_df, sent_csv_path
140
+
141
+ # ---- Gradio Interface ----
142
+ with gr.Blocks(title="BERT Token & Embedding Explorer") as demo:
143
+ gr.Markdown(
144
+ """
145
+ # 🧠 BERT Token & Embedding Explorer
146
+ - Cleans your text (removes **non-ASCII** chars, lowercases)
147
+ - Tokenizes with **bert-base-uncased**
148
+ - Shows per-token embeddings (first *N* dims)
149
+ - Exports **full 768-dim** token and sentence embeddings as CSV
150
+ """
151
+ )
152
+
153
+ with gr.Row():
154
+ inp = gr.Textbox(
155
+ label="Enter text",
156
+ placeholder="Type or paste text here…",
157
+ lines=5,
158
+ value="Don't you love 🤗 Transformers? BERT embeddings are neat!"
159
+ )
160
+ with gr.Row():
161
+ dims = gr.Slider(4, 64, value=DEFAULT_DIMS_TO_SHOW, step=1, label="Dimensions to display (preview)")
162
+
163
+ run_btn = gr.Button("Embed with BERT", variant="primary")
164
+
165
+ with gr.Row():
166
+ cleaned_out = gr.Textbox(label="Cleaned text (ASCII-only, lowercased)", interactive=False)
167
+
168
+ shape_info = gr.Textbox(label="Shapes & Info", interactive=False)
169
+
170
+ gr.Markdown("### Token embeddings (preview)")
171
+ token_df = gr.Dataframe(
172
+ label="Tokens with first N embedding dimensions",
173
+ interactive=False,
174
+ )
175
+ token_csv = gr.File(label="Download FULL token embeddings (CSV)")
176
+
177
+ gr.Markdown("### Sentence embedding (preview)")
178
+ sent_df = gr.Dataframe(
179
+ label="First N dimensions of the pooled sentence embedding",
180
+ interactive=False,
181
+ )
182
+ sent_csv = gr.File(label="Download FULL sentence embedding (CSV)")
183
+
184
+ run_btn.click(
185
+ fn=run_pipeline,
186
+ inputs=[inp, dims],
187
+ outputs=[cleaned_out, shape_info, token_df, token_csv, sent_df, sent_csv]
188
+ )
189
+
190
+ if __name__ == "__main__":
191
+ demo.launch()