File size: 17,287 Bytes
ad9572d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
"""
PXDesign + Q_theta Classifier Guidance.

Monkey-patches PXDesign's diffusion sampling loop to inject Q_theta selectivity
gradient after each denoising step. This steers the diffusion trajectory toward
binder backbones that are conformationally selective.

The patched diffusion loop:
    x_denoised = denoise_net(x_noisy, t_hat, ...)
    grad = βˆ‡_{x_denoised}[Q(holo,Y) - Q(apo,Y)]   # <-- INJECTED
    x_denoised = x_denoised + scale(t) * grad        # <-- INJECTED
    delta = (x_noisy - x_denoised) / t_hat
    x_l = x_noisy + eta * dt * delta

Usage:
    python code/scripts/pxdesign_guidance/guided_pxdesign.py \
        --input experiments/pxdesign_cam/output/cam_binder.json \
        --qtheta_checkpoint results/checkpoints_cam_v3/best_phase2.pt \
        --ref_holo data/pdbs/cam_holo/3CLN.pdb \
        --ref_apo data/pdbs/cam_apo/1CFD.pdb \
        --guidance_scale 1.0 \
        --N_sample 50 --N_step 400 \
        --gpu 0
"""

import os
import sys
import argparse
import json
import logging
import time
import shutil
from typing import Callable, Optional, Union
from functools import partial

import numpy as np
import torch

logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
logger = logging.getLogger(__name__)

# ── Paths ────────────────────────────────────────────────────────────────────
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
_ALLO_CODE_DIR = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..'))
_ALLO_ROOT = os.path.abspath(os.path.join(_ALLO_CODE_DIR, '..'))
_PXDESIGN_DIR = os.environ.get('PXDESIGN_DIR', '')

if _ALLO_CODE_DIR not in sys.path:
    sys.path.insert(0, _ALLO_CODE_DIR)
if _PXDESIGN_DIR not in sys.path:
    sys.path.insert(0, _PXDESIGN_DIR)


