Update Mo.py
Browse files
Mo.py
CHANGED
|
@@ -10,7 +10,8 @@ import tensorflow.keras.backend as K
|
|
| 10 |
from tensorflow.keras import mixed_precision
|
| 11 |
policy = mixed_precision.Policy('mixed_float16') # fp16
|
| 12 |
mixed_precision.set_global_policy(policy)
|
| 13 |
-
|
|
|
|
| 14 |
print('1')
|
| 15 |
tf.get_logger().setLevel("ERROR")
|
| 16 |
SEED = 42
|
|
@@ -18,18 +19,23 @@ tf.random.set_seed(SEED)
|
|
| 18 |
np.random.seed(SEED)
|
| 19 |
|
| 20 |
# TPU ์ด๊ธฐํ
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
else:
|
| 31 |
strategy = tf.distribute.get_strategy()
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
# =======================
|
| 35 |
# 1) ํ์ผ ๋ค์ด๋ก๋
|
|
|
|
| 10 |
from tensorflow.keras import mixed_precision
|
| 11 |
policy = mixed_precision.Policy('mixed_float16') # fp16
|
| 12 |
mixed_precision.set_global_policy(policy)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
print('1')
|
| 16 |
tf.get_logger().setLevel("ERROR")
|
| 17 |
SEED = 42
|
|
|
|
| 19 |
np.random.seed(SEED)
|
| 20 |
|
| 21 |
# TPU ์ด๊ธฐํ
|
| 22 |
+
try:
|
| 23 |
+
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
|
| 24 |
+
tf.tpu.experimental.initialize_tpu_system(resolver)
|
| 25 |
+
strategy = tf.distribute.TPUStrategy(resolver)
|
| 26 |
+
print("โ
TPU ์ด๊ธฐํ ์๋ฃ:", resolver.cluster_spec().as_dict())
|
| 27 |
+
on_tpu = True
|
| 28 |
+
|
| 29 |
+
except Exception as e:
|
| 30 |
+
print("โ ๏ธ TPU ๋ฏธ์ฌ์ฉ, GPU/CPU๋ก ์งํ:", e)
|
|
|
|
| 31 |
strategy = tf.distribute.get_strategy()
|
| 32 |
+
on_tpu = False
|
| 33 |
+
|
| 34 |
+
# Mixed precision
|
| 35 |
+
from tensorflow.keras import mixed_precision
|
| 36 |
+
policy = mixed_precision.Policy("mixed_bfloat16" if on_tpu else "float32")
|
| 37 |
+
mixed_precision.set_global_policy(policy)
|
| 38 |
+
print("โ
Mixed precision:", policy)
|
| 39 |
|
| 40 |
# =======================
|
| 41 |
# 1) ํ์ผ ๋ค์ด๋ก๋
|