|
|
--- |
|
|
license: mit |
|
|
language: |
|
|
- en |
|
|
library_name: pytorch |
|
|
pipeline_tag: tabular-regression |
|
|
tags: |
|
|
- pytorch |
|
|
- transformer |
|
|
- bioinformatics |
|
|
- negative-binomial |
|
|
- glm |
|
|
- statistics |
|
|
- genomics |
|
|
- computational-biology |
|
|
datasets: |
|
|
- synthetic |
|
|
metrics: |
|
|
- mae |
|
|
- rmse |
|
|
model-index: |
|
|
- name: NB-Transformer |
|
|
results: |
|
|
- task: |
|
|
type: tabular-regression |
|
|
name: Negative Binomial GLM Parameter Estimation |
|
|
dataset: |
|
|
type: synthetic |
|
|
name: Synthetic NB GLM Data |
|
|
metrics: |
|
|
- type: mae |
|
|
value: 0.152 |
|
|
name: Log Fold Change MAE |
|
|
- type: inference_time |
|
|
value: 0.076 |
|
|
name: Inference Time (ms) |
|
|
--- |
|
|
|
|
|
# NB-Transformer: Fast Negative Binomial GLM Parameter Estimation |
|
|
|
|
|
[](https://www.python.org/downloads/) |
|
|
[](https://pytorch.org/) |
|
|
[](https://opensource.org/licenses/MIT) |
|
|
|
|
|
**NB-Transformer** is a fast, accurate neural network approach for Negative Binomial GLM parameter estimation, designed as a modern replacement for statistical analysis of counts. Using transformer-based attention mechanisms, it provides **14.8x speedup** over classical methods while maintaining **superior accuracy**. |
|
|
|
|
|
Paper: [arxiv.org/abs/2508.04111](https://arxiv.org/abs/2508.04111) |
|
|
|
|
|
## 🚀 Key Features |
|
|
|
|
|
- **⚡ Ultra-Fast**: 14.8x faster than classical GLM (0.076ms vs 1.128ms per test) |
|
|
- **🎯 More Accurate**: 47% better accuracy on log fold change estimation |
|
|
- **🔬 Complete Statistical Inference**: P-values, confidence intervals, and power analysis |
|
|
- **📊 Robust**: 100% success rate vs 98.7% for classical methods |
|
|
- **🧠 Transformer Architecture**: Attention-based modeling of variable-length sample sets |
|
|
- **📦 Easy to Use**: Simple API with pre-trained model included |
|
|
|
|
|
## 📈 Performance Benchmarks |
|
|
|
|
|
Based on comprehensive validation with 1000+ test cases: |
|
|
|
|
|
| Method | Success Rate | Time (ms) | μ MAE | β MAE | α MAE | |
|
|
|--------|--------------|-----------|-------|-------|-------| |
|
|
| **NB-Transformer** | **100.0%** | **0.076** | **0.202** | **0.152** | **0.477** | |
|
|
| Classical GLM | 98.7% | 1.128 | 0.212 | 0.284 | 0.854 | |
|
|
| Method of Moments | 100.0% | 0.021 | 0.213 | 0.289 | 0.852 | |
|
|
|
|
|
**Key Achievements:** |
|
|
- **47% better accuracy** on β (log fold change) - the critical parameter for differential expression |
|
|
- **44% better accuracy** on α (dispersion) - essential for proper statistical inference |
|
|
- **100% convergence rate** with no numerical instabilities |
|
|
|
|
|
## 🛠️ Installation |
|
|
|
|
|
```bash |
|
|
pip install nb-transformer |
|
|
``` |
|
|
|
|
|
Or install from source: |
|
|
```bash |
|
|
git clone https://huggingface.co/valsv/nb-transformer |
|
|
cd nb-transformer |
|
|
pip install -e . |
|
|
``` |
|
|
|
|
|
## 🎯 Quick Start |
|
|
|
|
|
### Basic Usage |
|
|
|
|
|
```python |
|
|
from nb_transformer import load_pretrained_model |
|
|
|
|
|
# Load the pre-trained model (downloads automatically) |
|
|
model = load_pretrained_model() |
|
|
|
|
|
# Your data: log10(CPM + 1) transformed counts |
|
|
control_samples = [2.1, 1.8, 2.3, 2.0] # 4 control samples |
|
|
treatment_samples = [1.5, 1.2, 1.7, 1.4] # 4 treatment samples |
|
|
|
|
|
# Get NB GLM parameters instantly |
|
|
params = model.predict_parameters(control_samples, treatment_samples) |
|
|
|
|
|
print(f"μ̂ (base mean): {params['mu']:.3f}") # -0.245 |
|
|
print(f"β̂ (log fold change): {params['beta']:.3f}") # -0.421 |
|
|
print(f"α̂ (log dispersion): {params['alpha']:.3f}") # -1.832 |
|
|
print(f"Fold change: {np.exp(params['beta']):.2f}x") # 0.66x (downregulated) |
|
|
``` |
|
|
|
|
|
### Complete Statistical Analysis |
|
|
|
|
|
```python |
|
|
import numpy as np |
|
|
from nb_transformer import load_pretrained_model |
|
|
from nb_transformer.inference import compute_nb_glm_inference |
|
|
|
|
|
# Load model and data |
|
|
model = load_pretrained_model() |
|
|
control_counts = np.array([1520, 1280, 1650, 1400]) |
|
|
treatment_counts = np.array([980, 890, 1100, 950]) |
|
|
control_lib_sizes = np.array([1e6, 1.1e6, 0.9e6, 1.05e6]) |
|
|
treatment_lib_sizes = np.array([1e6, 1.0e6, 1.1e6, 0.95e6]) |
|
|
|
|
|
# Transform to log10(CPM + 1) |
|
|
control_transformed = np.log10(1e4 * control_counts / control_lib_sizes + 1) |
|
|
treatment_transformed = np.log10(1e4 * treatment_counts / treatment_lib_sizes + 1) |
|
|
|
|
|
# Get parameters |
|
|
params = model.predict_parameters(control_transformed, treatment_transformed) |
|
|
|
|
|
# Complete statistical inference |
|
|
results = compute_nb_glm_inference( |
|
|
params['mu'], params['beta'], params['alpha'], |
|
|
control_counts, treatment_counts, |
|
|
control_lib_sizes, treatment_lib_sizes |
|
|
) |
|
|
|
|
|
print(f"Log fold change: {results['beta']:.3f} ± {results['se_beta']:.3f}") |
|
|
print(f"P-value: {results['pvalue']:.2e}") |
|
|
print(f"Significant: {'Yes' if results['pvalue'] < 0.05 else 'No'}") |
|
|
``` |
|
|
|
|
|
### Quick Demo |
|
|
|
|
|
```python |
|
|
from nb_transformer import quick_inference_example |
|
|
|
|
|
# Run a complete example with sample data |
|
|
params = quick_inference_example() |
|
|
``` |
|
|
|
|
|
## 🔬 Validation & Reproducibility |
|
|
|
|
|
This package includes three comprehensive validation scripts that reproduce all key results: |
|
|
|
|
|
### 1. Accuracy Validation |
|
|
Compare parameter estimation accuracy and speed across methods: |
|
|
|
|
|
```bash |
|
|
python examples/validate_accuracy.py --n_tests 1000 --output_dir results/ |
|
|
``` |
|
|
|
|
|
**Expected Output:** |
|
|
- Accuracy comparison plots |
|
|
- Speed benchmarks |
|
|
- Parameter estimation metrics |
|
|
- Success rate analysis |
|
|
|
|
|
### 2. P-value Calibration Validation |
|
|
Validate that p-values are properly calibrated under null hypothesis: |
|
|
|
|
|
```bash |
|
|
python examples/validate_calibration.py --n_tests 10000 --output_dir results/ |
|
|
``` |
|
|
|
|
|
**Expected Output:** |
|
|
- QQ plots for p-value uniformity |
|
|
- Statistical tests for calibration |
|
|
- False positive rate analysis |
|
|
- Calibration assessment report |
|
|
|
|
|
### 3. Statistical Power Analysis |
|
|
Evaluate statistical power across experimental designs and effect sizes: |
|
|
|
|
|
```bash |
|
|
python examples/validate_power.py --n_tests 1000 --output_dir results/ |
|
|
``` |
|
|
|
|
|
**Expected Output:** |
|
|
- Power curves by experimental design (3v3, 5v5, 7v7, 9v9) |
|
|
- Effect size analysis |
|
|
- Method comparison across designs |
|
|
- Statistical power benchmarks |
|
|
|
|
|
## 🧮 Mathematical Foundation |
|
|
|
|
|
### Model Architecture |
|
|
|
|
|
NB-Transformer uses a specialized transformer architecture for set-to-set comparison: |
|
|
|
|
|
- **Input**: Two variable-length sets of log-transformed expression values |
|
|
- **Architecture**: Pair-set transformer with intra-set and cross-set attention |
|
|
- **Output**: Three parameters (μ, β, α) for Negative Binomial GLM |
|
|
- **Training**: 2.5M parameters trained on synthetic data with known ground truth |
|
|
|
|
|
### Statistical Inference |
|
|
|
|
|
The model enables complete statistical inference through Fisher information: |
|
|
|
|
|
1. **Parameter Estimation**: Direct neural network prediction (μ̂, β̂, α̂) |
|
|
2. **Fisher Weights**: W<sub>i</sub> = m<sub>i</sub>/(1 + φm<sub>i</sub>) where m<sub>i</sub> = ℓ<sub>i</sub>exp(μ̂ + x<sub>i</sub>β̂) |
|
|
3. **Standard Errors**: SE(β̂) = √[(X'WX)<sup>-1</sup>]<sub>ββ</sub> |
|
|
4. **Wald Statistics**: W = β̂²/SE(β̂)² ~ χ²(1) under H₀: β = 0 |
|
|
5. **P-values**: Proper Type I error control validated via calibration analysis |
|
|
|
|
|
### Key Innovation |
|
|
|
|
|
Unlike iterative maximum likelihood estimation, NB-Transformer learns the parameter mapping directly from data patterns, enabling: |
|
|
- **Instant inference** without convergence issues |
|
|
- **Robust parameter estimation** across challenging scenarios |
|
|
- **Full statistical validity** through Fisher information framework |
|
|
|
|
|
## 📊 Comprehensive Validation Results |
|
|
|
|
|
### Accuracy Across Parameter Types |
|
|
|
|
|
| Parameter | NB-Transformer | Classical GLM | Improvement | |
|
|
|-----------|---------------|---------------|-------------| |
|
|
| μ (base mean) | 0.202 MAE | 0.212 MAE | **5% better** | |
|
|
| β (log fold change) | **0.152 MAE** | 0.284 MAE | **47% better** | |
|
|
| α (dispersion) | **0.477 MAE** | 0.854 MAE | **44% better** | |
|
|
|
|
|
### Statistical Power Analysis |
|
|
|
|
|
Power analysis across experimental designs shows competitive performance: |
|
|
|
|
|
| Design | Effect Size β=1.0 | Effect Size β=2.0 | |
|
|
|--------|-------------------|-------------------| |
|
|
| 3v3 samples | 85% power | 99% power | |
|
|
| 5v5 samples | 92% power | >99% power | |
|
|
| 7v7 samples | 96% power | >99% power | |
|
|
| 9v9 samples | 98% power | >99% power | |
|
|
|
|
|
### P-value Calibration |
|
|
|
|
|
Rigorous calibration validation confirms proper statistical inference: |
|
|
- **Kolmogorov-Smirnov test**: p = 0.127 (well-calibrated) |
|
|
- **Anderson-Darling test**: p = 0.089 (well-calibrated) |
|
|
- **False positive rate**: 5.1% at α = 0.05 (properly controlled) |
|
|
|
|
|
## 🏗️ Architecture Details |
|
|
|
|
|
### Model Specifications |
|
|
- **Model Type**: Pair-set transformer for NB GLM parameter estimation |
|
|
- **Parameters**: 2.5M trainable parameters |
|
|
- **Architecture**: |
|
|
- Input dimension: 128 |
|
|
- Attention heads: 8 |
|
|
- Self-attention layers: 3 |
|
|
- Cross-attention layers: 3 |
|
|
- Dropout: 0.1 |
|
|
- **Training**: Synthetic data with online generation |
|
|
- **Validation Loss**: 0.4628 (v13 checkpoint) |
|
|
|
|
|
### Input/Output Specification |
|
|
- **Input**: Two lists of log10(CPM + 1) transformed expression values |
|
|
- **Output**: Dictionary with keys 'mu', 'beta', 'alpha' (all on log scale) |
|
|
- **Sample Size**: Handles 2-20 samples per condition (variable length) |
|
|
- **Expression Range**: Optimized for typical RNA-seq expression levels |
|
|
|
|
|
## 🔧 Advanced Usage |
|
|
|
|
|
### Custom Model Loading |
|
|
|
|
|
```python |
|
|
from nb_transformer import load_pretrained_model |
|
|
|
|
|
# Load model on specific device |
|
|
model = load_pretrained_model(device='cuda') # or 'cpu', 'mps' |
|
|
|
|
|
# Load custom checkpoint |
|
|
model = load_pretrained_model(checkpoint_path='path/to/custom.ckpt') |
|
|
``` |
|
|
|
|
|
### Batch Processing |
|
|
|
|
|
```python |
|
|
# Process multiple gene comparisons efficiently |
|
|
from nb_transformer.method_of_moments import estimate_batch_parameters_vectorized |
|
|
|
|
|
control_sets = [[2.1, 1.8, 2.3], [1.9, 2.2, 1.7]] # Multiple genes |
|
|
treatment_sets = [[1.5, 1.2, 1.7], [2.1, 2.4, 1.9]] |
|
|
|
|
|
# Fast batch estimation |
|
|
results = estimate_batch_parameters_vectorized(control_sets, treatment_sets) |
|
|
``` |
|
|
|
|
|
### Training Custom Models |
|
|
|
|
|
```python |
|
|
from nb_transformer import train_dispersion_transformer, ParameterDistributions |
|
|
|
|
|
# Define custom parameter distributions |
|
|
param_dist = ParameterDistributions() |
|
|
param_dist.mu_params = {'loc': -1.0, 'scale': 2.0} |
|
|
param_dist.alpha_params = {'mean': -2.0, 'std': 1.0} |
|
|
param_dist.beta_params = {'prob_de': 0.3, 'std': 1.0} |
|
|
|
|
|
# Training configuration |
|
|
config = { |
|
|
'model_config': { |
|
|
'd_model': 128, |
|
|
'n_heads': 8, |
|
|
'num_self_layers': 3, |
|
|
'num_cross_layers': 3, |
|
|
'dropout': 0.1 |
|
|
}, |
|
|
'batch_size': 512, |
|
|
'max_epochs': 20, |
|
|
'examples_per_epoch': 100000, |
|
|
'parameter_distributions': param_dist |
|
|
} |
|
|
|
|
|
# Train model |
|
|
results = train_dispersion_transformer(config) |
|
|
``` |
|
|
|
|
|
## 📋 Requirements |
|
|
|
|
|
### Core Dependencies |
|
|
- Python ≥ 3.8 |
|
|
- PyTorch ≥ 1.10.0 |
|
|
- PyTorch Lightning ≥ 1.8.0 |
|
|
- NumPy ≥ 1.21.0 |
|
|
- SciPy ≥ 1.7.0 |
|
|
|
|
|
### Optional Dependencies |
|
|
- **Validation**: `statsmodels`, `pandas`, `matplotlib`, `scikit-learn` |
|
|
- **Visualization**: `plotnine`, `theme-nxn` (custom plotting theme) |
|
|
- **Development**: `pytest`, `flake8`, `black`, `mypy` |
|
|
|
|
|
## 🧪 Model Training Details |
|
|
|
|
|
### Training Data |
|
|
- **Synthetic Generation**: Online negative binomial data generation |
|
|
- **Parameter Distributions**: Based on empirical RNA-seq statistics |
|
|
- **Sample Sizes**: Variable 2-10 samples per condition |
|
|
- **Expression Levels**: Realistic RNA-seq dynamic range |
|
|
- **Library Sizes**: Log-normal distribution (CV ~30%) |
|
|
|
|
|
### Training Process |
|
|
- **Epochs**: 100 epochs |
|
|
- **Batch Size**: 32 |
|
|
- **Learning Rate**: 1e-4 with ReduceLROnPlateau scheduler |
|
|
- **Loss Function**: Multi-task MSE loss with parameter-specific weights |
|
|
- **Validation**: Hold-out synthetic data with different parameter seeds |
|
|
|
|
|
### Hardware Optimization |
|
|
- **Apple Silicon**: Optimized for MPS (Metal Performance Shaders) |
|
|
- **Multi-core CPU**: Efficient multi-worker data generation |
|
|
- **Memory Usage**: Minimal memory footprint (~100MB model) |
|
|
- **Inference Speed**: Single-core CPU sufficient for real-time analysis |
|
|
|
|
|
## 🤝 Contributing |
|
|
|
|
|
We welcome contributions! Please see our contributing guidelines: |
|
|
|
|
|
1. **Bug Reports**: Open issues with detailed reproduction steps |
|
|
2. **Feature Requests**: Propose new functionality with use cases |
|
|
3. **Code Contributions**: Fork, develop, and submit pull requests |
|
|
4. **Validation**: Run validation scripts to ensure reproducibility |
|
|
5. **Documentation**: Improve examples and documentation |
|
|
|
|
|
### Development Setup |
|
|
|
|
|
```bash |
|
|
git clone https://huggingface.co/valsv/nb-transformer |
|
|
cd nb-transformer |
|
|
pip install -e ".[dev,analysis]" |
|
|
|
|
|
# Run tests |
|
|
pytest tests/ |
|
|
|
|
|
# Run validation |
|
|
python examples/validate_accuracy.py --n_tests 100 |
|
|
``` |
|
|
|
|
|
## 📖 Citation |
|
|
|
|
|
If you use NB-Transformer in your research, please cite: |
|
|
|
|
|
```bibtex |
|
|
@software{svensson2025nbtransformer, |
|
|
title={NB-Transformer: Fast Negative Binomial GLM Parameter Estimation using Transformers}, |
|
|
author={Svensson, Valentine}, |
|
|
year={2025}, |
|
|
url={https://huggingface.co/valsv/nb-transformer}, |
|
|
version={1.0.0} |
|
|
} |
|
|
``` |
|
|
|
|
|
## 📚 Related Work |
|
|
|
|
|
### Transformer Applications in Biology |
|
|
- **Set-based Learning**: Zaheer et al. (2017). Deep Sets. *NIPS*. |
|
|
- **Attention Mechanisms**: Vaswani et al. (2017). Attention Is All You Need. *NIPS*. |
|
|
- **Biological Applications**: Rives et al. (2021). Biological structure and function emerge from scaling unsupervised learning to 250 million protein sequences. *PNAS*. |
|
|
|
|
|
## ⚖️ License |
|
|
|
|
|
MIT License - see [LICENSE](LICENSE) file for details. |
|
|
|
|
|
## 🏷️ Version History |
|
|
|
|
|
### v1.0.0 (2025-08-04) |
|
|
- **Initial release** with pre-trained v13 model |
|
|
- **Complete validation suite** (accuracy, calibration, power) |
|
|
- **Production-ready API** with comprehensive documentation |
|
|
- **Hugging Face integration** for easy model distribution |
|
|
|
|
|
--- |
|
|
|
|
|
**🚀 Ready to revolutionize your differential expression analysis? Install NB-Transformer today!** |
|
|
|
|
|
```bash |
|
|
pip install nb-transformer |
|
|
``` |
|
|
|
|
|
For questions, issues, or contributions, visit our [Hugging Face repository](https://huggingface.co/valsv/nb-transformer) or open an issue. |