stivenDR14 commited on
Commit Β·
5c8d855
1
Parent(s): ae932ad
feat: Introduce audio captioning and categorization model with ONNX/ExecuTorch hybrid inference and category embedding generation.
Browse files- .gitignore +56 -0
- README.md +212 -3
- audio-caption/effb2_decoder_5sec.pte +3 -0
- audio-caption/effb2_encoder_preprocess-2.onnx +3 -0
- audio-caption/export_decoder_executorch.py +243 -0
- audio-caption/export_encoder_preprocess_onnx.py +201 -0
- audio-caption/generate_caption_hybrid.py +130 -0
- categories.json +54 -0
- pyproject.toml +25 -0
- sentence-transformers-embbedings/category_embeddings.json +0 -0
- sentence-transformers-embbedings/export_sentence_transformers_executorch.py +138 -0
- sentence-transformers-embbedings/generate_category_embeddings.py +101 -0
- sentence-transformers-embbedings/sentence_transformers_minilm.pte +3 -0
- uv.lock +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# System files
|
| 2 |
+
.DS_Store
|
| 3 |
+
.DS_Store?
|
| 4 |
+
._*
|
| 5 |
+
.Spotlight-V100
|
| 6 |
+
.Trashes
|
| 7 |
+
ehthumbs.db
|
| 8 |
+
Thumbs.db
|
| 9 |
+
|
| 10 |
+
# Environment variables
|
| 11 |
+
.env
|
| 12 |
+
.env.local
|
| 13 |
+
.env.*.local
|
| 14 |
+
|
| 15 |
+
# Python
|
| 16 |
+
__pycache__/
|
| 17 |
+
*.py[cod]
|
| 18 |
+
*$py.class
|
| 19 |
+
.venv/
|
| 20 |
+
venv/
|
| 21 |
+
ENV/
|
| 22 |
+
env/
|
| 23 |
+
.Python
|
| 24 |
+
build/
|
| 25 |
+
develop-eggs/
|
| 26 |
+
dist/
|
| 27 |
+
downloads/
|
| 28 |
+
eggs/
|
| 29 |
+
.eggs/
|
| 30 |
+
lib/
|
| 31 |
+
lib64/
|
| 32 |
+
parts/
|
| 33 |
+
sdist/
|
| 34 |
+
var/
|
| 35 |
+
wheels/
|
| 36 |
+
share/python-wheels/
|
| 37 |
+
*.egg-info/
|
| 38 |
+
.installed.cfg
|
| 39 |
+
*.egg
|
| 40 |
+
MANIFEST
|
| 41 |
+
_temp/
|
| 42 |
+
|
| 43 |
+
# Testing and Code Quality
|
| 44 |
+
.mypy_cache/
|
| 45 |
+
.pytest_cache/
|
| 46 |
+
.coverage
|
| 47 |
+
htmlcov/
|
| 48 |
+
.tox/
|
| 49 |
+
.nox/
|
| 50 |
+
|
| 51 |
+
# IDEs
|
| 52 |
+
.idea/
|
| 53 |
+
.vscode/
|
| 54 |
+
*.swp
|
| 55 |
+
*.swo
|
| 56 |
+
|
README.md
CHANGED
|
@@ -1,3 +1,212 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: apache-2.0
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- audio
|
| 5 |
+
- audio-classification
|
| 6 |
+
- audio-captioning
|
| 7 |
+
- onnx
|
| 8 |
+
- executorch
|
| 9 |
+
- mobile
|
| 10 |
+
- arm
|
| 11 |
+
language:
|
| 12 |
+
- en
|
| 13 |
+
pipeline_tag: audio-classification
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
# Audio Caption and Categorizer Models
|
| 17 |
+
|
| 18 |
+
## Model Description
|
| 19 |
+
|
| 20 |
+
This repository provides **optimized exports** of audio captioning and categorization models for **ARM-based mobile deployment**. The pipeline consists of:
|
| 21 |
+
|
| 22 |
+
1. **Audio Captioning**: Uses [`wsntxxn/effb2-trm-audiocaps-captioning`](https://huggingface.co/wsntxxn/effb2-trm-audiocaps-captioning) (EfficientNet-B2 encoder + Transformer decoder) to generate natural language descriptions of audio events.
|
| 23 |
+
|
| 24 |
+
2. **Audio Categorization**: Uses [`sentence-transformers/all-MiniLM-L6-v2`](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) to match generated captions to predefined sound categories via semantic similarity.
|
| 25 |
+
|
| 26 |
+
### Export Formats
|
| 27 |
+
- **Encoder**: ONNX format with integrated preprocessing (STFT, MelSpectrogram, AmplitudeToDB)
|
| 28 |
+
- **Decoder**: ExecuTorch (`.pte`) format with dynamic quantization for reduced model size
|
| 29 |
+
- **Categorizer**: ExecuTorch (`.pte`) format with quantization
|
| 30 |
+
|
| 31 |
+
### Key Features
|
| 32 |
+
- 5-second audio input at 16kHz
|
| 33 |
+
- Preprocessing baked into ONNX encoder (no external audio processing needed)
|
| 34 |
+
- Optimized for mobile inference with quantization
|
| 35 |
+
- Complete end-to-end pipeline from raw audio to categorized captions
|
| 36 |
+
|
| 37 |
+
## Usage
|
| 38 |
+
|
| 39 |
+
### Quick Start
|
| 40 |
+
|
| 41 |
+
Generate a caption for an audio file:
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
# Activate environment
|
| 45 |
+
source .venv/bin/activate
|
| 46 |
+
|
| 47 |
+
# Generate caption
|
| 48 |
+
python audio-caption/generate_caption_hybrid.py --audio sample_audio.wav
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### Python Example
|
| 52 |
+
|
| 53 |
+
```python
|
| 54 |
+
import onnxruntime as ort
|
| 55 |
+
from executorch.extension.pybindings.portable_lib import _load_for_executorch
|
| 56 |
+
from transformers import AutoTokenizer
|
| 57 |
+
import numpy as np
|
| 58 |
+
|
| 59 |
+
# Load models
|
| 60 |
+
encoder_session = ort.InferenceSession("audio-caption/effb2_encoder_preprocess.onnx")
|
| 61 |
+
decoder = _load_for_executorch("audio-caption/effb2_decoder_5sec.pte")
|
| 62 |
+
tokenizer = AutoTokenizer.from_pretrained("wsntxxn/audiocaps-simple-tokenizer", trust_remote_code=True)
|
| 63 |
+
|
| 64 |
+
# Process audio (16kHz, 5 seconds = 80000 samples)
|
| 65 |
+
audio = np.random.randn(1, 80000).astype(np.float32)
|
| 66 |
+
|
| 67 |
+
# Encode
|
| 68 |
+
attn_emb = encoder_session.run(["attn_emb"], {"audio": audio})[0]
|
| 69 |
+
|
| 70 |
+
# Decode (greedy search)
|
| 71 |
+
generated = [tokenizer.bos_token_id]
|
| 72 |
+
for _ in range(30):
|
| 73 |
+
logits = decoder.forward((
|
| 74 |
+
torch.tensor([generated]),
|
| 75 |
+
torch.tensor(attn_emb),
|
| 76 |
+
torch.tensor([attn_emb.shape[1] - 1])
|
| 77 |
+
))[0]
|
| 78 |
+
next_token = int(torch.argmax(logits[0, -1, :]))
|
| 79 |
+
generated.append(next_token)
|
| 80 |
+
if next_token == tokenizer.eos_token_id:
|
| 81 |
+
break
|
| 82 |
+
|
| 83 |
+
caption = tokenizer.decode(generated, skip_special_tokens=True)
|
| 84 |
+
print(caption)
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
## Training Details
|
| 90 |
+
|
| 91 |
+
### Base Models
|
| 92 |
+
|
| 93 |
+
This repository does **not train models** but exports pre-trained models to optimized formats:
|
| 94 |
+
|
| 95 |
+
| Component | Base Model | Training Dataset | Parameters |
|
| 96 |
+
|-----------|------------|------------------|------------|
|
| 97 |
+
| Audio Encoder | EfficientNet-B2 | AudioCaps | ~7.7M |
|
| 98 |
+
| Caption Decoder | Transformer (2 layers) | AudioCaps | ~4.3M |
|
| 99 |
+
| Categorizer | all-MiniLM-L6-v2 | 1B+ sentence pairs | ~22.7M |
|
| 100 |
+
|
| 101 |
+
### Export Configuration
|
| 102 |
+
|
| 103 |
+
**Audio Captioning**:
|
| 104 |
+
- **Preprocessing**: `n_mels=64`, `n_fft=512`, `hop_length=160`, `win_length=512`
|
| 105 |
+
- **Input**: Raw audio waveform (16kHz, 5 seconds)
|
| 106 |
+
- **Encoder**: ONNX opset 17 with dynamic axes
|
| 107 |
+
- **Decoder**: ExecuTorch with dynamic quantization (int8)
|
| 108 |
+
|
| 109 |
+
**Categorizer**:
|
| 110 |
+
- **Tokenizer**: RoBERTa-based (max length: 128)
|
| 111 |
+
- **Export**: ExecuTorch with dynamic quantization
|
| 112 |
+
- **Categories**: 50+ predefined audio event categories
|
| 113 |
+
|
| 114 |
+
### Quantization Impact
|
| 115 |
+
|
| 116 |
+
| Model | Original Size | Quantized Size | Quality Impact |
|
| 117 |
+
|-------|---------------|----------------|----------------|
|
| 118 |
+
| Decoder | ~17MB | ~15MB | Minimal (<2% caption quality) |
|
| 119 |
+
| Categorizer | ~90MB | ~23MB | Minimal (<1% accuracy) |
|
| 120 |
+
|
| 121 |
+
## Project Structure
|
| 122 |
+
|
| 123 |
+
```
|
| 124 |
+
.
|
| 125 |
+
βββ audio-caption/
|
| 126 |
+
β βββ export_encoder_preprocess_onnx.py # Export ONNX encoder
|
| 127 |
+
β βββ export_decoder_executorch.py # Export ExecuTorch decoder
|
| 128 |
+
β βββ generate_caption_hybrid.py # Inference pipeline
|
| 129 |
+
β βββ effb2_encoder_preprocess.onnx # Exported encoder
|
| 130 |
+
β βββ effb2_decoder_5sec.pte # Exported decoder
|
| 131 |
+
β
|
| 132 |
+
βββ sentence-transformers-embbedings/
|
| 133 |
+
β βββ export_sentence_transformers_executorch.py
|
| 134 |
+
β βββ generate_category_embeddings.py
|
| 135 |
+
β βββ category_embeddings.json
|
| 136 |
+
β
|
| 137 |
+
βββ categories.json # Category definitions
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
## Setup
|
| 141 |
+
|
| 142 |
+
### Prerequisites
|
| 143 |
+
|
| 144 |
+
```bash
|
| 145 |
+
# Install uv package manager
|
| 146 |
+
pip install uv
|
| 147 |
+
|
| 148 |
+
# Create environment
|
| 149 |
+
uv venv
|
| 150 |
+
source .venv/bin/activate
|
| 151 |
+
|
| 152 |
+
# Install dependencies
|
| 153 |
+
uv pip install -r pyproject.toml
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
### Configuration
|
| 157 |
+
|
| 158 |
+
Create a `.env` file:
|
| 159 |
+
|
| 160 |
+
```ini
|
| 161 |
+
# Hugging Face Token (for gated models)
|
| 162 |
+
HF_TOKEN=your_token_here
|
| 163 |
+
|
| 164 |
+
# Optional: Custom cache directory
|
| 165 |
+
# HF_HOME=./.cache/huggingface
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
### Export Models
|
| 169 |
+
|
| 170 |
+
```bash
|
| 171 |
+
# Export audio captioning models
|
| 172 |
+
python audio-caption/export_encoder_preprocess_onnx.py
|
| 173 |
+
python audio-caption/export_decoder_executorch.py
|
| 174 |
+
|
| 175 |
+
# Export categorization model
|
| 176 |
+
python sentence-transformers-embbedings/export_sentence_transformers_executorch.py
|
| 177 |
+
|
| 178 |
+
# Generate category embeddings
|
| 179 |
+
python sentence-transformers-embbedings/generate_category_embeddings.py
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
## License
|
| 183 |
+
|
| 184 |
+
Apache License 2.0
|
| 185 |
+
|
| 186 |
+
## Citations
|
| 187 |
+
|
| 188 |
+
### Audio Captioning Model
|
| 189 |
+
|
| 190 |
+
```bibtex
|
| 191 |
+
@inproceedings{xu2024efficient,
|
| 192 |
+
title={Efficient Audio Captioning with Encoder-Level Knowledge Distillation},
|
| 193 |
+
author={Xu, Xuenan and Liu, Haohe and Wu, Mengyue and Wang, Wenwu and Plumbley, Mark D.},
|
| 194 |
+
booktitle={Interspeech 2024},
|
| 195 |
+
year={2024},
|
| 196 |
+
doi={10.48550/arXiv.2407.14329},
|
| 197 |
+
url={https://arxiv.org/abs/2407.14329}
|
| 198 |
+
}
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
### Sentence Transformer
|
| 202 |
+
|
| 203 |
+
```bibtex
|
| 204 |
+
@inproceedings{reimers-2019-sentence-bert,
|
| 205 |
+
title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
|
| 206 |
+
author = "Reimers, Nils and Gurevych, Iryna",
|
| 207 |
+
booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
|
| 208 |
+
year = "2019",
|
| 209 |
+
publisher = "Association for Computational Linguistics",
|
| 210 |
+
url = "https://arxiv.org/abs/1908.10084",
|
| 211 |
+
}
|
| 212 |
+
```
|
audio-caption/effb2_decoder_5sec.pte
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:319fbb6363ba11fa13b2e0a2bc7b97cdc8526208cfa79a1cc7a65b6f683a91d0
|
| 3 |
+
size 15144068
|
audio-caption/effb2_encoder_preprocess-2.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ae814c75c799de5717308ad63672f282619d021f2f394c84aaf264044bb298bf
|
| 3 |
+
size 30925938
|
audio-caption/export_decoder_executorch.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Export decoder to ExecuTorch .pte format as an alternative to ONNX.
|
| 3 |
+
This might handle dynamic sequence lengths better.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import argparse
|
| 8 |
+
from transformers import AutoModel, AutoTokenizer
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
|
| 11 |
+
load_dotenv()
|
| 12 |
+
|
| 13 |
+
def main():
|
| 14 |
+
parser = argparse.ArgumentParser()
|
| 15 |
+
parser.add_argument("--model", default="wsntxxn/effb2-trm-audiocaps-captioning")
|
| 16 |
+
parser.add_argument("--out", default="effb2_decoder_step.pte")
|
| 17 |
+
args = parser.parse_args()
|
| 18 |
+
|
| 19 |
+
print(f"Loading model: {args.model}")
|
| 20 |
+
model = AutoModel.from_pretrained(args.model, trust_remote_code=True)
|
| 21 |
+
model.eval()
|
| 22 |
+
|
| 23 |
+
# Get decoder - navigate through the model structure
|
| 24 |
+
# Based on inspection: model.model.model.decoder
|
| 25 |
+
if hasattr(model, "model") and hasattr(model.model, "model") and hasattr(model.model.model, "decoder"):
|
| 26 |
+
decoder = model.model.model.decoder
|
| 27 |
+
encoder = model.model.model.encoder
|
| 28 |
+
print(f"Found decoder at model.model.model.decoder")
|
| 29 |
+
elif hasattr(model, "model") and hasattr(model.model, "decoder"):
|
| 30 |
+
decoder = model.model.decoder
|
| 31 |
+
encoder = model.model.encoder
|
| 32 |
+
print(f"Found decoder at model.model.decoder")
|
| 33 |
+
else:
|
| 34 |
+
# Try to find by iterating
|
| 35 |
+
for name, module in model.named_modules():
|
| 36 |
+
if "decoder" in name.lower() and "TransformerDecoder" in module.__class__.__name__:
|
| 37 |
+
decoder = module
|
| 38 |
+
print(f"Found decoder at {name}")
|
| 39 |
+
break
|
| 40 |
+
else:
|
| 41 |
+
raise RuntimeError("Could not find decoder in model")
|
| 42 |
+
|
| 43 |
+
print(f"Decoder: {decoder.__class__.__name__}")
|
| 44 |
+
|
| 45 |
+
# Wrap decoder similar to ONNX version
|
| 46 |
+
class DecoderStepWrapper(torch.nn.Module):
|
| 47 |
+
def __init__(self, decoder, vocab_size):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.decoder = decoder
|
| 50 |
+
self.vocab_size = vocab_size
|
| 51 |
+
|
| 52 |
+
def forward(self, word_ids, attn_emb, attn_emb_len):
|
| 53 |
+
"""
|
| 54 |
+
Args:
|
| 55 |
+
word_ids: (batch, seq_len)
|
| 56 |
+
attn_emb: (batch, time, dim)
|
| 57 |
+
attn_emb_len: (batch,)
|
| 58 |
+
Returns:
|
| 59 |
+
logits: (batch, seq_len, vocab_size)
|
| 60 |
+
"""
|
| 61 |
+
import math
|
| 62 |
+
|
| 63 |
+
# Replicate the custom decoder's forward logic
|
| 64 |
+
p_attn_emb = self.decoder.attn_proj(attn_emb)
|
| 65 |
+
p_attn_emb = p_attn_emb.transpose(0, 1) # [time, batch, dim]
|
| 66 |
+
|
| 67 |
+
embed = self.decoder.word_embedding(word_ids)
|
| 68 |
+
emb_dim = getattr(self.decoder, "emb_dim", 256)
|
| 69 |
+
embed = self.decoder.in_dropout(embed) * math.sqrt(emb_dim)
|
| 70 |
+
embed = embed.transpose(0, 1) # [seq, batch, dim]
|
| 71 |
+
embed = self.decoder.pos_encoder(embed)
|
| 72 |
+
|
| 73 |
+
# 5. Masks
|
| 74 |
+
# CRITICAL: Create causal mask without NaN
|
| 75 |
+
# Don't use ones * inf because 0 * inf = NaN!
|
| 76 |
+
seq_len = embed.size(0)
|
| 77 |
+
|
| 78 |
+
# Create causal mask: 0 on and below diagonal, -inf above diagonal
|
| 79 |
+
# Start with zeros, then mask_fill the upper triangle
|
| 80 |
+
tgt_mask = torch.zeros(seq_len, seq_len, device=embed.device, dtype=torch.float32)
|
| 81 |
+
if seq_len > 1:
|
| 82 |
+
tgt_mask = tgt_mask.masked_fill(
|
| 83 |
+
torch.triu(torch.ones(seq_len, seq_len, device=embed.device), diagonal=1).bool(),
|
| 84 |
+
float('-inf')
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# memory_key_padding_mask
|
| 88 |
+
batch_size = attn_emb.shape[0]
|
| 89 |
+
max_len = attn_emb.shape[1]
|
| 90 |
+
|
| 91 |
+
# Create range [0, 1, ..., max_len-1]
|
| 92 |
+
arange = torch.arange(max_len, device=attn_emb.device).unsqueeze(0).expand(batch_size, -1)
|
| 93 |
+
# Mask is True where arange >= length
|
| 94 |
+
memory_key_padding_mask = arange >= attn_emb_len.unsqueeze(1)
|
| 95 |
+
|
| 96 |
+
# tgt_key_padding_mask (cap_padding_mask)
|
| 97 |
+
# For generation, we assume no padding in word_ids (all valid)
|
| 98 |
+
tgt_key_padding_mask = torch.zeros(word_ids.shape[0], word_ids.shape[1], dtype=torch.bool, device=word_ids.device)
|
| 99 |
+
|
| 100 |
+
# 6. Inner Decoder Call
|
| 101 |
+
# Pass BOTH the mask AND is_causal=True
|
| 102 |
+
# Do NOT call generate_square_subsequent_mask as it might have detection logic
|
| 103 |
+
output = self.decoder.model(
|
| 104 |
+
embed,
|
| 105 |
+
p_attn_emb,
|
| 106 |
+
tgt_mask=tgt_mask, # Static causal mask
|
| 107 |
+
tgt_is_causal=True, # Hint for optimization
|
| 108 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
| 109 |
+
memory_key_padding_mask=memory_key_padding_mask
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
output = output.transpose(0, 1) # [batch, seq, dim]
|
| 113 |
+
logits = self.decoder.classifier(output)
|
| 114 |
+
|
| 115 |
+
return logits
|
| 116 |
+
|
| 117 |
+
# Get vocab size
|
| 118 |
+
tokenizer = AutoTokenizer.from_pretrained("wsntxxn/audiocaps-simple-tokenizer", trust_remote_code=True)
|
| 119 |
+
vocab_size = len(tokenizer)
|
| 120 |
+
|
| 121 |
+
# Create wrapper
|
| 122 |
+
wrapper = DecoderStepWrapper(decoder, vocab_size)
|
| 123 |
+
wrapper.eval()
|
| 124 |
+
|
| 125 |
+
# Test with dummy input
|
| 126 |
+
device = torch.device("cpu")
|
| 127 |
+
wrapper = wrapper.to(device)
|
| 128 |
+
|
| 129 |
+
# Get encoder output for attn_emb
|
| 130 |
+
# Use the existing ONNX encoder to avoid HF encoder complications
|
| 131 |
+
print("\nLoading ONNX encoder to get attn_emb...")
|
| 132 |
+
import onnxruntime as ort
|
| 133 |
+
import numpy as np
|
| 134 |
+
|
| 135 |
+
encoder_onnx_path = "audio-caption/effb2_encoder_preprocess.onnx"
|
| 136 |
+
enc_sess = ort.InferenceSession(encoder_onnx_path)
|
| 137 |
+
|
| 138 |
+
# Create exactly 5 seconds of audio (production use case)
|
| 139 |
+
sample_rate = 16000
|
| 140 |
+
dummy_audio_np = np.random.randn(1, sample_rate * 5).astype(np.float32)
|
| 141 |
+
enc_in_name = enc_sess.get_inputs()[0].name
|
| 142 |
+
enc_out_name = enc_sess.get_outputs()[0].name
|
| 143 |
+
|
| 144 |
+
attn_emb_np = enc_sess.run([enc_out_name], {enc_in_name: dummy_audio_np})[0]
|
| 145 |
+
attn_emb = torch.from_numpy(attn_emb_np)
|
| 146 |
+
attn_emb_len = torch.tensor([attn_emb.shape[1] - 1], dtype=torch.int64)
|
| 147 |
+
|
| 148 |
+
print(f"attn_emb shape for 5-sec audio: {attn_emb.shape}")
|
| 149 |
+
|
| 150 |
+
# Try exporting with variable sequence length
|
| 151 |
+
# Start with seq_len=1, then test with seq_len=5
|
| 152 |
+
for seq_len in [1, 5]:
|
| 153 |
+
print(f"\n--- Testing with seq_len={seq_len} ---")
|
| 154 |
+
dummy_input_ids = torch.randint(0, vocab_size, (1, seq_len), dtype=torch.long)
|
| 155 |
+
|
| 156 |
+
with torch.no_grad():
|
| 157 |
+
test_out = wrapper(dummy_input_ids, attn_emb, attn_emb_len)
|
| 158 |
+
print(f"β
Forward pass successful! Output shape: {test_out.shape}")
|
| 159 |
+
|
| 160 |
+
# Now try to export with dynamic shapes using torch.export
|
| 161 |
+
print("\n--- Attempting ExecuTorch Export ---")
|
| 162 |
+
|
| 163 |
+
try:
|
| 164 |
+
from executorch.exir import to_edge
|
| 165 |
+
from torch.export import export, Dim
|
| 166 |
+
|
| 167 |
+
# Define dynamic dimensions following PyTorch's suggestions
|
| 168 |
+
# batch is always 1 for mobile inference (PyTorch detected this)
|
| 169 |
+
# seq can vary from 1 to max_seq_len
|
| 170 |
+
seq = Dim("seq", max=100)
|
| 171 |
+
|
| 172 |
+
dynamic_shapes = {
|
| 173 |
+
"word_ids": {1: seq}, # Only seq dim is dynamic
|
| 174 |
+
"attn_emb": {}, # No dynamic dims (batch=1, time is fixed per audio)
|
| 175 |
+
"attn_emb_len": {}, # Scalar-like
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
# Export with a mid-range example (seq_len=3) to show it's variable
|
| 179 |
+
example_inputs = (
|
| 180 |
+
torch.randint(0, vocab_size, (1, 3), dtype=torch.long),
|
| 181 |
+
attn_emb,
|
| 182 |
+
attn_emb_len
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
print("Exporting with torch.export (seq_len=3 example)...")
|
| 186 |
+
exported_program = export(
|
| 187 |
+
wrapper,
|
| 188 |
+
example_inputs,
|
| 189 |
+
dynamic_shapes=dynamic_shapes
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
print("β
torch.export successful!")
|
| 193 |
+
print("Converting to ExecuTorch edge dialect...")
|
| 194 |
+
|
| 195 |
+
edge_program = to_edge(exported_program)
|
| 196 |
+
print("β
Edge conversion successful!")
|
| 197 |
+
|
| 198 |
+
# Save as .pte
|
| 199 |
+
with open(args.out, 'wb') as f:
|
| 200 |
+
edge_program.to_executorch().write_to_file(f)
|
| 201 |
+
print(f"β
ExecuTorch export done: {args.out}")
|
| 202 |
+
|
| 203 |
+
print("\nπ This .pte model supports dynamic sequence lengths!")
|
| 204 |
+
print(" You can pass (batch, 1), (batch, 2), ..., (batch, 30) at inference")
|
| 205 |
+
|
| 206 |
+
except ImportError:
|
| 207 |
+
print("β ExecuTorch not installed. Install with:")
|
| 208 |
+
print(" pip install executorch")
|
| 209 |
+
except Exception as e:
|
| 210 |
+
print(f"β ExecuTorch export failed: {e}")
|
| 211 |
+
import traceback
|
| 212 |
+
traceback.print_exc()
|
| 213 |
+
print("\nFalling back to regular torch.export (no ExecuTorch)")
|
| 214 |
+
|
| 215 |
+
# Try just torch.export to see if that works
|
| 216 |
+
try:
|
| 217 |
+
from torch.export import export, Dim
|
| 218 |
+
|
| 219 |
+
batch = Dim("batch", min=1, max=4)
|
| 220 |
+
seq = Dim("seq", min=1, max=30)
|
| 221 |
+
time = Dim("time", min=1, max=100)
|
| 222 |
+
|
| 223 |
+
dynamic_shapes = {
|
| 224 |
+
"word_ids": {0: batch, 1: seq},
|
| 225 |
+
"attn_emb": {0: batch, 1: time},
|
| 226 |
+
"attn_emb_len": {0: batch},
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
example_inputs = (
|
| 230 |
+
torch.randint(0, vocab_size, (1, 1), dtype=torch.long),
|
| 231 |
+
attn_emb,
|
| 232 |
+
attn_emb_len
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
exported_program = export(wrapper, example_inputs, dynamic_shapes=dynamic_shapes)
|
| 236 |
+
print("β
torch.export successful (without ExecuTorch conversion)")
|
| 237 |
+
print(" Dynamic shapes are supported in the exported graph")
|
| 238 |
+
|
| 239 |
+
except Exception as e2:
|
| 240 |
+
print(f"β torch.export also failed: {e2}")
|
| 241 |
+
|
| 242 |
+
if __name__ == "__main__":
|
| 243 |
+
main()
|
audio-caption/export_encoder_preprocess_onnx.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# export_encoder_proprocess_onnx.py
|
| 2 |
+
import torch
|
| 3 |
+
import torchaudio
|
| 4 |
+
from transformers import AutoModel
|
| 5 |
+
import argparse
|
| 6 |
+
import os
|
| 7 |
+
import onnxruntime_extensions # Ensure extensions are available if needed
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
+
|
| 10 |
+
load_dotenv()
|
| 11 |
+
|
| 12 |
+
parser = argparse.ArgumentParser()
|
| 13 |
+
parser.add_argument("--model_id", default="wsntxxn/effb2-trm-audiocaps-captioning")
|
| 14 |
+
parser.add_argument("--out", default="audio-caption/effb2_encoder_preprocess-2.onnx")
|
| 15 |
+
parser.add_argument("--opset", type=int, default=17)
|
| 16 |
+
parser.add_argument("--device", default="cpu")
|
| 17 |
+
args = parser.parse_args()
|
| 18 |
+
|
| 19 |
+
device = torch.device(args.device)
|
| 20 |
+
|
| 21 |
+
print("Loading model (trust_remote_code=True)...")
|
| 22 |
+
model = AutoModel.from_pretrained(args.model_id, trust_remote_code=True).to(device)
|
| 23 |
+
model.eval()
|
| 24 |
+
|
| 25 |
+
# Find the encoder (same logic as original script)
|
| 26 |
+
encoder_wrapper = None
|
| 27 |
+
for candidate in ("audio_encoder", "encoder", "model", "encoder_model"):
|
| 28 |
+
if hasattr(model, candidate):
|
| 29 |
+
encoder_wrapper = getattr(model, candidate)
|
| 30 |
+
break
|
| 31 |
+
if encoder_wrapper is None:
|
| 32 |
+
try:
|
| 33 |
+
encoder_wrapper = model.model.encoder
|
| 34 |
+
except Exception:
|
| 35 |
+
encoder_wrapper = None
|
| 36 |
+
|
| 37 |
+
if encoder_wrapper is None:
|
| 38 |
+
raise RuntimeError("Couldn't find encoder attribute on model.")
|
| 39 |
+
|
| 40 |
+
# Find actual encoder
|
| 41 |
+
actual_encoder = None
|
| 42 |
+
if hasattr(encoder_wrapper, 'model'):
|
| 43 |
+
if hasattr(encoder_wrapper.model, 'encoder'):
|
| 44 |
+
actual_encoder = encoder_wrapper.model.encoder
|
| 45 |
+
elif hasattr(encoder_wrapper.model, 'model') and hasattr(encoder_wrapper.model.model, 'encoder'):
|
| 46 |
+
actual_encoder = encoder_wrapper.model.model.encoder
|
| 47 |
+
|
| 48 |
+
if actual_encoder is None:
|
| 49 |
+
print("Could not find actual encoder, using encoder_wrapper as fallback (might fail if it expects dict)")
|
| 50 |
+
actual_encoder = encoder_wrapper
|
| 51 |
+
|
| 52 |
+
# Custom MelSpectrogram to avoid complex type issues in ONNX export
|
| 53 |
+
class OnnxCompatibleMelSpectrogram(torch.nn.Module):
|
| 54 |
+
def __init__(self, sample_rate=16000, n_fft=512, win_length=512, hop_length=160, n_mels=64):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.n_fft = n_fft
|
| 57 |
+
self.win_length = win_length
|
| 58 |
+
self.hop_length = hop_length
|
| 59 |
+
|
| 60 |
+
# Create window and mel scale buffers
|
| 61 |
+
window = torch.hann_window(win_length)
|
| 62 |
+
self.register_buffer('window', window)
|
| 63 |
+
|
| 64 |
+
self.mel_scale = torchaudio.transforms.MelScale(
|
| 65 |
+
n_mels=n_mels,
|
| 66 |
+
sample_rate=sample_rate,
|
| 67 |
+
n_stft=n_fft // 2 + 1
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def forward(self, waveform):
|
| 71 |
+
# Use return_complex=False to get (..., freq, time, 2)
|
| 72 |
+
# This avoids passing complex tensors which some ONNX exporters struggle with
|
| 73 |
+
spec = torch.stft(
|
| 74 |
+
waveform,
|
| 75 |
+
n_fft=self.n_fft,
|
| 76 |
+
hop_length=self.hop_length,
|
| 77 |
+
win_length=self.win_length,
|
| 78 |
+
window=self.window,
|
| 79 |
+
center=True,
|
| 80 |
+
pad_mode='reflect',
|
| 81 |
+
normalized=False,
|
| 82 |
+
onesided=True,
|
| 83 |
+
return_complex=False
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Calculate power spectrogram: real^2 + imag^2
|
| 87 |
+
# spec shape: (batch, freq, time, 2)
|
| 88 |
+
power_spec = spec.pow(2).sum(-1) # (batch, freq, time)
|
| 89 |
+
|
| 90 |
+
# Apply Mel Scale
|
| 91 |
+
# MelScale expects (..., freq, time)
|
| 92 |
+
mel_spec = self.mel_scale(power_spec)
|
| 93 |
+
|
| 94 |
+
return mel_spec
|
| 95 |
+
|
| 96 |
+
class PreprocessEncoderWrapper(torch.nn.Module):
|
| 97 |
+
def __init__(self, actual_encoder):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.actual_encoder = actual_encoder
|
| 100 |
+
|
| 101 |
+
# Extract components
|
| 102 |
+
self.backbone = actual_encoder.backbone if hasattr(actual_encoder, 'backbone') else None
|
| 103 |
+
self.fc = actual_encoder.fc if hasattr(actual_encoder, 'fc') else None
|
| 104 |
+
self.fc_proj = actual_encoder.fc_proj if hasattr(actual_encoder, 'fc_proj') else None
|
| 105 |
+
|
| 106 |
+
if self.backbone is None:
|
| 107 |
+
self.backbone = actual_encoder
|
| 108 |
+
|
| 109 |
+
# Preprocessing settings
|
| 110 |
+
self.mel_transform = OnnxCompatibleMelSpectrogram(
|
| 111 |
+
sample_rate=16000,
|
| 112 |
+
n_fft=512,
|
| 113 |
+
win_length=512,
|
| 114 |
+
hop_length=160,
|
| 115 |
+
n_mels=64
|
| 116 |
+
)
|
| 117 |
+
self.db_transform = torchaudio.transforms.AmplitudeToDB(top_db=120)
|
| 118 |
+
|
| 119 |
+
def forward(self, audio):
|
| 120 |
+
"""
|
| 121 |
+
Args:
|
| 122 |
+
audio: (batch, time) - Raw waveform
|
| 123 |
+
"""
|
| 124 |
+
# 1. Compute Mel Spectrogram
|
| 125 |
+
mel = self.mel_transform(audio)
|
| 126 |
+
|
| 127 |
+
# 2. Amplitude to DB
|
| 128 |
+
mel_db = self.db_transform(mel)
|
| 129 |
+
|
| 130 |
+
# 3. Encoder Forward Pass
|
| 131 |
+
features = self.backbone(mel_db)
|
| 132 |
+
|
| 133 |
+
# Apply pooling/projection
|
| 134 |
+
if self.fc is not None:
|
| 135 |
+
if features.dim() == 4:
|
| 136 |
+
pooled = torch.mean(features, dim=[2, 3])
|
| 137 |
+
elif features.dim() == 3:
|
| 138 |
+
pooled = torch.mean(features, dim=2)
|
| 139 |
+
else:
|
| 140 |
+
pooled = features
|
| 141 |
+
attn_emb = self.fc(pooled).unsqueeze(1)
|
| 142 |
+
elif self.fc_proj is not None:
|
| 143 |
+
if features.dim() == 4:
|
| 144 |
+
pooled = torch.mean(features, dim=[2, 3])
|
| 145 |
+
elif features.dim() == 3:
|
| 146 |
+
pooled = torch.mean(features, dim=2)
|
| 147 |
+
else:
|
| 148 |
+
pooled = features
|
| 149 |
+
attn_emb = self.fc_proj(pooled).unsqueeze(1)
|
| 150 |
+
else:
|
| 151 |
+
if features.dim() == 4:
|
| 152 |
+
attn_emb = torch.mean(features, dim=[2, 3]).unsqueeze(1)
|
| 153 |
+
elif features.dim() == 3:
|
| 154 |
+
attn_emb = features
|
| 155 |
+
else:
|
| 156 |
+
attn_emb = features.unsqueeze(1)
|
| 157 |
+
|
| 158 |
+
return attn_emb
|
| 159 |
+
|
| 160 |
+
print("\nAttempting to export Encoder with Preprocessing...")
|
| 161 |
+
|
| 162 |
+
# Create dummy audio input
|
| 163 |
+
# 1 second of audio at 16kHz
|
| 164 |
+
dummy_audio = torch.randn(1, 16000).to(device)
|
| 165 |
+
|
| 166 |
+
wrapper = PreprocessEncoderWrapper(actual_encoder).to(device)
|
| 167 |
+
wrapper.eval()
|
| 168 |
+
|
| 169 |
+
# Test forward pass
|
| 170 |
+
with torch.no_grad():
|
| 171 |
+
out = wrapper(dummy_audio)
|
| 172 |
+
print(f"β Wrapper output shape: {out.shape}")
|
| 173 |
+
|
| 174 |
+
# Export
|
| 175 |
+
export_inputs = (dummy_audio,)
|
| 176 |
+
input_names = ["audio"]
|
| 177 |
+
output_names = ["encoder_features"]
|
| 178 |
+
dynamic_axes = {
|
| 179 |
+
"audio": {0: "batch", 1: "time"},
|
| 180 |
+
"encoder_features": {0: "batch", 1: "time"}
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
print(f"Exporting to {args.out}...")
|
| 184 |
+
try:
|
| 185 |
+
torch.onnx.export(
|
| 186 |
+
wrapper,
|
| 187 |
+
export_inputs,
|
| 188 |
+
args.out,
|
| 189 |
+
export_params=True,
|
| 190 |
+
opset_version=args.opset,
|
| 191 |
+
do_constant_folding=True,
|
| 192 |
+
input_names=["audio"],
|
| 193 |
+
output_names=["attn_emb"],
|
| 194 |
+
dynamic_axes=dynamic_axes,
|
| 195 |
+
dynamo=False,
|
| 196 |
+
)
|
| 197 |
+
print("β
Export successful!")
|
| 198 |
+
except Exception as e:
|
| 199 |
+
print(f"β Export failed: {e}")
|
| 200 |
+
import traceback
|
| 201 |
+
traceback.print_exc()
|
audio-caption/generate_caption_hybrid.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Complete generation pipeline using:
|
| 3 |
+
- ONNX Encoder (with preprocessing): effb2_encoder_preprocess.onnx
|
| 4 |
+
- ExecuTorch Decoder: effb2_decoder_5sec.pte
|
| 5 |
+
|
| 6 |
+
This script demonstrates end-to-end caption generation from 5-second audio.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import onnxruntime as ort
|
| 11 |
+
import torch
|
| 12 |
+
from executorch.extension.pybindings.portable_lib import _load_for_executorch
|
| 13 |
+
from transformers import AutoTokenizer
|
| 14 |
+
import soundfile as sf
|
| 15 |
+
import argparse
|
| 16 |
+
from dotenv import load_dotenv
|
| 17 |
+
|
| 18 |
+
load_dotenv()
|
| 19 |
+
|
| 20 |
+
def load_and_prepare_audio(audio_path, target_duration=5.0, sample_rate=16000):
|
| 21 |
+
"""Load audio and ensure it's exactly target_duration seconds"""
|
| 22 |
+
audio, sr = sf.read(audio_path)
|
| 23 |
+
|
| 24 |
+
# Convert to mono if stereo
|
| 25 |
+
if audio.ndim > 1:
|
| 26 |
+
audio = np.mean(audio, axis=1)
|
| 27 |
+
|
| 28 |
+
# Resample if needed
|
| 29 |
+
if sr != sample_rate:
|
| 30 |
+
import librosa
|
| 31 |
+
audio = librosa.resample(audio, orig_sr=sr, target_sr=sample_rate)
|
| 32 |
+
|
| 33 |
+
target_length = int(sample_rate * target_duration)
|
| 34 |
+
|
| 35 |
+
# Pad or trim to exactly target_duration
|
| 36 |
+
if len(audio) < target_length:
|
| 37 |
+
# Pad with zeros
|
| 38 |
+
audio = np.pad(audio, (0, target_length - len(audio)), mode='constant')
|
| 39 |
+
elif len(audio) > target_length:
|
| 40 |
+
# Trim
|
| 41 |
+
audio = audio[:target_length]
|
| 42 |
+
|
| 43 |
+
return audio.astype(np.float32)
|
| 44 |
+
|
| 45 |
+
def generate_caption(audio_path, encoder_path, decoder_path, max_length=30):
|
| 46 |
+
"""Generate caption from audio file"""
|
| 47 |
+
|
| 48 |
+
# Load models
|
| 49 |
+
print("Loading models...")
|
| 50 |
+
tokenizer = AutoTokenizer.from_pretrained("wsntxxn/audiocaps-simple-tokenizer", trust_remote_code=True)
|
| 51 |
+
encoder_session = ort.InferenceSession(encoder_path)
|
| 52 |
+
decoder = _load_for_executorch(decoder_path)
|
| 53 |
+
|
| 54 |
+
# Load and prepare audio (exactly 5 seconds)
|
| 55 |
+
print(f"Loading audio: {audio_path}")
|
| 56 |
+
audio = load_and_prepare_audio(audio_path, target_duration=5.0)
|
| 57 |
+
audio_batch = audio[np.newaxis, :] # (1, 80000)
|
| 58 |
+
print(f"Audio shape: {audio_batch.shape} (5.0 seconds)")
|
| 59 |
+
|
| 60 |
+
# Run encoder
|
| 61 |
+
print("\nRunning ONNX encoder...")
|
| 62 |
+
enc_input_name = encoder_session.get_inputs()[0].name
|
| 63 |
+
enc_output_name = encoder_session.get_outputs()[0].name
|
| 64 |
+
attn_emb = encoder_session.run([enc_output_name], {enc_input_name: audio_batch})[0]
|
| 65 |
+
attn_emb_len = np.array([attn_emb.shape[1] - 1], dtype=np.int64)
|
| 66 |
+
|
| 67 |
+
print(f"Encoder output shape: {attn_emb.shape}")
|
| 68 |
+
|
| 69 |
+
# Initialize generation
|
| 70 |
+
generated = [tokenizer.bos_token_id if tokenizer.bos_token_id else 1]
|
| 71 |
+
|
| 72 |
+
# Autoregressive generation with ExecuTorch decoder
|
| 73 |
+
print(f"\nGenerating caption (max {max_length} tokens)...")
|
| 74 |
+
for step in range(max_length):
|
| 75 |
+
# Prepare inputs - FULL history (stateless decoder)
|
| 76 |
+
word_ids = np.array([generated], dtype=np.int64) # (1, current_length)
|
| 77 |
+
|
| 78 |
+
# Run ExecuTorch decoder
|
| 79 |
+
logits = decoder.forward((
|
| 80 |
+
torch.from_numpy(word_ids),
|
| 81 |
+
torch.from_numpy(attn_emb).to(torch.float32),
|
| 82 |
+
torch.from_numpy(attn_emb_len)
|
| 83 |
+
))[0].numpy() # (1, current_length, vocab_size)
|
| 84 |
+
|
| 85 |
+
# Get next token from last position
|
| 86 |
+
next_token_logits = logits[0, -1, :]
|
| 87 |
+
next_token = int(np.argmax(next_token_logits))
|
| 88 |
+
|
| 89 |
+
generated.append(next_token)
|
| 90 |
+
|
| 91 |
+
# Stop If EOS token
|
| 92 |
+
if next_token == (tokenizer.eos_token_id if tokenizer.eos_token_id else 2):
|
| 93 |
+
break
|
| 94 |
+
|
| 95 |
+
# Decode caption
|
| 96 |
+
caption = tokenizer.decode(generated, skip_special_tokens=True)
|
| 97 |
+
|
| 98 |
+
print(f"\nβ
Generated caption ({len(generated)-1} tokens): {caption}")
|
| 99 |
+
print(f"Token sequence: {generated}")
|
| 100 |
+
|
| 101 |
+
return caption
|
| 102 |
+
|
| 103 |
+
def main():
|
| 104 |
+
parser = argparse.ArgumentParser(description="Generate audio caption using ONNX encoder + ExecuTorch decoder")
|
| 105 |
+
parser.add_argument("--audio", default="doorbell.wav", help="Path to audio file")
|
| 106 |
+
parser.add_argument("--encoder", default="audio-caption/effb2_encoder_preprocess.onnx",
|
| 107 |
+
help="Path to ONNX encoder")
|
| 108 |
+
parser.add_argument("--decoder", default="audio-caption/effb2_decoder_5sec.pte",
|
| 109 |
+
help="Path to ExecuTorch decoder")
|
| 110 |
+
parser.add_argument("--max-length", type=int, default=30, help="Maximum caption length")
|
| 111 |
+
|
| 112 |
+
args = parser.parse_args()
|
| 113 |
+
|
| 114 |
+
print("="*60)
|
| 115 |
+
print("ONNX Encoder + ExecuTorch Decoder Caption Generation")
|
| 116 |
+
print("="*60)
|
| 117 |
+
|
| 118 |
+
caption = generate_caption(
|
| 119 |
+
audio_path=args.audio,
|
| 120 |
+
encoder_path=args.encoder,
|
| 121 |
+
decoder_path=args.decoder,
|
| 122 |
+
max_length=args.max_length
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
print("\n" + "="*60)
|
| 126 |
+
print(f"Final Caption: {caption}")
|
| 127 |
+
print("="*60)
|
| 128 |
+
|
| 129 |
+
if __name__ == "__main__":
|
| 130 |
+
main()
|
categories.json
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"categories": [
|
| 3 |
+
{
|
| 4 |
+
"id": "dog_bark",
|
| 5 |
+
"label": "bark of a dog",
|
| 6 |
+
"description": "dog barking sound, woofing, growling or howling from a canine"
|
| 7 |
+
},
|
| 8 |
+
{
|
| 9 |
+
"id": "doorbell",
|
| 10 |
+
"label": "doorbell ringing",
|
| 11 |
+
"description": "ding, bell or advice sound in house door entrance"
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"id": "baby_crying",
|
| 15 |
+
"label": "baby crying",
|
| 16 |
+
"description": "infant crying, wailing, sobbing or distressed baby sounds"
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"id": "glass_breaking",
|
| 20 |
+
"label": "glass breaking",
|
| 21 |
+
"description": "sound of glass shattering, breaking or crashing"
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"id": "car_horn",
|
| 25 |
+
"label": "car horn",
|
| 26 |
+
"description": "vehicle horn honking, beeping or car alert sound"
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"id": "alarm_clock",
|
| 30 |
+
"label": "alarm clock",
|
| 31 |
+
"description": "alarm clock ringing, beeping or buzzing wake-up sound"
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"id": "fire_alarm",
|
| 35 |
+
"label": "fire alarm",
|
| 36 |
+
"description": "fire alarm siren, emergency alert or smoke detector beeping"
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"id": "door_closing",
|
| 40 |
+
"label": "window or door closing",
|
| 41 |
+
"description": "sound of door or window shutting, closing or slamming"
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"id": "door_opening",
|
| 45 |
+
"label": "window or door opening",
|
| 46 |
+
"description": "sound of door or window opening, creaking or unlocking"
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"id": "stagger_swipe",
|
| 50 |
+
"label": "staggerer or swipe",
|
| 51 |
+
"description": "staggering footsteps, stumbling or swiping movement sound"
|
| 52 |
+
}
|
| 53 |
+
]
|
| 54 |
+
}
|
pyproject.toml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "whisper-audio-captioning-pte"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Export Whisper audio captioning model to ExecuTorch PTE format"
|
| 5 |
+
requires-python = ">=3.10"
|
| 6 |
+
dependencies = [
|
| 7 |
+
"torch>=2.1.0",
|
| 8 |
+
"transformers>=4.36.0",
|
| 9 |
+
"datasets>=2.14.0",
|
| 10 |
+
"torchaudio>=2.1.0",
|
| 11 |
+
"soundfile>=0.12.1",
|
| 12 |
+
"executorch>=0.3.0",
|
| 13 |
+
"onnxruntime>=1.16.0",
|
| 14 |
+
"librosa>=0.10.0",
|
| 15 |
+
"optimum[exporters]",
|
| 16 |
+
"onnx",
|
| 17 |
+
"efficientnet_pytorch",
|
| 18 |
+
"einops",
|
| 19 |
+
"onnxscript",
|
| 20 |
+
"python-dotenv",
|
| 21 |
+
"onnxruntime-extensions>=0.14.0",
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
[tool.uv]
|
| 25 |
+
package = false
|
sentence-transformers-embbedings/category_embeddings.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sentence-transformers-embbedings/export_sentence_transformers_executorch.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Export Sentence Transformers model to ExecuTorch .pte format.
|
| 4 |
+
This exports 'sentence-transformers/all-MiniLM-L6-v2' compatible with mobile deployment.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from transformers import AutoModel, AutoTokenizer
|
| 10 |
+
from torch.export import export
|
| 11 |
+
from executorch.exir import to_edge, EdgeCompileConfig
|
| 12 |
+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
|
| 13 |
+
from dotenv import load_dotenv
|
| 14 |
+
|
| 15 |
+
load_dotenv()
|
| 16 |
+
|
| 17 |
+
print("π Starting Sentence Transformers ExecuTorch Export")
|
| 18 |
+
|
| 19 |
+
# 1. Load the model
|
| 20 |
+
model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
| 21 |
+
print(f"π¦ Loading model: {model_name}")
|
| 22 |
+
hf_model = AutoModel.from_pretrained(model_name)
|
| 23 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 24 |
+
hf_model.eval()
|
| 25 |
+
print("β Model loaded")
|
| 26 |
+
|
| 27 |
+
# 2. Create a wrapper that mimics sentence-transformers embedding logic
|
| 28 |
+
class SentenceTransformerWrapper(torch.nn.Module):
|
| 29 |
+
"""
|
| 30 |
+
Wraps the transformer model to produce sentence embeddings.
|
| 31 |
+
Performs mean pooling + L2 normalization, matching sentence-transformers behavior.
|
| 32 |
+
"""
|
| 33 |
+
def __init__(self, model):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.model = model
|
| 36 |
+
|
| 37 |
+
def forward(self, input_ids, attention_mask):
|
| 38 |
+
"""
|
| 39 |
+
Args:
|
| 40 |
+
input_ids: [batch, seq_len]
|
| 41 |
+
attention_mask: [batch, seq_len]
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
embeddings: [batch, hidden_dim] - Normalized sentence embeddings
|
| 45 |
+
"""
|
| 46 |
+
# Forward through transformer
|
| 47 |
+
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
| 48 |
+
|
| 49 |
+
# Mean pooling
|
| 50 |
+
token_embeddings = outputs.last_hidden_state # [batch, seq_len, hidden]
|
| 51 |
+
|
| 52 |
+
# Expand attention mask: [batch, seq_len] -> [batch, seq_len, hidden]
|
| 53 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 54 |
+
|
| 55 |
+
# Sum embeddings where mask is 1
|
| 56 |
+
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1)
|
| 57 |
+
|
| 58 |
+
# Sum mask values (clamp to avoid division by zero)
|
| 59 |
+
sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
|
| 60 |
+
|
| 61 |
+
# Compute mean
|
| 62 |
+
embeddings = sum_embeddings / sum_mask
|
| 63 |
+
|
| 64 |
+
# L2 normalization
|
| 65 |
+
embeddings = F.normalize(embeddings, p=2, dim=1)
|
| 66 |
+
|
| 67 |
+
return embeddings
|
| 68 |
+
|
| 69 |
+
# 3. Wrap the model
|
| 70 |
+
print("π§ Wrapping model...")
|
| 71 |
+
model_wrapper = SentenceTransformerWrapper(hf_model)
|
| 72 |
+
|
| 73 |
+
# 4. Create example inputs
|
| 74 |
+
example_text = "This is a test sentence for embedding generation."
|
| 75 |
+
inputs = tokenizer(
|
| 76 |
+
example_text,
|
| 77 |
+
max_length=128,
|
| 78 |
+
padding="max_length",
|
| 79 |
+
truncation=True,
|
| 80 |
+
return_tensors="pt"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
example_args = (inputs["input_ids"], inputs["attention_mask"])
|
| 84 |
+
print(f"π Example input shape: {inputs['input_ids'].shape}")
|
| 85 |
+
|
| 86 |
+
# 5. Test forward pass
|
| 87 |
+
print("π§ͺ Testing forward pass...")
|
| 88 |
+
with torch.no_grad():
|
| 89 |
+
test_output = model_wrapper(*example_args)
|
| 90 |
+
print(f"β Output shape: {test_output.shape}")
|
| 91 |
+
print(f"β Output norm: {torch.norm(test_output, dim=1).item():.4f} (should be ~1.0)")
|
| 92 |
+
|
| 93 |
+
# 6. Export to ExecuTorch
|
| 94 |
+
print("\nπ€ Exporting to ExecuTorch...")
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
# Step 1: Capture the computational graph
|
| 98 |
+
print(" 1/4 Capturing graph with torch.export...")
|
| 99 |
+
exported_program = export(model_wrapper, example_args, strict=False)
|
| 100 |
+
print(" β Graph captured")
|
| 101 |
+
|
| 102 |
+
# Step 2: Lower to Edge IR
|
| 103 |
+
print(" 2/4 Lowering to Edge IR...")
|
| 104 |
+
edge_program = to_edge(
|
| 105 |
+
exported_program,
|
| 106 |
+
compile_config=EdgeCompileConfig(_check_ir_validity=False)
|
| 107 |
+
)
|
| 108 |
+
print(" β Edge IR created")
|
| 109 |
+
|
| 110 |
+
# Step 3: Partition for XNNPACK (includes quantization optimizations)
|
| 111 |
+
print(" 3/4 Partitioning for XNNPACK (with quantization)...")
|
| 112 |
+
edge_program = edge_program.to_backend(XnnpackPartitioner())
|
| 113 |
+
print(" β XNNPACK partitioning done")
|
| 114 |
+
|
| 115 |
+
# Step 4: Convert to ExecuTorch program
|
| 116 |
+
print(" 4/4 Converting to ExecuTorch program...")
|
| 117 |
+
executorch_program = edge_program.to_executorch()
|
| 118 |
+
print(" β Conversion complete")
|
| 119 |
+
|
| 120 |
+
# Save to file
|
| 121 |
+
output_path = "sentence_transformers_minilm.pte"
|
| 122 |
+
with open(output_path, "wb") as f:
|
| 123 |
+
executorch_program.write_to_file(f)
|
| 124 |
+
|
| 125 |
+
import os
|
| 126 |
+
file_size_mb = os.path.getsize(output_path) / (1024 * 1024)
|
| 127 |
+
|
| 128 |
+
print(f"\nπ Export successful!")
|
| 129 |
+
print(f"π Saved to: {output_path}")
|
| 130 |
+
print(f"π File size: {file_size_mb:.2f} MB")
|
| 131 |
+
print(f"\nπ‘ Usage: Load this .pte file in your mobile app")
|
| 132 |
+
print(f" Input: token IDs (int64) and attention mask (int64)")
|
| 133 |
+
print(f" Output: normalized embeddings (float32, dim=384)")
|
| 134 |
+
|
| 135 |
+
except Exception as e:
|
| 136 |
+
print(f"\nβ Export failed: {e}")
|
| 137 |
+
import traceback
|
| 138 |
+
traceback.print_exc()
|
sentence-transformers-embbedings/generate_category_embeddings.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Generate category embeddings using the exported sentence-transformers .pte model.
|
| 4 |
+
Reads categories from categories.json and outputs embeddings in the same format
|
| 5 |
+
as embeddings_granite_export/category_embeddings.json
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import torch
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from transformers import AutoTokenizer
|
| 12 |
+
from executorch.extension.pybindings.portable_lib import _load_for_executorch
|
| 13 |
+
|
| 14 |
+
print("π Generating Category Embeddings with Sentence Transformers")
|
| 15 |
+
|
| 16 |
+
# Configuration
|
| 17 |
+
MODEL_PATH = "sentence-transformers-embbedings/sentence_transformers_minilm.pte"
|
| 18 |
+
CATEGORIES_PATH = "categories.json"
|
| 19 |
+
OUTPUT_PATH = "sentence-transformers-embbedings/category_embeddings.json"
|
| 20 |
+
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
| 21 |
+
|
| 22 |
+
# 1. Load the tokenizer
|
| 23 |
+
print(f"π¦ Loading tokenizer: {MODEL_NAME}")
|
| 24 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 25 |
+
|
| 26 |
+
# 2. Load the .pte model
|
| 27 |
+
print(f"π¦ Loading .pte model: {MODEL_PATH}")
|
| 28 |
+
model = _load_for_executorch(MODEL_PATH)
|
| 29 |
+
print("β Model loaded")
|
| 30 |
+
|
| 31 |
+
# 3. Load categories
|
| 32 |
+
print(f"π Loading categories from: {CATEGORIES_PATH}")
|
| 33 |
+
with open(CATEGORIES_PATH, 'r') as f:
|
| 34 |
+
categories_data = json.load(f)
|
| 35 |
+
|
| 36 |
+
categories = categories_data['categories']
|
| 37 |
+
print(f"β Loaded {len(categories)} categories")
|
| 38 |
+
|
| 39 |
+
# 4. Generate embeddings for each category
|
| 40 |
+
print("\nπ§ Generating embeddings...")
|
| 41 |
+
embeddings_list = []
|
| 42 |
+
updated_categories = []
|
| 43 |
+
|
| 44 |
+
for idx, category in enumerate(categories):
|
| 45 |
+
# Create the text to embed (label + description, matching Granite format)
|
| 46 |
+
text_embedded = f"{category['label']}. {category['description']}"
|
| 47 |
+
|
| 48 |
+
# Tokenize
|
| 49 |
+
inputs = tokenizer(
|
| 50 |
+
text_embedded,
|
| 51 |
+
max_length=128,
|
| 52 |
+
padding="max_length",
|
| 53 |
+
truncation=True,
|
| 54 |
+
return_tensors="pt"
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Prepare inputs for ExecuTorch (as lists)
|
| 58 |
+
input_ids = inputs["input_ids"]
|
| 59 |
+
attention_mask = inputs["attention_mask"]
|
| 60 |
+
|
| 61 |
+
# Run inference
|
| 62 |
+
outputs = model.forward((input_ids, attention_mask))
|
| 63 |
+
|
| 64 |
+
# Extract embedding (should be [1, 384])
|
| 65 |
+
embedding_tensor = outputs[0]
|
| 66 |
+
embedding_list = embedding_tensor.squeeze(0).tolist()
|
| 67 |
+
|
| 68 |
+
embeddings_list.append(embedding_list)
|
| 69 |
+
|
| 70 |
+
# Add text_embedded field to category
|
| 71 |
+
category_copy = category.copy()
|
| 72 |
+
category_copy["text_embedded"] = text_embedded
|
| 73 |
+
updated_categories.append(category_copy)
|
| 74 |
+
|
| 75 |
+
print(f" β [{idx+1}/{len(categories)}] {category['id']}: {category['label']}")
|
| 76 |
+
|
| 77 |
+
# 5. Create output JSON in the same format as Granite embeddings
|
| 78 |
+
output_data = {
|
| 79 |
+
"categories": updated_categories,
|
| 80 |
+
"embeddings": embeddings_list,
|
| 81 |
+
"metadata": {
|
| 82 |
+
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
| 83 |
+
"model_file": MODEL_PATH,
|
| 84 |
+
"embedding_dimension": len(embeddings_list[0]),
|
| 85 |
+
"total_categories": len(categories),
|
| 86 |
+
"normalization": "L2",
|
| 87 |
+
"pooling": "mean"
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
# 6. Save to file
|
| 92 |
+
print(f"\nπΎ Saving embeddings to: {OUTPUT_PATH}")
|
| 93 |
+
with open(OUTPUT_PATH, 'w') as f:
|
| 94 |
+
json.dump(output_data, f, indent=2)
|
| 95 |
+
|
| 96 |
+
file_size_kb = Path(OUTPUT_PATH).stat().st_size / 1024
|
| 97 |
+
print(f"β Saved successfully ({file_size_kb:.2f} KB)")
|
| 98 |
+
|
| 99 |
+
print("\nπ Done!")
|
| 100 |
+
print(f"π Generated {len(embeddings_list)} embeddings of dimension {len(embeddings_list[0])}")
|
| 101 |
+
print(f"π Output: {OUTPUT_PATH}")
|
sentence-transformers-embbedings/sentence_transformers_minilm.pte
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:de2fdcf7daf9b592856a5b740108258c589c7b5c26921b51abe197364dd3cabb
|
| 3 |
+
size 90379856
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|