def guided_sample_diffusion(
    denoise_net: Callable,
    input_feature_dict: dict,
    s_inputs: torch.Tensor,
    s_trunk: torch.Tensor,
    z_trunk: torch.Tensor,
    noise_schedule: torch.Tensor,
    N_sample: int = 1,
    gamma0: float = 0.8,
    gamma_min: float = 1.0,
    noise_scale_lambda: float = 1.003,
    step_scale_eta: Union[float, dict] = {"type": "const", "min": 1.5, "max": 1.5},
    diffusion_chunk_size: Optional[int] = None,
    inplace_safe: bool = False,
    attn_chunk_size: Optional[int] = None,
    # Guidance parameters (injected via partial)
    guidance_module=None,
    guidance_scale: float = 1.0,
    guidance_start: float = 0.8,
    guidance_end: float = 0.1,
) -> torch.Tensor:
    """
    Modified PXDesign sample_diffusion with Q_theta classifier guidance.

    Same as original generator.sample_diffusion but with gradient injection
    after each denoising step. The gradient is scaled by a schedule that
    applies stronger guidance at high noise levels (early steps).
    """
    from protenix.model.utils import centre_random_augmentation

    N_atom = input_feature_dict["atom_to_token_idx"].size(-1)
    batch_shape = s_inputs.shape[:-2]
    device = s_inputs.device
    dtype = s_inputs.dtype

    logger.info(f"Guided sampling: scale={guidance_scale}, "
                f"window=[{guidance_end:.1f}, {guidance_start:.1f}]")

    def _chunk_sample_diffusion_guided(chunk_n_sample, inplace_safe):
        x_l = noise_schedule[0] * torch.randn(
            size=(*batch_shape, chunk_n_sample, N_atom, 3),
            device=device, dtype=dtype
        )
        T = len(noise_schedule)

        for step_t, (c_tau_last, c_tau) in enumerate(
            zip(noise_schedule[:-1], noise_schedule[1:])
        ):
            # Centre random augmentation
            x_l = (
                centre_random_augmentation(x_input_coords=x_l, N_sample=1)
                .squeeze(dim=-3)
                .to(dtype)
            )

            # Predictor step: add noise
            gamma = float(gamma0) if c_tau > gamma_min else 0
            t_hat = c_tau_last * (gamma + 1)
            delta_noise_level = torch.sqrt(t_hat**2 - c_tau_last**2)
            x_noisy = x_l + noise_scale_lambda * delta_noise_level * torch.randn(
                size=x_l.shape, device=device, dtype=dtype
            )

            # Reshape t_hat for network
            t_hat_tensor = (
                t_hat.reshape((1,) * (len(batch_shape) + 1))
                .expand(*batch_shape, chunk_n_sample)
                .to(dtype)
            )

            # Denoise
            x_denoised = denoise_net(
                x_noisy=x_noisy,
                t_hat_noise_level=t_hat_tensor,
                input_feature_dict=input_feature_dict,
                s_inputs=s_inputs,
                s_trunk=s_trunk,
                z_trunk=z_trunk,
                chunk_size=attn_chunk_size,
                inplace_safe=inplace_safe,
            )

            # ── Q_theta guidance injection ──────────────────────────────
            if guidance_module is not None:
                # Compute progress fraction (0=start/high noise, 1=end/low noise)
                progress = step_t / (T - 1) if T > 1 else 1.0

                # Apply guidance only within the specified window
                if guidance_end <= (1.0 - progress) <= guidance_start:
                    # Handle batch dimensions
                    x_for_grad = x_denoised
                    if x_for_grad.dim() > 3:
                        x_for_grad = x_for_grad.squeeze(0)

                    # Scale: stronger at high noise, weaker near convergence
                    noise_fraction = 1.0 - progress
                    scale = guidance_scale * noise_fraction

                    try:
                        # Compute gradient for first sample (or all if small batch)
                        n_guide = min(chunk_n_sample, 4)
                        grad_accum = torch.zeros_like(x_for_grad)

                        for si in range(n_guide):
                            grad, margin = guidance_module.compute_guidance_gradient(
                                x_for_grad, input_feature_dict,
                                t_hat=t_hat, sample_idx=si
                            )
                            grad_accum[si] = grad[si] if grad.shape[0] > si else grad[0]

                        # Broadcast gradient to remaining samples
                        if n_guide < chunk_n_sample and n_guide > 0:
                            avg_grad = grad_accum[:n_guide].mean(dim=0, keepdim=True)
                            grad_accum[n_guide:] = avg_grad.expand(
                                chunk_n_sample - n_guide, -1, -1)

                        # Normalize gradient to prevent explosion
                        grad_norm = grad_accum.norm(dim=-1, keepdim=True).clamp(min=1e-8)
                        grad_normalized = grad_accum / grad_norm
                        avg_norm = grad_norm.mean().item()

                        # Apply guidance
                        if avg_norm > 1e-6:
                            # Scale by average gradient magnitude to keep step size reasonable
                            x_denoised = x_denoised + scale * avg_norm * grad_normalized

                            if step_t % 50 == 0:
                                logger.info(
                                    f"  Step {step_t}/{T}: margin={margin:.3f}, "
                                    f"grad_norm={avg_norm:.4f}, scale={scale:.3f}")
                    except Exception as e:
                        if step_t % 100 == 0:
                            logger.debug(f"  Step {step_t}: guidance failed: {e}")
            # ── End guidance ────────────────────────────────────────────

            # Euler step
            delta = (x_noisy - x_denoised) / t_hat_tensor[..., None, None]
            dt = c_tau - t_hat_tensor
            if isinstance(step_scale_eta, float):
                eta = step_scale_eta
            elif step_scale_eta["type"] == "const":
                assert step_scale_eta["min"] == step_scale_eta["max"]
                eta = step_scale_eta["min"]
            else:
                eta_min, eta_max = step_scale_eta["min"], step_scale_eta["max"]
                if step_scale_eta["type"] == "linear":
                    eta = eta_min + (eta_max - eta_min) * (step_t / T)
                elif step_scale_eta["type"] == "poly":
                    eta = eta_min + (eta_max - eta_min) * (step_t / T) ** 2
                elif step_scale_eta["type"] == "cos":
                    eta = eta_min + 0.5 * (eta_max - eta_min) * (
                        1 - np.cos(np.pi * step_t / T))
                elif step_scale_eta["type"] == "piecewise":
                    eta = eta_min if step_t / T < 0.5 else eta_max
                elif step_scale_eta["type"] == "piecewise_65":
                    eta = eta_min if step_t / T < 0.65 else eta_max
                elif step_scale_eta["type"] == "piecewise_70":
                    eta = eta_min if step_t / T < 0.70 else eta_max
                else:
                    raise ValueError("Unsupported eta schedule!")
            x_l = x_noisy + eta * dt[..., None, None] * delta

        return x_l

    # Chunked sampling
    if diffusion_chunk_size is None:
        x_l = _chunk_sample_diffusion_guided(N_sample, inplace_safe=inplace_safe)
    else:
        x_l = []
        no_chunks = N_sample // diffusion_chunk_size + (
            N_sample % diffusion_chunk_size != 0)
        for i in range(no_chunks):
            chunk_n_sample = (
                diffusion_chunk_size
                if i < no_chunks - 1
                else N_sample - i * diffusion_chunk_size
            )
            chunk_x_l = _chunk_sample_diffusion_guided(
                chunk_n_sample, inplace_safe=inplace_safe)
            x_l.append(chunk_x_l)
        x_l = torch.cat(x_l, -3)

    return x_l


