| | |
| | from pathlib import Path |
| |
|
| | import modal |
| |
|
| | |
| | here = Path(__file__).parent |
| |
|
| | MINUTES = 60 |
| |
|
| | image = modal.Image.debian_slim(python_version="3.12").run_commands( |
| | "uv pip install --system --compile-bytecode torch>=2.6.0 chai_lab==0.6.1 hf_transfer==0.1.8 " |
| | ) |
| |
|
| | chai_model_volume = ( |
| | modal.Volume.from_name( |
| | "chai1-models", |
| | create_if_missing=True, |
| | ) |
| | ) |
| | models_dir = Path("/models/chai1") |
| |
|
| | image = image.env( |
| | { |
| | "CHAI_DOWNLOADS_DIR": str(models_dir), |
| | "HF_HUB_ENABLE_HF_TRANSFER": "1", |
| | } |
| | ) |
| |
|
| | chai_preds_volume = modal.Volume.from_name("chai1-preds", create_if_missing=True) |
| | preds_dir = Path("/preds") |
| |
|
| | |
| | app = modal.App("Chai1 inference") |
| |
|
| | @app.function( |
| | timeout=15 * MINUTES, |
| | gpu="H100", |
| | volumes={models_dir: chai_model_volume, preds_dir: chai_preds_volume}, |
| | image=image, |
| | ) |
| | def chai1_inference( |
| | fasta_content: str, inference_config: dict, run_id: str |
| | ) -> list[(bytes, str)]: |
| | from pathlib import Path |
| |
|
| | import torch |
| | from chai_lab import chai1 |
| |
|
| | N_DIFFUSION_SAMPLES = 5 |
| |
|
| | fasta_file = Path("/tmp/inputs.fasta") |
| | fasta_file.write_text(fasta_content.strip()) |
| |
|
| | output_dir = Path("/preds") / run_id |
| |
|
| | chai1.run_inference( |
| | fasta_file=fasta_file, |
| | output_dir=output_dir, |
| | device=torch.device("cuda"), |
| | **inference_config, |
| | ) |
| |
|
| | print( |
| | f"🧬 done, results written to /{output_dir.relative_to('/preds')} on remote volume" |
| | ) |
| |
|
| | results = [] |
| | for ii in range(N_DIFFUSION_SAMPLES): |
| | scores = (output_dir / f"scores.model_idx_{ii}.npz").read_bytes() |
| | cif = (output_dir / f"pred.model_idx_{ii}.cif").read_text() |
| |
|
| | results.append((scores, cif)) |
| |
|
| | return results |
| |
|
| | @app.function(volumes={models_dir: chai_model_volume}) |
| | async def download_inference_dependencies(force=False): |
| | import asyncio |
| |
|
| | import aiohttp |
| |
|
| | base_url = "https://chaiassets.com/chai1-inference-depencencies/" |
| | inference_dependencies = [ |
| | "conformers_v1.apkl", |
| | "models_v2/trunk.pt", |
| | "models_v2/token_embedder.pt", |
| | "models_v2/feature_embedding.pt", |
| | "models_v2/diffusion_module.pt", |
| | "models_v2/confidence_head.pt", |
| | ] |
| |
|
| | headers = { |
| | "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3" |
| | } |
| |
|
| | |
| | async with aiohttp.ClientSession(headers=headers) as session: |
| | tasks = [] |
| | for dep in inference_dependencies: |
| | local_path = models_dir / dep |
| | if force or not local_path.exists(): |
| | url = base_url + dep |
| | print(f"🧬 downloading {dep}") |
| | tasks.append(download_file(session, url, local_path)) |
| |
|
| | |
| | await asyncio.gather(*tasks) |
| |
|
| | chai_model_volume.commit() |
| |
|
| |
|
| | async def download_file(session, url: str, local_path: Path): |
| | async with session.get(url) as response: |
| | response.raise_for_status() |
| | local_path.parent.mkdir(parents=True, exist_ok=True) |
| | with open(local_path, "wb") as f: |
| | while chunk := await response.content.read(8192): |
| | f.write(chunk) |