|
|
import tensorflow as tf |
|
|
|
|
|
|
|
|
def recursive_cast_to_numpy(obj): |
|
|
if isinstance(obj, tf.Tensor): |
|
|
if obj.dtype == tf.string: |
|
|
|
|
|
return obj.numpy().tolist() if obj.ndim > 0 else obj.numpy().decode("utf-8") |
|
|
else: |
|
|
|
|
|
return obj.numpy() |
|
|
elif isinstance(obj, dict): |
|
|
|
|
|
return {key: recursive_cast_to_numpy(value) for key, value in obj.items()} |
|
|
elif isinstance(obj, list): |
|
|
|
|
|
return [recursive_cast_to_numpy(item) for item in obj] |
|
|
elif isinstance(obj, tuple): |
|
|
|
|
|
return tuple(recursive_cast_to_numpy(item) for item in obj) |
|
|
else: |
|
|
|
|
|
return obj |
|
|
|