|
|
--- |
|
|
license: mit |
|
|
--- |
|
|
# LegNet - Cell Type Specific Models |
|
|
|
|
|
LegNet model with weights trained on different cell types. |
|
|
|
|
|
## Available Cell Types: |
|
|
- `hepg2` - HepG2 cell line |
|
|
- `k562` - K562 cell line |
|
|
- `wtc11` - WTC11 cell line |
|
|
|
|
|
## Usage: |
|
|
|
|
|
```python |
|
|
from model_loader import load_cell_type_model |
|
|
|
|
|
# Load model for HepG2 |
|
|
model = load_cell_type_model("hepg2") |
|
|
|
|
|
# Load model for K562 |
|
|
model = load_cell_type_model("k562") |
|
|
``` |
|
|
### If you want to download weights |
|
|
|
|
|
```python |
|
|
def get_device(): |
|
|
"""Automatically detects available device""" |
|
|
if torch.cuda.is_available(): |
|
|
return torch.device("cuda") |
|
|
else: |
|
|
return torch.device("cpu") |
|
|
|
|
|
# Load Pre-Trained Model Weights for Human Legnet |
|
|
def download_and_load_model(cell_type="k562", repo_id="Ni-os/MPRALegNet", device=None): |
|
|
# Download main config |
|
|
config_path = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
filename="config.json" |
|
|
) |
|
|
|
|
|
# Load config |
|
|
with open(config_path, 'r') as f: |
|
|
config = json.load(f) |
|
|
|
|
|
# Create model |
|
|
model = LegNet( |
|
|
in_ch=config["in_ch"], |
|
|
stem_ch=config["stem_ch"], |
|
|
stem_ks=config["stem_ks"], |
|
|
ef_ks=config["ef_ks"], |
|
|
ef_block_sizes=config["ef_block_sizes"], |
|
|
pool_sizes=config["pool_sizes"], |
|
|
resize_factor=config["resize_factor"], |
|
|
activation=torch.nn.SiLU |
|
|
).to(device) |
|
|
|
|
|
# Determine which weight file to download |
|
|
weight_files = { |
|
|
"hepg2": "weights/hepg2_best_model_test1_val2.safetensors", |
|
|
"k562": "weights/k562_best_model_test1_val2.safetensors", |
|
|
"wtc11": "weights/wtc11_best_model_test1_val2.safetensors" |
|
|
} |
|
|
|
|
|
# Download weights |
|
|
weights_path = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
filename=weight_files[cell_type.lower()] |
|
|
) |
|
|
|
|
|
# Load weights into model |
|
|
state_dict = load_file(weights_path) |
|
|
model.load_state_dict(state_dict) |
|
|
model.eval() |
|
|
print(f"✅ Model for {cell_type} loaded!") |
|
|
return model |
|
|
|
|
|
device = get_device() |
|
|
|
|
|
print("Loading pre-trained model weights for Human Legnet") |
|
|
model_legnet = download_and_load_model("hepg2", device = device) |