File size: 11,169 Bytes
f34af6f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 | import os
import hydra
from omegaconf import DictConfig
import torch
import logging
import shutil
from models.proteinflow_clf_wrapperv2 import ProteinFlowModulev2
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
GPU_ID = 3 # Change this to the GPU you want to use
def setup_test_data():
"""Ensure test data directory exists and contains example PDB."""
# Use 6F57_A.pdb from current directory
example_pdb = "6F57_A.pdb"
if not os.path.exists(example_pdb):
logger.error(f"PDB file not found at {example_pdb}")
raise FileNotFoundError(f"Missing {example_pdb}")
return example_pdb
def setup_device():
"""Setup GPU device."""
if not torch.cuda.is_available():
logger.warning("CUDA not available, using CPU")
return torch.device("cpu")
if GPU_ID >= torch.cuda.device_count():
logger.warning(f"GPU {GPU_ID} not available, using GPU 0")
device = torch.device("cuda:0")
else:
device = torch.device(f"cuda:{GPU_ID}")
torch.cuda.set_device(device)
logger.info(f"Using GPU {GPU_ID}: {torch.cuda.get_device_name(device)}")
return device
def setup_model(cfg: DictConfig, device: torch.device) -> ProteinFlowModulev2:
"""Initialize and load the model.
Args:
cfg: Configuration dictionary
device: Device to load model on
Returns:
Loaded and configured model
"""
logger.info("Initializing model...")
model = ProteinFlowModulev2(cfg)
# Load checkpoint
checkpoint_path = "ckpt/se3-fm/dnmt-full-unconditioned/2024-09-26_00-42-38/epoch29.ckpt"
if not os.path.exists(checkpoint_path):
raise ValueError(f"Checkpoint not found at {checkpoint_path}")
# Load checkpoint to specified device
state_dict = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state_dict['state_dict'])
model = model.to(device)
model.eval()
# Set device for interpolant
model.interpolant.set_device(device)
# Set inference config
model._infer_cfg = cfg.inference
logger.info(f"Model and interpolant loaded on {device}")
return model
def test_central_residues(model: ProteinFlowModulev2, example_pdb: str, output_dir: str):
"""Test fixing central residues.
Args:
model: Loaded model
example_pdb: Path to example PDB file
output_dir: Directory to save outputs
"""
logger.info("Test 1: Fixing central residues...")
try:
# Fix residues 2-4 in the 5-residue helix
central_residues = [2, 3, 4]
samples = model.sample_with_fixed_residues(
pdb_path=example_pdb,
fixed_residues=central_residues,
num_samples=2,
temperature=0.8,
output_dir=output_dir
)
logger.info(f"Generated samples with fixed central residues: {samples}")
# Validate the samples
for i, sample_path in enumerate(samples):
quality_metrics = model.evaluate_structure_quality(
sample_path,
reference_pdb_path=example_pdb,
fixed_residues=central_residues
)
logger.info(f"Sample {i} quality metrics: {quality_metrics}")
except Exception as e:
logger.error(f"Error in central residue test: {e}")
logger.exception("Detailed traceback:")
def test_terminal_residues(model: ProteinFlowModulev2, example_pdb: str, output_dir: str):
"""Test fixing terminal residues.
Args:
model: Loaded model
example_pdb: Path to example PDB file
output_dir: Directory to save outputs
"""
logger.info("Test 2: Fixing terminal residues...")
try:
# Fix first and last residues
terminal_residues = [1, 5]
samples = model.sample_with_fixed_residues(
pdb_path=example_pdb,
fixed_residues=terminal_residues,
num_samples=2,
temperature=0.8,
output_dir=output_dir
)
logger.info(f"Generated samples with fixed terminals: {samples}")
for i, sample_path in enumerate(samples):
quality_metrics = model.evaluate_structure_quality(
sample_path,
reference_pdb_path=example_pdb,
fixed_residues=terminal_residues
)
logger.info(f"Sample {i} quality metrics: {quality_metrics}")
except Exception as e:
logger.error(f"Error in terminal residue test: {e}")
logger.exception("Detailed traceback:")
def test_temperature_sweep(model: ProteinFlowModulev2, example_pdb: str, output_dir: str):
"""Test different temperature values.
Args:
model: Loaded model
example_pdb: Path to example PDB file
output_dir: Directory to save outputs
"""
logger.info("Test 3: Temperature sensitivity analysis...")
temperatures = [0.5, 1.0, 1.5]
fixed_residues = [2, 3] # Fix middle residues
for temp in temperatures:
try:
samples = model.sample_with_fixed_residues(
pdb_path=example_pdb,
fixed_residues=fixed_residues,
num_samples=1,
temperature=temp,
output_dir=output_dir
)
logger.info(f"Sample at temperature {temp} saved at: {samples}")
# Save and analyze sample
quality_metrics = model.evaluate_structure_quality(
samples[0],
reference_pdb_path=example_pdb,
fixed_residues=fixed_residues
)
logger.info(f"Quality metrics at temperature {temp}: {quality_metrics}")
except Exception as e:
logger.error(f"Error at temperature {temp}: {e}")
logger.exception("Detailed traceback:")
def test_classifier_guidance(model: ProteinFlowModulev2, example_pdb: str, output_dir: str, cfg: DictConfig):
"""Test classifier-guided sampling.
Args:
model: Loaded model
example_pdb: Path to example PDB file
output_dir: Directory to save outputs
cfg: Configuration dictionary
"""
logger.info("Test 4: Classifier-guided sampling...")
try:
# Debug logging for config
logger.info("Checking inference config...")
logger.info(f"Full inference config: {cfg.inference}")
if not hasattr(cfg, 'inference'):
logger.error("No inference config found in cfg")
return
if not hasattr(cfg.inference, 'classifier'):
logger.error("No classifier config found in cfg.inference")
return
logger.info(f"Classifier config: {cfg.inference.classifier}")
# Load classifier if available
if hasattr(cfg.inference.classifier, 'ckpt_path'):
logger.info(f"Loading classifier from: {cfg.inference.classifier.ckpt_path}")
if not os.path.exists(cfg.inference.classifier.ckpt_path):
logger.error(f"Classifier checkpoint not found at: {cfg.inference.classifier.ckpt_path}")
return
try:
model.load_classifiers(cfg.inference.classifier)
logger.info("Classifier loaded successfully")
except Exception as e:
logger.error(f"Error loading classifier: {e}")
logger.exception("Detailed traceback:")
return
# Test different guidance scales
guidance_scales = [0.2]
# Use a smaller set of residues that are within the valid range for 6F57_A.pdb
fixed_residues = [639, 640, 641, 642, 643, 644, 645, 646, 647, 664, 665, 686, 687, 688, 888, 889, 890, 891, 892, 893, 894, 895]
for scale in guidance_scales:
logger.info(f"Testing with guidance scale {scale}...")
try:
samples = model.sample_with_fixed_residues(
pdb_path=example_pdb,
fixed_residues=fixed_residues,
num_samples=1,
temperature=0.8,
output_dir=os.path.join(output_dir, f"scale_{scale}"),
clf_model=model.cls_model,
guidance_scale=scale,
target_class=1
)
# Evaluate samples
for i, sample_path in enumerate(samples):
quality_metrics = model.evaluate_structure_quality(
sample_path,
reference_pdb_path=example_pdb,
fixed_residues=fixed_residues
)
logger.info(f"Sample {i} with guidance scale {scale} quality metrics: {quality_metrics}")
except Exception as e:
logger.error(f"Error during sampling with scale {scale}: {e}")
logger.exception("Detailed traceback:")
else:
logger.warning("No classifier checkpoint path found in config")
except Exception as e:
logger.error(f"Error in classifier-guided sampling test: {e}")
logger.exception("Detailed traceback:")
@hydra.main(version_base=None, config_path="./configs", config_name="inference")
def test_conditional_sampling(cfg: DictConfig):
"""Test the conditional sampling functionality.
This script:
1. Loads a pretrained model
2. Uses a small example alpha helix
3. Tests different conditional sampling scenarios
4. Saves and validates the results
"""
# Setup device
device = setup_device()
# Setup test environment
logger.info("Setting up test environment...")
example_pdb = setup_test_data()
# Create output directories
output_dir = "test_outputs"
os.makedirs(output_dir, exist_ok=True)
# Create subdirectories for different test cases
central_dir = os.path.join(output_dir, "central_fixed")
terminal_dir = os.path.join(output_dir, "terminal_fixed")
temperature_dir = os.path.join(output_dir, "temperature_sweep")
classifier_dir = os.path.join(output_dir, "classifier_guided")
os.makedirs(central_dir, exist_ok=True)
os.makedirs(terminal_dir, exist_ok=True)
os.makedirs(temperature_dir, exist_ok=True)
os.makedirs(classifier_dir, exist_ok=True)
# Initialize model
model = setup_model(cfg, device)
# Run test cases
# test_central_residues(model, example_pdb, central_dir)
# test_terminal_residues(model, example_pdb, terminal_dir)
# test_temperature_sweep(model, example_pdb, temperature_dir)
test_classifier_guidance(model, example_pdb, classifier_dir, cfg)
logger.info(f"All test outputs saved to {output_dir}")
logger.info("Testing complete!")
if __name__ == "__main__":
test_conditional_sampling() |