Safetensors
File size: 1,785 Bytes
4527b5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""
Module: model_directory.py

This module provides a centralized directory for managing and accessing different model architectures
used in the TEDDY project. It defines a dictionary of supported models and their configurations,
allowing for easy integration and dynamic loading of models based on their names or paths.

Main Features:
- **model_dict**: A dictionary mapping model names to their corresponding classes, configurations,
  and masking keys. This enables seamless switching between different model architectures.
- **get_architecture**: A utility function to retrieve the architecture name from a model's configuration file.

Dependencies:
- json: For loading model configuration files.
- os: For handling file paths.
- teddy.models.teddy_g.model: For importing the `TeddyGModel`, `TeddyGConfig`, and `TeddyGModelAnalysis` classes.

Usage:
1. Access a model and its configuration from the `model_dict`:
   ```python
   model_info = model_dict["TeddyGModel"]
   model_cls = model_info["model_cls"]
   config_cls = model_info["config_cls"]
   ```
2. Retrieve the architecture name from a model's configuration file:
   ```python
   architecture = get_architecture(model_name_or_path)
   ```
"""

import json
import os

from teddy.models.teddy_g.model import (
    TeddyGConfig,
    TeddyGModel,
    TeddyGModelAnalysis,
)

model_dict = {
    "TeddyGModel": {"model_cls": TeddyGModel, "config_cls": TeddyGConfig, "masking_key": "gene_ids"},
    "TeddyGModelAnalysis": {
        "model_cls": TeddyGModelAnalysis,
        "config_cls": TeddyGConfig,
        "masking_key": "gene_ids",
    },
}


def get_architecture(model_name_or_path):
    with open(os.path.join(model_name_or_path, "config.json")) as f:
        config = json.load(f)
    return config["architectures"][0]