File size: 2,149 Bytes
074701b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5db33a8
 
 
 
 
 
 
 
 
 
 
 
 
e80d32e
5db33a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2af1c72
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
---
license: mit
---
# LegNet - Cell Type Specific Models

LegNet model with weights trained on different cell types.

## Available Cell Types:
- `hepg2` - HepG2 cell line
- `k562` - K562 cell line  
- `wtc11` - WTC11 cell line

## Usage:

```python
from model_loader import load_cell_type_model

# Load model for HepG2
model = load_cell_type_model("hepg2")

# Load model for K562
model = load_cell_type_model("k562")
```
### If you want to download weights

```python
def get_device():
    """Automatically detects available device"""
    if torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")
        
# Load Pre-Trained Model Weights for Human Legnet
def download_and_load_model(cell_type="k562", repo_id="Ni-os/MPRALegNet", device=None):
    # Download main config
    config_path = hf_hub_download(
        repo_id=repo_id,
        filename="config.json"
    )
    
    # Load config
    with open(config_path, 'r') as f:
        config = json.load(f)
    
    # Create model
    model = LegNet(
        in_ch=config["in_ch"],
        stem_ch=config["stem_ch"],
        stem_ks=config["stem_ks"], 
        ef_ks=config["ef_ks"],
        ef_block_sizes=config["ef_block_sizes"],
        pool_sizes=config["pool_sizes"],
        resize_factor=config["resize_factor"],
        activation=torch.nn.SiLU
    ).to(device)
    
    # Determine which weight file to download
    weight_files = {
        "hepg2": "weights/hepg2_best_model_test1_val2.safetensors",
        "k562": "weights/k562_best_model_test1_val2.safetensors", 
        "wtc11": "weights/wtc11_best_model_test1_val2.safetensors"
    }
    
    # Download weights
    weights_path = hf_hub_download(
        repo_id=repo_id,
        filename=weight_files[cell_type.lower()]
    )
    
    # Load weights into model
    state_dict = load_file(weights_path)
    model.load_state_dict(state_dict)
    model.eval()
    print(f"✅ Model for {cell_type} loaded!")
    return model

device = get_device()
    
print("Loading pre-trained model weights for Human Legnet")
model_legnet = download_and_load_model("hepg2", device = device)