Update app.py
Browse files
app.py
CHANGED
|
@@ -211,12 +211,11 @@ def embed_sequence(seq: str) -> np.ndarray:
|
|
| 211 |
out = model(**enc, output_hidden_states=True)
|
| 212 |
hs = out.hidden_states
|
| 213 |
mask = enc["attention_mask"].unsqueeze(-1).float()
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
return np.concatenate(mvecs)
|
| 220 |
|
| 221 |
seq = seq.strip()
|
| 222 |
if len(seq) <= MAX:
|
|
|
|
| 211 |
out = model(**enc, output_hidden_states=True)
|
| 212 |
hs = out.hidden_states
|
| 213 |
mask = enc["attention_mask"].unsqueeze(-1).float()
|
| 214 |
+
|
| 215 |
+
# Grab the FINAL layer (-1) instead of hardcoding [8, 10, 11]
|
| 216 |
+
h = hs[-1]
|
| 217 |
+
mv = (h * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
|
| 218 |
+
return mv.squeeze(0).cpu().numpy()
|
|
|
|
| 219 |
|
| 220 |
seq = seq.strip()
|
| 221 |
if len(seq) <= MAX:
|