| | """
|
| | Utility to verify model dimensions across the codebase
|
| | """
|
| | import os
|
| | import json
|
| | import logging
|
| | import importlib.util
|
| |
|
| |
|
| | logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
| | logger = logging.getLogger(__name__)
|
| |
|
| | def check_config_json():
|
| | """Check dimensions in config.json"""
|
| | try:
|
| | config_path = os.path.join(os.path.dirname(__file__), "config.json")
|
| | with open(config_path, 'r') as f:
|
| | config = json.load(f)
|
| |
|
| | if "TRANSFORMER_CONFIG" in config:
|
| | tc = config["TRANSFORMER_CONFIG"]
|
| | emb_dim = tc.get("EMBEDDING_DIM", 0)
|
| | hidden_dim = tc.get("HIDDEN_DIM", 0)
|
| | num_heads = tc.get("NUM_HEADS", 0)
|
| |
|
| | logger.info(f"config.json dimensions: embedding={emb_dim}, hidden={hidden_dim}, heads={num_heads}")
|
| |
|
| | if emb_dim != 768 or hidden_dim != 768 or num_heads != 12:
|
| | logger.warning(f"config.json has non-standard dimensions! Should be 768/768/12")
|
| | return False
|
| | return True
|
| | except Exception as e:
|
| | logger.error(f"Error checking config.json: {e}")
|
| | return False
|
| |
|
| | def check_adapter_layer():
|
| | """Check dimensions in adapter_layer.py"""
|
| | try:
|
| | adapter_path = os.path.join(os.path.dirname(__file__), "adapter_layer.py")
|
| | with open(adapter_path, 'r') as f:
|
| | content = f.read()
|
| |
|
| |
|
| | if "embedding_dim\": 256" in content or "hidden_dim\": 256" in content:
|
| | logger.warning("adapter_layer.py contains 256 dimensions! Update to 768")
|
| | return False
|
| | elif "embedding_dim\": 768" in content and "hidden_dim\": 768" in content:
|
| | logger.info("adapter_layer.py has correct 768 dimensions")
|
| | return True
|
| | else:
|
| | logger.warning("Could not determine dimensions in adapter_layer.py")
|
| | return False
|
| | except Exception as e:
|
| | logger.error(f"Error checking adapter_layer.py: {e}")
|
| | return False
|
| |
|
| | def check_model_manager():
|
| | """Check dimensions in model_manager.py"""
|
| | try:
|
| | model_manager_path = os.path.join(os.path.dirname(__file__), "model_manager.py")
|
| | with open(model_manager_path, 'r') as f:
|
| | content = f.read()
|
| |
|
| | if "embedding_dim=256" in content or "hidden_dim=256" in content:
|
| | logger.warning("model_manager.py contains 256 dimensions! Update to 768")
|
| | return False
|
| | elif "embedding_dim=768" in content and "hidden_dim=768" in content:
|
| | logger.info("model_manager.py has correct 768 dimensions")
|
| | return True
|
| | else:
|
| | logger.warning("Could not determine dimensions in model_manager.py")
|
| | return False
|
| | except Exception as e:
|
| | logger.error(f"Error checking model_manager.py: {e}")
|
| | return False
|
| |
|
| | def check_main_py():
|
| | """Check dimensions in main.py"""
|
| | try:
|
| | main_path = os.path.join(os.path.dirname(__file__), "main.py")
|
| | with open(main_path, 'r') as f:
|
| | content = f.read()
|
| |
|
| | if "embedding_dim=256" in content or "hidden_dim=256" in content:
|
| | logger.warning("main.py contains 256 dimensions! Update to 768")
|
| | return False
|
| | elif "embedding_dim=768" in content and "hidden_dim=768" in content:
|
| | logger.info("main.py has correct 768 dimensions")
|
| | return True
|
| | else:
|
| | logger.warning("Could not determine dimensions in main.py")
|
| | return False
|
| | except Exception as e:
|
| | logger.error(f"Error checking main.py: {e}")
|
| | return False
|
| |
|
| | def verify_all_dimensions():
|
| | """Check dimensions across all key files"""
|
| | results = {
|
| | "config.json": check_config_json(),
|
| | "adapter_layer.py": check_adapter_layer(),
|
| | "model_manager.py": check_model_manager(),
|
| | "main.py": check_main_py()
|
| | }
|
| |
|
| | print("\n=== MODEL DIMENSION VERIFICATION ===")
|
| | all_correct = True
|
| | for file, correct in results.items():
|
| | status = "✓ CORRECT (768)" if correct else "✗ INCORRECT (256)"
|
| | print(f"{file:20} : {status}")
|
| | all_correct = all_correct and correct
|
| |
|
| | print("\nOverall Status:", "✓ ALL CORRECT" if all_correct else "✗ NEEDS FIXING")
|
| | print("\nRun this script after making changes to verify all dimensions are set to 768.\n")
|
| |
|
| | return all_correct
|
| |
|
| | if __name__ == "__main__":
|
| | verify_all_dimensions()
|
| |
|