def run_guided_pxdesign(args):
    """Run PXDesign with Q_theta classifier guidance."""
    if 'CUDA_VISIBLE_DEVICES' not in os.environ:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)

    # Import PXDesign components
    from pxdesign.runner.inference import InferenceRunner, main as pxdesign_main
    from pxdesign.utils.infer import (
        get_configs, convert_to_bioassembly_dict, download_inference_cache, derive_seed
    )
    from pxdesign.utils.inputs import process_input_file
    from protenix.config import save_config
    from protenix.utils.seed import seed_everything
    from protenix.utils.torch_utils import autocasting_disable_decorator

    from qtheta_pxdesign import QThetaPXDesignGuidance

    # Set up output directory
    outdir = args.outdir if os.path.isabs(args.outdir) else os.path.join(_ALLO_ROOT, args.outdir)
    os.makedirs(outdir, exist_ok=True)

    # Build PXDesign CLI arguments
    pxdesign_argv = [
        '--dump_dir', outdir,
        '--input', args.input,
        '--dtype', 'bf16',
        '--N_sample', str(args.N_sample),
        '--N_step', str(args.N_step),
    ]

    configs = get_configs(pxdesign_argv)
    configs.input_json_path = process_input_file(
        configs.input_json_path, out_dir=outdir)
    download_inference_cache(configs)

    # Convert inputs
    save_config(configs, os.path.join(outdir, "config.yaml"))
    with open(configs.input_json_path, "r") as f:
        orig_inputs = json.load(f)
    for x in orig_inputs:
        convert_to_bioassembly_dict(x, outdir)
    configs.input_json_path = os.path.join(outdir, "input_tasks.json")
    with open(configs.input_json_path, "w") as f:
        json.dump(orig_inputs, f, indent=4)

    # Create runner
    runner = InferenceRunner(configs)

    # Initialize Q_theta guidance
    guidance = QThetaPXDesignGuidance(
        checkpoint=args.qtheta_checkpoint if os.path.isabs(args.qtheta_checkpoint) else os.path.join(_ALLO_ROOT, args.qtheta_checkpoint),
        ref_holo=args.ref_holo if os.path.isabs(args.ref_holo) else os.path.join(_ALLO_ROOT, args.ref_holo),
        ref_apo=args.ref_apo if os.path.isabs(args.ref_apo) else os.path.join(_ALLO_ROOT, args.ref_apo),
        ref_chain=args.ref_chain,
        device='cuda:0',  # After CUDA_VISIBLE_DEVICES remapping
        esm_target=args.esm_target,
    )

    # Monkey-patch the sample_diffusion function
    from pxdesign.model import generator as pxdesign_generator
    import pxdesign.model.pxdesign as pxdesign_model

    # Create guided version with guidance params bound
    guided_fn = partial(
        guided_sample_diffusion,
        guidance_module=guidance,
        guidance_scale=args.guidance_scale,
        guidance_start=args.guidance_start,
        guidance_end=args.guidance_end,
    )

    # Patch the module-level function in generator.py
    pxdesign_generator.sample_diffusion = guided_fn

    # CRITICAL: pxdesign.py does `from pxdesign.model.generator import sample_diffusion`
    # which creates a local binding in pxdesign.model.pxdesign namespace.
    # We must patch that local binding too, otherwise the ProtenixDesign.sample_diffusion()
    # method will still call the original unpatched function.
    pxdesign_model.sample_diffusion = guided_fn

    logger.info("PXDesign diffusion loop patched with Q_theta guidance")

    # Run inference
    seeds = [derive_seed(time.time_ns())] if not configs.seeds else configs.seeds
    for seed in seeds:
        logger.info(f"Running guided inference with seed {seed}")
        seed_everything(seed=seed, deterministic=False)
        runner._inference(seed)

    # Score all generated designs
    logger.info("Scoring generated designs...")
    from glob import glob

    pdb_dir = outdir
    pdbs = []
    for ext in ('*.pdb', '*.cif'):
        pdbs.extend(glob(os.path.join(pdb_dir, '**/' + ext), recursive=True))
    pdbs = sorted([p for p in pdbs if 'sample' in os.path.basename(p).lower()])

    results = []
    for i, pdb_path in enumerate(pdbs):
        design_id = os.path.basename(pdb_path).replace('.pdb', '').replace('.cif', '')
        result = guidance.score_design(pdb_path)
        if result is not None:
            result['design_id'] = design_id
            result['pdb_path'] = pdb_path
            results.append(result)
            logger.info(
                f"[{i+1}/{len(pdbs)}] {design_id}: "
                f"Q+={result['q_holo']:.3f} Q-={result['q_apo']:.3f} "
                f"S={result['margin']:+.3f}")

    # Save results
    if results:
        results.sort(key=lambda x: x['margin'], reverse=True)
        margins = np.array([r['margin'] for r in results])

        summary = {
            'method': 'PXDesign + Classifier Guidance',
            'n_designs': len(results),
            'guidance_scale': args.guidance_scale,
            'guidance_window': [args.guidance_end, args.guidance_start],
            'margin_mean': float(margins.mean()),
            'margin_std': float(margins.std()),
            'frac_positive': float((margins > 0).mean()),
            'q_holo_mean': float(np.mean([r['q_holo'] for r in results])),
            'q_apo_mean': float(np.mean([r['q_apo'] for r in results])),
        }

        with open(os.path.join(outdir, 'guided_scores.json'), 'w') as f:
            json.dump(results, f, indent=2)
        with open(os.path.join(outdir, 'guided_summary.json'), 'w') as f:
            json.dump(summary, f, indent=2)

        logger.info(f"\n{'='*60}")
        logger.info(f"PXDesign + Classifier Guidance Results ({len(results)} designs)")
        logger.info(f"  Margin: {margins.mean():.3f} Β± {margins.std():.3f}")
        logger.info(f"  Fraction S > 0: {(margins > 0).mean():.1%}")
        logger.info(f"  Q(holo) mean: {summary['q_holo_mean']:.3f}")
        logger.info(f"{'='*60}")


