Safetensors
TEDDY / teddy /models /model_directory.py
soumyatghosh's picture
Upload folder using huggingface_hub
4527b5f verified
"""
Module: model_directory.py
This module provides a centralized directory for managing and accessing different model architectures
used in the TEDDY project. It defines a dictionary of supported models and their configurations,
allowing for easy integration and dynamic loading of models based on their names or paths.
Main Features:
- **model_dict**: A dictionary mapping model names to their corresponding classes, configurations,
and masking keys. This enables seamless switching between different model architectures.
- **get_architecture**: A utility function to retrieve the architecture name from a model's configuration file.
Dependencies:
- json: For loading model configuration files.
- os: For handling file paths.
- teddy.models.teddy_g.model: For importing the `TeddyGModel`, `TeddyGConfig`, and `TeddyGModelAnalysis` classes.
Usage:
1. Access a model and its configuration from the `model_dict`:
```python
model_info = model_dict["TeddyGModel"]
model_cls = model_info["model_cls"]
config_cls = model_info["config_cls"]
```
2. Retrieve the architecture name from a model's configuration file:
```python
architecture = get_architecture(model_name_or_path)
```
"""
import json
import os
from teddy.models.teddy_g.model import (
TeddyGConfig,
TeddyGModel,
TeddyGModelAnalysis,
)
model_dict = {
"TeddyGModel": {"model_cls": TeddyGModel, "config_cls": TeddyGConfig, "masking_key": "gene_ids"},
"TeddyGModelAnalysis": {
"model_cls": TeddyGModelAnalysis,
"config_cls": TeddyGConfig,
"masking_key": "gene_ids",
},
}
def get_architecture(model_name_or_path):
with open(os.path.join(model_name_or_path, "config.json")) as f:
config = json.load(f)
return config["architectures"][0]