import re import os import tempfile import numpy as np import pandas as pd import torch import gradio as gr from transformers import BertTokenizer, BertModel # ---- Configuration ---- MODEL_NAME = "bert-base-uncased" DEFAULT_DIMS_TO_SHOW = 16 # how many embedding dims to show in the UI tables (full 768 in CSV) POOLING = "mean" # "mean" or "cls" # ---- Load tokenizer & model once (cached) ---- tokenizer = BertTokenizer.from_pretrained(MODEL_NAME) model = BertModel.from_pretrained(MODEL_NAME) model.eval() # inference mode device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # ---- Text cleaning helpers ---- def remove_non_ascii_and_lowercase(text: str) -> str: """Remove non-ASCII characters and lowercase the text.""" text_ascii = re.sub(r"[^\x00-\x7F]+", "", text or "") return text_ascii.lower() # ---- Embedding helpers ---- def get_embeddings(clean_text: str): """ Generate token and sentence embeddings using BERT. Returns: tokens_with_special (list[str]): tokens including [CLS]/[SEP] embeddings (np.ndarray): shape (seq_len, hidden_size) sent_embedding (np.ndarray): shape (hidden_size,) """ if not clean_text.strip(): return [], np.zeros((0, 768), dtype=np.float32), np.zeros((768,), dtype=np.float32) enc = tokenizer( clean_text, return_tensors="pt", add_special_tokens=True, padding=False, truncation=True, max_length=512 ) enc = {k: v.to(device) for k, v in enc.items()} with torch.no_grad(): outputs = model(**enc) last_hidden = outputs.last_hidden_state # (1, seq_len, hidden) last_hidden_np = last_hidden.squeeze(0).detach().cpu().numpy() tokens_with_special = tokenizer.convert_ids_to_tokens(enc["input_ids"][0]) if POOLING == "cls": sent_embedding = last_hidden_np[0] # [CLS] else: mask = enc["attention_mask"].squeeze(0).detach().cpu().numpy().astype(bool) if mask.any(): sent_embedding = last_hidden_np[mask].mean(axis=0) else: sent_embedding = last_hidden_np.mean(axis=0) return tokens_with_special, last_hidden_np, sent_embedding def build_token_df(tokens, embeddings, dims_to_show=DEFAULT_DIMS_TO_SHOW) -> pd.DataFrame: """Create a DataFrame of tokens with the first N embedding dimensions.""" if len(tokens) == 0: return pd.DataFrame(columns=["token"] + [f"dim_{i}" for i in range(dims_to_show)]) dims_to_show = max(1, min(dims_to_show, embeddings.shape[1])) cols = ["token"] + [f"dim_{i}" for i in range(dims_to_show)] data = [] for tok, vec in zip(tokens, embeddings): row = [tok] + list(vec[:dims_to_show]) data.append(row) return pd.DataFrame(data, columns=cols) def save_full_token_csv(tokens, embeddings) -> str: """Save full 768-dim token embeddings to a CSV and return file path.""" if len(tokens) == 0: fd, empty_path = tempfile.mkstemp(suffix=".csv") os.close(fd) return empty_path cols = ["token"] + [f"dim_{i}" for i in range(embeddings.shape[1])] rows = [[tok] + list(vec) for tok, vec in zip(tokens, embeddings)] df = pd.DataFrame(rows, columns=cols) tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv") df.to_csv(tmp.name, index=False) return tmp.name def save_sentence_csv(sent_embedding) -> str: """Save 768-dim sentence embedding to CSV and return file path.""" cols = [f"dim_{i}" for i in range(sent_embedding.shape[0])] df = pd.DataFrame([sent_embedding], columns=cols) tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv") df.to_csv(tmp.name, index=False) return tmp.name # ---- Gradio pipeline ---- def run_pipeline(raw_text: str, dims_to_show: int): """ Process input text, generate embeddings, and prepare preview/CSV outputs. Returns: cleaned_text (str) shape_info (str) token_df (DataFrame with first N dims) token_csv_path (File) sent_df (DataFrame with first N dims) sent_csv_path (File) """ cleaned_text = remove_non_ascii_and_lowercase(raw_text or "") tokens, token_embeds, sent_embed = get_embeddings(cleaned_text) seq_len = token_embeds.shape[0] hidden = token_embeds.shape[1] if seq_len > 0 else 768 shape_info = ( f"Tokens (including [CLS]/[SEP]): {seq_len}\n" f"Embedding size: {hidden}\n" f"Sentence embedding size: {sent_embed.shape[0]}" ) token_df = build_token_df(tokens, token_embeds, dims_to_show=dims_to_show) dims_to_show = max(1, min(dims_to_show, sent_embed.shape[0])) sent_df = pd.DataFrame([list(sent_embed[:dims_to_show])], columns=[f"dim_{i}" for i in range(dims_to_show)]) token_csv_path = save_full_token_csv(tokens, token_embeds) sent_csv_path = save_sentence_csv(sent_embed) return cleaned_text, shape_info, token_df, token_csv_path, sent_df, sent_csv_path # ---- Gradio Interface ---- with gr.Blocks(title="BERT Token & Embedding Explorer") as demo: gr.Markdown( """ # 🧠 BERT Token & Embedding Explorer - Cleans your text (removes **non-ASCII** chars, lowercases) - Tokenizes with **bert-base-uncased** - Shows per-token embeddings (first *N* dims) - Exports **full 768-dim** token and sentence embeddings as CSV """ ) with gr.Row(): inp = gr.Textbox( label="Enter text", placeholder="Type or paste text here…", lines=5, value="Don't you love 🤗 Transformers? BERT embeddings are neat!" ) with gr.Row(): dims = gr.Slider(4, 64, value=DEFAULT_DIMS_TO_SHOW, step=1, label="Dimensions to display (preview)") run_btn = gr.Button("Embed with BERT", variant="primary") with gr.Row(): cleaned_out = gr.Textbox(label="Cleaned text (ASCII-only, lowercased)", interactive=False) shape_info = gr.Textbox(label="Shapes & Info", interactive=False) gr.Markdown("### Token embeddings (preview)") token_df = gr.Dataframe( label="Tokens with first N embedding dimensions", interactive=False, ) token_csv = gr.File(label="Download FULL token embeddings (CSV)") gr.Markdown("### Sentence embedding (preview)") sent_df = gr.Dataframe( label="First N dimensions of the pooled sentence embedding", interactive=False, ) sent_csv = gr.File(label="Download FULL sentence embedding (CSV)") run_btn.click( fn=run_pipeline, inputs=[inp, dims], outputs=[cleaned_out, shape_info, token_df, token_csv, sent_df, sent_csv] ) if __name__ == "__main__": demo.launch()