saniaE
refactored code
26ed134
import os
import sys
import logging
import tensorflow as tf
def apply_fixes():
# 1. Mute TF Logs
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf.get_logger().setLevel(logging.ERROR)
tf.autograph.set_verbosity(0)
# 2. Force TF 1.15 behavior
if hasattr(tf, 'compat') and hasattr(tf.compat, 'v1'):
tf.compat.v1.disable_v2_behavior()
# 3. Keras Bridge for Mask R-CNN
try:
from tensorflow.python.keras import engine as KE
except ImportError:
from tensorflow.python.keras.api._v1.keras import engine as KE
import tensorflow.python.keras as keras
sys.modules['keras'] = keras
sys.modules['keras.engine'] = KE
sys.modules['keras.layers'] = keras.layers
sys.modules['keras.models'] = keras.models
def patch_model_file(model_path="mrcnn/model.py"):
if not os.path.exists(model_path):
return
with open(model_path, 'r') as f:
content = f.read()
# Apply all dynamic shape fixes
replacements = {
'KL.Reshape((s[1], num_classes, 4), name="mrcnn_bbox")(x)': 'KL.Reshape((-1, num_classes, 4), name="mrcnn_bbox")(x)',
'KL.Reshape((s[1], s[2], s[3], num_classes), name="mrcnn_mask")(x)': 'KL.Reshape((-1, s[2], s[3], num_classes), name="mrcnn_mask")(x)',
'tf.range(probs.shape[0])': 'tf.range(tf.shape(probs)[0])'
}
for old, new in replacements.items():
content = content.replace(old, new)
with open(model_path, 'w') as f:
f.write(content)
print("TF 1.15 patches applied to mrcnn/model.py")