bryan7264 commited on
Commit
4f7cbb1
·
verified ·
1 Parent(s): f62d4e5

readme: reference .ptw weights

Browse files
Files changed (1) hide show
  1. README.md +3 -3
README.md CHANGED
@@ -69,8 +69,8 @@ Trained PIVOT checkpoints live under `models/`:
69
  - `models/norman/` - trained on Norman 2019 (CRISPRa K562)
70
  - `models/replogle_k562/` - trained on Replogle 2022 (CRISPRi K562)
71
 
72
- each folder has `model.pt` (state dict), `config.json` (the training config), and
73
- `train_info.json` (history + run info). loading needs the matching preprocessed dataset,
74
  since the perturbation encoder vocabulary comes from the data:
75
 
76
  ```python
@@ -81,7 +81,7 @@ from src.training.train import TrainConfig, make_model
81
  cfg = TrainConfig(**json.load(open("models/norman/config.json")))
82
  data = load_dataset(cfg.dataset)
83
  model = make_model(data, cfg, device="cpu")
84
- model.load_state_dict(torch.load("models/norman/model.pt", map_location="cpu"))
85
  model.eval()
86
  ```
87
 
 
69
  - `models/norman/` - trained on Norman 2019 (CRISPRa K562)
70
  - `models/replogle_k562/` - trained on Replogle 2022 (CRISPRi K562)
71
 
72
+ each folder has `model.ptw` (a plain torch state dict), `config.json` (the training config),
73
+ and `train_info.json` (history + run info). loading needs the matching preprocessed dataset,
74
  since the perturbation encoder vocabulary comes from the data:
75
 
76
  ```python
 
81
  cfg = TrainConfig(**json.load(open("models/norman/config.json")))
82
  data = load_dataset(cfg.dataset)
83
  model = make_model(data, cfg, device="cpu")
84
+ model.load_state_dict(torch.load("models/norman/model.ptw", map_location="cpu"))
85
  model.eval()
86
  ```
87