chai-torch

PyTorch (eager CUDA) inference port of Chai-1, Chai Discovery's all-atom biomolecular structure prediction model.

The weights in this repository are derived from Chai Discovery's released Chai-1 TorchScript checkpoint. The architecture, safetensors parameter naming, and precision policy track the TorchScript reference module-for-module, so the bundle loads into chai-torch without any conversion step at load time.

Quick start

from chai_torch import ChaiTorch, featurize_fasta

model = ChaiTorch.from_pretrained("josephjojoe/chai-torch", device="cuda").eval()
ctx = featurize_fasta("input.fasta", output_dir="./out")
result = model.run_inference(ctx, recycles=3, num_samples=5, num_steps=200)
# result.coords:     torch.Tensor on CUDA, shape (B, S, A, 3)
# result.confidence: pae_logits, pde_logits, plddt_logits
# result.ranking:    aggregate_score, ptm, iptm, per-chain breakdowns, clashes

ChaiTorch.from_pretrained accepts either a HuggingFace repo id (as above, via huggingface_hub) or a local directory containing config.json plus model.safetensors (or sharded safetensors with an index file). The default compute_dtype="reference" matches the TorchScript reference's mixed-precision policy: bf16 trunk / confidence with fp32 diffusion and other preserved-fp32 parameters. Pass compute_dtype="float32" to keep the port in fp32 throughout.

Files

File Size Purpose
config.json 2.5 KB ChaiConfig dataclass tree (hyperparameters, precision policy)
model.safetensors.index.json 205 KB Sharded weight map
model-feature_embedding.safetensors 4.8 MB Input feature projections (token / pair / atom / MSA / template)
model-bond_loss_input_proj.safetensors 2.1 KB Bond adjacency projection
model-token_embedder.safetensors 6.6 MB Token input atom encoder + pair / single projections
model-trunk.safetensors 680 MB 48-block pairformer + MSA module + template embedder
model-diffusion_module.safetensors 512 MB Conditioning + 16-block diffusion transformer + atom enc/dec
model-confidence_head.safetensors 59 MB 4 pairformer blocks + pLDDT / PAE / PDE projections

Total: ~1.2 GB float32 safetensors, ~316 M parameters.

License

Apache-2.0. This model is a derivative work of chai-lab (Apache-2.0). The weights are derived from Chai Discovery's released Chai-1 TorchScript checkpoint, distributed under their terms.

Citation

@article{Chai-1,
  author = {{Chai Discovery}},
  title = {Chai-1 Technical Report},
  year = {2024},
  url = {https://chaiassets.com/chai-1/paper/technical_report_v1.pdf}
}
Downloads last month
27
Safetensors
Model size
0.3B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support