pswap commited on
Commit
c1bed2a
·
1 Parent(s): a552320
Files changed (1) hide show
  1. esm_utils.py +1 -3
esm_utils.py CHANGED
@@ -32,8 +32,6 @@ class EsmEmbedding:
32
 
33
  mean_embedding = hidden[1:-1].mean(dim=0) # mean over non-[CLS]/[EOS]
34
  cls_embedding = hidden[0] # CLS token
35
- print("Mean",mean_embedding)
36
- print("CLS",cls_embedding)
37
  return mean_embedding, cls_embedding
38
 
39
  @spaces.GPU(duration=128)
@@ -97,7 +95,7 @@ class EsmEmbedding:
97
  embeddings_stack = torch.stack(embeddings, dim=0).to(torch.float64)
98
  embeddings_stack = torch.nn.functional.normalize(embeddings_stack, p=2, dim=1)
99
  embeddings_mean = torch.mean(embeddings_stack, dim=0)
100
-
101
  return embeddings_mean.cpu().numpy() # Move to CPU and convert to numpy
102
 
103
  # Example usage:
 
32
 
33
  mean_embedding = hidden[1:-1].mean(dim=0) # mean over non-[CLS]/[EOS]
34
  cls_embedding = hidden[0] # CLS token
 
 
35
  return mean_embedding, cls_embedding
36
 
37
  @spaces.GPU(duration=128)
 
95
  embeddings_stack = torch.stack(embeddings, dim=0).to(torch.float64)
96
  embeddings_stack = torch.nn.functional.normalize(embeddings_stack, p=2, dim=1)
97
  embeddings_mean = torch.mean(embeddings_stack, dim=0)
98
+ print("Mean",embeddings_mean)
99
  return embeddings_mean.cpu().numpy() # Move to CPU and convert to numpy
100
 
101
  # Example usage: