Update README.md
Browse files
README.md
CHANGED
|
@@ -57,7 +57,7 @@ model.load_weights(th_p_, strict=True)
|
|
| 57 |
### Example: Evaluate the model on some text
|
| 58 |
|
| 59 |
```python
|
| 60 |
-
def
|
| 61 |
text_ = text_.encode('utf-8')
|
| 62 |
|
| 63 |
x_prev_0s, state_prevs, x_prev_1s = (mx.zeros([config['layers'], 1, 1, config['input_dims']], dtype=dtype),
|
|
@@ -82,7 +82,12 @@ def eval(text_: str, model, config, per_token=False):
|
|
| 82 |
return nn.losses.cross_entropy(logits, mx.roll(txt_btch, -1, axis=1))[:, :-1].mean(), (mx.argmax(logits, axis=-1) == mx.roll(txt_btch, -1, axis=1)).mean()
|
| 83 |
```
|
| 84 |
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
text_ = '''def to_char(x):
|
| 87 |
try:
|
| 88 |
return bytes([x]).decode('utf-8')
|
|
@@ -90,11 +95,11 @@ text_ = '''def to_char(x):
|
|
| 90 |
return f'{x}'
|
| 91 |
'''
|
| 92 |
|
| 93 |
-
print(
|
| 94 |
```
|
| 95 |
|
| 96 |
```
|
| 97 |
-
(array(0.738281, dtype=bfloat16), array(0.77451, dtype=float32))
|
| 98 |
```
|
| 99 |
|
| 100 |
### Example: Visualize the attention maps (beta)
|
|
|
|
| 57 |
### Example: Evaluate the model on some text
|
| 58 |
|
| 59 |
```python
|
| 60 |
+
def eval_loss(text_: str, model, config, per_token=False):
|
| 61 |
text_ = text_.encode('utf-8')
|
| 62 |
|
| 63 |
x_prev_0s, state_prevs, x_prev_1s = (mx.zeros([config['layers'], 1, 1, config['input_dims']], dtype=dtype),
|
|
|
|
| 82 |
return nn.losses.cross_entropy(logits, mx.roll(txt_btch, -1, axis=1))[:, :-1].mean(), (mx.argmax(logits, axis=-1) == mx.roll(txt_btch, -1, axis=1)).mean()
|
| 83 |
```
|
| 84 |
|
| 85 |
+
The text should show something like '[STX]def to_char(x): ...' since '[STX]' is my start token. Else, add the \x02 character in, NOT the picture version.
|
| 86 |
+
|
| 87 |
+
The STX character should appear **bright red**, the version on the right is the correct one.
|
| 88 |
+

|
| 89 |
+
|
| 90 |
+
```python
|
| 91 |
text_ = '''def to_char(x):
|
| 92 |
try:
|
| 93 |
return bytes([x]).decode('utf-8')
|
|
|
|
| 95 |
return f'{x}'
|
| 96 |
'''
|
| 97 |
|
| 98 |
+
print(eval_loss(text_, model, config)) # returns (CE Loss, Accuracy of next character)
|
| 99 |
```
|
| 100 |
|
| 101 |
```
|
| 102 |
+
(array(0.738281, dtype=bfloat16), array(0.77451, dtype=float32))
|
| 103 |
```
|
| 104 |
|
| 105 |
### Example: Visualize the attention maps (beta)
|