--- base_model: nomic-ai/CodeRankEmbed base_model_relation: quantized library_name: sentence-transformers license: mit tags: - flash-attention - code-retrieval - sentence-transformers - nomic-bert - bf16 language: - en --- # CodeRankEmbed-flash-attn A **bf16 quantization of [`nomic-ai/CodeRankEmbed`](https://huggingface.co/nomic-ai/CodeRankEmbed)** with **flash-attention built into a custom `modeling_hf_nomic_bert.py` shipped in this repo.** It is not a finetune — the weights are the original CodeRankEmbed weights cast to bf16 (no further training) — and attention runs via `flash_attn` varlen instead of the original eager `O(seq²)` path. ## Why `nomic-ai/CodeRankEmbed` loads through `trust_remote_code`, and its attention path is **eager only** — activation memory grows as `batch × heads × seq²`, which OOMs at large batches even though the model is only 137M params. `flash_attn`'s varlen path computes the same attention in `O(N)` memory by packing unpadded sequences, so the large batches that OOM the eager path run comfortably — with **parity embeddings** (no quality change). This repo ships that flash path in the modeling file itself, so no runtime patching or post-load hooks are needed. ## Behavior - **Loads bf16 by default.** `flash_attn` requires half precision and the model runs bf16 in any real serving setup, so the weights are stored bf16 and `config.json` declares `torch_dtype: bfloat16`. The upstream custom `from_pretrained` silently dropped `torch_dtype` and always loaded fp32; the copy in this repo honors it, so the model loads bf16 natively, like any normal HF model. Pass `torch_dtype=torch.float32` to load fp32 (note: the stored weights are bf16-precision, so this only widens the dtype, not the precision). - **CUDA + `flash_attn`** → flash-varlen path (fast, low VRAM). Attention tensors are cast to bf16 internally as a safety net; with the default bf16 load this is a no-op. - **CPU, or no `flash_attn`** → the original eager attention algorithm runs unchanged. Because the model is bf16, eager runs in bf16 here — numerically equivalent to the flash path, just without its memory and throughput wins. The model loads and encodes on any host; the forward selects the path automatically (`_FLASH_AVAILABLE and hidden_states.is_cuda`). ## Usage Identical to the original. The query prompt **must** include the task-instruction prefix `"Represent this query for searching relevant code: "`; documents need no prefix. ```python from sentence_transformers import SentenceTransformer model = SentenceTransformer("handwoven8588/CodeRankEmbed-flash-attn", trust_remote_code=True) queries = ["Represent this query for searching relevant code: Calculate the n-th factorial"] codes = ["def fact(n):\n if n < 0:\n raise ValueError\n return 1 if n == 0 else n * fact(n - 1)"] q = model.encode(queries, normalize_embeddings=True) d = model.encode(codes, normalize_embeddings=True) ``` ## Parity & performance The weights are the original CodeRankEmbed weights (bf16-cast), so embeddings match the fp32 original to within bf16 precision. Measured on an RTX 3090 Ti, `flash_attn` 2.8.3, CLS pooling, L2-normalized, batch size 64: | metric | `nomic-ai/CodeRankEmbed` (fp32, eager) | this repo (bf16, flash-varlen) | | --- | --- | --- | | cosine vs fp32 reference | 1.000000 | **0.9986** | | peak VRAM (bs=64) | 6.7 GB | **2.1 GB** (≈3.3× less) | | throughput (bs=64) | 52,000 tok/s | **162,000 tok/s** (≈3.1× faster) | Same 512-snippet corpus (~236k tokens, 20–900 tokens each), batch size 64. The flash-varlen path also scales to far larger batches (bs=256 fits in ~7.6 GB) where eager O(seq²) blows up. Parity (>0.997) reproduced on both an RTX 4060 and the 3090 Ti. ## What changed vs the source repo 1. **Weights**: fp32 → bf16. `flash_attn` only accepts half precision and the model runs bf16 in any real serving configuration, so the weights are stored bf16 and (via the load fix below) arrive bf16 — which is simply how this model is used, and removes the need for a post-load dtype cast. Parity-neutral (cosine 0.9986 vs the fp32 original); the smaller download is incidental, not the reason. 2. **`from_pretrained` dtype fix**: the upstream custom `from_pretrained` instantiated the model fp32 and `load_state_dict`-ed the checkpoint into fp32 params, **ignoring `torch_dtype`**. The copy here adds the standard transformers dtype resolution (explicit arg → `config.torch_dtype` → checkpoint dtype) so the model loads in its declared dtype. 3. **Flash-varlen forward**: `NomicBertAttention.forward` gains an `unpad → flash_attn_varlen_qkvpacked_func(causal=False) → repad` branch (CUDA + flash_attn); the original eager block is kept as the fallback. `NomicBertModel.forward` skips `get_extended_attention_mask` on the flash path. Rotary embeddings are applied to the dense `[B, S, 3, H, D]` tensor **before** unpadding — the correctness keystone. ## License & attribution MIT — same license as `nomic-ai/CodeRankEmbed` (see `NOTICE`). The weights, tokenizer, and the bulk of the modeling file are a verbatim derivative of `nomic-ai/CodeRankEmbed`; the modeling file derives from Tri Dao's BERT implementation, and `CodeRankEmbed` was trained by the CoRNStack team (Suresh et al., 2025). Cite their work: ```bibtex @misc{suresh2025cornstackhighqualitycontrastivedata, title = {CoRNStack: High-Quality Contrastive Data for Text and Code Retrieval}, author = {Suresh, K N Q and Wang, Xiang and Khan, Saqib and others}, year = {2025}, } ```