chai-torch
PyTorch (eager CUDA) inference port of Chai-1, Chai Discovery's all-atom biomolecular structure prediction model.
The weights in this repository are derived from Chai Discovery's
released Chai-1 TorchScript checkpoint. The architecture, safetensors
parameter naming, and precision policy track the TorchScript reference
module-for-module, so the bundle loads into chai-torch without any
conversion step at load time.
Quick start
from chai_torch import ChaiTorch, featurize_fasta
model = ChaiTorch.from_pretrained("josephjojoe/chai-torch", device="cuda").eval()
ctx = featurize_fasta("input.fasta", output_dir="./out")
result = model.run_inference(ctx, recycles=3, num_samples=5, num_steps=200)
# result.coords: torch.Tensor on CUDA, shape (B, S, A, 3)
# result.confidence: pae_logits, pde_logits, plddt_logits
# result.ranking: aggregate_score, ptm, iptm, per-chain breakdowns, clashes
ChaiTorch.from_pretrained accepts either a HuggingFace repo id (as
above, via huggingface_hub) or a local directory containing
config.json plus model.safetensors (or sharded safetensors with an
index file). The default compute_dtype="reference" matches the
TorchScript reference's mixed-precision policy: bf16 trunk / confidence
with fp32 diffusion and other preserved-fp32 parameters. Pass
compute_dtype="float32" to keep the port in fp32 throughout.
Files
| File | Size | Purpose |
|---|---|---|
config.json |
2.5 KB | ChaiConfig dataclass tree (hyperparameters, precision policy) |
model.safetensors.index.json |
205 KB | Sharded weight map |
model-feature_embedding.safetensors |
4.8 MB | Input feature projections (token / pair / atom / MSA / template) |
model-bond_loss_input_proj.safetensors |
2.1 KB | Bond adjacency projection |
model-token_embedder.safetensors |
6.6 MB | Token input atom encoder + pair / single projections |
model-trunk.safetensors |
680 MB | 48-block pairformer + MSA module + template embedder |
model-diffusion_module.safetensors |
512 MB | Conditioning + 16-block diffusion transformer + atom enc/dec |
model-confidence_head.safetensors |
59 MB | 4 pairformer blocks + pLDDT / PAE / PDE projections |
Total: ~1.2 GB float32 safetensors, ~316 M parameters.
License
Apache-2.0. This model is a derivative work of chai-lab (Apache-2.0). The weights are derived from Chai Discovery's released Chai-1 TorchScript checkpoint, distributed under their terms.
Citation
@article{Chai-1,
author = {{Chai Discovery}},
title = {Chai-1 Technical Report},
year = {2024},
url = {https://chaiassets.com/chai-1/paper/technical_report_v1.pdf}
}
- Downloads last month
- 27