rianders commited on
Commit
c3db0f5
·
verified ·
1 Parent(s): f7ffea0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -14,9 +14,9 @@ def get_bert_embeddings(words):
14
  for word in words:
15
  inputs = tokenizer(word, return_tensors='pt')
16
  outputs = model(**inputs)
17
- # Use the [CLS] token's embedding
18
- cls_embedding = outputs.last_hidden_state[0][0].detach().numpy()
19
- embeddings.append(cls_embedding)
20
 
21
  if len(embeddings) > 0:
22
  pca = PCA(n_components=3)
@@ -25,6 +25,7 @@ def get_bert_embeddings(words):
25
  return []
26
 
27
 
 
28
  # Plotly plotting function
29
  def plot_interactive_bert_embeddings(embeddings, words):
30
  if len(words) < 4:
 
14
  for word in words:
15
  inputs = tokenizer(word, return_tensors='pt')
16
  outputs = model(**inputs)
17
+ # Calculate mean of embeddings across all tokens in the phrase
18
+ mean_embedding = outputs.last_hidden_state[0].mean(dim=0).detach().numpy()
19
+ embeddings.append(mean_embedding)
20
 
21
  if len(embeddings) > 0:
22
  pca = PCA(n_components=3)
 
25
  return []
26
 
27
 
28
+
29
  # Plotly plotting function
30
  def plot_interactive_bert_embeddings(embeddings, words):
31
  if len(words) < 4: