File size: 3,377 Bytes
f00e091
 
 
 
 
 
 
893acaf
 
 
f00e091
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8c4b27
f00e091
 
 
 
 
a8c4b27
f00e091
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import torch
import json
from model import LegNet  # Import your model

class CellTypeModelLoader:
    def __init__(self, repo_id="Ni-os/MPRALegNet
    
                 "):
        self.repo_id = repo_id
        self.available_cell_types = {
            "hepg2": "cell_type_configs/hepg2_config.json",
            "k562": "cell_type_configs/k562_config.json", 
            "wtc11": "cell_type_configs/wtc11_config.json"
        }
    
    def get_available_cell_types(self):
        """Returns list of available cell types"""
        return list(self.available_cell_types.keys())
        
    def get_device(self):
        """Automatically check available devices"""
        if torch.cuda.is_available():
            return torch.device("cuda")
        else:
            return torch.device("cpu")
            
    def load_model(self, cell_type, model_config=None, device = None):
        """
        Loads model for specified cell type
        
        Args:
            cell_type (str): one of ['hepg2', 'k562', 'wtc11']
            model_config (dict): optional custom model parameters
        """
        
        if device is None:
            device = self.get_device()
            
        # Check if cell type is available
        if cell_type.lower() not in self.available_cell_types:
            available = self.get_available_cell_types()
            raise ValueError(f"Cell type '{cell_type}' not found. Available: {available}")
        
        # Load main model config
        if model_config is None:
            config_path = hf_hub_download(
                repo_id=self.repo_id,
                filename="config.json"
            )
            with open(config_path, 'r') as f:
                model_config = json.load(f)
        
        # Create model
        model = LegNet(
            in_ch=model_config["in_ch"],
            stem_ch=model_config["stem_ch"],
            stem_ks=model_config["stem_ks"],
            ef_ks=model_config["ef_ks"],
            ef_block_sizes=model_config["ef_block_sizes"],
            pool_sizes=model_config["pool_sizes"],
            resize_factor=model_config["resize_factor"],
            activation=torch.nn.SiLU)
        ).to(device)
        
        # Load cell type specific config
        cell_config_path = hf_hub_download(
            repo_id=self.repo_id,
            filename=self.available_cell_types[cell_type.lower()]
        )
        
        with open(cell_config_path, 'r') as f:
            cell_config = json.load(f)
        
        # Load weights
        weights_path = hf_hub_download(
            repo_id=self.repo_id,
            filename=cell_config["weights_file"]
        )
        
        # Load state_dict
        state_dict = load_file(weights_path)
        model.load_state_dict(state_dict)
        
        print(f"✅ Loaded model for {cell_config['cell_type']} cell type")
        return model

# Convenience function for easy usage
def load_cell_type_model(cell_type, repo_id="Ni-os/MPRALegNet", **kwargs):
    """
    Simple function to load model by cell type
    
    Example:
        model = load_cell_type_model("hepg2")
        model = load_cell_type_model("k562", repo_id="Ni-os/MPRALegNet")
    """
    loader = CellTypeModelLoader(repo_id)
    return loader.load_model(cell_type, **kwargs)