llama-nv-embed-reasoning-3b / mteb_llama_nv_embed_reasoning_3b.py
jiaruic's picture
Update model name
a7b155d verified
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0.
"""
MTEB encoder and ModelMeta for nvidia/llama-nv-embed-reasoning-3b.
"""
from mteb.models.model_meta import ModelMeta
from mteb.models.model_implementations.nvidia_models import (
LlamaEmbedNemotron,
LlamaEmbedNemotron_CITATION,
llama_embed_nemotron_evaluated_languages,
llama_embed_nemotron_training_datasets,
)
from mteb.types import PromptType
BRIGHT_TASK_INSTRUCTIONS = {
"BrightBiologyRetrieval": "Given a Biology post, retrieve relevant passages.",
"BrightEarthScienceRetrieval": "Given an Earth Science post, retrieve relevant passages.",
"BrightEconomicsRetrieval": "Given an Economics post, retrieve relevant passages.",
"BrightPsychologyRetrieval": "Given a Psychology post, retrieve relevant passages.",
"BrightRoboticsRetrieval": "Given a Robotics post, retrieve relevant passages.",
"BrightStackoverflowRetrieval": "Given a Stack Overflow post, retrieve relevant passages.",
"BrightSustainableLivingRetrieval": "Given a Sustainable Living post, retrieve relevant passages.",
"BrightLeetcodeRetrieval": "Given a Coding problem, retrieve relevant passages.",
"BrightPonyRetrieval": "Given a Pony question, retrieve relevant passages.",
"BrightAopsRetrieval": "Given a Math problem, retrieve relevant passages.",
"BrightTheoremQAQuestionsRetrieval": "Given a Math problem, retrieve relevant passages.",
"BrightTheoremQATheoremsRetrieval": "Given a Math problem, retrieve relevant passages.",
}
BRIGHT_PASSAGE_PREFIX = "passage: "
class LlamaNvEmbedReasoning(LlamaEmbedNemotron):
"""LlamaNvEmbedReasoning for reasoning with BRIGHT benchmark prompts."""
def __init__(self, model_name: str, revision: str, device: str | None = None, **kwargs) -> None:
super().__init__(model_name, revision=revision, device=device)
self.max_seq_length = kwargs.get("max_seq_length", 8192)
def _get_base_instruction(self, task_metadata, prompt_type: PromptType | None) -> str:
task_name = task_metadata.name
if task_name in BRIGHT_TASK_INSTRUCTIONS:
if prompt_type == PromptType.document:
return ""
return BRIGHT_TASK_INSTRUCTIONS[task_name]
return super()._get_base_instruction(task_metadata, prompt_type)
def encode(
self,
inputs,
*,
task_metadata,
hf_split: str = "",
hf_subset: str = "",
prompt_type: PromptType | None = None,
**kwargs,
):
task_name = task_metadata.name
if task_name in BRIGHT_TASK_INSTRUCTIONS and prompt_type == PromptType.document:
prefix = BRIGHT_PASSAGE_PREFIX
else:
instruction = self._get_task_specific_instruction(task_metadata, prompt_type)
prefix = self.format_instruction(instruction, prompt_type)
return self._extract_embeddings(inputs, instruction=prefix, **kwargs)
LLAMA_NV_EMBED_REASONING_3B_META = ModelMeta(
loader=LlamaNvEmbedReasoning,
loader_kwargs=dict(max_seq_length=8192),
name="nvidia/llama-nv-embed-reasoning-3b",
model_type=["dense"],
languages=llama_embed_nemotron_evaluated_languages,
open_weights=True,
revision="main",
release_date="2026-02-23",
n_parameters=3_212_749_824,
memory_usage_mb=6000,
embed_dim=3072,
license="https://huggingface.co/nvidia/llama-nv-embed-reasoning-3b/blob/main/LICENSE",
max_tokens=8192,
reference="https://huggingface.co/nvidia/llama-nv-embed-reasoning-3b",
similarity_fn_name="cosine",
framework=["PyTorch", "Transformers"],
use_instructions=True,
training_datasets=llama_embed_nemotron_training_datasets,
public_training_code=None,
public_training_data=None,
citation=LlamaEmbedNemotron_CITATION,
)