def main():
    parser = argparse.ArgumentParser(description='PXDesign + Q_theta Classifier Guidance')
    parser.add_argument('--input', default='experiments/pxdesign_cam/output/cam_binder.json',
                        help='PXDesign input JSON')
    parser.add_argument('--qtheta_checkpoint',
                        default='results/checkpoints_cam_v3/best_phase2.pt')
    parser.add_argument('--ref_holo', default='data/pdbs/cam_holo/3CLN.pdb')
    parser.add_argument('--ref_apo', default='data/pdbs/cam_apo/1CFD.pdb')
    parser.add_argument('--ref_chain', default='A')
    parser.add_argument('--guidance_scale', type=float, default=1.0,
                        help='Guidance gradient scale')
    parser.add_argument('--guidance_start', type=float, default=0.8,
                        help='Start guidance at this noise fraction (high noise)')
    parser.add_argument('--guidance_end', type=float, default=0.1,
                        help='Stop guidance at this noise fraction (low noise)')
    parser.add_argument('--N_sample', type=int, default=50)
    parser.add_argument('--N_step', type=int, default=400)
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--outdir', default='results/pxdesign_guided')
    parser.add_argument('--esm_target', default='cam',
                        help='Subdir under data/esm2_embeddings (e.g., adk, cam)')
    args = parser.parse_args()

    run_guided_pxdesign(args)


if __name__ == '__main__':
    main()