readme: reference .ptw weights
Browse files
README.md
CHANGED
|
@@ -69,8 +69,8 @@ Trained PIVOT checkpoints live under `models/`:
|
|
| 69 |
- `models/norman/` - trained on Norman 2019 (CRISPRa K562)
|
| 70 |
- `models/replogle_k562/` - trained on Replogle 2022 (CRISPRi K562)
|
| 71 |
|
| 72 |
-
each folder has `model.
|
| 73 |
-
`train_info.json` (history + run info). loading needs the matching preprocessed dataset,
|
| 74 |
since the perturbation encoder vocabulary comes from the data:
|
| 75 |
|
| 76 |
```python
|
|
@@ -81,7 +81,7 @@ from src.training.train import TrainConfig, make_model
|
|
| 81 |
cfg = TrainConfig(**json.load(open("models/norman/config.json")))
|
| 82 |
data = load_dataset(cfg.dataset)
|
| 83 |
model = make_model(data, cfg, device="cpu")
|
| 84 |
-
model.load_state_dict(torch.load("models/norman/model.
|
| 85 |
model.eval()
|
| 86 |
```
|
| 87 |
|
|
|
|
| 69 |
- `models/norman/` - trained on Norman 2019 (CRISPRa K562)
|
| 70 |
- `models/replogle_k562/` - trained on Replogle 2022 (CRISPRi K562)
|
| 71 |
|
| 72 |
+
each folder has `model.ptw` (a plain torch state dict), `config.json` (the training config),
|
| 73 |
+
and `train_info.json` (history + run info). loading needs the matching preprocessed dataset,
|
| 74 |
since the perturbation encoder vocabulary comes from the data:
|
| 75 |
|
| 76 |
```python
|
|
|
|
| 81 |
cfg = TrainConfig(**json.load(open("models/norman/config.json")))
|
| 82 |
data = load_dataset(cfg.dataset)
|
| 83 |
model = make_model(data, cfg, device="cpu")
|
| 84 |
+
model.load_state_dict(torch.load("models/norman/model.ptw", map_location="cpu"))
|
| 85 |
model.eval()
|
| 86 |
```
|
| 87 |
|