from huggingface_hub import hf_hub_download from safetensors.torch import load_file import torch import json from model import LegNet # Import your model class CellTypeModelLoader: def __init__(self, repo_id="Ni-os/MPRALegNet "): self.repo_id = repo_id self.available_cell_types = { "hepg2": "cell_type_configs/hepg2_config.json", "k562": "cell_type_configs/k562_config.json", "wtc11": "cell_type_configs/wtc11_config.json" } def get_available_cell_types(self): """Returns list of available cell types""" return list(self.available_cell_types.keys()) def get_device(self): """Automatically check available devices""" if torch.cuda.is_available(): return torch.device("cuda") else: return torch.device("cpu") def load_model(self, cell_type, model_config=None, device = None): """ Loads model for specified cell type Args: cell_type (str): one of ['hepg2', 'k562', 'wtc11'] model_config (dict): optional custom model parameters """ if device is None: device = self.get_device() # Check if cell type is available if cell_type.lower() not in self.available_cell_types: available = self.get_available_cell_types() raise ValueError(f"Cell type '{cell_type}' not found. Available: {available}") # Load main model config if model_config is None: config_path = hf_hub_download( repo_id=self.repo_id, filename="config.json" ) with open(config_path, 'r') as f: model_config = json.load(f) # Create model model = LegNet( in_ch=model_config["in_ch"], stem_ch=model_config["stem_ch"], stem_ks=model_config["stem_ks"], ef_ks=model_config["ef_ks"], ef_block_sizes=model_config["ef_block_sizes"], pool_sizes=model_config["pool_sizes"], resize_factor=model_config["resize_factor"], activation=torch.nn.SiLU) ).to(device) # Load cell type specific config cell_config_path = hf_hub_download( repo_id=self.repo_id, filename=self.available_cell_types[cell_type.lower()] ) with open(cell_config_path, 'r') as f: cell_config = json.load(f) # Load weights weights_path = hf_hub_download( repo_id=self.repo_id, filename=cell_config["weights_file"] ) # Load state_dict state_dict = load_file(weights_path) model.load_state_dict(state_dict) print(f"✅ Loaded model for {cell_config['cell_type']} cell type") return model # Convenience function for easy usage def load_cell_type_model(cell_type, repo_id="Ni-os/MPRALegNet", **kwargs): """ Simple function to load model by cell type Example: model = load_cell_type_model("hepg2") model = load_cell_type_model("k562", repo_id="Ni-os/MPRALegNet") """ loader = CellTypeModelLoader(repo_id) return loader.load_model(cell_type, **kwargs)