OpenLab-NLP commited on
Commit
870a9a7
ยท
verified ยท
1 Parent(s): 3383069

Update Mo.py

Browse files
Files changed (1) hide show
  1. Mo.py +18 -12
Mo.py CHANGED
@@ -10,7 +10,8 @@ import tensorflow.keras.backend as K
10
  from tensorflow.keras import mixed_precision
11
  policy = mixed_precision.Policy('mixed_float16') # fp16
12
  mixed_precision.set_global_policy(policy)
13
- print("โœ… Mixed precision ์ ์šฉ:", policy)
 
14
  print('1')
15
  tf.get_logger().setLevel("ERROR")
16
  SEED = 42
@@ -18,18 +19,23 @@ tf.random.set_seed(SEED)
18
  np.random.seed(SEED)
19
 
20
  # TPU ์ดˆ๊ธฐํ™”
21
- gpus = tf.config.list_physical_devices('GPU')
22
- if gpus:
23
- try:
24
- for gpu in gpus:
25
- tf.config.experimental.set_memory_growth(gpu, True)
26
- strategy = tf.distribute.MirroredStrategy(devices=[f"/GPU:{i}" for i in range(len(gpus))])
27
- print(f"โœ… GPU {len(gpus)}๊ฐœ ์‚ฌ์šฉ: {strategy.num_replicas_in_sync} replicas")
28
- except RuntimeError as e:
29
- print("โš ๏ธ GPU ์„ค์ • ์—๋Ÿฌ:", e)
30
- else:
31
  strategy = tf.distribute.get_strategy()
32
- print("โš ๏ธ GPU ์—†์Œ, CPU ์‚ฌ์šฉ")
 
 
 
 
 
 
33
 
34
  # =======================
35
  # 1) ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ
 
10
  from tensorflow.keras import mixed_precision
11
  policy = mixed_precision.Policy('mixed_float16') # fp16
12
  mixed_precision.set_global_policy(policy)
13
+
14
+
15
  print('1')
16
  tf.get_logger().setLevel("ERROR")
17
  SEED = 42
 
19
  np.random.seed(SEED)
20
 
21
  # TPU ์ดˆ๊ธฐํ™”
22
+ try:
23
+ resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
24
+ tf.tpu.experimental.initialize_tpu_system(resolver)
25
+ strategy = tf.distribute.TPUStrategy(resolver)
26
+ print("โœ… TPU ์ดˆ๊ธฐํ™” ์™„๋ฃŒ:", resolver.cluster_spec().as_dict())
27
+ on_tpu = True
28
+
29
+ except Exception as e:
30
+ print("โš ๏ธ TPU ๋ฏธ์‚ฌ์šฉ, GPU/CPU๋กœ ์ง„ํ–‰:", e)
 
31
  strategy = tf.distribute.get_strategy()
32
+ on_tpu = False
33
+
34
+ # Mixed precision
35
+ from tensorflow.keras import mixed_precision
36
+ policy = mixed_precision.Policy("mixed_bfloat16" if on_tpu else "float32")
37
+ mixed_precision.set_global_policy(policy)
38
+ print("โœ… Mixed precision:", policy)
39
 
40
  # =======================
41
  # 1) ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