xTimeCrystal commited on
Commit
b3ff271
·
verified ·
1 Parent(s): 59ce5b2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +9 -4
README.md CHANGED
@@ -57,7 +57,7 @@ model.load_weights(th_p_, strict=True)
57
  ### Example: Evaluate the model on some text
58
 
59
  ```python
60
- def eval(text_: str, model, config, per_token=False):
61
  text_ = text_.encode('utf-8')
62
 
63
  x_prev_0s, state_prevs, x_prev_1s = (mx.zeros([config['layers'], 1, 1, config['input_dims']], dtype=dtype),
@@ -82,7 +82,12 @@ def eval(text_: str, model, config, per_token=False):
82
  return nn.losses.cross_entropy(logits, mx.roll(txt_btch, -1, axis=1))[:, :-1].mean(), (mx.argmax(logits, axis=-1) == mx.roll(txt_btch, -1, axis=1)).mean()
83
  ```
84
 
85
- ```python
 
 
 
 
 
86
  text_ = '''def to_char(x):
87
  try:
88
  return bytes([x]).decode('utf-8')
@@ -90,11 +95,11 @@ text_ = '''def to_char(x):
90
  return f'{x}'
91
  '''
92
 
93
- print(eval(text_, model, config))
94
  ```
95
 
96
  ```
97
- (array(0.738281, dtype=bfloat16), array(0.77451, dtype=float32)) # (CE Loss, Accuracy of next character)
98
  ```
99
 
100
  ### Example: Visualize the attention maps (beta)
 
57
  ### Example: Evaluate the model on some text
58
 
59
  ```python
60
+ def eval_loss(text_: str, model, config, per_token=False):
61
  text_ = text_.encode('utf-8')
62
 
63
  x_prev_0s, state_prevs, x_prev_1s = (mx.zeros([config['layers'], 1, 1, config['input_dims']], dtype=dtype),
 
82
  return nn.losses.cross_entropy(logits, mx.roll(txt_btch, -1, axis=1))[:, :-1].mean(), (mx.argmax(logits, axis=-1) == mx.roll(txt_btch, -1, axis=1)).mean()
83
  ```
84
 
85
+ The text should show something like '[STX]def to_char(x): ...' since '[STX]' is my start token. Else, add the \x02 character in, NOT the picture version.
86
+
87
+ The STX character should appear **bright red**, the version on the right is the correct one.
88
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/66a767dcbe4c3c2683495a8b/rTVMKiioh-uo1Syim3BZ3.png)
89
+
90
+ ```python
91
  text_ = '''def to_char(x):
92
  try:
93
  return bytes([x]).decode('utf-8')
 
95
  return f'{x}'
96
  '''
97
 
98
+ print(eval_loss(text_, model, config)) # returns (CE Loss, Accuracy of next character)
99
  ```
100
 
101
  ```
102
+ (array(0.738281, dtype=bfloat16), array(0.77451, dtype=float32))
103
  ```
104
 
105
  ### Example: Visualize the attention maps (beta)