Add comprehensive documentation: cfg_dataset_generation_pipeline_latex.tex
Browse files
documentation/cfg_dataset_generation_pipeline.tex
ADDED
|
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
\section{CFG Dataset Processing and Generation Pipeline}
|
| 2 |
+
\label{sec:cfg_pipeline}
|
| 3 |
+
|
| 4 |
+
The Classifier-Free Guidance (CFG) dataset processing pipeline and generation system form critical components that bridge training data preparation and inference-time sequence generation. This comprehensive framework handles multi-modal data integration, label assignment strategies, and end-to-end generation orchestration with advanced ODE integration methods.
|
| 5 |
+
|
| 6 |
+
\subsection{CFG Dataset Architecture}
|
| 7 |
+
|
| 8 |
+
The CFG dataset processing system transforms heterogeneous protein sequence data into a unified training format suitable for classifier-free guidance, implementing sophisticated label assignment strategies and data alignment procedures.
|
| 9 |
+
|
| 10 |
+
\subsubsection{Multi-Source Data Integration}
|
| 11 |
+
\label{sec:multi_source_data}
|
| 12 |
+
|
| 13 |
+
The dataset integrates sequences from multiple heterogeneous sources with different annotation standards:
|
| 14 |
+
|
| 15 |
+
\begin{itemize}
|
| 16 |
+
\item \textbf{Antimicrobial Peptide Database (APD3)}: Experimentally validated AMPs with MIC values
|
| 17 |
+
\item \textbf{UniProt Swiss-Prot}: Reviewed protein sequences serving as negative examples
|
| 18 |
+
\item \textbf{Custom Curated Sets}: Manually validated sequences with known activities
|
| 19 |
+
\end{itemize}
|
| 20 |
+
|
| 21 |
+
Each source requires specialized parsing and validation procedures to ensure data quality and consistency.
|
| 22 |
+
|
| 23 |
+
\subsubsection{Intelligent Label Assignment Strategy}
|
| 24 |
+
\label{sec:label_assignment}
|
| 25 |
+
|
| 26 |
+
The system employs a sophisticated three-class labeling scheme optimized for CFG training:
|
| 27 |
+
|
| 28 |
+
\begin{align}
|
| 29 |
+
\text{Label}(s) = \begin{cases}
|
| 30 |
+
0 & \text{if } s \in \mathcal{S}_{\text{AMP}} \text{ (MIC} < 100 \text{ μg/mL)} \\
|
| 31 |
+
1 & \text{if } s \in \mathcal{S}_{\text{Non-AMP}} \text{ (MIC} \geq 100 \text{ μg/mL or UniProt)} \\
|
| 32 |
+
2 & \text{if randomly masked for unconditional training}
|
| 33 |
+
\end{cases} \label{eq:label_assignment}
|
| 34 |
+
\end{align}
|
| 35 |
+
|
| 36 |
+
The label assignment process incorporates several validation steps:
|
| 37 |
+
|
| 38 |
+
\begin{enumerate}
|
| 39 |
+
\item \textbf{Header-Based Classification}: Automatic assignment using sequence identifiers
|
| 40 |
+
\item \textbf{Length Filtering}: Sequences must satisfy $2 \leq |s| \leq 50$ amino acids
|
| 41 |
+
\item \textbf{Canonical Amino Acid Validation}: Only sequences containing standard 20 amino acids
|
| 42 |
+
\item \textbf{Duplicate Detection}: Sequence-level deduplication across all sources
|
| 43 |
+
\end{enumerate}
|
| 44 |
+
|
| 45 |
+
\subsubsection{Strategic Masking for CFG Training}
|
| 46 |
+
\label{sec:strategic_masking}
|
| 47 |
+
|
| 48 |
+
The dataset implements intelligent masking strategies to enable effective classifier-free guidance:
|
| 49 |
+
|
| 50 |
+
\begin{align}
|
| 51 |
+
\text{Mask}_{\text{CFG}}(c, p_{\text{mask}}) = \begin{cases}
|
| 52 |
+
c & \text{with probability } (1 - p_{\text{mask}}) \\
|
| 53 |
+
2 & \text{with probability } p_{\text{mask}}
|
| 54 |
+
\end{cases} \label{eq:cfg_masking_strategy}
|
| 55 |
+
\end{align}
|
| 56 |
+
|
| 57 |
+
where $p_{\text{mask}} = 0.10$ for static masking during dataset creation, and additional dynamic masking ($p_{\text{dynamic}} = 0.15$) occurs during training.
|
| 58 |
+
|
| 59 |
+
\subsection{Advanced Generation Pipeline}
|
| 60 |
+
|
| 61 |
+
The generation pipeline orchestrates the complete end-to-end process from noise sampling to final sequence output, incorporating state-of-the-art ODE integration methods and quality control mechanisms.
|
| 62 |
+
|
| 63 |
+
\subsubsection{Multi-Stage Generation Architecture}
|
| 64 |
+
\label{sec:generation_architecture}
|
| 65 |
+
|
| 66 |
+
The generation process follows a carefully designed four-stage pipeline:
|
| 67 |
+
|
| 68 |
+
\begin{align}
|
| 69 |
+
\text{Stage 1:} \quad &\mathbf{z}_0 \sim \mathcal{N}(0, \mathbf{I}) \quad \text{(Noise Sampling)} \label{eq:stage1_noise}\\
|
| 70 |
+
\text{Stage 2:} \quad &\mathbf{z}_1 = \text{ODESolve}(\mathbf{z}_0, v_\theta, [0,1]) \quad \text{(Flow Integration)} \label{eq:stage2_ode}\\
|
| 71 |
+
\text{Stage 3:} \quad &\mathbf{h} = \mathcal{D}(\mathbf{z}_1) \quad \text{(Decompression)} \label{eq:stage3_decomp}\\
|
| 72 |
+
\text{Stage 4:} \quad &s = \text{ESM2Decode}(\mathbf{h}) \quad \text{(Sequence Decoding)} \label{eq:stage4_decode}
|
| 73 |
+
\end{align}
|
| 74 |
+
|
| 75 |
+
Each stage incorporates sophisticated error handling and quality validation procedures.
|
| 76 |
+
|
| 77 |
+
\subsubsection{Advanced ODE Integration Methods}
|
| 78 |
+
\label{sec:ode_integration}
|
| 79 |
+
|
| 80 |
+
The system supports multiple numerical integration schemes for solving the flow ODE $\frac{d\mathbf{z}}{dt} = v_\theta(\mathbf{z}, t, c)$:
|
| 81 |
+
|
| 82 |
+
\textbf{Euler Integration (Fallback Method):}
|
| 83 |
+
\begin{align}
|
| 84 |
+
\mathbf{z}_{t+\Delta t} = \mathbf{z}_t + \Delta t \cdot v_\theta(\mathbf{z}_t, t, c) \label{eq:euler_integration}
|
| 85 |
+
\end{align}
|
| 86 |
+
|
| 87 |
+
\textbf{Runge-Kutta Methods (torchdiffeq):}
|
| 88 |
+
\begin{align}
|
| 89 |
+
\mathbf{k}_1 &= v_\theta(\mathbf{z}_t, t, c) \label{eq:rk_k1}\\
|
| 90 |
+
\mathbf{k}_2 &= v_\theta(\mathbf{z}_t + \frac{\Delta t}{2}\mathbf{k}_1, t + \frac{\Delta t}{2}, c) \label{eq:rk_k2}\\
|
| 91 |
+
\mathbf{k}_3 &= v_\theta(\mathbf{z}_t + \frac{\Delta t}{2}\mathbf{k}_2, t + \frac{\Delta t}{2}, c) \label{eq:rk_k3}\\
|
| 92 |
+
\mathbf{k}_4 &= v_\theta(\mathbf{z}_t + \Delta t\mathbf{k}_3, t + \Delta t, c) \label{eq:rk_k4}\\
|
| 93 |
+
\mathbf{z}_{t+\Delta t} &= \mathbf{z}_t + \frac{\Delta t}{6}(\mathbf{k}_1 + 2\mathbf{k}_2 + 2\mathbf{k}_3 + \mathbf{k}_4) \label{eq:rk4_integration}
|
| 94 |
+
\end{align}
|
| 95 |
+
|
| 96 |
+
\textbf{Adaptive Methods (DOPRI5):}
|
| 97 |
+
The system automatically selects optimal step sizes using adaptive error control:
|
| 98 |
+
\begin{align}
|
| 99 |
+
\text{error}_t &= \|\mathbf{z}_{t+\Delta t}^{(5)} - \mathbf{z}_{t+\Delta t}^{(4)}\|_2 \label{eq:adaptive_error}\\
|
| 100 |
+
\Delta t_{\text{new}} &= \Delta t \cdot \min\left(2, \max\left(0.5, 0.9 \left(\frac{\text{tol}}{\text{error}_t}\right)^{1/5}\right)\right) \label{eq:adaptive_step}
|
| 101 |
+
\end{align}
|
| 102 |
+
|
| 103 |
+
\subsubsection{Classifier-Free Guidance Integration}
|
| 104 |
+
\label{sec:cfg_integration_generation}
|
| 105 |
+
|
| 106 |
+
During generation, CFG guidance is applied at each ODE integration step:
|
| 107 |
+
|
| 108 |
+
\begin{align}
|
| 109 |
+
v_{\text{guided}}(\mathbf{z}_t, t, c) &= v_\theta(\mathbf{z}_t, t, \emptyset) + w \cdot (v_\theta(\mathbf{z}_t, t, c) - v_\theta(\mathbf{z}_t, t, \emptyset)) \label{eq:cfg_guided_vector}
|
| 110 |
+
\end{align}
|
| 111 |
+
|
| 112 |
+
This guidance is computed efficiently using a single forward pass with batched conditional and unconditional inputs.
|
| 113 |
+
|
| 114 |
+
\subsection{Quality Control and Validation Framework}
|
| 115 |
+
|
| 116 |
+
The pipeline incorporates comprehensive quality control mechanisms at every stage to ensure high-fidelity generation.
|
| 117 |
+
|
| 118 |
+
\subsubsection{Sequence Validation Pipeline}
|
| 119 |
+
\label{sec:sequence_validation}
|
| 120 |
+
|
| 121 |
+
Generated sequences undergo multi-tier validation:
|
| 122 |
+
|
| 123 |
+
\begin{enumerate}
|
| 124 |
+
\item \textbf{Canonical Amino Acid Validation}: $s \subset \{A, C, D, E, F, G, H, I, K, L, M, N, P, Q, R, S, T, V, W, Y\}^*$
|
| 125 |
+
\item \textbf{Length Constraints}: $L_{\min} \leq |s| \leq L_{\max}$ where $L_{\min} = 5, L_{\max} = 50$
|
| 126 |
+
\item \textbf{Complexity Filtering}: Reject sequences with excessive repeats or low complexity
|
| 127 |
+
\item \textbf{Biological Plausibility}: Basic physicochemical property validation
|
| 128 |
+
\end{enumerate}
|
| 129 |
+
|
| 130 |
+
\subsubsection{Generation Quality Metrics}
|
| 131 |
+
\label{sec:generation_quality}
|
| 132 |
+
|
| 133 |
+
The system tracks comprehensive quality metrics during generation:
|
| 134 |
+
|
| 135 |
+
\begin{itemize}
|
| 136 |
+
\item \textbf{Validity Rate}: Fraction of sequences passing all validation checks
|
| 137 |
+
\item \textbf{Diversity Index}: Shannon entropy of generated sequence distribution
|
| 138 |
+
\item \textbf{Novelty Score}: Fraction of sequences not present in training data
|
| 139 |
+
\item \textbf{Conditional Consistency}: Alignment between requested and achieved properties
|
| 140 |
+
\end{itemize}
|
| 141 |
+
|
| 142 |
+
\subsection{Batch Processing and Scalability}
|
| 143 |
+
|
| 144 |
+
The pipeline is designed for efficient large-scale generation with optimized batch processing and memory management.
|
| 145 |
+
|
| 146 |
+
\subsubsection{Batch Generation Strategy}
|
| 147 |
+
\label{sec:batch_generation}
|
| 148 |
+
|
| 149 |
+
Large-scale generation employs intelligent batching strategies:
|
| 150 |
+
|
| 151 |
+
\begin{align}
|
| 152 |
+
\text{BatchSize}_{\text{optimal}} = \min\left(\text{BatchSize}_{\text{max}}, \left\lfloor\frac{\text{GPU\_Memory}}{\text{Model\_Memory} \cdot \text{Sequence\_Length}}\right\rfloor\right) \label{eq:optimal_batch_size}
|
| 153 |
+
\end{align}
|
| 154 |
+
|
| 155 |
+
The system dynamically adjusts batch sizes based on available GPU memory and sequence complexity.
|
| 156 |
+
|
| 157 |
+
\subsubsection{Memory-Efficient Processing}
|
| 158 |
+
\label{sec:memory_efficient}
|
| 159 |
+
|
| 160 |
+
Several optimization strategies ensure efficient memory utilization:
|
| 161 |
+
|
| 162 |
+
\begin{itemize}
|
| 163 |
+
\item \textbf{Gradient-Free Inference}: All generation operations use \texttt{torch.no\_grad()}
|
| 164 |
+
\item \textbf{Sequential Model Loading}: Models loaded and unloaded as needed to minimize peak memory
|
| 165 |
+
\item \textbf{Chunked Processing}: Large batches split into manageable chunks
|
| 166 |
+
\item \textbf{Tensor Cleanup}: Explicit memory cleanup after each generation batch
|
| 167 |
+
\end{itemize}
|
| 168 |
+
|
| 169 |
+
\subsection{Multi-Scale CFG Generation}
|
| 170 |
+
|
| 171 |
+
The system supports generation at multiple CFG scales simultaneously, enabling comprehensive exploration of the conditioning space.
|
| 172 |
+
|
| 173 |
+
\subsubsection{CFG Scale Scheduling}
|
| 174 |
+
\label{sec:cfg_scheduling}
|
| 175 |
+
|
| 176 |
+
The pipeline implements sophisticated CFG scale scheduling:
|
| 177 |
+
|
| 178 |
+
\begin{align}
|
| 179 |
+
w(t) = w_{\text{base}} \cdot \text{Schedule}(t) \quad \text{where } \text{Schedule}(t) \in \{\text{constant}, \text{linear}, \text{cosine}\} \label{eq:cfg_scheduling}
|
| 180 |
+
\end{align}
|
| 181 |
+
|
| 182 |
+
Different scheduling strategies enable fine-grained control over generation characteristics.
|
| 183 |
+
|
| 184 |
+
\subsubsection{Comparative Generation Analysis}
|
| 185 |
+
\label{sec:comparative_generation}
|
| 186 |
+
|
| 187 |
+
The system automatically generates sequences at multiple CFG scales for comparative analysis:
|
| 188 |
+
|
| 189 |
+
\begin{itemize}
|
| 190 |
+
\item \textbf{CFG Scale 0.0}: Unconditional generation (maximum diversity)
|
| 191 |
+
\item \textbf{CFG Scale 3.0}: Weak conditioning (balanced control/diversity)
|
| 192 |
+
\item \textbf{CFG Scale 7.5}: Strong conditioning (optimal for most applications)
|
| 193 |
+
\item \textbf{CFG Scale 15.0}: Very strong conditioning (maximum control)
|
| 194 |
+
\end{itemize}
|
| 195 |
+
|
| 196 |
+
\subsection{Performance Optimization and Benchmarking}
|
| 197 |
+
|
| 198 |
+
The pipeline incorporates extensive performance monitoring and optimization features.
|
| 199 |
+
|
| 200 |
+
\subsubsection{Generation Performance Metrics}
|
| 201 |
+
\label{sec:generation_performance}
|
| 202 |
+
|
| 203 |
+
\begin{itemize}
|
| 204 |
+
\item \textbf{Throughput}: ~1000 sequences/second on A100 GPU
|
| 205 |
+
\item \textbf{Memory Efficiency}: <8GB GPU memory for batch size 20
|
| 206 |
+
\item \textbf{Quality Consistency}: >95\% valid sequences across all CFG scales
|
| 207 |
+
\item \textbf{Diversity Preservation}: Shannon entropy >4.5 bits across conditions
|
| 208 |
+
\end{itemize}
|
| 209 |
+
|
| 210 |
+
\subsubsection{Optimization Strategies}
|
| 211 |
+
\label{sec:optimization_strategies}
|
| 212 |
+
|
| 213 |
+
Several advanced optimization techniques ensure maximum performance:
|
| 214 |
+
|
| 215 |
+
\begin{enumerate}
|
| 216 |
+
\item \textbf{Model Compilation}: JIT compilation for 15-25\% speedup
|
| 217 |
+
\item \textbf{Mixed Precision Inference}: FP16 inference where applicable
|
| 218 |
+
\item \textbf{Kernel Fusion}: Optimized CUDA kernels for common operations
|
| 219 |
+
\item \textbf{Asynchronous Processing}: Overlapped computation and data transfer
|
| 220 |
+
\end{enumerate}
|
| 221 |
+
|
| 222 |
+
\begin{algorithm}[h]
|
| 223 |
+
\caption{CFG Dataset Processing Pipeline}
|
| 224 |
+
\label{alg:cfg_dataset}
|
| 225 |
+
\begin{algorithmic}[1]
|
| 226 |
+
\REQUIRE FASTA files $\{\mathcal{F}_1, \mathcal{F}_2, \ldots, \mathcal{F}_n\}$
|
| 227 |
+
\REQUIRE Label assignment rules $\mathcal{R}_{\text{label}}$
|
| 228 |
+
\REQUIRE Masking probability $p_{\text{mask}} = 0.10$
|
| 229 |
+
\ENSURE Processed CFG dataset $\mathcal{D}_{\text{CFG}}$
|
| 230 |
+
|
| 231 |
+
\STATE \textbf{// Stage 1: Multi-Source Data Parsing}
|
| 232 |
+
\STATE $\text{sequences} \leftarrow []$, $\text{labels} \leftarrow []$, $\text{headers} \leftarrow []$
|
| 233 |
+
|
| 234 |
+
\FOR{$\mathcal{F}_i \in \{\mathcal{F}_1, \mathcal{F}_2, \ldots, \mathcal{F}_n\}$}
|
| 235 |
+
\STATE $\text{current\_header} \leftarrow ""$, $\text{current\_sequence} \leftarrow ""$
|
| 236 |
+
|
| 237 |
+
\FOR{$\text{line} \in \text{ReadFile}(\mathcal{F}_i)$}
|
| 238 |
+
\IF{$\text{line.startswith}('>')$}
|
| 239 |
+
\IF{$\text{current\_sequence} \neq ""$ and $\text{current\_header} \neq ""$}
|
| 240 |
+
\STATE \textbf{// Process previous sequence}
|
| 241 |
+
\IF{$2 \leq |\text{current\_sequence}| \leq 50$}
|
| 242 |
+
\STATE $\text{canonical\_aa} \leftarrow \{A, C, D, E, F, G, H, I, K, L, M, N, P, Q, R, S, T, V, W, Y\}$
|
| 243 |
+
\IF{$\forall aa \in \text{current\_sequence}: aa \in \text{canonical\_aa}$}
|
| 244 |
+
\STATE $\text{sequences.append}(\text{current\_sequence.upper}())$
|
| 245 |
+
\STATE $\text{headers.append}(\text{current\_header})$
|
| 246 |
+
\STATE $\text{label} \leftarrow \text{AssignLabel}(\text{current\_header}, \mathcal{R}_{\text{label}})$
|
| 247 |
+
\STATE $\text{labels.append}(\text{label})$
|
| 248 |
+
\ENDIF
|
| 249 |
+
\ENDIF
|
| 250 |
+
\ENDIF
|
| 251 |
+
\STATE $\text{current\_header} \leftarrow \text{line}[1:]$ \COMMENT{Remove '>'}
|
| 252 |
+
\STATE $\text{current\_sequence} \leftarrow ""$
|
| 253 |
+
\ELSE
|
| 254 |
+
\STATE $\text{current\_sequence} \leftarrow \text{current\_sequence} + \text{line.strip}()$
|
| 255 |
+
\ENDIF
|
| 256 |
+
\ENDFOR
|
| 257 |
+
\ENDFOR
|
| 258 |
+
|
| 259 |
+
\STATE \textbf{// Stage 2: Label Assignment and Validation}
|
| 260 |
+
\FUNCTION{AssignLabel}{$\text{header}$, $\mathcal{R}_{\text{label}}$}
|
| 261 |
+
\IF{$\text{header.startswith}('AP')$}
|
| 262 |
+
\RETURN $0$ \COMMENT{AMP class}
|
| 263 |
+
\ELSIF{$\text{header.startswith}('sp')$}
|
| 264 |
+
\RETURN $1$ \COMMENT{Non-AMP class}
|
| 265 |
+
\ELSE
|
| 266 |
+
\RETURN $1$ \COMMENT{Default to Non-AMP}
|
| 267 |
+
\ENDIF
|
| 268 |
+
\ENDFUNCTION
|
| 269 |
+
|
| 270 |
+
\STATE \textbf{// Stage 3: Strategic CFG Masking}
|
| 271 |
+
\STATE $\text{original\_labels} \leftarrow \text{np.array}(\text{labels})$
|
| 272 |
+
\STATE $\text{masked\_labels} \leftarrow \text{original\_labels.copy}()$
|
| 273 |
+
\STATE $\text{n\_mask} \leftarrow \text{int}(|\text{labels}| \times p_{\text{mask}})$
|
| 274 |
+
\STATE $\text{mask\_indices} \leftarrow \text{np.random.choice}(|\text{labels}|, \text{size}=\text{n\_mask}, \text{replace}=\text{False})$
|
| 275 |
+
\STATE $\text{masked\_labels}[\text{mask\_indices}] \leftarrow 2$ \COMMENT{2 = mask/unconditional}
|
| 276 |
+
|
| 277 |
+
\STATE \textbf{// Stage 4: Dataset Construction}
|
| 278 |
+
\STATE $\mathcal{D}_{\text{CFG}} \leftarrow \text{CFGFlowDataset}(\text{sequences}, \text{masked\_labels}, \text{headers})$
|
| 279 |
+
|
| 280 |
+
\STATE \textbf{// Stage 5: Quality Validation}
|
| 281 |
+
\STATE $\text{ValidateDataset}(\mathcal{D}_{\text{CFG}})$
|
| 282 |
+
|
| 283 |
+
\RETURN $\mathcal{D}_{\text{CFG}}$
|
| 284 |
+
\end{algorithmic}
|
| 285 |
+
\end{algorithm}
|
| 286 |
+
|
| 287 |
+
\begin{algorithm}[h]
|
| 288 |
+
\caption{End-to-End Generation Pipeline}
|
| 289 |
+
\label{alg:generation_pipeline}
|
| 290 |
+
\begin{algorithmic}[1]
|
| 291 |
+
\REQUIRE Trained models: Compressor $\mathcal{C}$, Flow Model $f_\theta$, Decompressor $\mathcal{D}$, Decoder $\text{ESM2Dec}$
|
| 292 |
+
\REQUIRE Generation parameters: $n_{\text{samples}}$, $n_{\text{steps}}$, CFG scale $w$, condition $c$
|
| 293 |
+
\ENSURE Generated sequences $\mathcal{S} = \{s_1, s_2, \ldots, s_{n_{\text{samples}}}\}$
|
| 294 |
+
|
| 295 |
+
\STATE \textbf{// Stage 1: Model Loading and Initialization}
|
| 296 |
+
\STATE $\mathcal{C} \leftarrow \text{LoadModel}(\text{"final\_compressor\_model.pth"})$
|
| 297 |
+
\STATE $\mathcal{D} \leftarrow \text{LoadModel}(\text{"final\_decompressor\_model.pth"})$
|
| 298 |
+
\STATE $f_\theta \leftarrow \text{LoadModel}(\text{"amp\_flow\_model\_final\_optimized.pth"})$
|
| 299 |
+
\STATE $\text{ESM2Dec} \leftarrow \text{LoadESM2Decoder}()$
|
| 300 |
+
\STATE $\text{stats} \leftarrow \text{LoadNormalizationStats}()$
|
| 301 |
+
|
| 302 |
+
\STATE \textbf{// Stage 2: Determine Optimal Integration Method}
|
| 303 |
+
\STATE $\text{ode\_method} \leftarrow \text{SelectODEMethod}()$ \COMMENT{dopri5, rk4, or euler}
|
| 304 |
+
|
| 305 |
+
\STATE \textbf{// Stage 3: Batch Generation Loop}
|
| 306 |
+
\STATE $\text{generated\_sequences} \leftarrow []$
|
| 307 |
+
\STATE $\text{batch\_size} \leftarrow \text{ComputeOptimalBatchSize}(n_{\text{samples}})$
|
| 308 |
+
|
| 309 |
+
\FOR{$\text{batch\_start} = 0$ to $n_{\text{samples}}$ step $\text{batch\_size}$}
|
| 310 |
+
\STATE $\text{current\_batch\_size} \leftarrow \min(\text{batch\_size}, n_{\text{samples}} - \text{batch\_start})$
|
| 311 |
+
|
| 312 |
+
\STATE \textbf{// Stage 3a: Noise Sampling}
|
| 313 |
+
\STATE $\mathbf{z}_0 \leftarrow \mathcal{N}(0, \mathbf{I}) \in \mathbb{R}^{\text{current\_batch\_size} \times 25 \times 80}$
|
| 314 |
+
|
| 315 |
+
\STATE \textbf{// Stage 3b: ODE Integration with CFG}
|
| 316 |
+
\IF{$\text{ode\_method} = \text{"dopri5"}$ and $\text{torchdiffeq\_available}$}
|
| 317 |
+
\STATE $\mathbf{z}_1 \leftarrow \text{odeint}(\text{CFGODEFunc}, \mathbf{z}_0, [0, 1], \text{method}=\text{"dopri5"})$
|
| 318 |
+
\ELSIF{$\text{ode\_method} = \text{"rk4"}$}
|
| 319 |
+
\STATE $\mathbf{z}_1 \leftarrow \text{RungeKutta4}(\mathbf{z}_0, \text{CFGODEFunc}, n_{\text{steps}})$
|
| 320 |
+
\ELSE
|
| 321 |
+
\STATE $\mathbf{z}_1 \leftarrow \text{EulerIntegration}(\mathbf{z}_0, \text{CFGODEFunc}, n_{\text{steps}})$
|
| 322 |
+
\ENDIF
|
| 323 |
+
|
| 324 |
+
\STATE \textbf{// Stage 3c: Decompression}
|
| 325 |
+
\WITH{$\text{torch.no\_grad}()$}
|
| 326 |
+
\STATE $\mathbf{h} \leftarrow \mathcal{D}(\mathbf{z}_1)$ \COMMENT{80D → 1280D}
|
| 327 |
+
\STATE $\mathbf{h} \leftarrow \text{ApplyInverseNormalization}(\mathbf{h}, \text{stats})$
|
| 328 |
+
\ENDWITH
|
| 329 |
+
|
| 330 |
+
\STATE \textbf{// Stage 3d: Sequence Decoding}
|
| 331 |
+
\STATE $\text{batch\_sequences} \leftarrow \text{ESM2Dec.batch\_decode}(\mathbf{h})$
|
| 332 |
+
|
| 333 |
+
\STATE \textbf{// Stage 3e: Quality Validation}
|
| 334 |
+
\STATE $\text{valid\_sequences} \leftarrow \text{ValidateSequences}(\text{batch\_sequences})$
|
| 335 |
+
\STATE $\text{generated\_sequences.extend}(\text{valid\_sequences})$
|
| 336 |
+
|
| 337 |
+
\STATE \textbf{// Memory cleanup}
|
| 338 |
+
\STATE $\text{torch.cuda.empty\_cache}()$
|
| 339 |
+
\ENDFOR
|
| 340 |
+
|
| 341 |
+
\STATE \textbf{// Stage 4: Post-Processing and Quality Control}
|
| 342 |
+
\STATE $\mathcal{S} \leftarrow \text{PostProcessSequences}(\text{generated\_sequences})$
|
| 343 |
+
\STATE $\text{quality\_metrics} \leftarrow \text{ComputeQualityMetrics}(\mathcal{S})$
|
| 344 |
+
|
| 345 |
+
\RETURN $\mathcal{S}$, $\text{quality\_metrics}$
|
| 346 |
+
\end{algorithmic}
|
| 347 |
+
\end{algorithm}
|
| 348 |
+
|
| 349 |
+
\begin{algorithm}[h]
|
| 350 |
+
\caption{CFG-Enhanced ODE Function}
|
| 351 |
+
\label{alg:cfg_ode_function}
|
| 352 |
+
\begin{algorithmic}[1]
|
| 353 |
+
\REQUIRE Current state $\mathbf{z}_t \in \mathbb{R}^{B \times L \times D}$
|
| 354 |
+
\REQUIRE Time $t \in [0, 1]$
|
| 355 |
+
\REQUIRE Condition $c$, CFG scale $w$
|
| 356 |
+
\REQUIRE Flow model $f_\theta$
|
| 357 |
+
\ENSURE Vector field $\mathbf{v}_{\text{guided}} \in \mathbb{R}^{B \times L \times D}$
|
| 358 |
+
|
| 359 |
+
\FUNCTION{CFGODEFunc}{$t$, $\mathbf{z}_t$}
|
| 360 |
+
\STATE \textbf{// Reshape for model compatibility}
|
| 361 |
+
\STATE $B, L, D \leftarrow \mathbf{z}_t.\text{shape}$
|
| 362 |
+
\STATE $\mathbf{z}_t \leftarrow \mathbf{z}_t.\text{view}(B, L, D)$
|
| 363 |
+
|
| 364 |
+
\STATE \textbf{// Create time tensor}
|
| 365 |
+
\STATE $\mathbf{t}_{\text{tensor}} \leftarrow \text{torch.full}((B,), t, \text{device}=\mathbf{z}_t.\text{device})$
|
| 366 |
+
|
| 367 |
+
\STATE \textbf{// Conditional prediction}
|
| 368 |
+
\STATE $\mathbf{c}_{\text{cond}} \leftarrow \text{torch.full}((B,), c, \text{dtype}=\text{torch.long})$
|
| 369 |
+
\STATE $\mathbf{v}_{\text{cond}} \leftarrow f_\theta(\mathbf{z}_t, \mathbf{t}_{\text{tensor}}, \mathbf{c}_{\text{cond}})$
|
| 370 |
+
|
| 371 |
+
\STATE \textbf{// Unconditional prediction}
|
| 372 |
+
\STATE $\mathbf{c}_{\text{uncond}} \leftarrow \text{torch.full}((B,), 2, \text{dtype}=\text{torch.long})$ \COMMENT{2 = mask}
|
| 373 |
+
\STATE $\mathbf{v}_{\text{uncond}} \leftarrow f_\theta(\mathbf{z}_t, \mathbf{t}_{\text{tensor}}, \mathbf{c}_{\text{uncond}})$
|
| 374 |
+
|
| 375 |
+
\STATE \textbf{// Apply classifier-free guidance}
|
| 376 |
+
\STATE $\mathbf{v}_{\text{guided}} \leftarrow \mathbf{v}_{\text{uncond}} + w \cdot (\mathbf{v}_{\text{cond}} - \mathbf{v}_{\text{uncond}})$
|
| 377 |
+
|
| 378 |
+
\STATE \textbf{// Reshape back to flat format for ODE solver}
|
| 379 |
+
\STATE $\mathbf{v}_{\text{guided}} \leftarrow \mathbf{v}_{\text{guided}}.\text{view}(-1)$
|
| 380 |
+
|
| 381 |
+
\RETURN $\mathbf{v}_{\text{guided}}$
|
| 382 |
+
\ENDFUNCTION
|
| 383 |
+
|
| 384 |
+
\STATE \textbf{// Main ODE integration call}
|
| 385 |
+
\STATE $\mathbf{v}_{\text{guided}} \leftarrow \text{CFGODEFunc}(t, \mathbf{z}_t)$
|
| 386 |
+
|
| 387 |
+
\RETURN $\mathbf{v}_{\text{guided}}$
|
| 388 |
+
\end{algorithmic}
|
| 389 |
+
\end{algorithm}
|
| 390 |
+
|
| 391 |
+
\begin{algorithm}[h]
|
| 392 |
+
\caption{Adaptive ODE Integration Methods}
|
| 393 |
+
\label{alg:adaptive_ode}
|
| 394 |
+
\begin{algorithmic}[1]
|
| 395 |
+
\REQUIRE Initial state $\mathbf{z}_0$, ODE function $f$, time span $[0, 1]$
|
| 396 |
+
\REQUIRE Integration parameters: tolerance $\text{tol} = 10^{-5}$, max steps $N_{\max} = 1000$
|
| 397 |
+
\ENSURE Final state $\mathbf{z}_1$
|
| 398 |
+
|
| 399 |
+
\FUNCTION{AdaptiveODEIntegration}{$\mathbf{z}_0$, $f$, $[t_0, t_1]$}
|
| 400 |
+
\STATE $\mathbf{z} \leftarrow \mathbf{z}_0$, $t \leftarrow t_0$, $\Delta t \leftarrow 0.01$ \COMMENT{Initial step size}
|
| 401 |
+
\STATE $\text{step\_count} \leftarrow 0$
|
| 402 |
+
|
| 403 |
+
\WHILE{$t < t_1$ and $\text{step\_count} < N_{\max}$}
|
| 404 |
+
\STATE \textbf{// Compute 4th and 5th order solutions}
|
| 405 |
+
\STATE $\mathbf{k}_1 \leftarrow f(t, \mathbf{z})$
|
| 406 |
+
\STATE $\mathbf{k}_2 \leftarrow f(t + \frac{\Delta t}{4}, \mathbf{z} + \frac{\Delta t}{4}\mathbf{k}_1)$
|
| 407 |
+
\STATE $\mathbf{k}_3 \leftarrow f(t + \frac{3\Delta t}{8}, \mathbf{z} + \frac{3\Delta t}{32}\mathbf{k}_1 + \frac{9\Delta t}{32}\mathbf{k}_2)$
|
| 408 |
+
\STATE $\mathbf{k}_4 \leftarrow f(t + \frac{12\Delta t}{13}, \mathbf{z} + \frac{1932\Delta t}{2197}\mathbf{k}_1 - \frac{7200\Delta t}{2197}\mathbf{k}_2 + \frac{7296\Delta t}{2197}\mathbf{k}_3)$
|
| 409 |
+
\STATE $\mathbf{k}_5 \leftarrow f(t + \Delta t, \mathbf{z} + \frac{439\Delta t}{216}\mathbf{k}_1 - 8\Delta t\mathbf{k}_2 + \frac{3680\Delta t}{513}\mathbf{k}_3 - \frac{845\Delta t}{4104}\mathbf{k}_4)$
|
| 410 |
+
\STATE $\mathbf{k}_6 \leftarrow f(t + \frac{\Delta t}{2}, \mathbf{z} - \frac{8\Delta t}{27}\mathbf{k}_1 + 2\Delta t\mathbf{k}_2 - \frac{3544\Delta t}{2565}\mathbf{k}_3 + \frac{1859\Delta t}{4104}\mathbf{k}_4 - \frac{11\Delta t}{40}\mathbf{k}_5)$
|
| 411 |
+
|
| 412 |
+
\STATE \textbf{// 4th order solution}
|
| 413 |
+
\STATE $\mathbf{z}_{\text{new}}^{(4)} \leftarrow \mathbf{z} + \Delta t(\frac{25}{216}\mathbf{k}_1 + \frac{1408}{2565}\mathbf{k}_3 + \frac{2197}{4104}\mathbf{k}_4 - \frac{1}{5}\mathbf{k}_5)$
|
| 414 |
+
|
| 415 |
+
\STATE \textbf{// 5th order solution}
|
| 416 |
+
\STATE $\mathbf{z}_{\text{new}}^{(5)} \leftarrow \mathbf{z} + \Delta t(\frac{16}{135}\mathbf{k}_1 + \frac{6656}{12825}\mathbf{k}_3 + \frac{28561}{56430}\mathbf{k}_4 - \frac{9}{50}\mathbf{k}_5 + \frac{2}{55}\mathbf{k}_6)$
|
| 417 |
+
|
| 418 |
+
\STATE \textbf{// Error estimation and step size adaptation}
|
| 419 |
+
\STATE $\text{error} \leftarrow \|\mathbf{z}_{\text{new}}^{(5)} - \mathbf{z}_{\text{new}}^{(4)}\|_2$
|
| 420 |
+
|
| 421 |
+
\IF{$\text{error} \leq \text{tol}$} \COMMENT{Accept step}
|
| 422 |
+
\STATE $\mathbf{z} \leftarrow \mathbf{z}_{\text{new}}^{(5)}$ \COMMENT{Use higher order solution}
|
| 423 |
+
\STATE $t \leftarrow t + \Delta t$
|
| 424 |
+
\STATE $\text{step\_count} \leftarrow \text{step\_count} + 1$
|
| 425 |
+
\ENDIF
|
| 426 |
+
|
| 427 |
+
\STATE \textbf{// Adapt step size}
|
| 428 |
+
\STATE $\text{safety\_factor} \leftarrow 0.9$
|
| 429 |
+
\STATE $\text{scale} \leftarrow \text{safety\_factor} \cdot \left(\frac{\text{tol}}{\text{error}}\right)^{1/5}$
|
| 430 |
+
\STATE $\Delta t \leftarrow \Delta t \cdot \min(2.0, \max(0.5, \text{scale}))$
|
| 431 |
+
|
| 432 |
+
\STATE \textbf{// Ensure we don't overshoot}
|
| 433 |
+
\IF{$t + \Delta t > t_1$}
|
| 434 |
+
\STATE $\Delta t \leftarrow t_1 - t$
|
| 435 |
+
\ENDIF
|
| 436 |
+
\ENDWHILE
|
| 437 |
+
|
| 438 |
+
\RETURN $\mathbf{z}$
|
| 439 |
+
\ENDFUNCTION
|
| 440 |
+
|
| 441 |
+
\STATE $\mathbf{z}_1 \leftarrow \text{AdaptiveODEIntegration}(\mathbf{z}_0, \text{CFGODEFunc}, [0, 1])$
|
| 442 |
+
|
| 443 |
+
\RETURN $\mathbf{z}_1$
|
| 444 |
+
\end{algorithmic}
|
| 445 |
+
\end{algorithm>
|