File size: 1,505 Bytes
16b2a71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import gradio as gr
import torch
from transformers import BertTokenizer, BertModel
import re

# Load BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")

def preprocess_text(text):
    # Remove ASCII characters and lowercase
    cleaned = re.sub(r'[^\x80-\uFFFF]+', '', text)
    return cleaned.lower()

def get_bert_embeddings(text):
    cleaned_text = preprocess_text(text)
    inputs = tokenizer(cleaned_text, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
    embeddings = outputs.last_hidden_state.squeeze(0)  # shape: [seq_len, hidden_size]
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze(0))

    # Convert embeddings to list of floats for display
    token_embeddings = {token: embedding.tolist() for token, embedding in zip(tokens, embeddings)}
    return token_embeddings

def format_output(token_embeddings):
    formatted = ""
    for token, emb in token_embeddings.items():
        formatted += f"Token: {token}\nEmbedding: {emb[:5]}... ({len(emb)} dims)\n\n"
    return formatted

demo = gr.Interface(
    fn=lambda text: format_output(get_bert_embeddings(text)),
    inputs=gr.Textbox(lines=4, placeholder="Enter text here..."),
    outputs="text",
    title="BERT Token Embeddings Viewer",
    description="Removes ASCII characters, lowercases input, and shows BERT tokens with embeddings."
)

if __name__ == "__main__":
    demo.launch()