stivenDR14 commited on
Commit
5c8d855
Β·
1 Parent(s): ae932ad

feat: Introduce audio captioning and categorization model with ONNX/ExecuTorch hybrid inference and category embedding generation.

Browse files
.gitignore ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # System files
2
+ .DS_Store
3
+ .DS_Store?
4
+ ._*
5
+ .Spotlight-V100
6
+ .Trashes
7
+ ehthumbs.db
8
+ Thumbs.db
9
+
10
+ # Environment variables
11
+ .env
12
+ .env.local
13
+ .env.*.local
14
+
15
+ # Python
16
+ __pycache__/
17
+ *.py[cod]
18
+ *$py.class
19
+ .venv/
20
+ venv/
21
+ ENV/
22
+ env/
23
+ .Python
24
+ build/
25
+ develop-eggs/
26
+ dist/
27
+ downloads/
28
+ eggs/
29
+ .eggs/
30
+ lib/
31
+ lib64/
32
+ parts/
33
+ sdist/
34
+ var/
35
+ wheels/
36
+ share/python-wheels/
37
+ *.egg-info/
38
+ .installed.cfg
39
+ *.egg
40
+ MANIFEST
41
+ _temp/
42
+
43
+ # Testing and Code Quality
44
+ .mypy_cache/
45
+ .pytest_cache/
46
+ .coverage
47
+ htmlcov/
48
+ .tox/
49
+ .nox/
50
+
51
+ # IDEs
52
+ .idea/
53
+ .vscode/
54
+ *.swp
55
+ *.swo
56
+
README.md CHANGED
@@ -1,3 +1,212 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - audio
5
+ - audio-classification
6
+ - audio-captioning
7
+ - onnx
8
+ - executorch
9
+ - mobile
10
+ - arm
11
+ language:
12
+ - en
13
+ pipeline_tag: audio-classification
14
+ ---
15
+
16
+ # Audio Caption and Categorizer Models
17
+
18
+ ## Model Description
19
+
20
+ This repository provides **optimized exports** of audio captioning and categorization models for **ARM-based mobile deployment**. The pipeline consists of:
21
+
22
+ 1. **Audio Captioning**: Uses [`wsntxxn/effb2-trm-audiocaps-captioning`](https://huggingface.co/wsntxxn/effb2-trm-audiocaps-captioning) (EfficientNet-B2 encoder + Transformer decoder) to generate natural language descriptions of audio events.
23
+
24
+ 2. **Audio Categorization**: Uses [`sentence-transformers/all-MiniLM-L6-v2`](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) to match generated captions to predefined sound categories via semantic similarity.
25
+
26
+ ### Export Formats
27
+ - **Encoder**: ONNX format with integrated preprocessing (STFT, MelSpectrogram, AmplitudeToDB)
28
+ - **Decoder**: ExecuTorch (`.pte`) format with dynamic quantization for reduced model size
29
+ - **Categorizer**: ExecuTorch (`.pte`) format with quantization
30
+
31
+ ### Key Features
32
+ - 5-second audio input at 16kHz
33
+ - Preprocessing baked into ONNX encoder (no external audio processing needed)
34
+ - Optimized for mobile inference with quantization
35
+ - Complete end-to-end pipeline from raw audio to categorized captions
36
+
37
+ ## Usage
38
+
39
+ ### Quick Start
40
+
41
+ Generate a caption for an audio file:
42
+
43
+ ```bash
44
+ # Activate environment
45
+ source .venv/bin/activate
46
+
47
+ # Generate caption
48
+ python audio-caption/generate_caption_hybrid.py --audio sample_audio.wav
49
+ ```
50
+
51
+ ### Python Example
52
+
53
+ ```python
54
+ import onnxruntime as ort
55
+ from executorch.extension.pybindings.portable_lib import _load_for_executorch
56
+ from transformers import AutoTokenizer
57
+ import numpy as np
58
+
59
+ # Load models
60
+ encoder_session = ort.InferenceSession("audio-caption/effb2_encoder_preprocess.onnx")
61
+ decoder = _load_for_executorch("audio-caption/effb2_decoder_5sec.pte")
62
+ tokenizer = AutoTokenizer.from_pretrained("wsntxxn/audiocaps-simple-tokenizer", trust_remote_code=True)
63
+
64
+ # Process audio (16kHz, 5 seconds = 80000 samples)
65
+ audio = np.random.randn(1, 80000).astype(np.float32)
66
+
67
+ # Encode
68
+ attn_emb = encoder_session.run(["attn_emb"], {"audio": audio})[0]
69
+
70
+ # Decode (greedy search)
71
+ generated = [tokenizer.bos_token_id]
72
+ for _ in range(30):
73
+ logits = decoder.forward((
74
+ torch.tensor([generated]),
75
+ torch.tensor(attn_emb),
76
+ torch.tensor([attn_emb.shape[1] - 1])
77
+ ))[0]
78
+ next_token = int(torch.argmax(logits[0, -1, :]))
79
+ generated.append(next_token)
80
+ if next_token == tokenizer.eos_token_id:
81
+ break
82
+
83
+ caption = tokenizer.decode(generated, skip_special_tokens=True)
84
+ print(caption)
85
+ ```
86
+
87
+
88
+
89
+ ## Training Details
90
+
91
+ ### Base Models
92
+
93
+ This repository does **not train models** but exports pre-trained models to optimized formats:
94
+
95
+ | Component | Base Model | Training Dataset | Parameters |
96
+ |-----------|------------|------------------|------------|
97
+ | Audio Encoder | EfficientNet-B2 | AudioCaps | ~7.7M |
98
+ | Caption Decoder | Transformer (2 layers) | AudioCaps | ~4.3M |
99
+ | Categorizer | all-MiniLM-L6-v2 | 1B+ sentence pairs | ~22.7M |
100
+
101
+ ### Export Configuration
102
+
103
+ **Audio Captioning**:
104
+ - **Preprocessing**: `n_mels=64`, `n_fft=512`, `hop_length=160`, `win_length=512`
105
+ - **Input**: Raw audio waveform (16kHz, 5 seconds)
106
+ - **Encoder**: ONNX opset 17 with dynamic axes
107
+ - **Decoder**: ExecuTorch with dynamic quantization (int8)
108
+
109
+ **Categorizer**:
110
+ - **Tokenizer**: RoBERTa-based (max length: 128)
111
+ - **Export**: ExecuTorch with dynamic quantization
112
+ - **Categories**: 50+ predefined audio event categories
113
+
114
+ ### Quantization Impact
115
+
116
+ | Model | Original Size | Quantized Size | Quality Impact |
117
+ |-------|---------------|----------------|----------------|
118
+ | Decoder | ~17MB | ~15MB | Minimal (<2% caption quality) |
119
+ | Categorizer | ~90MB | ~23MB | Minimal (<1% accuracy) |
120
+
121
+ ## Project Structure
122
+
123
+ ```
124
+ .
125
+ β”œβ”€β”€ audio-caption/
126
+ β”‚ β”œβ”€β”€ export_encoder_preprocess_onnx.py # Export ONNX encoder
127
+ β”‚ β”œβ”€β”€ export_decoder_executorch.py # Export ExecuTorch decoder
128
+ β”‚ β”œβ”€β”€ generate_caption_hybrid.py # Inference pipeline
129
+ β”‚ β”œβ”€β”€ effb2_encoder_preprocess.onnx # Exported encoder
130
+ β”‚ └── effb2_decoder_5sec.pte # Exported decoder
131
+ β”‚
132
+ β”œβ”€β”€ sentence-transformers-embbedings/
133
+ β”‚ β”œβ”€β”€ export_sentence_transformers_executorch.py
134
+ β”‚ β”œβ”€β”€ generate_category_embeddings.py
135
+ β”‚ └── category_embeddings.json
136
+ β”‚
137
+ └── categories.json # Category definitions
138
+ ```
139
+
140
+ ## Setup
141
+
142
+ ### Prerequisites
143
+
144
+ ```bash
145
+ # Install uv package manager
146
+ pip install uv
147
+
148
+ # Create environment
149
+ uv venv
150
+ source .venv/bin/activate
151
+
152
+ # Install dependencies
153
+ uv pip install -r pyproject.toml
154
+ ```
155
+
156
+ ### Configuration
157
+
158
+ Create a `.env` file:
159
+
160
+ ```ini
161
+ # Hugging Face Token (for gated models)
162
+ HF_TOKEN=your_token_here
163
+
164
+ # Optional: Custom cache directory
165
+ # HF_HOME=./.cache/huggingface
166
+ ```
167
+
168
+ ### Export Models
169
+
170
+ ```bash
171
+ # Export audio captioning models
172
+ python audio-caption/export_encoder_preprocess_onnx.py
173
+ python audio-caption/export_decoder_executorch.py
174
+
175
+ # Export categorization model
176
+ python sentence-transformers-embbedings/export_sentence_transformers_executorch.py
177
+
178
+ # Generate category embeddings
179
+ python sentence-transformers-embbedings/generate_category_embeddings.py
180
+ ```
181
+
182
+ ## License
183
+
184
+ Apache License 2.0
185
+
186
+ ## Citations
187
+
188
+ ### Audio Captioning Model
189
+
190
+ ```bibtex
191
+ @inproceedings{xu2024efficient,
192
+ title={Efficient Audio Captioning with Encoder-Level Knowledge Distillation},
193
+ author={Xu, Xuenan and Liu, Haohe and Wu, Mengyue and Wang, Wenwu and Plumbley, Mark D.},
194
+ booktitle={Interspeech 2024},
195
+ year={2024},
196
+ doi={10.48550/arXiv.2407.14329},
197
+ url={https://arxiv.org/abs/2407.14329}
198
+ }
199
+ ```
200
+
201
+ ### Sentence Transformer
202
+
203
+ ```bibtex
204
+ @inproceedings{reimers-2019-sentence-bert,
205
+ title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
206
+ author = "Reimers, Nils and Gurevych, Iryna",
207
+ booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
208
+ year = "2019",
209
+ publisher = "Association for Computational Linguistics",
210
+ url = "https://arxiv.org/abs/1908.10084",
211
+ }
212
+ ```
audio-caption/effb2_decoder_5sec.pte ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:319fbb6363ba11fa13b2e0a2bc7b97cdc8526208cfa79a1cc7a65b6f683a91d0
3
+ size 15144068
audio-caption/effb2_encoder_preprocess-2.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae814c75c799de5717308ad63672f282619d021f2f394c84aaf264044bb298bf
3
+ size 30925938
audio-caption/export_decoder_executorch.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Export decoder to ExecuTorch .pte format as an alternative to ONNX.
3
+ This might handle dynamic sequence lengths better.
4
+ """
5
+
6
+ import torch
7
+ import argparse
8
+ from transformers import AutoModel, AutoTokenizer
9
+ from dotenv import load_dotenv
10
+
11
+ load_dotenv()
12
+
13
+ def main():
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument("--model", default="wsntxxn/effb2-trm-audiocaps-captioning")
16
+ parser.add_argument("--out", default="effb2_decoder_step.pte")
17
+ args = parser.parse_args()
18
+
19
+ print(f"Loading model: {args.model}")
20
+ model = AutoModel.from_pretrained(args.model, trust_remote_code=True)
21
+ model.eval()
22
+
23
+ # Get decoder - navigate through the model structure
24
+ # Based on inspection: model.model.model.decoder
25
+ if hasattr(model, "model") and hasattr(model.model, "model") and hasattr(model.model.model, "decoder"):
26
+ decoder = model.model.model.decoder
27
+ encoder = model.model.model.encoder
28
+ print(f"Found decoder at model.model.model.decoder")
29
+ elif hasattr(model, "model") and hasattr(model.model, "decoder"):
30
+ decoder = model.model.decoder
31
+ encoder = model.model.encoder
32
+ print(f"Found decoder at model.model.decoder")
33
+ else:
34
+ # Try to find by iterating
35
+ for name, module in model.named_modules():
36
+ if "decoder" in name.lower() and "TransformerDecoder" in module.__class__.__name__:
37
+ decoder = module
38
+ print(f"Found decoder at {name}")
39
+ break
40
+ else:
41
+ raise RuntimeError("Could not find decoder in model")
42
+
43
+ print(f"Decoder: {decoder.__class__.__name__}")
44
+
45
+ # Wrap decoder similar to ONNX version
46
+ class DecoderStepWrapper(torch.nn.Module):
47
+ def __init__(self, decoder, vocab_size):
48
+ super().__init__()
49
+ self.decoder = decoder
50
+ self.vocab_size = vocab_size
51
+
52
+ def forward(self, word_ids, attn_emb, attn_emb_len):
53
+ """
54
+ Args:
55
+ word_ids: (batch, seq_len)
56
+ attn_emb: (batch, time, dim)
57
+ attn_emb_len: (batch,)
58
+ Returns:
59
+ logits: (batch, seq_len, vocab_size)
60
+ """
61
+ import math
62
+
63
+ # Replicate the custom decoder's forward logic
64
+ p_attn_emb = self.decoder.attn_proj(attn_emb)
65
+ p_attn_emb = p_attn_emb.transpose(0, 1) # [time, batch, dim]
66
+
67
+ embed = self.decoder.word_embedding(word_ids)
68
+ emb_dim = getattr(self.decoder, "emb_dim", 256)
69
+ embed = self.decoder.in_dropout(embed) * math.sqrt(emb_dim)
70
+ embed = embed.transpose(0, 1) # [seq, batch, dim]
71
+ embed = self.decoder.pos_encoder(embed)
72
+
73
+ # 5. Masks
74
+ # CRITICAL: Create causal mask without NaN
75
+ # Don't use ones * inf because 0 * inf = NaN!
76
+ seq_len = embed.size(0)
77
+
78
+ # Create causal mask: 0 on and below diagonal, -inf above diagonal
79
+ # Start with zeros, then mask_fill the upper triangle
80
+ tgt_mask = torch.zeros(seq_len, seq_len, device=embed.device, dtype=torch.float32)
81
+ if seq_len > 1:
82
+ tgt_mask = tgt_mask.masked_fill(
83
+ torch.triu(torch.ones(seq_len, seq_len, device=embed.device), diagonal=1).bool(),
84
+ float('-inf')
85
+ )
86
+
87
+ # memory_key_padding_mask
88
+ batch_size = attn_emb.shape[0]
89
+ max_len = attn_emb.shape[1]
90
+
91
+ # Create range [0, 1, ..., max_len-1]
92
+ arange = torch.arange(max_len, device=attn_emb.device).unsqueeze(0).expand(batch_size, -1)
93
+ # Mask is True where arange >= length
94
+ memory_key_padding_mask = arange >= attn_emb_len.unsqueeze(1)
95
+
96
+ # tgt_key_padding_mask (cap_padding_mask)
97
+ # For generation, we assume no padding in word_ids (all valid)
98
+ tgt_key_padding_mask = torch.zeros(word_ids.shape[0], word_ids.shape[1], dtype=torch.bool, device=word_ids.device)
99
+
100
+ # 6. Inner Decoder Call
101
+ # Pass BOTH the mask AND is_causal=True
102
+ # Do NOT call generate_square_subsequent_mask as it might have detection logic
103
+ output = self.decoder.model(
104
+ embed,
105
+ p_attn_emb,
106
+ tgt_mask=tgt_mask, # Static causal mask
107
+ tgt_is_causal=True, # Hint for optimization
108
+ tgt_key_padding_mask=tgt_key_padding_mask,
109
+ memory_key_padding_mask=memory_key_padding_mask
110
+ )
111
+
112
+ output = output.transpose(0, 1) # [batch, seq, dim]
113
+ logits = self.decoder.classifier(output)
114
+
115
+ return logits
116
+
117
+ # Get vocab size
118
+ tokenizer = AutoTokenizer.from_pretrained("wsntxxn/audiocaps-simple-tokenizer", trust_remote_code=True)
119
+ vocab_size = len(tokenizer)
120
+
121
+ # Create wrapper
122
+ wrapper = DecoderStepWrapper(decoder, vocab_size)
123
+ wrapper.eval()
124
+
125
+ # Test with dummy input
126
+ device = torch.device("cpu")
127
+ wrapper = wrapper.to(device)
128
+
129
+ # Get encoder output for attn_emb
130
+ # Use the existing ONNX encoder to avoid HF encoder complications
131
+ print("\nLoading ONNX encoder to get attn_emb...")
132
+ import onnxruntime as ort
133
+ import numpy as np
134
+
135
+ encoder_onnx_path = "audio-caption/effb2_encoder_preprocess.onnx"
136
+ enc_sess = ort.InferenceSession(encoder_onnx_path)
137
+
138
+ # Create exactly 5 seconds of audio (production use case)
139
+ sample_rate = 16000
140
+ dummy_audio_np = np.random.randn(1, sample_rate * 5).astype(np.float32)
141
+ enc_in_name = enc_sess.get_inputs()[0].name
142
+ enc_out_name = enc_sess.get_outputs()[0].name
143
+
144
+ attn_emb_np = enc_sess.run([enc_out_name], {enc_in_name: dummy_audio_np})[0]
145
+ attn_emb = torch.from_numpy(attn_emb_np)
146
+ attn_emb_len = torch.tensor([attn_emb.shape[1] - 1], dtype=torch.int64)
147
+
148
+ print(f"attn_emb shape for 5-sec audio: {attn_emb.shape}")
149
+
150
+ # Try exporting with variable sequence length
151
+ # Start with seq_len=1, then test with seq_len=5
152
+ for seq_len in [1, 5]:
153
+ print(f"\n--- Testing with seq_len={seq_len} ---")
154
+ dummy_input_ids = torch.randint(0, vocab_size, (1, seq_len), dtype=torch.long)
155
+
156
+ with torch.no_grad():
157
+ test_out = wrapper(dummy_input_ids, attn_emb, attn_emb_len)
158
+ print(f"βœ… Forward pass successful! Output shape: {test_out.shape}")
159
+
160
+ # Now try to export with dynamic shapes using torch.export
161
+ print("\n--- Attempting ExecuTorch Export ---")
162
+
163
+ try:
164
+ from executorch.exir import to_edge
165
+ from torch.export import export, Dim
166
+
167
+ # Define dynamic dimensions following PyTorch's suggestions
168
+ # batch is always 1 for mobile inference (PyTorch detected this)
169
+ # seq can vary from 1 to max_seq_len
170
+ seq = Dim("seq", max=100)
171
+
172
+ dynamic_shapes = {
173
+ "word_ids": {1: seq}, # Only seq dim is dynamic
174
+ "attn_emb": {}, # No dynamic dims (batch=1, time is fixed per audio)
175
+ "attn_emb_len": {}, # Scalar-like
176
+ }
177
+
178
+ # Export with a mid-range example (seq_len=3) to show it's variable
179
+ example_inputs = (
180
+ torch.randint(0, vocab_size, (1, 3), dtype=torch.long),
181
+ attn_emb,
182
+ attn_emb_len
183
+ )
184
+
185
+ print("Exporting with torch.export (seq_len=3 example)...")
186
+ exported_program = export(
187
+ wrapper,
188
+ example_inputs,
189
+ dynamic_shapes=dynamic_shapes
190
+ )
191
+
192
+ print("βœ… torch.export successful!")
193
+ print("Converting to ExecuTorch edge dialect...")
194
+
195
+ edge_program = to_edge(exported_program)
196
+ print("βœ… Edge conversion successful!")
197
+
198
+ # Save as .pte
199
+ with open(args.out, 'wb') as f:
200
+ edge_program.to_executorch().write_to_file(f)
201
+ print(f"βœ… ExecuTorch export done: {args.out}")
202
+
203
+ print("\nπŸ“ This .pte model supports dynamic sequence lengths!")
204
+ print(" You can pass (batch, 1), (batch, 2), ..., (batch, 30) at inference")
205
+
206
+ except ImportError:
207
+ print("❌ ExecuTorch not installed. Install with:")
208
+ print(" pip install executorch")
209
+ except Exception as e:
210
+ print(f"❌ ExecuTorch export failed: {e}")
211
+ import traceback
212
+ traceback.print_exc()
213
+ print("\nFalling back to regular torch.export (no ExecuTorch)")
214
+
215
+ # Try just torch.export to see if that works
216
+ try:
217
+ from torch.export import export, Dim
218
+
219
+ batch = Dim("batch", min=1, max=4)
220
+ seq = Dim("seq", min=1, max=30)
221
+ time = Dim("time", min=1, max=100)
222
+
223
+ dynamic_shapes = {
224
+ "word_ids": {0: batch, 1: seq},
225
+ "attn_emb": {0: batch, 1: time},
226
+ "attn_emb_len": {0: batch},
227
+ }
228
+
229
+ example_inputs = (
230
+ torch.randint(0, vocab_size, (1, 1), dtype=torch.long),
231
+ attn_emb,
232
+ attn_emb_len
233
+ )
234
+
235
+ exported_program = export(wrapper, example_inputs, dynamic_shapes=dynamic_shapes)
236
+ print("βœ… torch.export successful (without ExecuTorch conversion)")
237
+ print(" Dynamic shapes are supported in the exported graph")
238
+
239
+ except Exception as e2:
240
+ print(f"❌ torch.export also failed: {e2}")
241
+
242
+ if __name__ == "__main__":
243
+ main()
audio-caption/export_encoder_preprocess_onnx.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # export_encoder_proprocess_onnx.py
2
+ import torch
3
+ import torchaudio
4
+ from transformers import AutoModel
5
+ import argparse
6
+ import os
7
+ import onnxruntime_extensions # Ensure extensions are available if needed
8
+ from dotenv import load_dotenv
9
+
10
+ load_dotenv()
11
+
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument("--model_id", default="wsntxxn/effb2-trm-audiocaps-captioning")
14
+ parser.add_argument("--out", default="audio-caption/effb2_encoder_preprocess-2.onnx")
15
+ parser.add_argument("--opset", type=int, default=17)
16
+ parser.add_argument("--device", default="cpu")
17
+ args = parser.parse_args()
18
+
19
+ device = torch.device(args.device)
20
+
21
+ print("Loading model (trust_remote_code=True)...")
22
+ model = AutoModel.from_pretrained(args.model_id, trust_remote_code=True).to(device)
23
+ model.eval()
24
+
25
+ # Find the encoder (same logic as original script)
26
+ encoder_wrapper = None
27
+ for candidate in ("audio_encoder", "encoder", "model", "encoder_model"):
28
+ if hasattr(model, candidate):
29
+ encoder_wrapper = getattr(model, candidate)
30
+ break
31
+ if encoder_wrapper is None:
32
+ try:
33
+ encoder_wrapper = model.model.encoder
34
+ except Exception:
35
+ encoder_wrapper = None
36
+
37
+ if encoder_wrapper is None:
38
+ raise RuntimeError("Couldn't find encoder attribute on model.")
39
+
40
+ # Find actual encoder
41
+ actual_encoder = None
42
+ if hasattr(encoder_wrapper, 'model'):
43
+ if hasattr(encoder_wrapper.model, 'encoder'):
44
+ actual_encoder = encoder_wrapper.model.encoder
45
+ elif hasattr(encoder_wrapper.model, 'model') and hasattr(encoder_wrapper.model.model, 'encoder'):
46
+ actual_encoder = encoder_wrapper.model.model.encoder
47
+
48
+ if actual_encoder is None:
49
+ print("Could not find actual encoder, using encoder_wrapper as fallback (might fail if it expects dict)")
50
+ actual_encoder = encoder_wrapper
51
+
52
+ # Custom MelSpectrogram to avoid complex type issues in ONNX export
53
+ class OnnxCompatibleMelSpectrogram(torch.nn.Module):
54
+ def __init__(self, sample_rate=16000, n_fft=512, win_length=512, hop_length=160, n_mels=64):
55
+ super().__init__()
56
+ self.n_fft = n_fft
57
+ self.win_length = win_length
58
+ self.hop_length = hop_length
59
+
60
+ # Create window and mel scale buffers
61
+ window = torch.hann_window(win_length)
62
+ self.register_buffer('window', window)
63
+
64
+ self.mel_scale = torchaudio.transforms.MelScale(
65
+ n_mels=n_mels,
66
+ sample_rate=sample_rate,
67
+ n_stft=n_fft // 2 + 1
68
+ )
69
+
70
+ def forward(self, waveform):
71
+ # Use return_complex=False to get (..., freq, time, 2)
72
+ # This avoids passing complex tensors which some ONNX exporters struggle with
73
+ spec = torch.stft(
74
+ waveform,
75
+ n_fft=self.n_fft,
76
+ hop_length=self.hop_length,
77
+ win_length=self.win_length,
78
+ window=self.window,
79
+ center=True,
80
+ pad_mode='reflect',
81
+ normalized=False,
82
+ onesided=True,
83
+ return_complex=False
84
+ )
85
+
86
+ # Calculate power spectrogram: real^2 + imag^2
87
+ # spec shape: (batch, freq, time, 2)
88
+ power_spec = spec.pow(2).sum(-1) # (batch, freq, time)
89
+
90
+ # Apply Mel Scale
91
+ # MelScale expects (..., freq, time)
92
+ mel_spec = self.mel_scale(power_spec)
93
+
94
+ return mel_spec
95
+
96
+ class PreprocessEncoderWrapper(torch.nn.Module):
97
+ def __init__(self, actual_encoder):
98
+ super().__init__()
99
+ self.actual_encoder = actual_encoder
100
+
101
+ # Extract components
102
+ self.backbone = actual_encoder.backbone if hasattr(actual_encoder, 'backbone') else None
103
+ self.fc = actual_encoder.fc if hasattr(actual_encoder, 'fc') else None
104
+ self.fc_proj = actual_encoder.fc_proj if hasattr(actual_encoder, 'fc_proj') else None
105
+
106
+ if self.backbone is None:
107
+ self.backbone = actual_encoder
108
+
109
+ # Preprocessing settings
110
+ self.mel_transform = OnnxCompatibleMelSpectrogram(
111
+ sample_rate=16000,
112
+ n_fft=512,
113
+ win_length=512,
114
+ hop_length=160,
115
+ n_mels=64
116
+ )
117
+ self.db_transform = torchaudio.transforms.AmplitudeToDB(top_db=120)
118
+
119
+ def forward(self, audio):
120
+ """
121
+ Args:
122
+ audio: (batch, time) - Raw waveform
123
+ """
124
+ # 1. Compute Mel Spectrogram
125
+ mel = self.mel_transform(audio)
126
+
127
+ # 2. Amplitude to DB
128
+ mel_db = self.db_transform(mel)
129
+
130
+ # 3. Encoder Forward Pass
131
+ features = self.backbone(mel_db)
132
+
133
+ # Apply pooling/projection
134
+ if self.fc is not None:
135
+ if features.dim() == 4:
136
+ pooled = torch.mean(features, dim=[2, 3])
137
+ elif features.dim() == 3:
138
+ pooled = torch.mean(features, dim=2)
139
+ else:
140
+ pooled = features
141
+ attn_emb = self.fc(pooled).unsqueeze(1)
142
+ elif self.fc_proj is not None:
143
+ if features.dim() == 4:
144
+ pooled = torch.mean(features, dim=[2, 3])
145
+ elif features.dim() == 3:
146
+ pooled = torch.mean(features, dim=2)
147
+ else:
148
+ pooled = features
149
+ attn_emb = self.fc_proj(pooled).unsqueeze(1)
150
+ else:
151
+ if features.dim() == 4:
152
+ attn_emb = torch.mean(features, dim=[2, 3]).unsqueeze(1)
153
+ elif features.dim() == 3:
154
+ attn_emb = features
155
+ else:
156
+ attn_emb = features.unsqueeze(1)
157
+
158
+ return attn_emb
159
+
160
+ print("\nAttempting to export Encoder with Preprocessing...")
161
+
162
+ # Create dummy audio input
163
+ # 1 second of audio at 16kHz
164
+ dummy_audio = torch.randn(1, 16000).to(device)
165
+
166
+ wrapper = PreprocessEncoderWrapper(actual_encoder).to(device)
167
+ wrapper.eval()
168
+
169
+ # Test forward pass
170
+ with torch.no_grad():
171
+ out = wrapper(dummy_audio)
172
+ print(f"βœ“ Wrapper output shape: {out.shape}")
173
+
174
+ # Export
175
+ export_inputs = (dummy_audio,)
176
+ input_names = ["audio"]
177
+ output_names = ["encoder_features"]
178
+ dynamic_axes = {
179
+ "audio": {0: "batch", 1: "time"},
180
+ "encoder_features": {0: "batch", 1: "time"}
181
+ }
182
+
183
+ print(f"Exporting to {args.out}...")
184
+ try:
185
+ torch.onnx.export(
186
+ wrapper,
187
+ export_inputs,
188
+ args.out,
189
+ export_params=True,
190
+ opset_version=args.opset,
191
+ do_constant_folding=True,
192
+ input_names=["audio"],
193
+ output_names=["attn_emb"],
194
+ dynamic_axes=dynamic_axes,
195
+ dynamo=False,
196
+ )
197
+ print("βœ… Export successful!")
198
+ except Exception as e:
199
+ print(f"❌ Export failed: {e}")
200
+ import traceback
201
+ traceback.print_exc()
audio-caption/generate_caption_hybrid.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Complete generation pipeline using:
3
+ - ONNX Encoder (with preprocessing): effb2_encoder_preprocess.onnx
4
+ - ExecuTorch Decoder: effb2_decoder_5sec.pte
5
+
6
+ This script demonstrates end-to-end caption generation from 5-second audio.
7
+ """
8
+
9
+ import numpy as np
10
+ import onnxruntime as ort
11
+ import torch
12
+ from executorch.extension.pybindings.portable_lib import _load_for_executorch
13
+ from transformers import AutoTokenizer
14
+ import soundfile as sf
15
+ import argparse
16
+ from dotenv import load_dotenv
17
+
18
+ load_dotenv()
19
+
20
+ def load_and_prepare_audio(audio_path, target_duration=5.0, sample_rate=16000):
21
+ """Load audio and ensure it's exactly target_duration seconds"""
22
+ audio, sr = sf.read(audio_path)
23
+
24
+ # Convert to mono if stereo
25
+ if audio.ndim > 1:
26
+ audio = np.mean(audio, axis=1)
27
+
28
+ # Resample if needed
29
+ if sr != sample_rate:
30
+ import librosa
31
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=sample_rate)
32
+
33
+ target_length = int(sample_rate * target_duration)
34
+
35
+ # Pad or trim to exactly target_duration
36
+ if len(audio) < target_length:
37
+ # Pad with zeros
38
+ audio = np.pad(audio, (0, target_length - len(audio)), mode='constant')
39
+ elif len(audio) > target_length:
40
+ # Trim
41
+ audio = audio[:target_length]
42
+
43
+ return audio.astype(np.float32)
44
+
45
+ def generate_caption(audio_path, encoder_path, decoder_path, max_length=30):
46
+ """Generate caption from audio file"""
47
+
48
+ # Load models
49
+ print("Loading models...")
50
+ tokenizer = AutoTokenizer.from_pretrained("wsntxxn/audiocaps-simple-tokenizer", trust_remote_code=True)
51
+ encoder_session = ort.InferenceSession(encoder_path)
52
+ decoder = _load_for_executorch(decoder_path)
53
+
54
+ # Load and prepare audio (exactly 5 seconds)
55
+ print(f"Loading audio: {audio_path}")
56
+ audio = load_and_prepare_audio(audio_path, target_duration=5.0)
57
+ audio_batch = audio[np.newaxis, :] # (1, 80000)
58
+ print(f"Audio shape: {audio_batch.shape} (5.0 seconds)")
59
+
60
+ # Run encoder
61
+ print("\nRunning ONNX encoder...")
62
+ enc_input_name = encoder_session.get_inputs()[0].name
63
+ enc_output_name = encoder_session.get_outputs()[0].name
64
+ attn_emb = encoder_session.run([enc_output_name], {enc_input_name: audio_batch})[0]
65
+ attn_emb_len = np.array([attn_emb.shape[1] - 1], dtype=np.int64)
66
+
67
+ print(f"Encoder output shape: {attn_emb.shape}")
68
+
69
+ # Initialize generation
70
+ generated = [tokenizer.bos_token_id if tokenizer.bos_token_id else 1]
71
+
72
+ # Autoregressive generation with ExecuTorch decoder
73
+ print(f"\nGenerating caption (max {max_length} tokens)...")
74
+ for step in range(max_length):
75
+ # Prepare inputs - FULL history (stateless decoder)
76
+ word_ids = np.array([generated], dtype=np.int64) # (1, current_length)
77
+
78
+ # Run ExecuTorch decoder
79
+ logits = decoder.forward((
80
+ torch.from_numpy(word_ids),
81
+ torch.from_numpy(attn_emb).to(torch.float32),
82
+ torch.from_numpy(attn_emb_len)
83
+ ))[0].numpy() # (1, current_length, vocab_size)
84
+
85
+ # Get next token from last position
86
+ next_token_logits = logits[0, -1, :]
87
+ next_token = int(np.argmax(next_token_logits))
88
+
89
+ generated.append(next_token)
90
+
91
+ # Stop If EOS token
92
+ if next_token == (tokenizer.eos_token_id if tokenizer.eos_token_id else 2):
93
+ break
94
+
95
+ # Decode caption
96
+ caption = tokenizer.decode(generated, skip_special_tokens=True)
97
+
98
+ print(f"\nβœ… Generated caption ({len(generated)-1} tokens): {caption}")
99
+ print(f"Token sequence: {generated}")
100
+
101
+ return caption
102
+
103
+ def main():
104
+ parser = argparse.ArgumentParser(description="Generate audio caption using ONNX encoder + ExecuTorch decoder")
105
+ parser.add_argument("--audio", default="doorbell.wav", help="Path to audio file")
106
+ parser.add_argument("--encoder", default="audio-caption/effb2_encoder_preprocess.onnx",
107
+ help="Path to ONNX encoder")
108
+ parser.add_argument("--decoder", default="audio-caption/effb2_decoder_5sec.pte",
109
+ help="Path to ExecuTorch decoder")
110
+ parser.add_argument("--max-length", type=int, default=30, help="Maximum caption length")
111
+
112
+ args = parser.parse_args()
113
+
114
+ print("="*60)
115
+ print("ONNX Encoder + ExecuTorch Decoder Caption Generation")
116
+ print("="*60)
117
+
118
+ caption = generate_caption(
119
+ audio_path=args.audio,
120
+ encoder_path=args.encoder,
121
+ decoder_path=args.decoder,
122
+ max_length=args.max_length
123
+ )
124
+
125
+ print("\n" + "="*60)
126
+ print(f"Final Caption: {caption}")
127
+ print("="*60)
128
+
129
+ if __name__ == "__main__":
130
+ main()
categories.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "categories": [
3
+ {
4
+ "id": "dog_bark",
5
+ "label": "bark of a dog",
6
+ "description": "dog barking sound, woofing, growling or howling from a canine"
7
+ },
8
+ {
9
+ "id": "doorbell",
10
+ "label": "doorbell ringing",
11
+ "description": "ding, bell or advice sound in house door entrance"
12
+ },
13
+ {
14
+ "id": "baby_crying",
15
+ "label": "baby crying",
16
+ "description": "infant crying, wailing, sobbing or distressed baby sounds"
17
+ },
18
+ {
19
+ "id": "glass_breaking",
20
+ "label": "glass breaking",
21
+ "description": "sound of glass shattering, breaking or crashing"
22
+ },
23
+ {
24
+ "id": "car_horn",
25
+ "label": "car horn",
26
+ "description": "vehicle horn honking, beeping or car alert sound"
27
+ },
28
+ {
29
+ "id": "alarm_clock",
30
+ "label": "alarm clock",
31
+ "description": "alarm clock ringing, beeping or buzzing wake-up sound"
32
+ },
33
+ {
34
+ "id": "fire_alarm",
35
+ "label": "fire alarm",
36
+ "description": "fire alarm siren, emergency alert or smoke detector beeping"
37
+ },
38
+ {
39
+ "id": "door_closing",
40
+ "label": "window or door closing",
41
+ "description": "sound of door or window shutting, closing or slamming"
42
+ },
43
+ {
44
+ "id": "door_opening",
45
+ "label": "window or door opening",
46
+ "description": "sound of door or window opening, creaking or unlocking"
47
+ },
48
+ {
49
+ "id": "stagger_swipe",
50
+ "label": "staggerer or swipe",
51
+ "description": "staggering footsteps, stumbling or swiping movement sound"
52
+ }
53
+ ]
54
+ }
pyproject.toml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "whisper-audio-captioning-pte"
3
+ version = "0.1.0"
4
+ description = "Export Whisper audio captioning model to ExecuTorch PTE format"
5
+ requires-python = ">=3.10"
6
+ dependencies = [
7
+ "torch>=2.1.0",
8
+ "transformers>=4.36.0",
9
+ "datasets>=2.14.0",
10
+ "torchaudio>=2.1.0",
11
+ "soundfile>=0.12.1",
12
+ "executorch>=0.3.0",
13
+ "onnxruntime>=1.16.0",
14
+ "librosa>=0.10.0",
15
+ "optimum[exporters]",
16
+ "onnx",
17
+ "efficientnet_pytorch",
18
+ "einops",
19
+ "onnxscript",
20
+ "python-dotenv",
21
+ "onnxruntime-extensions>=0.14.0",
22
+ ]
23
+
24
+ [tool.uv]
25
+ package = false
sentence-transformers-embbedings/category_embeddings.json ADDED
The diff for this file is too large to render. See raw diff
 
