Update README.md
Browse files
README.md
CHANGED
|
@@ -20,19 +20,25 @@ import torch
|
|
| 20 |
from transformers import AutoTokenizer, AutoModel
|
| 21 |
|
| 22 |
model_id = "heyongxin233/DETree"
|
| 23 |
-
tgt_layer = 18
|
|
|
|
|
|
|
| 24 |
|
| 25 |
tok = AutoTokenizer.from_pretrained(model_id)
|
| 26 |
-
enc = AutoModel.from_pretrained(model_id, output_hidden_states=True)
|
|
|
|
| 27 |
|
| 28 |
texts = ["An example sentence.", "Another one."]
|
| 29 |
batch = tok(texts, padding=True, truncation=True, return_tensors="pt")
|
|
|
|
| 30 |
|
| 31 |
-
with torch.
|
| 32 |
out = enc(**batch)
|
| 33 |
-
hs = out.hidden_states[tgt_layer]
|
| 34 |
-
mask = batch["attention_mask"].unsqueeze(-1)
|
| 35 |
hs = hs.masked_fill(~mask.bool(), float("-inf"))
|
| 36 |
-
emb, _ = hs.max(dim=1)
|
| 37 |
emb = torch.nn.functional.normalize(emb, p=2, dim=-1)
|
| 38 |
-
|
|
|
|
|
|
|
|
|
| 20 |
from transformers import AutoTokenizer, AutoModel
|
| 21 |
|
| 22 |
model_id = "heyongxin233/DETree"
|
| 23 |
+
tgt_layer = 18 # 0=embeddings; 1..24=encoder layers (RoBERTa-large)
|
| 24 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 25 |
+
print("Using device:", device)
|
| 26 |
|
| 27 |
tok = AutoTokenizer.from_pretrained(model_id)
|
| 28 |
+
enc = AutoModel.from_pretrained(model_id, output_hidden_states=True).to(device)
|
| 29 |
+
enc.eval()
|
| 30 |
|
| 31 |
texts = ["An example sentence.", "Another one."]
|
| 32 |
batch = tok(texts, padding=True, truncation=True, return_tensors="pt")
|
| 33 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 34 |
|
| 35 |
+
with torch.inference_mode():
|
| 36 |
out = enc(**batch)
|
| 37 |
+
hs = out.hidden_states[tgt_layer] # (bsz, seq, hidden)
|
| 38 |
+
mask = batch["attention_mask"].unsqueeze(-1) # (bsz, seq, 1)
|
| 39 |
hs = hs.masked_fill(~mask.bool(), float("-inf"))
|
| 40 |
+
emb, _ = hs.max(dim=1) # max-pool over tokens
|
| 41 |
emb = torch.nn.functional.normalize(emb, p=2, dim=-1)
|
| 42 |
+
|
| 43 |
+
print(emb.device, emb.shape) # -> (batch_size, 1024)
|
| 44 |
+
```
|