--- license: mit --- # LegNet - Cell Type Specific Models LegNet model with weights trained on different cell types. ## Available Cell Types: - `hepg2` - HepG2 cell line - `k562` - K562 cell line - `wtc11` - WTC11 cell line ## Usage: ```python from model_loader import load_cell_type_model # Load model for HepG2 model = load_cell_type_model("hepg2") # Load model for K562 model = load_cell_type_model("k562") ``` ### If you want to download weights ```python def get_device(): """Automatically detects available device""" if torch.cuda.is_available(): return torch.device("cuda") else: return torch.device("cpu") # Load Pre-Trained Model Weights for Human Legnet def download_and_load_model(cell_type="k562", repo_id="Ni-os/MPRALegNet", device=None): # Download main config config_path = hf_hub_download( repo_id=repo_id, filename="config.json" ) # Load config with open(config_path, 'r') as f: config = json.load(f) # Create model model = LegNet( in_ch=config["in_ch"], stem_ch=config["stem_ch"], stem_ks=config["stem_ks"], ef_ks=config["ef_ks"], ef_block_sizes=config["ef_block_sizes"], pool_sizes=config["pool_sizes"], resize_factor=config["resize_factor"], activation=torch.nn.SiLU ).to(device) # Determine which weight file to download weight_files = { "hepg2": "weights/hepg2_best_model_test1_val2.safetensors", "k562": "weights/k562_best_model_test1_val2.safetensors", "wtc11": "weights/wtc11_best_model_test1_val2.safetensors" } # Download weights weights_path = hf_hub_download( repo_id=repo_id, filename=weight_files[cell_type.lower()] ) # Load weights into model state_dict = load_file(weights_path) model.load_state_dict(state_dict) model.eval() print(f"✅ Model for {cell_type} loaded!") return model device = get_device() print("Loading pre-trained model weights for Human Legnet") model_legnet = download_and_load_model("hepg2", device = device)