| | from huggingface_hub import hf_hub_download |
| | from safetensors.torch import load_file |
| | import torch |
| | import json |
| | from model import LegNet |
| |
|
| | class CellTypeModelLoader: |
| | def __init__(self, repo_id="Ni-os/Human_Legnet"): |
| | self.repo_id = repo_id |
| | self.available_cell_types = { |
| | "hepg2": "cell_type_configs/hepg2_config.json", |
| | "k562": "cell_type_configs/k562_config.json", |
| | "wtc11": "cell_type_configs/wtc11_config.json" |
| | } |
| | |
| | def get_available_cell_types(self): |
| | """Returns list of available cell types""" |
| | return list(self.available_cell_types.keys()) |
| | |
| | def get_device(self): |
| | """Automatically check available devices""" |
| | if torch.cuda.is_available(): |
| | return torch.device("cuda") |
| | else: |
| | return torch.device("cpu") |
| | |
| | def load_model(self, cell_type, model_config=None, device = None): |
| | """ |
| | Loads model for specified cell type |
| | |
| | Args: |
| | cell_type (str): one of ['hepg2', 'k562', 'wtc11'] |
| | model_config (dict): optional custom model parameters |
| | """ |
| | |
| | if device is None: |
| | device = self.get_device() |
| | |
| | |
| | if cell_type.lower() not in self.available_cell_types: |
| | available = self.get_available_cell_types() |
| | raise ValueError(f"Cell type '{cell_type}' not found. Available: {available}") |
| | |
| | |
| | if model_config is None: |
| | config_path = hf_hub_download( |
| | repo_id=self.repo_id, |
| | filename="config.json" |
| | ) |
| | with open(config_path, 'r') as f: |
| | model_config = json.load(f) |
| | |
| | |
| | model = LegNet( |
| | in_ch=model_config["in_ch"], |
| | stem_ch=model_config["stem_ch"], |
| | stem_ks=model_config["stem_ks"], |
| | ef_ks=model_config["ef_ks"], |
| | ef_block_sizes=model_config["ef_block_sizes"], |
| | pool_sizes=model_config["pool_sizes"], |
| | resize_factor=model_config["resize_factor"], |
| | activation=torch.nn.SiLU) |
| | ).to(device) |
| | |
| | |
| | cell_config_path = hf_hub_download( |
| | repo_id=self.repo_id, |
| | filename=self.available_cell_types[cell_type.lower()] |
| | ) |
| | |
| | with open(cell_config_path, 'r') as f: |
| | cell_config = json.load(f) |
| | |
| | |
| | weights_path = hf_hub_download( |
| | repo_id=self.repo_id, |
| | filename=cell_config["weights_file"] |
| | ) |
| | |
| | |
| | state_dict = load_file(weights_path) |
| | model.load_state_dict(state_dict) |
| | |
| | print(f"✅ Loaded model for {cell_config['cell_type']} cell type") |
| | return model |
| |
|
| | |
| | def load_cell_type_model(cell_type, repo_id="Ni-os/Human_Legnet", **kwargs): |
| | """ |
| | Simple function to load model by cell type |
| | |
| | Example: |
| | model = load_cell_type_model("hepg2") |
| | model = load_cell_type_model("k562", repo_id="Ni-os/Human_Legnet") |
| | """ |
| | loader = CellTypeModelLoader(repo_id) |
| | return loader.load_model(cell_type, **kwargs) |