sentence-transformers-embbedings/export_sentence_transformers_executorch.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Export Sentence Transformers model to ExecuTorch .pte format.
4
+ This exports 'sentence-transformers/all-MiniLM-L6-v2' compatible with mobile deployment.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from transformers import AutoModel, AutoTokenizer
10
+ from torch.export import export
11
+ from executorch.exir import to_edge, EdgeCompileConfig
12
+ from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
13
+ from dotenv import load_dotenv
14
+
15
+ load_dotenv()
16
+
17
+ print("πŸš€ Starting Sentence Transformers ExecuTorch Export")
18
+
19
+ # 1. Load the model
20
+ model_name = "sentence-transformers/all-MiniLM-L6-v2"
21
+ print(f"πŸ“¦ Loading model: {model_name}")
22
+ hf_model = AutoModel.from_pretrained(model_name)
23
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
24
+ hf_model.eval()
25
+ print("βœ“ Model loaded")
26
+
27
+ # 2. Create a wrapper that mimics sentence-transformers embedding logic
28
+ class SentenceTransformerWrapper(torch.nn.Module):
29
+ """
30
+ Wraps the transformer model to produce sentence embeddings.
31
+ Performs mean pooling + L2 normalization, matching sentence-transformers behavior.
32
+ """
33
+ def __init__(self, model):
34
+ super().__init__()
35
+ self.model = model
36
+
37
+ def forward(self, input_ids, attention_mask):
38
+ """
39
+ Args:
40
+ input_ids: [batch, seq_len]
41
+ attention_mask: [batch, seq_len]
42
+
43
+ Returns:
44
+ embeddings: [batch, hidden_dim] - Normalized sentence embeddings
45
+ """
46
+ # Forward through transformer
47
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
48
+
49
+ # Mean pooling
50
+ token_embeddings = outputs.last_hidden_state # [batch, seq_len, hidden]
51
+
52
+ # Expand attention mask: [batch, seq_len] -> [batch, seq_len, hidden]
53
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
54
+
55
+ # Sum embeddings where mask is 1
56
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1)
57
+
58
+ # Sum mask values (clamp to avoid division by zero)
59
+ sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
60
+
61
+ # Compute mean
62
+ embeddings = sum_embeddings / sum_mask
63
+
64
+ # L2 normalization
65
+ embeddings = F.normalize(embeddings, p=2, dim=1)
66
+
67
+ return embeddings
68
+
69
+ # 3. Wrap the model
70
+ print("πŸ”§ Wrapping model...")
71
+ model_wrapper = SentenceTransformerWrapper(hf_model)
72
+
73
+ # 4. Create example inputs
74
+ example_text = "This is a test sentence for embedding generation."
75
+ inputs = tokenizer(
76
+ example_text,
77
+ max_length=128,
78
+ padding="max_length",
79
+ truncation=True,
80
+ return_tensors="pt"
81
+ )
82
+
83
+ example_args = (inputs["input_ids"], inputs["attention_mask"])
84
+ print(f"πŸ“‹ Example input shape: {inputs['input_ids'].shape}")
85
+
86
+ # 5. Test forward pass
87
+ print("πŸ§ͺ Testing forward pass...")
88
+ with torch.no_grad():
89
+ test_output = model_wrapper(*example_args)
90
+ print(f"βœ“ Output shape: {test_output.shape}")
91
+ print(f"βœ“ Output norm: {torch.norm(test_output, dim=1).item():.4f} (should be ~1.0)")
92
+
93
+ # 6. Export to ExecuTorch
94
+ print("\nπŸ“€ Exporting to ExecuTorch...")
95
+
96
+ try:
97
+ # Step 1: Capture the computational graph
98
+ print(" 1/4 Capturing graph with torch.export...")
99
+ exported_program = export(model_wrapper, example_args, strict=False)
100
+ print(" βœ“ Graph captured")
101
+
102
+ # Step 2: Lower to Edge IR
103
+ print(" 2/4 Lowering to Edge IR...")
104
+ edge_program = to_edge(
105
+ exported_program,
106
+ compile_config=EdgeCompileConfig(_check_ir_validity=False)
107
+ )
108
+ print(" βœ“ Edge IR created")
109
+
110
+ # Step 3: Partition for XNNPACK (includes quantization optimizations)
111
+ print(" 3/4 Partitioning for XNNPACK (with quantization)...")
112
+ edge_program = edge_program.to_backend(XnnpackPartitioner())
113
+ print(" βœ“ XNNPACK partitioning done")
114
+
115
+ # Step 4: Convert to ExecuTorch program
116
+ print(" 4/4 Converting to ExecuTorch program...")
117
+ executorch_program = edge_program.to_executorch()
118
+ print(" βœ“ Conversion complete")
119
+
120
+ # Save to file
121
+ output_path = "sentence_transformers_minilm.pte"
122
+ with open(output_path, "wb") as f:
123
+ executorch_program.write_to_file(f)
124
+
125
+ import os
126
+ file_size_mb = os.path.getsize(output_path) / (1024 * 1024)
127
+
128
+ print(f"\nπŸŽ‰ Export successful!")
129
+ print(f"πŸ“ Saved to: {output_path}")
130
+ print(f"πŸ“Š File size: {file_size_mb:.2f} MB")
131
+ print(f"\nπŸ’‘ Usage: Load this .pte file in your mobile app")
132
+ print(f" Input: token IDs (int64) and attention mask (int64)")
133
+ print(f" Output: normalized embeddings (float32, dim=384)")
134
+
135
+ except Exception as e:
136
+ print(f"\n❌ Export failed: {e}")
137
+ import traceback
138
+ traceback.print_exc()
sentence-transformers-embbedings/generate_category_embeddings.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Generate category embeddings using the exported sentence-transformers .pte model.
4
+ Reads categories from categories.json and outputs embeddings in the same format
5
+ as embeddings_granite_export/category_embeddings.json
6
+ """
7
+
8
+ import json
9
+ import torch
10
+ from pathlib import Path
11
+ from transformers import AutoTokenizer
12
+ from executorch.extension.pybindings.portable_lib import _load_for_executorch
13
+
14
+ print("πŸš€ Generating Category Embeddings with Sentence Transformers")
15
+
16
+ # Configuration
17
+ MODEL_PATH = "sentence-transformers-embbedings/sentence_transformers_minilm.pte"
18
+ CATEGORIES_PATH = "categories.json"
19
+ OUTPUT_PATH = "sentence-transformers-embbedings/category_embeddings.json"
20
+ MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
21
+
22
+ # 1. Load the tokenizer
23
+ print(f"πŸ“¦ Loading tokenizer: {MODEL_NAME}")
24
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
25
+
26
+ # 2. Load the .pte model
27
+ print(f"πŸ“¦ Loading .pte model: {MODEL_PATH}")
28
+ model = _load_for_executorch(MODEL_PATH)
29
+ print("βœ“ Model loaded")
30
+
31
+ # 3. Load categories
32
+ print(f"πŸ“– Loading categories from: {CATEGORIES_PATH}")
33
+ with open(CATEGORIES_PATH, 'r') as f:
34
+ categories_data = json.load(f)
35
+
36
+ categories = categories_data['categories']
37
+ print(f"βœ“ Loaded {len(categories)} categories")
38
+
39
+ # 4. Generate embeddings for each category
40
+ print("\nπŸ”§ Generating embeddings...")
41
+ embeddings_list = []
42
+ updated_categories = []
43
+
44
+ for idx, category in enumerate(categories):
45
+ # Create the text to embed (label + description, matching Granite format)
46
+ text_embedded = f"{category['label']}. {category['description']}"
47
+
48
+ # Tokenize
49
+ inputs = tokenizer(
50
+ text_embedded,
51
+ max_length=128,
52
+ padding="max_length",
53
+ truncation=True,
54
+ return_tensors="pt"
55
+ )
56
+
57
+ # Prepare inputs for ExecuTorch (as lists)
58
+ input_ids = inputs["input_ids"]
59
+ attention_mask = inputs["attention_mask"]
60
+
61
+ # Run inference
62
+ outputs = model.forward((input_ids, attention_mask))
63
+
64
+ # Extract embedding (should be [1, 384])
65
+ embedding_tensor = outputs[0]
66
+ embedding_list = embedding_tensor.squeeze(0).tolist()
67
+
68
+ embeddings_list.append(embedding_list)
69
+
70
+ # Add text_embedded field to category
71
+ category_copy = category.copy()
72
+ category_copy["text_embedded"] = text_embedded
73
+ updated_categories.append(category_copy)
74
+
75
+ print(f" βœ“ [{idx+1}/{len(categories)}] {category['id']}: {category['label']}")
76
+
77
+ # 5. Create output JSON in the same format as Granite embeddings
78
+ output_data = {
79
+ "categories": updated_categories,
80
+ "embeddings": embeddings_list,
81
+ "metadata": {
82
+ "model": "sentence-transformers/all-MiniLM-L6-v2",
83
+ "model_file": MODEL_PATH,
84
+ "embedding_dimension": len(embeddings_list[0]),
85
+ "total_categories": len(categories),
86
+ "normalization": "L2",
87
+ "pooling": "mean"
88
+ }
89
+ }
90
+
91
+ # 6. Save to file
92
+ print(f"\nπŸ’Ύ Saving embeddings to: {OUTPUT_PATH}")
93
+ with open(OUTPUT_PATH, 'w') as f:
94
+ json.dump(output_data, f, indent=2)
95
+
96
+ file_size_kb = Path(OUTPUT_PATH).stat().st_size / 1024
97
+ print(f"βœ“ Saved successfully ({file_size_kb:.2f} KB)")
98
+
99
+ print("\nπŸŽ‰ Done!")
100
+ print(f"πŸ“Š Generated {len(embeddings_list)} embeddings of dimension {len(embeddings_list[0])}")
101
+ print(f"πŸ“ Output: {OUTPUT_PATH}")
sentence-transformers-embbedings/sentence_transformers_minilm.pte ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de2fdcf7daf9b592856a5b740108258c589c7b5c26921b51abe197364dd3cabb
3
+ size 90379856
uv.lock ADDED
The diff for this file is too large to render. See raw diff