Patrick von Platen commited on
Commit ·
4023cca
1
Parent(s): 473506b
run forward
Browse files
run_forward_gpt2_large.py
CHANGED
|
@@ -4,12 +4,12 @@ from flax.jax_utils import replicate
|
|
| 4 |
from jax import jit, pmap
|
| 5 |
import numpy as np
|
| 6 |
|
| 7 |
-
model = FlaxGPT2LMHeadModel.from_pretrained("
|
| 8 |
dummy_inputs = np.array(4 * [256 * [1]], dtype=np.int32)
|
| 9 |
|
| 10 |
|
| 11 |
def run_forward(inputs, params):
|
| 12 |
-
return model(inputs, params).logits
|
| 13 |
|
| 14 |
|
| 15 |
jitted_forward = jit(run_forward)
|
|
|
|
| 4 |
from jax import jit, pmap
|
| 5 |
import numpy as np
|
| 6 |
|
| 7 |
+
model = FlaxGPT2LMHeadModel.from_pretrained("gpt2-large")
|
| 8 |
dummy_inputs = np.array(4 * [256 * [1]], dtype=np.int32)
|
| 9 |
|
| 10 |
|
| 11 |
def run_forward(inputs, params):
|
| 12 |
+
return model(inputs, params=params).logits
|
| 13 |
|
| 14 |
|
| 15 |
jitted_forward = jit(run_forward)
|