updated per new model code
Browse files- pt_resnet_to_flax.py +20 -24
pt_resnet_to_flax.py
CHANGED
|
@@ -1,23 +1,23 @@
|
|
| 1 |
-
from transformers import ResNetConfig, FlaxResNetForImageClassification, ResNetForImageClassification
|
| 2 |
from flax.traverse_util import flatten_dict, unflatten_dict
|
| 3 |
from flax.core.frozen_dict import unfreeze
|
| 4 |
import re
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
pt_resnet = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")
|
| 7 |
-
pt_state = pt_resnet.state_dict()
|
| 8 |
|
|
|
|
| 9 |
flax_resnet = FlaxResNetForImageClassification(pt_resnet.config)
|
|
|
|
|
|
|
| 10 |
flax_state = flatten_dict(unfreeze(flax_resnet.params))
|
| 11 |
-
|
|
|
|
| 12 |
new_pt_state = {}
|
| 13 |
-
pt_batch_stats = {}
|
| 14 |
for key, tensor in pt_state.items():
|
| 15 |
key_parts = set(key.split("."))
|
| 16 |
tensor = tensor.numpy()
|
| 17 |
|
| 18 |
-
key = re.sub(r"(?<=[a-zA-Z]).(?=\d)", "_", key)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
if "convolution.weight" in key:
|
| 22 |
key = key.replace("weight", "kernel")
|
| 23 |
tensor = tensor.transpose((2, 3, 1, 0))
|
|
@@ -34,36 +34,32 @@ for key, tensor in pt_state.items():
|
|
| 34 |
key = "params."+key
|
| 35 |
new_pt_state[key] = tensor
|
| 36 |
|
| 37 |
-
elif "
|
| 38 |
-
key = "params.classifier.kernel"
|
| 39 |
new_pt_state[key] = tensor.transpose()
|
| 40 |
|
| 41 |
-
elif "
|
| 42 |
-
key = "params.classifier.bias"
|
| 43 |
new_pt_state[key] = tensor
|
| 44 |
|
| 45 |
elif "normalization.running_mean" in key:
|
| 46 |
key = key.replace("running_mean", "mean")
|
| 47 |
-
|
|
|
|
| 48 |
|
| 49 |
elif "normalization.running_var" in key:
|
| 50 |
key = key.replace("running_var", "var")
|
| 51 |
-
|
|
|
|
| 52 |
|
| 53 |
else:
|
| 54 |
-
|
| 55 |
|
|
|
|
| 56 |
for total_updated, (new_key, new_tensor) in enumerate(new_pt_state.items()):
|
| 57 |
orig_flax_tensor = flax_state.get(tuple(new_key.split(".")))
|
| 58 |
assert orig_flax_tensor is not None
|
| 59 |
-
|
| 60 |
-
if not("classifier" in new_key):
|
| 61 |
-
assert orig_flax_tensor.shape == new_tensor.shape
|
| 62 |
flax_state[tuple(new_key.split("."))] = new_tensor
|
| 63 |
-
|
| 64 |
flax_state = unflatten_dict(flax_state)
|
| 65 |
-
|
| 66 |
-
pt_batch_stats = unflatten_dict({tuple(k.split(".")):v for k,v in pt_batch_stats.items()})
|
| 67 |
-
flax_state["batch_stats"] = pt_batch_stats
|
| 68 |
-
|
| 69 |
-
flax_resnet.save_pretrained("resnet_50_flax", params=flax_state)
|
|
|
|
| 1 |
+
from transformers import ResNetConfig, FlaxResNetForImageClassification, ResNetForImageClassification, FlaxResNetModel, ResNetModel
|
| 2 |
from flax.traverse_util import flatten_dict, unflatten_dict
|
| 3 |
from flax.core.frozen_dict import unfreeze
|
| 4 |
import re
|
| 5 |
+
import jax.numpy as jnp
|
| 6 |
+
import torch
|
| 7 |
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
pt_resnet = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")
|
| 10 |
flax_resnet = FlaxResNetForImageClassification(pt_resnet.config)
|
| 11 |
+
|
| 12 |
+
pt_state = pt_resnet.state_dict()
|
| 13 |
flax_state = flatten_dict(unfreeze(flax_resnet.params))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
new_pt_state = {}
|
|
|
|
| 17 |
for key, tensor in pt_state.items():
|
| 18 |
key_parts = set(key.split("."))
|
| 19 |
tensor = tensor.numpy()
|
| 20 |
|
|
|
|
|
|
|
|
|
|
| 21 |
if "convolution.weight" in key:
|
| 22 |
key = key.replace("weight", "kernel")
|
| 23 |
tensor = tensor.transpose((2, 3, 1, 0))
|
|
|
|
| 34 |
key = "params."+key
|
| 35 |
new_pt_state[key] = tensor
|
| 36 |
|
| 37 |
+
elif "classifier.1.weight" in key:
|
| 38 |
+
key = "params.classifier.1.kernel"
|
| 39 |
new_pt_state[key] = tensor.transpose()
|
| 40 |
|
| 41 |
+
elif "classifier.1.bias" in key:
|
| 42 |
+
key = "params.classifier.1.bias"
|
| 43 |
new_pt_state[key] = tensor
|
| 44 |
|
| 45 |
elif "normalization.running_mean" in key:
|
| 46 |
key = key.replace("running_mean", "mean")
|
| 47 |
+
key = "batch_stats."+key
|
| 48 |
+
new_pt_state[key] = tensor
|
| 49 |
|
| 50 |
elif "normalization.running_var" in key:
|
| 51 |
key = key.replace("running_var", "var")
|
| 52 |
+
key = "batch_stats."+key
|
| 53 |
+
new_pt_state[key] = tensor
|
| 54 |
|
| 55 |
else:
|
| 56 |
+
continue
|
| 57 |
|
| 58 |
+
|
| 59 |
for total_updated, (new_key, new_tensor) in enumerate(new_pt_state.items()):
|
| 60 |
orig_flax_tensor = flax_state.get(tuple(new_key.split(".")))
|
| 61 |
assert orig_flax_tensor is not None
|
| 62 |
+
assert orig_flax_tensor.shape == new_tensor.shape
|
|
|
|
|
|
|
| 63 |
flax_state[tuple(new_key.split("."))] = new_tensor
|
|
|
|
| 64 |
flax_state = unflatten_dict(flax_state)
|
| 65 |
+
flax_resnet.save_pretrained("resnet_50_flax", params=flax_state)
|
|
|
|
|
|
|
|
|
|
|
|