metadata
language:
- en
license: apache-2.0
tags:
- sentence-transformers
- sentence-similarity
- feature-extraction
- dense
- generated_from_trainer
- dataset_size:900
- loss:MatryoshkaLoss
- loss:MultipleNegativesRankingLoss
base_model: microsoft/codebert-base
widget:
- source_sentence: Best practices for _invocation_params
sentences:
- >-
def after_model(self, state: StateT, runtime: Runtime[ContextT]) ->
dict[str, Any] | None:
"""Logic to run after the model is called.
Args:
state: The current agent state.
runtime: The runtime context.
Returns:
Agent state updates to apply after model call.
"""
- |-
def _get_trace_callbacks(
project_name: str | None = None,
example_id: str | UUID | None = None,
callback_manager: CallbackManager | AsyncCallbackManager | None = None,
) -> Callbacks:
if _tracing_v2_is_enabled():
project_name_ = project_name or _get_tracer_project()
tracer = tracing_v2_callback_var.get() or LangChainTracer(
project_name=project_name_,
example_id=example_id,
)
if callback_manager is None:
cb = cast("Callbacks", [tracer])
else:
if not any(
isinstance(handler, LangChainTracer)
for handler in callback_manager.handlers
):
callback_manager.add_handler(tracer)
# If it already has a LangChainTracer, we don't need to add another one.
# this would likely mess up the trace hierarchy.
cb = callback_manager
else:
cb = None
return cb
- |-
def _invocation_params(self) -> dict[str, Any]:
params: dict = {"model": self.model, **self.model_kwargs}
if self.dimensions is not None:
params["dimensions"] = self.dimensions
return params
- source_sentence: How does _approximate_token_counter work in Python?
sentences:
- |-
def _approximate_token_counter(messages: Sequence[BaseMessage]) -> int:
"""Wrapper for `count_tokens_approximately` that matches expected signature."""
return count_tokens_approximately(messages)
- |-
def remove_request_headers(request: Any) -> Any:
for k in request.headers:
request.headers[k] = "**REDACTED**"
return request
- |-
def get_format_instructions(self) -> str:
"""Returns formatting instructions for the given output parser."""
return self.format_instructions
- source_sentence: How to implement _create_thread_and_run?
sentences:
- |-
async def on_retriever_end(
self, documents: Sequence[Document], **kwargs: Any
) -> None:
"""Run when the retriever ends running.
Args:
documents: The retrieved documents.
**kwargs: Additional keyword arguments.
"""
if not self.handlers:
return
await ahandle_event(
self.handlers,
"on_retriever_end",
"ignore_retriever",
documents,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
- |-
def _create_thread_and_run(self, input_dict: dict, thread: dict) -> Any:
params = {
k: v
for k, v in input_dict.items()
if k
in (
"instructions",
"model",
"tools",
"parallel_tool_calls",
"top_p",
"temperature",
"max_completion_tokens",
"max_prompt_tokens",
"run_metadata",
)
}
return self.client.beta.threads.create_and_run(
assistant_id=self.assistant_id,
thread=thread,
**params,
)
- |-
def test_pandas_output_parser_col_no_array() -> None:
with pytest.raises(OutputParserException):
parser.parse("column:num_legs")
- source_sentence: Explain the get_token_ids logic
sentences:
- |-
def _runnable(inputs: dict[str, Any]) -> str:
if inputs["text"] == "foo":
return "first"
if "exception" not in inputs:
msg = "missing exception"
raise ValueError(msg)
if inputs["text"] == "bar":
return "second"
if isinstance(inputs["exception"], ValueError):
raise RuntimeError # noqa: TRY004
return "third"
- |-
def validate_params(cls, values: dict) -> dict:
"""Validate similarity parameters."""
if values["k"] is None and values["similarity_threshold"] is None:
msg = "Must specify one of `k` or `similarity_threshold`."
raise ValueError(msg)
return values
- |-
def get_token_ids(self, text: str) -> list[int]:
"""Return the ordered IDs of the tokens in a text.
Args:
text: The string input to tokenize.
Returns:
A list of IDs corresponding to the tokens in the text, in order they occur
in the text.
"""
if self.custom_get_token_ids is not None:
return self.custom_get_token_ids(text)
return _get_token_ids_default_method(text)
- source_sentence: How does __init__ work in Python?
sentences:
- |-
def test_loading_few_shot_prompt_from_json() -> None:
"""Test loading few shot prompt from json."""
with change_directory(EXAMPLE_DIR):
prompt = load_prompt("few_shot_prompt.json")
expected_prompt = FewShotPromptTemplate(
input_variables=["adjective"],
prefix="Write antonyms for the following words.",
example_prompt=PromptTemplate(
input_variables=["input", "output"],
template="Input: {input}\nOutput: {output}",
),
examples=[
{"input": "happy", "output": "sad"},
{"input": "tall", "output": "short"},
],
suffix="Input: {adjective}\nOutput:",
)
assert prompt == expected_prompt
- |-
def __init__(
self,
encoding_name: str = "gpt2",
model_name: str | None = None,
allowed_special: Literal["all"] | AbstractSet[str] = set(),
disallowed_special: Literal["all"] | Collection[str] = "all",
**kwargs: Any,
) -> None:
"""Create a new `TextSplitter`.
Args:
encoding_name: The name of the tiktoken encoding to use.
model_name: The name of the model to use. If provided, this will
override the `encoding_name`.
allowed_special: Special tokens that are allowed during encoding.
disallowed_special: Special tokens that are disallowed during encoding.
Raises:
ImportError: If the tiktoken package is not installed.
"""
super().__init__(**kwargs)
if not _HAS_TIKTOKEN:
msg = (
"Could not import tiktoken python package. "
"This is needed in order to for TokenTextSplitter. "
"Please install it with `pip install tiktoken`."
)
raise ImportError(msg)
if model_name is not None:
enc = tiktoken.encoding_for_model(model_name)
else:
enc = tiktoken.get_encoding(encoding_name)
self._tokenizer = enc
self._allowed_special = allowed_special
self._disallowed_special = disallowed_special
- |-
def test_fixed_message_response_when_docs_found() -> None:
fixed_resp = "I don't know"
answer = "I know the answer!"
llm = FakeListLLM(responses=[answer])
retriever = SequentialRetriever(
sequential_responses=[[Document(page_content=answer)]],
)
memory = ConversationBufferMemory(
k=1,
output_key="answer",
memory_key="chat_history",
return_messages=True,
)
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
memory=memory,
retriever=retriever,
return_source_documents=True,
rephrase_question=False,
response_if_no_docs_found=fixed_resp,
verbose=True,
)
got = qa_chain("What is the answer?")
assert got["chat_history"][1].content == answer
assert got["answer"] == answer
pipeline_tag: sentence-similarity
library_name: sentence-transformers
metrics:
- cosine_accuracy@1
- cosine_accuracy@3
- cosine_accuracy@5
- cosine_accuracy@10
- cosine_precision@1
- cosine_precision@3
- cosine_precision@5
- cosine_precision@10
- cosine_recall@1
- cosine_recall@3
- cosine_recall@5
- cosine_recall@10
- cosine_ndcg@10
- cosine_mrr@10
- cosine_map@100
model-index:
- name: codeBert Base
results:
- task:
type: information-retrieval
name: Information Retrieval
dataset:
name: dim 768
type: dim_768
metrics:
- type: cosine_accuracy@1
value: 0.83
name: Cosine Accuracy@1
- type: cosine_accuracy@3
value: 0.85
name: Cosine Accuracy@3
- type: cosine_accuracy@5
value: 0.86
name: Cosine Accuracy@5
- type: cosine_accuracy@10
value: 0.94
name: Cosine Accuracy@10
- type: cosine_precision@1
value: 0.83
name: Cosine Precision@1
- type: cosine_precision@3
value: 0.83
name: Cosine Precision@3
- type: cosine_precision@5
value: 0.83
name: Cosine Precision@5
- type: cosine_precision@10
value: 0.45299999999999996
name: Cosine Precision@10
- type: cosine_recall@1
value: 0.16599999999999998
name: Cosine Recall@1
- type: cosine_recall@3
value: 0.498
name: Cosine Recall@3
- type: cosine_recall@5
value: 0.83
name: Cosine Recall@5
- type: cosine_recall@10
value: 0.9059999999999999
name: Cosine Recall@10
- type: cosine_ndcg@10
value: 0.8712089918828809
name: Cosine Ndcg@10
- type: cosine_mrr@10
value: 0.8532738095238095
name: Cosine Mrr@10
- type: cosine_map@100
value: 0.861635686929646
name: Cosine Map@100
- task:
type: information-retrieval
name: Information Retrieval
dataset:
name: dim 512
type: dim_512
metrics:
- type: cosine_accuracy@1
value: 0.85
name: Cosine Accuracy@1
- type: cosine_accuracy@3
value: 0.86
name: Cosine Accuracy@3
- type: cosine_accuracy@5
value: 0.87
name: Cosine Accuracy@5
- type: cosine_accuracy@10
value: 0.95
name: Cosine Accuracy@10
- type: cosine_precision@1
value: 0.85
name: Cosine Precision@1
- type: cosine_precision@3
value: 0.84
name: Cosine Precision@3
- type: cosine_precision@5
value: 0.8419999999999999
name: Cosine Precision@5
- type: cosine_precision@10
value: 0.45299999999999996
name: Cosine Precision@10
- type: cosine_recall@1
value: 0.16999999999999996
name: Cosine Recall@1
- type: cosine_recall@3
value: 0.504
name: Cosine Recall@3
- type: cosine_recall@5
value: 0.8419999999999999
name: Cosine Recall@5
- type: cosine_recall@10
value: 0.9059999999999999
name: Cosine Recall@10
- type: cosine_ndcg@10
value: 0.8775797199885595
name: Cosine Ndcg@10
- type: cosine_mrr@10
value: 0.8699404761904762
name: Cosine Mrr@10
- type: cosine_map@100
value: 0.8692738075020783
name: Cosine Map@100
- task:
type: information-retrieval
name: Information Retrieval
dataset:
name: dim 256
type: dim_256
metrics:
- type: cosine_accuracy@1
value: 0.86
name: Cosine Accuracy@1
- type: cosine_accuracy@3
value: 0.89
name: Cosine Accuracy@3
- type: cosine_accuracy@5
value: 0.9
name: Cosine Accuracy@5
- type: cosine_accuracy@10
value: 0.93
name: Cosine Accuracy@10
- type: cosine_precision@1
value: 0.86
name: Cosine Precision@1
- type: cosine_precision@3
value: 0.85
name: Cosine Precision@3
- type: cosine_precision@5
value: 0.85
name: Cosine Precision@5
- type: cosine_precision@10
value: 0.45
name: Cosine Precision@10
- type: cosine_recall@1
value: 0.17199999999999996
name: Cosine Recall@1
- type: cosine_recall@3
value: 0.51
name: Cosine Recall@3
- type: cosine_recall@5
value: 0.85
name: Cosine Recall@5
- type: cosine_recall@10
value: 0.9
name: Cosine Recall@10
- type: cosine_ndcg@10
value: 0.8789938349894767
name: Cosine Ndcg@10
- type: cosine_mrr@10
value: 0.8805952380952381
name: Cosine Mrr@10
- type: cosine_map@100
value: 0.8726611807317667
name: Cosine Map@100
- task:
type: information-retrieval
name: Information Retrieval
dataset:
name: dim 128
type: dim_128
metrics:
- type: cosine_accuracy@1
value: 0.84
name: Cosine Accuracy@1
- type: cosine_accuracy@3
value: 0.87
name: Cosine Accuracy@3
- type: cosine_accuracy@5
value: 0.88
name: Cosine Accuracy@5
- type: cosine_accuracy@10
value: 0.93
name: Cosine Accuracy@10
- type: cosine_precision@1
value: 0.84
name: Cosine Precision@1
- type: cosine_precision@3
value: 0.8366666666666667
name: Cosine Precision@3
- type: cosine_precision@5
value: 0.8419999999999999
name: Cosine Precision@5
- type: cosine_precision@10
value: 0.455
name: Cosine Precision@10
- type: cosine_recall@1
value: 0.16799999999999998
name: Cosine Recall@1
- type: cosine_recall@3
value: 0.502
name: Cosine Recall@3
- type: cosine_recall@5
value: 0.8419999999999999
name: Cosine Recall@5
- type: cosine_recall@10
value: 0.91
name: Cosine Recall@10
- type: cosine_ndcg@10
value: 0.8777095006932575
name: Cosine Ndcg@10
- type: cosine_mrr@10
value: 0.8630000000000001
name: Cosine Mrr@10
- type: cosine_map@100
value: 0.8661619081282643
name: Cosine Map@100
- task:
type: information-retrieval
name: Information Retrieval
dataset:
name: dim 64
type: dim_64
metrics:
- type: cosine_accuracy@1
value: 0.78
name: Cosine Accuracy@1
- type: cosine_accuracy@3
value: 0.81
name: Cosine Accuracy@3
- type: cosine_accuracy@5
value: 0.81
name: Cosine Accuracy@5
- type: cosine_accuracy@10
value: 0.93
name: Cosine Accuracy@10
- type: cosine_precision@1
value: 0.78
name: Cosine Precision@1
- type: cosine_precision@3
value: 0.7866666666666667
name: Cosine Precision@3
- type: cosine_precision@5
value: 0.7859999999999999
name: Cosine Precision@5
- type: cosine_precision@10
value: 0.44799999999999995
name: Cosine Precision@10
- type: cosine_recall@1
value: 0.15599999999999997
name: Cosine Recall@1
- type: cosine_recall@3
value: 0.472
name: Cosine Recall@3
- type: cosine_recall@5
value: 0.7859999999999999
name: Cosine Recall@5
- type: cosine_recall@10
value: 0.8959999999999999
name: Cosine Recall@10
- type: cosine_ndcg@10
value: 0.8445404597381452
name: Cosine Ndcg@10
- type: cosine_mrr@10
value: 0.8120634920634922
name: Cosine Mrr@10
- type: cosine_map@100
value: 0.8308457034802883
name: Cosine Map@100
codeBert Base
This is a sentence-transformers model finetuned from microsoft/codebert-base. It maps sentences & paragraphs to a 768-dimensional dense vector space and can be used for semantic textual similarity, semantic search, paraphrase mining, text classification, clustering, and more.
Model Details
Model Description
- Model Type: Sentence Transformer
- Base model: microsoft/codebert-base
- Maximum Sequence Length: 512 tokens
- Output Dimensionality: 768 dimensions
- Similarity Function: Cosine Similarity
- Language: en
- License: apache-2.0
Model Sources
- Documentation: Sentence Transformers Documentation
- Repository: Sentence Transformers on GitHub
- Hugging Face: Sentence Transformers on Hugging Face
Full Model Architecture
SentenceTransformer(
(0): Transformer({'max_seq_length': 512, 'do_lower_case': False, 'architecture': 'RobertaModel'})
(1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
)
Usage
Direct Usage (Sentence Transformers)
First install the Sentence Transformers library:
pip install -U sentence-transformers
Then you can load this model and run inference.
from sentence_transformers import SentenceTransformer
# Download from the 🤗 Hub
model = SentenceTransformer("killdollar/codebert-embed-base-dense-retriever")
# Run inference
sentences = [
'How does __init__ work in Python?',
'def __init__(\n self,\n encoding_name: str = "gpt2",\n model_name: str | None = None,\n allowed_special: Literal["all"] | AbstractSet[str] = set(),\n disallowed_special: Literal["all"] | Collection[str] = "all",\n **kwargs: Any,\n ) -> None:\n """Create a new `TextSplitter`.\n\n Args:\n encoding_name: The name of the tiktoken encoding to use.\n model_name: The name of the model to use. If provided, this will\n override the `encoding_name`.\n allowed_special: Special tokens that are allowed during encoding.\n disallowed_special: Special tokens that are disallowed during encoding.\n\n Raises:\n ImportError: If the tiktoken package is not installed.\n """\n super().__init__(**kwargs)\n if not _HAS_TIKTOKEN:\n msg = (\n "Could not import tiktoken python package. "\n "This is needed in order to for TokenTextSplitter. "\n "Please install it with `pip install tiktoken`."\n )\n raise ImportError(msg)\n\n if model_name is not None:\n enc = tiktoken.encoding_for_model(model_name)\n else:\n enc = tiktoken.get_encoding(encoding_name)\n self._tokenizer = enc\n self._allowed_special = allowed_special\n self._disallowed_special = disallowed_special',
'def test_fixed_message_response_when_docs_found() -> None:\n fixed_resp = "I don\'t know"\n answer = "I know the answer!"\n llm = FakeListLLM(responses=[answer])\n retriever = SequentialRetriever(\n sequential_responses=[[Document(page_content=answer)]],\n )\n memory = ConversationBufferMemory(\n k=1,\n output_key="answer",\n memory_key="chat_history",\n return_messages=True,\n )\n qa_chain = ConversationalRetrievalChain.from_llm(\n llm=llm,\n memory=memory,\n retriever=retriever,\n return_source_documents=True,\n rephrase_question=False,\n response_if_no_docs_found=fixed_resp,\n verbose=True,\n )\n got = qa_chain("What is the answer?")\n assert got["chat_history"][1].content == answer\n assert got["answer"] == answer',
]
embeddings = model.encode(sentences)
print(embeddings.shape)
# [3, 768]
# Get the similarity scores for the embeddings
similarities = model.similarity(embeddings, embeddings)
print(similarities)
# tensor([[1.0000, 0.7336, 0.0979],
# [0.7336, 1.0000, 0.1742],
# [0.0979, 0.1742, 1.0000]])
Evaluation
Metrics
Information Retrieval
- Dataset:
dim_768 - Evaluated with
InformationRetrievalEvaluatorwith these parameters:{ "truncate_dim": 768 }
| Metric | Value |
|---|---|
| cosine_accuracy@1 | 0.83 |
| cosine_accuracy@3 | 0.85 |
| cosine_accuracy@5 | 0.86 |
| cosine_accuracy@10 | 0.94 |
| cosine_precision@1 | 0.83 |
| cosine_precision@3 | 0.83 |
| cosine_precision@5 | 0.83 |
| cosine_precision@10 | 0.453 |
| cosine_recall@1 | 0.166 |
| cosine_recall@3 | 0.498 |
| cosine_recall@5 | 0.83 |
| cosine_recall@10 | 0.906 |
| cosine_ndcg@10 | 0.8712 |
| cosine_mrr@10 | 0.8533 |
| cosine_map@100 | 0.8616 |
Information Retrieval
- Dataset:
dim_512 - Evaluated with
InformationRetrievalEvaluatorwith these parameters:{ "truncate_dim": 512 }
| Metric | Value |
|---|---|
| cosine_accuracy@1 | 0.85 |
| cosine_accuracy@3 | 0.86 |
| cosine_accuracy@5 | 0.87 |
| cosine_accuracy@10 | 0.95 |
| cosine_precision@1 | 0.85 |
| cosine_precision@3 | 0.84 |
| cosine_precision@5 | 0.842 |
| cosine_precision@10 | 0.453 |
| cosine_recall@1 | 0.17 |
| cosine_recall@3 | 0.504 |
| cosine_recall@5 | 0.842 |
| cosine_recall@10 | 0.906 |
| cosine_ndcg@10 | 0.8776 |
| cosine_mrr@10 | 0.8699 |
| cosine_map@100 | 0.8693 |
Information Retrieval
- Dataset:
dim_256 - Evaluated with
InformationRetrievalEvaluatorwith these parameters:{ "truncate_dim": 256 }
| Metric | Value |
|---|---|
| cosine_accuracy@1 | 0.86 |
| cosine_accuracy@3 | 0.89 |
| cosine_accuracy@5 | 0.9 |
| cosine_accuracy@10 | 0.93 |
| cosine_precision@1 | 0.86 |
| cosine_precision@3 | 0.85 |
| cosine_precision@5 | 0.85 |
| cosine_precision@10 | 0.45 |
| cosine_recall@1 | 0.172 |
| cosine_recall@3 | 0.51 |
| cosine_recall@5 | 0.85 |
| cosine_recall@10 | 0.9 |
| cosine_ndcg@10 | 0.879 |
| cosine_mrr@10 | 0.8806 |
| cosine_map@100 | 0.8727 |
Information Retrieval
- Dataset:
dim_128 - Evaluated with
InformationRetrievalEvaluatorwith these parameters:{ "truncate_dim": 128 }
| Metric | Value |
|---|---|
| cosine_accuracy@1 | 0.84 |
| cosine_accuracy@3 | 0.87 |
| cosine_accuracy@5 | 0.88 |
| cosine_accuracy@10 | 0.93 |
| cosine_precision@1 | 0.84 |
| cosine_precision@3 | 0.8367 |
| cosine_precision@5 | 0.842 |
| cosine_precision@10 | 0.455 |
| cosine_recall@1 | 0.168 |
| cosine_recall@3 | 0.502 |
| cosine_recall@5 | 0.842 |
| cosine_recall@10 | 0.91 |
| cosine_ndcg@10 | 0.8777 |
| cosine_mrr@10 | 0.863 |
| cosine_map@100 | 0.8662 |
Information Retrieval
- Dataset:
dim_64 - Evaluated with
InformationRetrievalEvaluatorwith these parameters:{ "truncate_dim": 64 }
| Metric | Value |
|---|---|
| cosine_accuracy@1 | 0.78 |
| cosine_accuracy@3 | 0.81 |
| cosine_accuracy@5 | 0.81 |
| cosine_accuracy@10 | 0.93 |
| cosine_precision@1 | 0.78 |
| cosine_precision@3 | 0.7867 |
| cosine_precision@5 | 0.786 |
| cosine_precision@10 | 0.448 |
| cosine_recall@1 | 0.156 |
| cosine_recall@3 | 0.472 |
| cosine_recall@5 | 0.786 |
| cosine_recall@10 | 0.896 |
| cosine_ndcg@10 | 0.8445 |
| cosine_mrr@10 | 0.8121 |
| cosine_map@100 | 0.8308 |
Training Details
Training Dataset
Unnamed Dataset
- Size: 900 training samples
- Columns:
anchorandpositive - Approximate statistics based on the first 900 samples:
anchor positive type string string details - min: 6 tokens
- mean: 13.15 tokens
- max: 42 tokens
- min: 25 tokens
- mean: 239.87 tokens
- max: 512 tokens
- Samples:
anchor positive Explain the test_qdrant_similarity_search_with_relevance_scores logicdef test_qdrant_similarity_search_with_relevance_scores(
batch_size: int,
content_payload_key: str,
metadata_payload_key: str,
vector_name: str | None,
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
docsearch = Qdrant.from_texts(
texts,
ConsistentFakeEmbeddings(),
location=":memory:",
content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key,
batch_size=batch_size,
vector_name=vector_name,
)
output = docsearch.similarity_search_with_relevance_scores("foo", k=3)
assert all(
(score <= 1 or np.isclose(score, 1)) and score >= 0 for _, score in output
)How to implement LangChainPendingDeprecationWarning?class LangChainPendingDeprecationWarning(PendingDeprecationWarning):
"""A class for issuing deprecation warnings for LangChain users."""Example usage of random_namedef random_name() -> str:
"""Generate a random name."""
adjective = random.choice(adjectives) # noqa: S311
noun = random.choice(nouns) # noqa: S311
number = random.randint(1, 100) # noqa: S311
return f"{adjective}-{noun}-{number}" - Loss:
MatryoshkaLosswith these parameters:{ "loss": "MultipleNegativesRankingLoss", "matryoshka_dims": [ 768, 512, 256, 128, 64 ], "matryoshka_weights": [ 1, 1, 1, 1, 1 ], "n_dims_per_step": -1 }
Training Hyperparameters
Non-Default Hyperparameters
eval_strategy: epochper_device_train_batch_size: 4per_device_eval_batch_size: 4gradient_accumulation_steps: 16learning_rate: 2e-05num_train_epochs: 4lr_scheduler_type: cosinewarmup_ratio: 0.1fp16: Trueload_best_model_at_end: Trueoptim: adamw_torchbatch_sampler: no_duplicates
All Hyperparameters
Click to expand
overwrite_output_dir: Falsedo_predict: Falseeval_strategy: epochprediction_loss_only: Trueper_device_train_batch_size: 4per_device_eval_batch_size: 4per_gpu_train_batch_size: Noneper_gpu_eval_batch_size: Nonegradient_accumulation_steps: 16eval_accumulation_steps: Nonetorch_empty_cache_steps: Nonelearning_rate: 2e-05weight_decay: 0.0adam_beta1: 0.9adam_beta2: 0.999adam_epsilon: 1e-08max_grad_norm: 1.0num_train_epochs: 4max_steps: -1lr_scheduler_type: cosinelr_scheduler_kwargs: {}warmup_ratio: 0.1warmup_steps: 0log_level: passivelog_level_replica: warninglog_on_each_node: Truelogging_nan_inf_filter: Truesave_safetensors: Truesave_on_each_node: Falsesave_only_model: Falserestore_callback_states_from_checkpoint: Falseno_cuda: Falseuse_cpu: Falseuse_mps_device: Falseseed: 42data_seed: Nonejit_mode_eval: Falsebf16: Falsefp16: Truefp16_opt_level: O1half_precision_backend: autobf16_full_eval: Falsefp16_full_eval: Falsetf32: Nonelocal_rank: 0ddp_backend: Nonetpu_num_cores: Nonetpu_metrics_debug: Falsedebug: []dataloader_drop_last: Falsedataloader_num_workers: 0dataloader_prefetch_factor: Nonepast_index: -1disable_tqdm: Falseremove_unused_columns: Truelabel_names: Noneload_best_model_at_end: Trueignore_data_skip: Falsefsdp: []fsdp_min_num_params: 0fsdp_config: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}fsdp_transformer_layer_cls_to_wrap: Noneaccelerator_config: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}parallelism_config: Nonedeepspeed: Nonelabel_smoothing_factor: 0.0optim: adamw_torchoptim_args: Noneadafactor: Falsegroup_by_length: Falselength_column_name: lengthproject: huggingfacetrackio_space_id: trackioddp_find_unused_parameters: Noneddp_bucket_cap_mb: Noneddp_broadcast_buffers: Falsedataloader_pin_memory: Truedataloader_persistent_workers: Falseskip_memory_metrics: Trueuse_legacy_prediction_loop: Falsepush_to_hub: Falseresume_from_checkpoint: Nonehub_model_id: Nonehub_strategy: every_savehub_private_repo: Nonehub_always_push: Falsehub_revision: Nonegradient_checkpointing: Falsegradient_checkpointing_kwargs: Noneinclude_inputs_for_metrics: Falseinclude_for_metrics: []eval_do_concat_batches: Truefp16_backend: autopush_to_hub_model_id: Nonepush_to_hub_organization: Nonemp_parameters:auto_find_batch_size: Falsefull_determinism: Falsetorchdynamo: Noneray_scope: lastddp_timeout: 1800torch_compile: Falsetorch_compile_backend: Nonetorch_compile_mode: Noneinclude_tokens_per_second: Falseinclude_num_input_tokens_seen: noneftune_noise_alpha: Noneoptim_target_modules: Nonebatch_eval_metrics: Falseeval_on_start: Falseuse_liger_kernel: Falseliger_kernel_config: Noneeval_use_gather_object: Falseaverage_tokens_across_devices: Trueprompts: Nonebatch_sampler: no_duplicatesmulti_dataset_batch_sampler: proportionalrouter_mapping: {}learning_rate_mapping: {}
Training Logs
| Epoch | Step | Training Loss | dim_768_cosine_ndcg@10 | dim_512_cosine_ndcg@10 | dim_256_cosine_ndcg@10 | dim_128_cosine_ndcg@10 | dim_64_cosine_ndcg@10 |
|---|---|---|---|---|---|---|---|
| 0.7111 | 10 | 6.8447 | - | - | - | - | - |
| 1.0 | 15 | - | 0.1025 | 0.0367 | 0.0548 | 0.0502 | 0.1185 |
| 0.7111 | 10 | 4.8545 | - | - | - | - | - |
| 1.0 | 15 | - | 0.2250 | 0.3047 | 0.2895 | 0.2892 | 0.3178 |
| 0.7111 | 10 | 1.9011 | - | - | - | - | - |
| 1.0 | 15 | - | 0.6530 | 0.6393 | 0.6269 | 0.6631 | 0.6658 |
| 1.3556 | 20 | 0.6349 | - | - | - | - | - |
| 2.0 | 30 | 0.1887 | 0.8480 | 0.8643 | 0.8641 | 0.8532 | 0.7974 |
| 2.7111 | 40 | 0.0959 | - | - | - | - | - |
| 3.0 | 45 | - | 0.8688 | 0.8774 | 0.8754 | 0.8725 | 0.8457 |
| 3.3556 | 50 | 0.0359 | - | - | - | - | - |
| 4.0 | 60 | 0.0515 | 0.8712 | 0.8776 | 0.879 | 0.8777 | 0.8445 |
- The bold row denotes the saved checkpoint.
Framework Versions
- Python: 3.12.12
- Sentence Transformers: 5.2.0
- Transformers: 4.57.3
- PyTorch: 2.9.0+cu126
- Accelerate: 1.12.0
- Datasets: 4.0.0
- Tokenizers: 0.22.2
Citation
BibTeX
Sentence Transformers
@inproceedings{reimers-2019-sentence-bert,
title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
author = "Reimers, Nils and Gurevych, Iryna",
booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
month = "11",
year = "2019",
publisher = "Association for Computational Linguistics",
url = "https://arxiv.org/abs/1908.10084",
}
MatryoshkaLoss
@misc{kusupati2024matryoshka,
title={Matryoshka Representation Learning},
author={Aditya Kusupati and Gantavya Bhatt and Aniket Rege and Matthew Wallingford and Aditya Sinha and Vivek Ramanujan and William Howard-Snyder and Kaifeng Chen and Sham Kakade and Prateek Jain and Ali Farhadi},
year={2024},
eprint={2205.13147},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
MultipleNegativesRankingLoss
@misc{henderson2017efficient,
title={Efficient Natural Language Response Suggestion for Smart Reply},
author={Matthew Henderson and Rami Al-Rfou and Brian Strope and Yun-hsuan Sung and Laszlo Lukacs and Ruiqi Guo and Sanjiv Kumar and Balint Miklos and Ray Kurzweil},
year={2017},
eprint={1705.00652},
archivePrefix={arXiv},
primaryClass={cs.CL}
}