aimgo commited on
Commit
9b62ecf
·
verified ·
1 Parent(s): d54fd7c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +22 -8
README.md CHANGED
@@ -39,27 +39,41 @@ import torch
39
  from transformers import AutoModel, AutoTokenizer
40
 
41
  device = "cuda" if torch.cuda.is_available() else "cpu"
42
- model = AutoModel.from_pretrained("aimgo/CaputEmendatoris", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
43
- tokenizer = AutoTokenizer.from_pretrained("aimgo/Emendator")
 
 
 
 
 
 
 
 
 
 
44
  model.eval()
45
 
46
  text = "quandoquidcrn natura anirni rnortalis habctur."
 
47
  enc = tokenizer(text, return_tensors="pt").to(device)
48
 
49
- # detect errors at each byte
50
  with torch.no_grad():
51
- probs = model.detect(enc["input_ids"], enc["attention_mask"])
 
 
52
 
53
- # byte probability -> character
54
- byte_probs = probs[0][:-1].cpu().tolist()
55
  char_probs = []
56
  byte_idx = 0
57
  for c in text:
58
  n = len(c.encode("utf-8"))
59
- char_probs.append(max(byte_probs[byte_idx:byte_idx + n]) if byte_idx + n <= len(byte_probs) else 0.0)
 
 
 
60
  byte_idx += n
61
 
62
- output = char_probs
63
  ```
64
 
65
  If you use this in your work, please cite:
 
39
  from transformers import AutoModel, AutoTokenizer
40
 
41
  device = "cuda" if torch.cuda.is_available() else "cpu"
42
+
43
+ model_repo = "aimgo/caputemendatoris"
44
+ tokenizer_repo = "aimgo/Emendator"
45
+
46
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_repo)
47
+
48
+ model = AutoModel.from_pretrained(
49
+ model_repo,
50
+ trust_remote_code=True, # <=== NECESSARY, THIS HEAD HAS A CUSTOM MODELING FILE
51
+ torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
52
+ ).to(device)
53
+
54
  model.eval()
55
 
56
  text = "quandoquidcrn natura anirni rnortalis habctur."
57
+
58
  enc = tokenizer(text, return_tensors="pt").to(device)
59
 
60
+ # detector
61
  with torch.no_grad():
62
+ probs = model.detect(enc["input_ids"],enc.get("attention_mask", None))
63
+
64
+ byte_probs = probs[0][:-1].detach().cpu().tolist()
65
 
 
 
66
  char_probs = []
67
  byte_idx = 0
68
  for c in text:
69
  n = len(c.encode("utf-8"))
70
+ if byte_idx + n <= len(byte_probs):
71
+ char_probs.append(max(byte_probs[byte_idx:byte_idx+n]))
72
+ else:
73
+ char_probs.append(0.0)
74
  byte_idx += n
75
 
76
+ print(char_probs)
77
  ```
78
 
79
  If you use this in your work, please cite: