| |
| |
| """ |
| MTEB encoder and ModelMeta for nvidia/llama-nv-embed-reasoning-3b. |
| """ |
|
|
| from mteb.models.model_meta import ModelMeta |
| from mteb.models.model_implementations.nvidia_models import ( |
| LlamaEmbedNemotron, |
| LlamaEmbedNemotron_CITATION, |
| llama_embed_nemotron_evaluated_languages, |
| llama_embed_nemotron_training_datasets, |
| ) |
| from mteb.types import PromptType |
|
|
| BRIGHT_TASK_INSTRUCTIONS = { |
| "BrightBiologyRetrieval": "Given a Biology post, retrieve relevant passages.", |
| "BrightEarthScienceRetrieval": "Given an Earth Science post, retrieve relevant passages.", |
| "BrightEconomicsRetrieval": "Given an Economics post, retrieve relevant passages.", |
| "BrightPsychologyRetrieval": "Given a Psychology post, retrieve relevant passages.", |
| "BrightRoboticsRetrieval": "Given a Robotics post, retrieve relevant passages.", |
| "BrightStackoverflowRetrieval": "Given a Stack Overflow post, retrieve relevant passages.", |
| "BrightSustainableLivingRetrieval": "Given a Sustainable Living post, retrieve relevant passages.", |
| "BrightLeetcodeRetrieval": "Given a Coding problem, retrieve relevant passages.", |
| "BrightPonyRetrieval": "Given a Pony question, retrieve relevant passages.", |
| "BrightAopsRetrieval": "Given a Math problem, retrieve relevant passages.", |
| "BrightTheoremQAQuestionsRetrieval": "Given a Math problem, retrieve relevant passages.", |
| "BrightTheoremQATheoremsRetrieval": "Given a Math problem, retrieve relevant passages.", |
| } |
|
|
| BRIGHT_PASSAGE_PREFIX = "passage: " |
|
|
| class LlamaNvEmbedReasoning(LlamaEmbedNemotron): |
| """LlamaNvEmbedReasoning for reasoning with BRIGHT benchmark prompts.""" |
|
|
| def __init__(self, model_name: str, revision: str, device: str | None = None, **kwargs) -> None: |
| super().__init__(model_name, revision=revision, device=device) |
| self.max_seq_length = kwargs.get("max_seq_length", 8192) |
|
|
| def _get_base_instruction(self, task_metadata, prompt_type: PromptType | None) -> str: |
| task_name = task_metadata.name |
| if task_name in BRIGHT_TASK_INSTRUCTIONS: |
| if prompt_type == PromptType.document: |
| return "" |
| return BRIGHT_TASK_INSTRUCTIONS[task_name] |
| return super()._get_base_instruction(task_metadata, prompt_type) |
|
|
| def encode( |
| self, |
| inputs, |
| *, |
| task_metadata, |
| hf_split: str = "", |
| hf_subset: str = "", |
| prompt_type: PromptType | None = None, |
| **kwargs, |
| ): |
| task_name = task_metadata.name |
| if task_name in BRIGHT_TASK_INSTRUCTIONS and prompt_type == PromptType.document: |
| prefix = BRIGHT_PASSAGE_PREFIX |
| else: |
| instruction = self._get_task_specific_instruction(task_metadata, prompt_type) |
| prefix = self.format_instruction(instruction, prompt_type) |
| return self._extract_embeddings(inputs, instruction=prefix, **kwargs) |
|
|
| LLAMA_NV_EMBED_REASONING_3B_META = ModelMeta( |
| loader=LlamaNvEmbedReasoning, |
| loader_kwargs=dict(max_seq_length=8192), |
| name="nvidia/llama-nv-embed-reasoning-3b", |
| model_type=["dense"], |
| languages=llama_embed_nemotron_evaluated_languages, |
| open_weights=True, |
| revision="main", |
| release_date="2026-02-23", |
| n_parameters=3_212_749_824, |
| memory_usage_mb=6000, |
| embed_dim=3072, |
| license="https://huggingface.co/nvidia/llama-nv-embed-reasoning-3b/blob/main/LICENSE", |
| max_tokens=8192, |
| reference="https://huggingface.co/nvidia/llama-nv-embed-reasoning-3b", |
| similarity_fn_name="cosine", |
| framework=["PyTorch", "Transformers"], |
| use_instructions=True, |
| training_datasets=llama_embed_nemotron_training_datasets, |
| public_training_code=None, |
| public_training_data=None, |
| citation=LlamaEmbedNemotron_CITATION, |
| ) |
|
|