heyongxin233 commited on
Commit
7b303af
·
verified ·
1 Parent(s): 9e0a7dc

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +13 -7
README.md CHANGED
@@ -20,19 +20,25 @@ import torch
20
  from transformers import AutoTokenizer, AutoModel
21
 
22
  model_id = "heyongxin233/DETree"
23
- tgt_layer = 18 # 0 = embeddings; 1..24 = encoder layers (RoBERTa-large)
 
 
24
 
25
  tok = AutoTokenizer.from_pretrained(model_id)
26
- enc = AutoModel.from_pretrained(model_id, output_hidden_states=True)
 
27
 
28
  texts = ["An example sentence.", "Another one."]
29
  batch = tok(texts, padding=True, truncation=True, return_tensors="pt")
 
30
 
31
- with torch.no_grad():
32
  out = enc(**batch)
33
- hs = out.hidden_states[tgt_layer] # (bsz, seq, hidden)
34
- mask = batch["attention_mask"].unsqueeze(-1) # (bsz, seq, 1)
35
  hs = hs.masked_fill(~mask.bool(), float("-inf"))
36
- emb, _ = hs.max(dim=1) # (bsz, hidden) max-pool over tokens
37
  emb = torch.nn.functional.normalize(emb, p=2, dim=-1)
38
- print(emb.shape) # -> (batch_size, 1024)
 
 
 
20
  from transformers import AutoTokenizer, AutoModel
21
 
22
  model_id = "heyongxin233/DETree"
23
+ tgt_layer = 18 # 0=embeddings; 1..24=encoder layers (RoBERTa-large)
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ print("Using device:", device)
26
 
27
  tok = AutoTokenizer.from_pretrained(model_id)
28
+ enc = AutoModel.from_pretrained(model_id, output_hidden_states=True).to(device)
29
+ enc.eval()
30
 
31
  texts = ["An example sentence.", "Another one."]
32
  batch = tok(texts, padding=True, truncation=True, return_tensors="pt")
33
+ batch = {k: v.to(device) for k, v in batch.items()}
34
 
35
+ with torch.inference_mode():
36
  out = enc(**batch)
37
+ hs = out.hidden_states[tgt_layer] # (bsz, seq, hidden)
38
+ mask = batch["attention_mask"].unsqueeze(-1) # (bsz, seq, 1)
39
  hs = hs.masked_fill(~mask.bool(), float("-inf"))
40
+ emb, _ = hs.max(dim=1) # max-pool over tokens
41
  emb = torch.nn.functional.normalize(emb, p=2, dim=-1)
42
+
43
+ print(emb.device, emb.shape) # -> (batch_size, 1024)
44
+ ```