|
|
|
|
|
from pathlib import Path |
|
|
import modal |
|
|
|
|
|
|
|
|
here = Path(__file__).parent |
|
|
|
|
|
MINUTES = 60 |
|
|
|
|
|
image = modal.Image.debian_slim(python_version="3.12").run_commands( |
|
|
"uv pip install --system --compile-bytecode torch>=2.6.0 chai_lab==0.6.1 hf_transfer==0.1.8 " |
|
|
) |
|
|
|
|
|
chai_model_volume = ( |
|
|
modal.Volume.from_name( |
|
|
"chai1-models", |
|
|
create_if_missing=True, |
|
|
) |
|
|
) |
|
|
models_dir = Path("/models/chai1") |
|
|
|
|
|
image = image.env( |
|
|
{ |
|
|
"CHAI_DOWNLOADS_DIR": str(models_dir), |
|
|
"HF_HUB_ENABLE_HF_TRANSFER": "1", |
|
|
} |
|
|
) |
|
|
|
|
|
chai_preds_volume = modal.Volume.from_name("chai1-preds", create_if_missing=True) |
|
|
preds_dir = Path("/preds") |
|
|
|
|
|
|
|
|
app = modal.App("Chai1 inference") |
|
|
|
|
|
@app.function( |
|
|
timeout=15 * MINUTES, |
|
|
gpu="H100", |
|
|
volumes={models_dir: chai_model_volume, preds_dir: chai_preds_volume}, |
|
|
image=image, |
|
|
) |
|
|
def chai1_inference( |
|
|
fasta_content: str, inference_config: dict, run_id: str |
|
|
) -> list[(bytes, str)]: |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
from chai_lab import chai1 |
|
|
|
|
|
N_DIFFUSION_SAMPLES = 5 |
|
|
|
|
|
fasta_file = Path("/tmp/inputs.fasta") |
|
|
fasta_file.write_text(fasta_content.strip()) |
|
|
|
|
|
output_dir = Path("/preds") / run_id |
|
|
|
|
|
chai1.run_inference( |
|
|
fasta_file=fasta_file, |
|
|
output_dir=output_dir, |
|
|
device=torch.device("cuda"), |
|
|
**inference_config, |
|
|
) |
|
|
|
|
|
print( |
|
|
f"🧬 done, results written to /{output_dir.relative_to('/preds')} on remote volume" |
|
|
) |
|
|
|
|
|
results = [] |
|
|
for ii in range(N_DIFFUSION_SAMPLES): |
|
|
scores = (output_dir / f"scores.model_idx_{ii}.npz").read_bytes() |
|
|
cif = (output_dir / f"pred.model_idx_{ii}.cif").read_text() |
|
|
|
|
|
results.append((scores, cif)) |
|
|
|
|
|
return results |
|
|
|
|
|
@app.function(volumes={models_dir: chai_model_volume}) |
|
|
async def download_inference_dependencies(force=False): |
|
|
import asyncio |
|
|
|
|
|
import aiohttp |
|
|
|
|
|
base_url = "https://chaiassets.com/chai1-inference-depencencies/" |
|
|
inference_dependencies = [ |
|
|
"conformers_v1.apkl", |
|
|
"models_v2/trunk.pt", |
|
|
"models_v2/token_embedder.pt", |
|
|
"models_v2/feature_embedding.pt", |
|
|
"models_v2/diffusion_module.pt", |
|
|
"models_v2/confidence_head.pt", |
|
|
] |
|
|
|
|
|
headers = { |
|
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3" |
|
|
} |
|
|
|
|
|
|
|
|
async with aiohttp.ClientSession(headers=headers) as session: |
|
|
tasks = [] |
|
|
for dep in inference_dependencies: |
|
|
local_path = models_dir / dep |
|
|
if force or not local_path.exists(): |
|
|
url = base_url + dep |
|
|
print(f"🧬 downloading {dep}") |
|
|
tasks.append(download_file(session, url, local_path)) |
|
|
|
|
|
|
|
|
await asyncio.gather(*tasks) |
|
|
|
|
|
chai_model_volume.commit() |
|
|
|
|
|
|
|
|
async def download_file(session, url: str, local_path: Path): |
|
|
async with session.get(url) as response: |
|
|
response.raise_for_status() |
|
|
local_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
with open(local_path, "wb") as f: |
|
|
while chunk := await response.content.read(8192): |
|
|
f.write(chunk) |