File size: 1,785 Bytes
4527b5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
"""
Module: model_directory.py
This module provides a centralized directory for managing and accessing different model architectures
used in the TEDDY project. It defines a dictionary of supported models and their configurations,
allowing for easy integration and dynamic loading of models based on their names or paths.
Main Features:
- **model_dict**: A dictionary mapping model names to their corresponding classes, configurations,
and masking keys. This enables seamless switching between different model architectures.
- **get_architecture**: A utility function to retrieve the architecture name from a model's configuration file.
Dependencies:
- json: For loading model configuration files.
- os: For handling file paths.
- teddy.models.teddy_g.model: For importing the `TeddyGModel`, `TeddyGConfig`, and `TeddyGModelAnalysis` classes.
Usage:
1. Access a model and its configuration from the `model_dict`:
```python
model_info = model_dict["TeddyGModel"]
model_cls = model_info["model_cls"]
config_cls = model_info["config_cls"]
```
2. Retrieve the architecture name from a model's configuration file:
```python
architecture = get_architecture(model_name_or_path)
```
"""
import json
import os
from teddy.models.teddy_g.model import (
TeddyGConfig,
TeddyGModel,
TeddyGModelAnalysis,
)
model_dict = {
"TeddyGModel": {"model_cls": TeddyGModel, "config_cls": TeddyGConfig, "masking_key": "gene_ids"},
"TeddyGModelAnalysis": {
"model_cls": TeddyGModelAnalysis,
"config_cls": TeddyGConfig,
"masking_key": "gene_ids",
},
}
def get_architecture(model_name_or_path):
with open(os.path.join(model_name_or_path, "config.json")) as f:
config = json.load(f)
return config["architectures"][0]
|