Ni-os commited on
Commit
f00e091
·
verified ·
1 Parent(s): 8293c43

Create model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +97 -0
model_loader.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download
2
+ from safetensors.torch import load_file
3
+ import torch
4
+ import json
5
+ from model import LegNet # Import your model
6
+
7
+ class CellTypeModelLoader:
8
+ def __init__(self, repo_id="Ni-os/Human_Legnet"):
9
+ self.repo_id = repo_id
10
+ self.available_cell_types = {
11
+ "hepg2": "cell_type_configs/hepg2_config.json",
12
+ "k562": "cell_type_configs/k562_config.json",
13
+ "wtc11": "cell_type_configs/wtc11_config.json"
14
+ }
15
+
16
+ def get_available_cell_types(self):
17
+ """Returns list of available cell types"""
18
+ return list(self.available_cell_types.keys())
19
+
20
+ def get_device(self):
21
+ """Automatically check available devices"""
22
+ if torch.cuda.is_available():
23
+ return torch.device("cuda")
24
+ else:
25
+ return torch.device("cpu")
26
+
27
+ def load_model(self, cell_type, model_config=None, device = None):
28
+ """
29
+ Loads model for specified cell type
30
+
31
+ Args:
32
+ cell_type (str): one of ['hepg2', 'k562', 'wtc11']
33
+ model_config (dict): optional custom model parameters
34
+ """
35
+
36
+ if device is None:
37
+ device = self.get_device()
38
+
39
+ # Check if cell type is available
40
+ if cell_type.lower() not in self.available_cell_types:
41
+ available = self.get_available_cell_types()
42
+ raise ValueError(f"Cell type '{cell_type}' not found. Available: {available}")
43
+
44
+ # Load main model config
45
+ if model_config is None:
46
+ config_path = hf_hub_download(
47
+ repo_id=self.repo_id,
48
+ filename="config.json"
49
+ )
50
+ with open(config_path, 'r') as f:
51
+ model_config = json.load(f)
52
+
53
+ # Create model
54
+ model = LegNet(
55
+ in_ch=model_config["in_ch"],
56
+ stem_ch=model_config["stem_ch"],
57
+ stem_ks=model_config["stem_ks"],
58
+ ef_ks=model_config["ef_ks"],
59
+ ef_block_sizes=model_config["ef_block_sizes"],
60
+ pool_sizes=model_config["pool_sizes"],
61
+ resize_factor=model_config["resize_factor"],
62
+ activation=torch.nn.SiLU)
63
+ ).to(device)
64
+
65
+ # Load cell type specific config
66
+ cell_config_path = hf_hub_download(
67
+ repo_id=self.repo_id,
68
+ filename=self.available_cell_types[cell_type.lower()]
69
+ )
70
+
71
+ with open(cell_config_path, 'r') as f:
72
+ cell_config = json.load(f)
73
+
74
+ # Load weights
75
+ weights_path = hf_hub_download(
76
+ repo_id=self.repo_id,
77
+ filename=cell_config["weights_file"]
78
+ )
79
+
80
+ # Load state_dict
81
+ state_dict = load_file(weights_path)
82
+ model.load_state_dict(state_dict)
83
+
84
+ print(f"✅ Loaded model for {cell_config['cell_type']} cell type")
85
+ return model
86
+
87
+ # Convenience function for easy usage
88
+ def load_cell_type_model(cell_type, repo_id="Ni-os/Human_Legnet", **kwargs):
89
+ """
90
+ Simple function to load model by cell type
91
+
92
+ Example:
93
+ model = load_cell_type_model("hepg2")
94
+ model = load_cell_type_model("k562", repo_id="Ni-os/Human_Legnet")
95
+ """
96
+ loader = CellTypeModelLoader(repo_id)
97
+ return loader.load_model(cell_type, **kwargs)