ym59 commited on
Commit
91ae831
·
verified ·
1 Parent(s): 31465a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -211,12 +211,11 @@ def embed_sequence(seq: str) -> np.ndarray:
211
  out = model(**enc, output_hidden_states=True)
212
  hs = out.hidden_states
213
  mask = enc["attention_mask"].unsqueeze(-1).float()
214
- mvecs = []
215
- for li in [8, 10, 11]:
216
- h = hs[li]
217
- mv = (h * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
218
- mvecs.append(mv.squeeze(0).cpu().numpy())
219
- return np.concatenate(mvecs)
220
 
221
  seq = seq.strip()
222
  if len(seq) <= MAX:
 
211
  out = model(**enc, output_hidden_states=True)
212
  hs = out.hidden_states
213
  mask = enc["attention_mask"].unsqueeze(-1).float()
214
+
215
+ # Grab the FINAL layer (-1) instead of hardcoding [8, 10, 11]
216
+ h = hs[-1]
217
+ mv = (h * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
218
+ return mv.squeeze(0).cpu().numpy()
 
219
 
220
  seq = seq.strip()
221
  if len(seq) <= MAX: