metadata
language:
- en
license: apache-2.0
tags:
- sentence-transformers
- sentence-similarity
- feature-extraction
- dense
- generated_from_trainer
- dataset_size:180
- loss:MatryoshkaLoss
- loss:MultipleNegativesRankingLoss
base_model: shubharuidas/codebert-embed-base-dense-retriever
widget:
- source_sentence: Explain the __init__ logic
sentences:
- |-
async def test_handler_with_async_execution() -> None:
"""Test handler works correctly with async tool execution."""
@tool
def async_add(a: int, b: int) -> int:
"""Async add two numbers."""
return a + b
def modifying_handler(
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
"""Handler that modifies arguments."""
# Add 10 to both arguments using override method
modified_call = {
**request.tool_call,
"args": {
**request.tool_call["args"],
"a": request.tool_call["args"]["a"] + 10,
"b": request.tool_call["args"]["b"] + 10,
},
}
modified_request = request.override(tool_call=modified_call)
return execute(modified_request)
tool_node = ToolNode([async_add], wrap_tool_call=modifying_handler)
result = await tool_node.ainvoke(
{
"messages": [
AIMessage(
"adding",
tool_calls=[
{
"name": "async_add",
"args": {"a": 1, "b": 2},
"id": "call_13",
}
],
)
]
},
config=_create_config_with_runtime(),
)
tool_message = result["messages"][-1]
assert isinstance(tool_message, ToolMessage)
# Original: 1 + 2 = 3, with modifications: 11 + 12 = 23
assert tool_message.content == "23"
- |-
def __init__(self) -> None:
self.loads: set[str] = set()
self.stores: set[str] = set()
- |-
class InternalServerError(APIStatusError):
pass
- source_sentence: Explain the async _load_checkpoint_tuple logic
sentences:
- >-
def task(__func_or_none__: Callable[P, Awaitable[T]]) ->
_TaskFunction[P, T]: ...
- |-
class State(BaseModel):
query: str
inner: InnerObject
answer: str | None = None
docs: Annotated[list[str], sorted_add]
- >-
async def _load_checkpoint_tuple(self, value: DictRow) ->
CheckpointTuple:
"""
Convert a database row into a CheckpointTuple object.
Args:
value: A row from the database containing checkpoint data.
Returns:
CheckpointTuple: A structured representation of the checkpoint,
including its configuration, metadata, parent checkpoint (if any),
and pending writes.
"""
return CheckpointTuple(
{
"configurable": {
"thread_id": value["thread_id"],
"checkpoint_ns": value["checkpoint_ns"],
"checkpoint_id": value["checkpoint_id"],
}
},
{
**value["checkpoint"],
"channel_values": {
**(value["checkpoint"].get("channel_values") or {}),
**self._load_blobs(value["channel_values"]),
},
},
value["metadata"],
(
{
"configurable": {
"thread_id": value["thread_id"],
"checkpoint_ns": value["checkpoint_ns"],
"checkpoint_id": value["parent_checkpoint_id"],
}
}
if value["parent_checkpoint_id"]
else None
),
await asyncio.to_thread(self._load_writes, value["pending_writes"]),
)
- source_sentence: Explain the flattened_runs logic
sentences:
- |-
class ChannelWrite(RunnableCallable):
"""Implements the logic for sending writes to CONFIG_KEY_SEND.
Can be used as a runnable or as a static method to call imperatively."""
writes: list[ChannelWriteEntry | ChannelWriteTupleEntry | Send]
"""Sequence of write entries or Send objects to write."""
def __init__(
self,
writes: Sequence[ChannelWriteEntry | ChannelWriteTupleEntry | Send],
*,
tags: Sequence[str] | None = None,
):
super().__init__(
func=self._write,
afunc=self._awrite,
name=None,
tags=tags,
trace=False,
)
self.writes = cast(
list[ChannelWriteEntry | ChannelWriteTupleEntry | Send], writes
)
def get_name(self, suffix: str | None = None, *, name: str | None = None) -> str:
if not name:
name = f"ChannelWrite<{','.join(w.channel if isinstance(w, ChannelWriteEntry) else '...' if isinstance(w, ChannelWriteTupleEntry) else w.node for w in self.writes)}>"
return super().get_name(suffix, name=name)
def _write(self, input: Any, config: RunnableConfig) -> None:
writes = [
ChannelWriteEntry(write.channel, input, write.skip_none, write.mapper)
if isinstance(write, ChannelWriteEntry) and write.value is PASSTHROUGH
else ChannelWriteTupleEntry(write.mapper, input)
if isinstance(write, ChannelWriteTupleEntry) and write.value is PASSTHROUGH
else write
for write in self.writes
]
self.do_write(
config,
writes,
)
return input
async def _awrite(self, input: Any, config: RunnableConfig) -> None:
writes = [
ChannelWriteEntry(write.channel, input, write.skip_none, write.mapper)
if isinstance(write, ChannelWriteEntry) and write.value is PASSTHROUGH
else ChannelWriteTupleEntry(write.mapper, input)
if isinstance(write, ChannelWriteTupleEntry) and write.value is PASSTHROUGH
else write
for write in self.writes
]
self.do_write(
config,
writes,
)
return input
@staticmethod
def do_write(
config: RunnableConfig,
writes: Sequence[ChannelWriteEntry | ChannelWriteTupleEntry | Send],
allow_passthrough: bool = True,
) -> None:
# validate
for w in writes:
if isinstance(w, ChannelWriteEntry):
if w.channel == TASKS:
raise InvalidUpdateError(
"Cannot write to the reserved channel TASKS"
)
if w.value is PASSTHROUGH and not allow_passthrough:
raise InvalidUpdateError("PASSTHROUGH value must be replaced")
if isinstance(w, ChannelWriteTupleEntry):
if w.value is PASSTHROUGH and not allow_passthrough:
raise InvalidUpdateError("PASSTHROUGH value must be replaced")
# if we want to persist writes found before hitting a ParentCommand
# can move this to a finally block
write: TYPE_SEND = config[CONF][CONFIG_KEY_SEND]
write(_assemble_writes(writes))
@staticmethod
def is_writer(runnable: Runnable) -> bool:
"""Used by PregelNode to distinguish between writers and other runnables."""
return (
isinstance(runnable, ChannelWrite)
or getattr(runnable, "_is_channel_writer", MISSING) is not MISSING
)
@staticmethod
def get_static_writes(
runnable: Runnable,
) -> Sequence[tuple[str, Any, str | None]] | None:
"""Used to get conditional writes a writer declares for static analysis."""
if isinstance(runnable, ChannelWrite):
return [
w
for entry in runnable.writes
if isinstance(entry, ChannelWriteTupleEntry) and entry.static
for w in entry.static
] or None
elif writes := getattr(runnable, "_is_channel_writer", MISSING):
if writes is not MISSING:
writes = cast(
Sequence[tuple[ChannelWriteEntry | Send, str | None]],
writes,
)
entries = [e for e, _ in writes]
labels = [la for _, la in writes]
return [(*t, la) for t, la in zip(_assemble_writes(entries), labels)]
@staticmethod
def register_writer(
runnable: R,
static: Sequence[tuple[ChannelWriteEntry | Send, str | None]] | None = None,
) -> R:
"""Used to mark a runnable as a writer, so that it can be detected by is_writer.
Instances of ChannelWrite are automatically marked as writers.
Optionally, a list of declared writes can be passed for static analysis."""
# using object.__setattr__ to work around objects that override __setattr__
# eg. pydantic models and dataclasses
object.__setattr__(runnable, "_is_channel_writer", static)
return runnable
- >-
def test_double_interrupt_subgraph(sync_checkpointer:
BaseCheckpointSaver) -> None:
class AgentState(TypedDict):
input: str
def node_1(state: AgentState):
result = interrupt("interrupt node 1")
return {"input": result}
def node_2(state: AgentState):
result = interrupt("interrupt node 2")
return {"input": result}
subgraph_builder = (
StateGraph(AgentState)
.add_node("node_1", node_1)
.add_node("node_2", node_2)
.add_edge(START, "node_1")
.add_edge("node_1", "node_2")
.add_edge("node_2", END)
)
# invoke the sub graph
subgraph = subgraph_builder.compile(checkpointer=sync_checkpointer)
thread = {"configurable": {"thread_id": str(uuid.uuid4())}}
assert [c for c in subgraph.stream({"input": "test"}, thread)] == [
{
"__interrupt__": (
Interrupt(
value="interrupt node 1",
id=AnyStr(),
),
)
},
]
# resume from the first interrupt
assert [c for c in subgraph.stream(Command(resume="123"), thread)] == [
{
"node_1": {"input": "123"},
},
{
"__interrupt__": (
Interrupt(
value="interrupt node 2",
id=AnyStr(),
),
)
},
]
# resume from the second interrupt
assert [c for c in subgraph.stream(Command(resume="123"), thread)] == [
{
"node_2": {"input": "123"},
},
]
subgraph = subgraph_builder.compile()
def invoke_sub_agent(state: AgentState):
return subgraph.invoke(state)
thread = {"configurable": {"thread_id": str(uuid.uuid4())}}
parent_agent = (
StateGraph(AgentState)
.add_node("invoke_sub_agent", invoke_sub_agent)
.add_edge(START, "invoke_sub_agent")
.add_edge("invoke_sub_agent", END)
.compile(checkpointer=sync_checkpointer)
)
assert [c for c in parent_agent.stream({"input": "test"}, thread)] == [
{
"__interrupt__": (
Interrupt(
value="interrupt node 1",
id=AnyStr(),
),
)
},
]
# resume from the first interrupt
assert [c for c in parent_agent.stream(Command(resume=True), thread)] == [
{
"__interrupt__": (
Interrupt(
value="interrupt node 2",
id=AnyStr(),
),
)
}
]
# resume from 2nd interrupt
assert [c for c in parent_agent.stream(Command(resume=True), thread)] == [
{
"invoke_sub_agent": {"input": True},
},
]
- |-
def flattened_runs(self) -> list[Run]:
q = [] + self.runs
result = []
while q:
parent = q.pop()
result.append(parent)
if parent.child_runs:
q.extend(parent.child_runs)
return result
- source_sentence: Explain the SubGraphState logic
sentences:
- |-
class Cron(TypedDict):
"""Represents a scheduled task."""
cron_id: str
"""The ID of the cron."""
assistant_id: str
"""The ID of the assistant."""
thread_id: str | None
"""The ID of the thread."""
on_run_completed: OnCompletionBehavior | None
"""What to do with the thread after the run completes. Only applicable for stateless crons."""
end_time: datetime | None
"""The end date to stop running the cron."""
schedule: str
"""The schedule to run, cron format."""
created_at: datetime
"""The time the cron was created."""
updated_at: datetime
"""The last time the cron was updated."""
payload: dict
"""The run payload to use for creating new run."""
user_id: str | None
"""The user ID of the cron."""
next_run_date: datetime | None
"""The next run date of the cron."""
metadata: dict
"""The metadata of the cron."""
- |-
class SubGraphState(MessagesState):
city: str
- |-
def task_path_str(tup: str | int | tuple) -> str:
"""Generate a string representation of the task path."""
return (
f"~{', '.join(task_path_str(x) for x in tup)}"
if isinstance(tup, (tuple, list))
else f"{tup:010d}"
if isinstance(tup, int)
else str(tup)
)
- source_sentence: Best practices for test_list_namespaces_operations
sentences:
- |-
def test_doubly_nested_graph_state(
sync_checkpointer: BaseCheckpointSaver,
) -> None:
class State(TypedDict):
my_key: str
class ChildState(TypedDict):
my_key: str
class GrandChildState(TypedDict):
my_key: str
def grandchild_1(state: ChildState):
return {"my_key": state["my_key"] + " here"}
def grandchild_2(state: ChildState):
return {
"my_key": state["my_key"] + " and there",
}
grandchild = StateGraph(GrandChildState)
grandchild.add_node("grandchild_1", grandchild_1)
grandchild.add_node("grandchild_2", grandchild_2)
grandchild.add_edge("grandchild_1", "grandchild_2")
grandchild.set_entry_point("grandchild_1")
grandchild.set_finish_point("grandchild_2")
child = StateGraph(ChildState)
child.add_node(
"child_1",
grandchild.compile(interrupt_before=["grandchild_2"]),
)
child.set_entry_point("child_1")
child.set_finish_point("child_1")
def parent_1(state: State):
return {"my_key": "hi " + state["my_key"]}
def parent_2(state: State):
return {"my_key": state["my_key"] + " and back again"}
graph = StateGraph(State)
graph.add_node("parent_1", parent_1)
graph.add_node("child", child.compile())
graph.add_node("parent_2", parent_2)
graph.set_entry_point("parent_1")
graph.add_edge("parent_1", "child")
graph.add_edge("child", "parent_2")
graph.set_finish_point("parent_2")
app = graph.compile(checkpointer=sync_checkpointer)
# test invoke w/ nested interrupt
config = {"configurable": {"thread_id": "1"}}
assert [
c
for c in app.stream(
{"my_key": "my value"}, config, subgraphs=True, durability="exit"
)
] == [
((), {"parent_1": {"my_key": "hi my value"}}),
(
(AnyStr("child:"), AnyStr("child_1:")),
{"grandchild_1": {"my_key": "hi my value here"}},
),
((), {"__interrupt__": ()}),
]
# get state without subgraphs
outer_state = app.get_state(config)
assert outer_state == StateSnapshot(
values={"my_key": "hi my value"},
tasks=(
PregelTask(
AnyStr(),
"child",
(PULL, "child"),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child"),
}
},
),
),
next=("child",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 1,
},
created_at=AnyStr(),
parent_config=None,
interrupts=(),
)
child_state = app.get_state(outer_state.tasks[0].state)
assert child_state == StateSnapshot(
values={"my_key": "hi my value"},
tasks=(
PregelTask(
AnyStr(),
"child_1",
(PULL, "child_1"),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
}
},
),
),
next=("child_1",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
}
),
}
},
metadata={
"parents": {"": AnyStr()},
"source": "loop",
"step": 0,
},
created_at=AnyStr(),
parent_config=None,
interrupts=(),
)
grandchild_state = app.get_state(child_state.tasks[0].state)
assert grandchild_state == StateSnapshot(
values={"my_key": "hi my value here"},
tasks=(
PregelTask(
AnyStr(),
"grandchild_2",
(PULL, "grandchild_2"),
),
),
next=("grandchild_2",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),
}
),
}
},
metadata={
"parents": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
}
),
"source": "loop",
"step": 1,
},
created_at=AnyStr(),
parent_config=None,
interrupts=(),
)
# get state with subgraphs
assert app.get_state(config, subgraphs=True) == StateSnapshot(
values={"my_key": "hi my value"},
tasks=(
PregelTask(
AnyStr(),
"child",
(PULL, "child"),
state=StateSnapshot(
values={"my_key": "hi my value"},
tasks=(
PregelTask(
AnyStr(),
"child_1",
(PULL, "child_1"),
state=StateSnapshot(
values={"my_key": "hi my value here"},
tasks=(
PregelTask(
AnyStr(),
"grandchild_2",
(PULL, "grandchild_2"),
),
),
next=("grandchild_2",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
AnyStr(
re.compile(r"child:.+|child1:")
): AnyStr(),
}
),
}
},
metadata={
"parents": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
}
),
"source": "loop",
"step": 1,
},
created_at=AnyStr(),
parent_config=None,
interrupts=(),
),
),
),
next=("child_1",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
metadata={
"parents": {"": AnyStr()},
"source": "loop",
"step": 0,
},
created_at=AnyStr(),
parent_config=None,
interrupts=(),
),
),
),
next=("child",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 1,
},
created_at=AnyStr(),
parent_config=None,
interrupts=(),
)
# # resume
assert [c for c in app.stream(None, config, subgraphs=True, durability="exit")] == [
(
(AnyStr("child:"), AnyStr("child_1:")),
{"grandchild_2": {"my_key": "hi my value here and there"}},
),
((AnyStr("child:"),), {"child_1": {"my_key": "hi my value here and there"}}),
((), {"child": {"my_key": "hi my value here and there"}}),
((), {"parent_2": {"my_key": "hi my value here and there and back again"}}),
]
# get state with and without subgraphs
assert (
app.get_state(config)
== app.get_state(config, subgraphs=True)
== StateSnapshot(
values={"my_key": "hi my value here and there and back again"},
tasks=(),
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 3,
},
created_at=AnyStr(),
parent_config=(
{
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
}
),
interrupts=(),
)
)
# get outer graph history
outer_history = list(app.get_state_history(config))
assert outer_history == [
StateSnapshot(
values={"my_key": "hi my value here and there and back again"},
tasks=(),
next=(),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 3,
},
created_at=AnyStr(),
parent_config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
interrupts=(),
),
StateSnapshot(
values={"my_key": "hi my value"},
tasks=(
PregelTask(
AnyStr(),
"child",
(PULL, "child"),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child"),
}
},
result=None,
),
),
next=("child",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": "",
"checkpoint_id": AnyStr(),
}
},
metadata={
"parents": {},
"source": "loop",
"step": 1,
},
created_at=AnyStr(),
parent_config=None,
interrupts=(),
),
]
# get child graph history
child_history = list(app.get_state_history(outer_history[1].tasks[0].state))
assert child_history == [
StateSnapshot(
values={"my_key": "hi my value"},
next=("child_1",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{"": AnyStr(), AnyStr("child:"): AnyStr()}
),
}
},
metadata={
"source": "loop",
"step": 0,
"parents": {"": AnyStr()},
},
created_at=AnyStr(),
parent_config=None,
tasks=(
PregelTask(
id=AnyStr(),
name="child_1",
path=(PULL, "child_1"),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("child:"),
}
},
result=None,
),
),
interrupts=(),
),
]
# get grandchild graph history
grandchild_history = list(app.get_state_history(child_history[0].tasks[0].state))
assert grandchild_history == [
StateSnapshot(
values={"my_key": "hi my value here"},
next=("grandchild_2",),
config={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr(),
"checkpoint_id": AnyStr(),
"checkpoint_map": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),
}
),
}
},
metadata={
"source": "loop",
"step": 1,
"parents": AnyDict(
{
"": AnyStr(),
AnyStr("child:"): AnyStr(),
}
),
},
created_at=AnyStr(),
parent_config=None,
tasks=(
PregelTask(
id=AnyStr(),
name="grandchild_2",
path=(PULL, "grandchild_2"),
result=None,
),
),
interrupts=(),
),
]
- |-
def _msgpack_enc(data: Any) -> bytes:
return ormsgpack.packb(data, default=_msgpack_default, option=_option)
- |-
def test_list_namespaces_operations(
fake_embeddings: CharacterEmbeddings,
) -> None:
"""Test list namespaces functionality with various filters."""
with create_vector_store(
fake_embeddings, text_fields=["key0", "key1", "key3"]
) as store:
test_pref = str(uuid.uuid4())
test_namespaces = [
(test_pref, "test", "documents", "public", test_pref),
(test_pref, "test", "documents", "private", test_pref),
(test_pref, "test", "images", "public", test_pref),
(test_pref, "test", "images", "private", test_pref),
(test_pref, "prod", "documents", "public", test_pref),
(test_pref, "prod", "documents", "some", "nesting", "public", test_pref),
(test_pref, "prod", "documents", "private", test_pref),
]
# Add test data
for namespace in test_namespaces:
store.put(namespace, "dummy", {"content": "dummy"})
# Test prefix filtering
prefix_result = store.list_namespaces(prefix=(test_pref, "test"))
assert len(prefix_result) == 4
assert all(ns[1] == "test" for ns in prefix_result)
# Test specific prefix
specific_prefix_result = store.list_namespaces(
prefix=(test_pref, "test", "documents")
)
assert len(specific_prefix_result) == 2
assert all(ns[1:3] == ("test", "documents") for ns in specific_prefix_result)
# Test suffix filtering
suffix_result = store.list_namespaces(suffix=("public", test_pref))
assert len(suffix_result) == 4
assert all(ns[-2] == "public" for ns in suffix_result)
# Test combined prefix and suffix
prefix_suffix_result = store.list_namespaces(
prefix=(test_pref, "test"), suffix=("public", test_pref)
)
assert len(prefix_suffix_result) == 2
assert all(
ns[1] == "test" and ns[-2] == "public" for ns in prefix_suffix_result
)
# Test wildcard in prefix
wildcard_prefix_result = store.list_namespaces(
prefix=(test_pref, "*", "documents")
)
assert len(wildcard_prefix_result) == 5
assert all(ns[2] == "documents" for ns in wildcard_prefix_result)
# Test wildcard in suffix
wildcard_suffix_result = store.list_namespaces(
suffix=("*", "public", test_pref)
)
assert len(wildcard_suffix_result) == 4
assert all(ns[-2] == "public" for ns in wildcard_suffix_result)
wildcard_single = store.list_namespaces(
suffix=("some", "*", "public", test_pref)
)
assert len(wildcard_single) == 1
assert wildcard_single[0] == (
test_pref,
"prod",
"documents",
"some",
"nesting",
"public",
test_pref,
)
# Test max depth
max_depth_result = store.list_namespaces(max_depth=3)
assert all(len(ns) <= 3 for ns in max_depth_result)
max_depth_result = store.list_namespaces(
max_depth=4, prefix=(test_pref, "*", "documents")
)
assert len(set(res for res in max_depth_result)) == len(max_depth_result) == 5
# Test pagination
limit_result = store.list_namespaces(prefix=(test_pref,), limit=3)
assert len(limit_result) == 3
offset_result = store.list_namespaces(prefix=(test_pref,), offset=3)
assert len(offset_result) == len(test_namespaces) - 3
empty_prefix_result = store.list_namespaces(prefix=(test_pref,))
assert len(empty_prefix_result) == len(test_namespaces)
assert set(empty_prefix_result) == set(test_namespaces)
# Clean up
for namespace in test_namespaces:
store.delete(namespace, "dummy")
pipeline_tag: sentence-similarity
library_name: sentence-transformers
metrics:
- cosine_accuracy@1
- cosine_accuracy@3
- cosine_accuracy@5
- cosine_accuracy@10
- cosine_precision@1
- cosine_precision@3
- cosine_precision@5
- cosine_precision@10
- cosine_recall@1
- cosine_recall@3
- cosine_recall@5
- cosine_recall@10
- cosine_ndcg@10
- cosine_mrr@10
- cosine_map@100
model-index:
- name: codeBert dense retriever
results:
- task:
type: information-retrieval
name: Information Retrieval
dataset:
name: dim 768
type: dim_768
metrics:
- type: cosine_accuracy@1
value: 0.9
name: Cosine Accuracy@1
- type: cosine_accuracy@3
value: 0.9
name: Cosine Accuracy@3
- type: cosine_accuracy@5
value: 1
name: Cosine Accuracy@5
- type: cosine_accuracy@10
value: 1
name: Cosine Accuracy@10
- type: cosine_precision@1
value: 0.9
name: Cosine Precision@1
- type: cosine_precision@3
value: 0.29999999999999993
name: Cosine Precision@3
- type: cosine_precision@5
value: 0.20000000000000004
name: Cosine Precision@5
- type: cosine_precision@10
value: 0.10000000000000002
name: Cosine Precision@10
- type: cosine_recall@1
value: 0.9
name: Cosine Recall@1
- type: cosine_recall@3
value: 0.9
name: Cosine Recall@3
- type: cosine_recall@5
value: 1
name: Cosine Recall@5
- type: cosine_recall@10
value: 1
name: Cosine Recall@10
- type: cosine_ndcg@10
value: 0.9408764682653967
name: Cosine Ndcg@10
- type: cosine_mrr@10
value: 0.9225
name: Cosine Mrr@10
- type: cosine_map@100
value: 0.9225
name: Cosine Map@100
- task:
type: information-retrieval
name: Information Retrieval
dataset:
name: dim 512
type: dim_512
metrics:
- type: cosine_accuracy@1
value: 0.9
name: Cosine Accuracy@1
- type: cosine_accuracy@3
value: 0.9
name: Cosine Accuracy@3
- type: cosine_accuracy@5
value: 1
name: Cosine Accuracy@5
- type: cosine_accuracy@10
value: 1
name: Cosine Accuracy@10
- type: cosine_precision@1
value: 0.9
name: Cosine Precision@1
- type: cosine_precision@3
value: 0.29999999999999993
name: Cosine Precision@3
- type: cosine_precision@5
value: 0.20000000000000004
name: Cosine Precision@5
- type: cosine_precision@10
value: 0.10000000000000002
name: Cosine Precision@10
- type: cosine_recall@1
value: 0.9
name: Cosine Recall@1
- type: cosine_recall@3
value: 0.9
name: Cosine Recall@3
- type: cosine_recall@5
value: 1
name: Cosine Recall@5
- type: cosine_recall@10
value: 1
name: Cosine Recall@10
- type: cosine_ndcg@10
value: 0.9408764682653967
name: Cosine Ndcg@10
- type: cosine_mrr@10
value: 0.9225
name: Cosine Mrr@10
- type: cosine_map@100
value: 0.9225
name: Cosine Map@100
- task:
type: information-retrieval
name: Information Retrieval
dataset:
name: dim 256
type: dim_256
metrics:
- type: cosine_accuracy@1
value: 0.9
name: Cosine Accuracy@1
- type: cosine_accuracy@3
value: 0.9
name: Cosine Accuracy@3
- type: cosine_accuracy@5
value: 1
name: Cosine Accuracy@5
- type: cosine_accuracy@10
value: 1
name: Cosine Accuracy@10
- type: cosine_precision@1
value: 0.9
name: Cosine Precision@1
- type: cosine_precision@3
value: 0.29999999999999993
name: Cosine Precision@3
- type: cosine_precision@5
value: 0.20000000000000004
name: Cosine Precision@5
- type: cosine_precision@10
value: 0.10000000000000002
name: Cosine Precision@10
- type: cosine_recall@1
value: 0.9
name: Cosine Recall@1
- type: cosine_recall@3
value: 0.9
name: Cosine Recall@3
- type: cosine_recall@5
value: 1
name: Cosine Recall@5
- type: cosine_recall@10
value: 1
name: Cosine Recall@10
- type: cosine_ndcg@10
value: 0.9408764682653967
name: Cosine Ndcg@10
- type: cosine_mrr@10
value: 0.9225
name: Cosine Mrr@10
- type: cosine_map@100
value: 0.9225
name: Cosine Map@100
- task:
type: information-retrieval
name: Information Retrieval
dataset:
name: dim 128
type: dim_128
metrics:
- type: cosine_accuracy@1
value: 0.85
name: Cosine Accuracy@1
- type: cosine_accuracy@3
value: 0.9
name: Cosine Accuracy@3
- type: cosine_accuracy@5
value: 0.95
name: Cosine Accuracy@5
- type: cosine_accuracy@10
value: 0.95
name: Cosine Accuracy@10
- type: cosine_precision@1
value: 0.85
name: Cosine Precision@1
- type: cosine_precision@3
value: 0.29999999999999993
name: Cosine Precision@3
- type: cosine_precision@5
value: 0.19000000000000003
name: Cosine Precision@5
- type: cosine_precision@10
value: 0.09500000000000001
name: Cosine Precision@10
- type: cosine_recall@1
value: 0.85
name: Cosine Recall@1
- type: cosine_recall@3
value: 0.9
name: Cosine Recall@3
- type: cosine_recall@5
value: 0.95
name: Cosine Recall@5
- type: cosine_recall@10
value: 0.95
name: Cosine Recall@10
- type: cosine_ndcg@10
value: 0.894342640361727
name: Cosine Ndcg@10
- type: cosine_mrr@10
value: 0.8766666666666666
name: Cosine Mrr@10
- type: cosine_map@100
value: 0.8799999999999999
name: Cosine Map@100
- task:
type: information-retrieval
name: Information Retrieval
dataset:
name: dim 64
type: dim_64
metrics:
- type: cosine_accuracy@1
value: 0.85
name: Cosine Accuracy@1
- type: cosine_accuracy@3
value: 0.9
name: Cosine Accuracy@3
- type: cosine_accuracy@5
value: 0.9
name: Cosine Accuracy@5
- type: cosine_accuracy@10
value: 1
name: Cosine Accuracy@10
- type: cosine_precision@1
value: 0.85
name: Cosine Precision@1
- type: cosine_precision@3
value: 0.29999999999999993
name: Cosine Precision@3
- type: cosine_precision@5
value: 0.18000000000000005
name: Cosine Precision@5
- type: cosine_precision@10
value: 0.10000000000000002
name: Cosine Precision@10
- type: cosine_recall@1
value: 0.85
name: Cosine Recall@1
- type: cosine_recall@3
value: 0.9
name: Cosine Recall@3
- type: cosine_recall@5
value: 0.9
name: Cosine Recall@5
- type: cosine_recall@10
value: 1
name: Cosine Recall@10
- type: cosine_ndcg@10
value: 0.9074399105059531
name: Cosine Ndcg@10
- type: cosine_mrr@10
value: 0.8800595238095237
name: Cosine Mrr@10
- type: cosine_map@100
value: 0.8800595238095237
name: Cosine Map@100
codeBert dense retriever
This is a sentence-transformers model finetuned from shubharuidas/codebert-embed-base-dense-retriever. It maps sentences & paragraphs to a 768-dimensional dense vector space and can be used for semantic textual similarity, semantic search, paraphrase mining, text classification, clustering, and more.
Model Details
Model Description
- Model Type: Sentence Transformer
- Base model: shubharuidas/codebert-embed-base-dense-retriever
- Maximum Sequence Length: 512 tokens
- Output Dimensionality: 768 dimensions
- Similarity Function: Cosine Similarity
- Language: en
- License: apache-2.0
Model Sources
- Documentation: Sentence Transformers Documentation
- Repository: Sentence Transformers on GitHub
- Hugging Face: Sentence Transformers on Hugging Face
Full Model Architecture
SentenceTransformer(
(0): Transformer({'max_seq_length': 512, 'do_lower_case': False, 'architecture': 'RobertaModel'})
(1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
)
Usage
Direct Usage (Sentence Transformers)
First install the Sentence Transformers library:
pip install -U sentence-transformers
Then you can load this model and run inference.
from sentence_transformers import SentenceTransformer
# Download from the 🤗 Hub
model = SentenceTransformer("anaghaj111/codebert-base-code-embed-mrl-langchain-langgraph")
# Run inference
sentences = [
'Best practices for test_list_namespaces_operations',
'def test_list_namespaces_operations(\n fake_embeddings: CharacterEmbeddings,\n) -> None:\n """Test list namespaces functionality with various filters."""\n with create_vector_store(\n fake_embeddings, text_fields=["key0", "key1", "key3"]\n ) as store:\n test_pref = str(uuid.uuid4())\n test_namespaces = [\n (test_pref, "test", "documents", "public", test_pref),\n (test_pref, "test", "documents", "private", test_pref),\n (test_pref, "test", "images", "public", test_pref),\n (test_pref, "test", "images", "private", test_pref),\n (test_pref, "prod", "documents", "public", test_pref),\n (test_pref, "prod", "documents", "some", "nesting", "public", test_pref),\n (test_pref, "prod", "documents", "private", test_pref),\n ]\n\n # Add test data\n for namespace in test_namespaces:\n store.put(namespace, "dummy", {"content": "dummy"})\n\n # Test prefix filtering\n prefix_result = store.list_namespaces(prefix=(test_pref, "test"))\n assert len(prefix_result) == 4\n assert all(ns[1] == "test" for ns in prefix_result)\n\n # Test specific prefix\n specific_prefix_result = store.list_namespaces(\n prefix=(test_pref, "test", "documents")\n )\n assert len(specific_prefix_result) == 2\n assert all(ns[1:3] == ("test", "documents") for ns in specific_prefix_result)\n\n # Test suffix filtering\n suffix_result = store.list_namespaces(suffix=("public", test_pref))\n assert len(suffix_result) == 4\n assert all(ns[-2] == "public" for ns in suffix_result)\n\n # Test combined prefix and suffix\n prefix_suffix_result = store.list_namespaces(\n prefix=(test_pref, "test"), suffix=("public", test_pref)\n )\n assert len(prefix_suffix_result) == 2\n assert all(\n ns[1] == "test" and ns[-2] == "public" for ns in prefix_suffix_result\n )\n\n # Test wildcard in prefix\n wildcard_prefix_result = store.list_namespaces(\n prefix=(test_pref, "*", "documents")\n )\n assert len(wildcard_prefix_result) == 5\n assert all(ns[2] == "documents" for ns in wildcard_prefix_result)\n\n # Test wildcard in suffix\n wildcard_suffix_result = store.list_namespaces(\n suffix=("*", "public", test_pref)\n )\n assert len(wildcard_suffix_result) == 4\n assert all(ns[-2] == "public" for ns in wildcard_suffix_result)\n\n wildcard_single = store.list_namespaces(\n suffix=("some", "*", "public", test_pref)\n )\n assert len(wildcard_single) == 1\n assert wildcard_single[0] == (\n test_pref,\n "prod",\n "documents",\n "some",\n "nesting",\n "public",\n test_pref,\n )\n\n # Test max depth\n max_depth_result = store.list_namespaces(max_depth=3)\n assert all(len(ns) <= 3 for ns in max_depth_result)\n\n max_depth_result = store.list_namespaces(\n max_depth=4, prefix=(test_pref, "*", "documents")\n )\n assert len(set(res for res in max_depth_result)) == len(max_depth_result) == 5\n\n # Test pagination\n limit_result = store.list_namespaces(prefix=(test_pref,), limit=3)\n assert len(limit_result) == 3\n\n offset_result = store.list_namespaces(prefix=(test_pref,), offset=3)\n assert len(offset_result) == len(test_namespaces) - 3\n\n empty_prefix_result = store.list_namespaces(prefix=(test_pref,))\n assert len(empty_prefix_result) == len(test_namespaces)\n assert set(empty_prefix_result) == set(test_namespaces)\n\n # Clean up\n for namespace in test_namespaces:\n store.delete(namespace, "dummy")',
'def test_doubly_nested_graph_state(\n sync_checkpointer: BaseCheckpointSaver,\n) -> None:\n class State(TypedDict):\n my_key: str\n\n class ChildState(TypedDict):\n my_key: str\n\n class GrandChildState(TypedDict):\n my_key: str\n\n def grandchild_1(state: ChildState):\n return {"my_key": state["my_key"] + " here"}\n\n def grandchild_2(state: ChildState):\n return {\n "my_key": state["my_key"] + " and there",\n }\n\n grandchild = StateGraph(GrandChildState)\n grandchild.add_node("grandchild_1", grandchild_1)\n grandchild.add_node("grandchild_2", grandchild_2)\n grandchild.add_edge("grandchild_1", "grandchild_2")\n grandchild.set_entry_point("grandchild_1")\n grandchild.set_finish_point("grandchild_2")\n\n child = StateGraph(ChildState)\n child.add_node(\n "child_1",\n grandchild.compile(interrupt_before=["grandchild_2"]),\n )\n child.set_entry_point("child_1")\n child.set_finish_point("child_1")\n\n def parent_1(state: State):\n return {"my_key": "hi " + state["my_key"]}\n\n def parent_2(state: State):\n return {"my_key": state["my_key"] + " and back again"}\n\n graph = StateGraph(State)\n graph.add_node("parent_1", parent_1)\n graph.add_node("child", child.compile())\n graph.add_node("parent_2", parent_2)\n graph.set_entry_point("parent_1")\n graph.add_edge("parent_1", "child")\n graph.add_edge("child", "parent_2")\n graph.set_finish_point("parent_2")\n\n app = graph.compile(checkpointer=sync_checkpointer)\n\n # test invoke w/ nested interrupt\n config = {"configurable": {"thread_id": "1"}}\n assert [\n c\n for c in app.stream(\n {"my_key": "my value"}, config, subgraphs=True, durability="exit"\n )\n ] == [\n ((), {"parent_1": {"my_key": "hi my value"}}),\n (\n (AnyStr("child:"), AnyStr("child_1:")),\n {"grandchild_1": {"my_key": "hi my value here"}},\n ),\n ((), {"__interrupt__": ()}),\n ]\n # get state without subgraphs\n outer_state = app.get_state(config)\n assert outer_state == StateSnapshot(\n values={"my_key": "hi my value"},\n tasks=(\n PregelTask(\n AnyStr(),\n "child",\n (PULL, "child"),\n state={\n "configurable": {\n "thread_id": "1",\n "checkpoint_ns": AnyStr("child"),\n }\n },\n ),\n ),\n next=("child",),\n config={\n "configurable": {\n "thread_id": "1",\n "checkpoint_ns": "",\n "checkpoint_id": AnyStr(),\n }\n },\n metadata={\n "parents": {},\n "source": "loop",\n "step": 1,\n },\n created_at=AnyStr(),\n parent_config=None,\n interrupts=(),\n )\n child_state = app.get_state(outer_state.tasks[0].state)\n assert child_state == StateSnapshot(\n values={"my_key": "hi my value"},\n tasks=(\n PregelTask(\n AnyStr(),\n "child_1",\n (PULL, "child_1"),\n state={\n "configurable": {\n "thread_id": "1",\n "checkpoint_ns": AnyStr(),\n }\n },\n ),\n ),\n next=("child_1",),\n config={\n "configurable": {\n "thread_id": "1",\n "checkpoint_ns": AnyStr("child:"),\n "checkpoint_id": AnyStr(),\n "checkpoint_map": AnyDict(\n {\n "": AnyStr(),\n AnyStr("child:"): AnyStr(),\n }\n ),\n }\n },\n metadata={\n "parents": {"": AnyStr()},\n "source": "loop",\n "step": 0,\n },\n created_at=AnyStr(),\n parent_config=None,\n interrupts=(),\n )\n grandchild_state = app.get_state(child_state.tasks[0].state)\n assert grandchild_state == StateSnapshot(\n values={"my_key": "hi my value here"},\n tasks=(\n PregelTask(\n AnyStr(),\n "grandchild_2",\n (PULL, "grandchild_2"),\n ),\n ),\n next=("grandchild_2",),\n config={\n "configurable": {\n "thread_id": "1",\n "checkpoint_ns": AnyStr(),\n "checkpoint_id": AnyStr(),\n "checkpoint_map": AnyDict(\n {\n "": AnyStr(),\n AnyStr("child:"): AnyStr(),\n AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),\n }\n ),\n }\n },\n metadata={\n "parents": AnyDict(\n {\n "": AnyStr(),\n AnyStr("child:"): AnyStr(),\n }\n ),\n "source": "loop",\n "step": 1,\n },\n created_at=AnyStr(),\n parent_config=None,\n interrupts=(),\n )\n # get state with subgraphs\n assert app.get_state(config, subgraphs=True) == StateSnapshot(\n values={"my_key": "hi my value"},\n tasks=(\n PregelTask(\n AnyStr(),\n "child",\n (PULL, "child"),\n state=StateSnapshot(\n values={"my_key": "hi my value"},\n tasks=(\n PregelTask(\n AnyStr(),\n "child_1",\n (PULL, "child_1"),\n state=StateSnapshot(\n values={"my_key": "hi my value here"},\n tasks=(\n PregelTask(\n AnyStr(),\n "grandchild_2",\n (PULL, "grandchild_2"),\n ),\n ),\n next=("grandchild_2",),\n config={\n "configurable": {\n "thread_id": "1",\n "checkpoint_ns": AnyStr(),\n "checkpoint_id": AnyStr(),\n "checkpoint_map": AnyDict(\n {\n "": AnyStr(),\n AnyStr("child:"): AnyStr(),\n AnyStr(\n re.compile(r"child:.+|child1:")\n ): AnyStr(),\n }\n ),\n }\n },\n metadata={\n "parents": AnyDict(\n {\n "": AnyStr(),\n AnyStr("child:"): AnyStr(),\n }\n ),\n "source": "loop",\n "step": 1,\n },\n created_at=AnyStr(),\n parent_config=None,\n interrupts=(),\n ),\n ),\n ),\n next=("child_1",),\n config={\n "configurable": {\n "thread_id": "1",\n "checkpoint_ns": AnyStr("child:"),\n "checkpoint_id": AnyStr(),\n "checkpoint_map": AnyDict(\n {"": AnyStr(), AnyStr("child:"): AnyStr()}\n ),\n }\n },\n metadata={\n "parents": {"": AnyStr()},\n "source": "loop",\n "step": 0,\n },\n created_at=AnyStr(),\n parent_config=None,\n interrupts=(),\n ),\n ),\n ),\n next=("child",),\n config={\n "configurable": {\n "thread_id": "1",\n "checkpoint_ns": "",\n "checkpoint_id": AnyStr(),\n }\n },\n metadata={\n "parents": {},\n "source": "loop",\n "step": 1,\n },\n created_at=AnyStr(),\n parent_config=None,\n interrupts=(),\n )\n # # resume\n assert [c for c in app.stream(None, config, subgraphs=True, durability="exit")] == [\n (\n (AnyStr("child:"), AnyStr("child_1:")),\n {"grandchild_2": {"my_key": "hi my value here and there"}},\n ),\n ((AnyStr("child:"),), {"child_1": {"my_key": "hi my value here and there"}}),\n ((), {"child": {"my_key": "hi my value here and there"}}),\n ((), {"parent_2": {"my_key": "hi my value here and there and back again"}}),\n ]\n # get state with and without subgraphs\n assert (\n app.get_state(config)\n == app.get_state(config, subgraphs=True)\n == StateSnapshot(\n values={"my_key": "hi my value here and there and back again"},\n tasks=(),\n next=(),\n config={\n "configurable": {\n "thread_id": "1",\n "checkpoint_ns": "",\n "checkpoint_id": AnyStr(),\n }\n },\n metadata={\n "parents": {},\n "source": "loop",\n "step": 3,\n },\n created_at=AnyStr(),\n parent_config=(\n {\n "configurable": {\n "thread_id": "1",\n "checkpoint_ns": "",\n "checkpoint_id": AnyStr(),\n }\n }\n ),\n interrupts=(),\n )\n )\n\n # get outer graph history\n outer_history = list(app.get_state_history(config))\n assert outer_history == [\n StateSnapshot(\n values={"my_key": "hi my value here and there and back again"},\n tasks=(),\n next=(),\n config={\n "configurable": {\n "thread_id": "1",\n "checkpoint_ns": "",\n "checkpoint_id": AnyStr(),\n }\n },\n metadata={\n "parents": {},\n "source": "loop",\n "step": 3,\n },\n created_at=AnyStr(),\n parent_config={\n "configurable": {\n "thread_id": "1",\n "checkpoint_ns": "",\n "checkpoint_id": AnyStr(),\n }\n },\n interrupts=(),\n ),\n StateSnapshot(\n values={"my_key": "hi my value"},\n tasks=(\n PregelTask(\n AnyStr(),\n "child",\n (PULL, "child"),\n state={\n "configurable": {\n "thread_id": "1",\n "checkpoint_ns": AnyStr("child"),\n }\n },\n result=None,\n ),\n ),\n next=("child",),\n config={\n "configurable": {\n "thread_id": "1",\n "checkpoint_ns": "",\n "checkpoint_id": AnyStr(),\n }\n },\n metadata={\n "parents": {},\n "source": "loop",\n "step": 1,\n },\n created_at=AnyStr(),\n parent_config=None,\n interrupts=(),\n ),\n ]\n # get child graph history\n child_history = list(app.get_state_history(outer_history[1].tasks[0].state))\n assert child_history == [\n StateSnapshot(\n values={"my_key": "hi my value"},\n next=("child_1",),\n config={\n "configurable": {\n "thread_id": "1",\n "checkpoint_ns": AnyStr("child:"),\n "checkpoint_id": AnyStr(),\n "checkpoint_map": AnyDict(\n {"": AnyStr(), AnyStr("child:"): AnyStr()}\n ),\n }\n },\n metadata={\n "source": "loop",\n "step": 0,\n "parents": {"": AnyStr()},\n },\n created_at=AnyStr(),\n parent_config=None,\n tasks=(\n PregelTask(\n id=AnyStr(),\n name="child_1",\n path=(PULL, "child_1"),\n state={\n "configurable": {\n "thread_id": "1",\n "checkpoint_ns": AnyStr("child:"),\n }\n },\n result=None,\n ),\n ),\n interrupts=(),\n ),\n ]\n # get grandchild graph history\n grandchild_history = list(app.get_state_history(child_history[0].tasks[0].state))\n assert grandchild_history == [\n StateSnapshot(\n values={"my_key": "hi my value here"},\n next=("grandchild_2",),\n config={\n "configurable": {\n "thread_id": "1",\n "checkpoint_ns": AnyStr(),\n "checkpoint_id": AnyStr(),\n "checkpoint_map": AnyDict(\n {\n "": AnyStr(),\n AnyStr("child:"): AnyStr(),\n AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),\n }\n ),\n }\n },\n metadata={\n "source": "loop",\n "step": 1,\n "parents": AnyDict(\n {\n "": AnyStr(),\n AnyStr("child:"): AnyStr(),\n }\n ),\n },\n created_at=AnyStr(),\n parent_config=None,\n tasks=(\n PregelTask(\n id=AnyStr(),\n name="grandchild_2",\n path=(PULL, "grandchild_2"),\n result=None,\n ),\n ),\n interrupts=(),\n ),\n ]',
]
embeddings = model.encode(sentences)
print(embeddings.shape)
# [3, 768]
# Get the similarity scores for the embeddings
similarities = model.similarity(embeddings, embeddings)
print(similarities)
# tensor([[1.0000, 0.7789, 0.3589],
# [0.7789, 1.0000, 0.4748],
# [0.3589, 0.4748, 1.0000]])
Evaluation
Metrics
Information Retrieval
- Dataset:
dim_768 - Evaluated with
InformationRetrievalEvaluatorwith these parameters:{ "truncate_dim": 768 }
| Metric | Value |
|---|---|
| cosine_accuracy@1 | 0.9 |
| cosine_accuracy@3 | 0.9 |
| cosine_accuracy@5 | 1.0 |
| cosine_accuracy@10 | 1.0 |
| cosine_precision@1 | 0.9 |
| cosine_precision@3 | 0.3 |
| cosine_precision@5 | 0.2 |
| cosine_precision@10 | 0.1 |
| cosine_recall@1 | 0.9 |
| cosine_recall@3 | 0.9 |
| cosine_recall@5 | 1.0 |
| cosine_recall@10 | 1.0 |
| cosine_ndcg@10 | 0.9409 |
| cosine_mrr@10 | 0.9225 |
| cosine_map@100 | 0.9225 |
Information Retrieval
- Dataset:
dim_512 - Evaluated with
InformationRetrievalEvaluatorwith these parameters:{ "truncate_dim": 512 }
| Metric | Value |
|---|---|
| cosine_accuracy@1 | 0.9 |
| cosine_accuracy@3 | 0.9 |
| cosine_accuracy@5 | 1.0 |
| cosine_accuracy@10 | 1.0 |
| cosine_precision@1 | 0.9 |
| cosine_precision@3 | 0.3 |
| cosine_precision@5 | 0.2 |
| cosine_precision@10 | 0.1 |
| cosine_recall@1 | 0.9 |
| cosine_recall@3 | 0.9 |
| cosine_recall@5 | 1.0 |
| cosine_recall@10 | 1.0 |
| cosine_ndcg@10 | 0.9409 |
| cosine_mrr@10 | 0.9225 |
| cosine_map@100 | 0.9225 |
Information Retrieval
- Dataset:
dim_256 - Evaluated with
InformationRetrievalEvaluatorwith these parameters:{ "truncate_dim": 256 }
| Metric | Value |
|---|---|
| cosine_accuracy@1 | 0.9 |
| cosine_accuracy@3 | 0.9 |
| cosine_accuracy@5 | 1.0 |
| cosine_accuracy@10 | 1.0 |
| cosine_precision@1 | 0.9 |
| cosine_precision@3 | 0.3 |
| cosine_precision@5 | 0.2 |
| cosine_precision@10 | 0.1 |
| cosine_recall@1 | 0.9 |
| cosine_recall@3 | 0.9 |
| cosine_recall@5 | 1.0 |
| cosine_recall@10 | 1.0 |
| cosine_ndcg@10 | 0.9409 |
| cosine_mrr@10 | 0.9225 |
| cosine_map@100 | 0.9225 |
Information Retrieval
- Dataset:
dim_128 - Evaluated with
InformationRetrievalEvaluatorwith these parameters:{ "truncate_dim": 128 }
| Metric | Value |
|---|---|
| cosine_accuracy@1 | 0.85 |
| cosine_accuracy@3 | 0.9 |
| cosine_accuracy@5 | 0.95 |
| cosine_accuracy@10 | 0.95 |
| cosine_precision@1 | 0.85 |
| cosine_precision@3 | 0.3 |
| cosine_precision@5 | 0.19 |
| cosine_precision@10 | 0.095 |
| cosine_recall@1 | 0.85 |
| cosine_recall@3 | 0.9 |
| cosine_recall@5 | 0.95 |
| cosine_recall@10 | 0.95 |
| cosine_ndcg@10 | 0.8943 |
| cosine_mrr@10 | 0.8767 |
| cosine_map@100 | 0.88 |
Information Retrieval
- Dataset:
dim_64 - Evaluated with
InformationRetrievalEvaluatorwith these parameters:{ "truncate_dim": 64 }
| Metric | Value |
|---|---|
| cosine_accuracy@1 | 0.85 |
| cosine_accuracy@3 | 0.9 |
| cosine_accuracy@5 | 0.9 |
| cosine_accuracy@10 | 1.0 |
| cosine_precision@1 | 0.85 |
| cosine_precision@3 | 0.3 |
| cosine_precision@5 | 0.18 |
| cosine_precision@10 | 0.1 |
| cosine_recall@1 | 0.85 |
| cosine_recall@3 | 0.9 |
| cosine_recall@5 | 0.9 |
| cosine_recall@10 | 1.0 |
| cosine_ndcg@10 | 0.9074 |
| cosine_mrr@10 | 0.8801 |
| cosine_map@100 | 0.8801 |
Training Details
Training Dataset
Unnamed Dataset
- Size: 180 training samples
- Columns:
anchorandpositive - Approximate statistics based on the first 180 samples:
anchor positive type string string details - min: 6 tokens
- mean: 12.34 tokens
- max: 117 tokens
- min: 14 tokens
- mean: 273.18 tokens
- max: 512 tokens
- Samples:
anchor positive How to implement State?class State(TypedDict):
messages: Annotated[list[str], operator.add]Best practices for test_sql_injection_vulnerabilitydef test_sql_injection_vulnerability(store: SqliteStore) -> None:
"""Test that SQL injection via malicious filter keys is prevented."""
# Add public and private documents
store.put(("docs",), "public", {"access": "public", "data": "public info"})
store.put(
("docs",), "private", {"access": "private", "data": "secret", "password": "123"}
)
# Normal query - returns 1 public document
normal = store.search(("docs",), filter={"access": "public"})
assert len(normal) == 1
assert normal[0].value["access"] == "public"
# SQL injection attempt via malicious key should raise ValueError
malicious_key = "access') = 'public' OR '1'='1' OR json_extract(value, '$."
with pytest.raises(ValueError, match="Invalid filter key"):
store.search(("docs",), filter={malicious_key: "dummy"})Example usage of put_writesdef put_writes(
self,
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
"""Store intermediate writes linked to a checkpoint.
This method saves intermediate writes associated with a checkpoint to the Postgres database.
Args:
config: Configuration of the related checkpoint.
writes: List of writes to store.
task_id: Identifier for the task creating the writes.
"""
query = (
self.UPSERT_CHECKPOINT_WRITES_SQL
if all(w[0] in WRITES_IDX_MAP for w in writes)
else self.INSERT_CHECKPOINT_WRITES_SQL
)
with self._cursor(pipeline=True) as cur:
cur.executemany(
query,
self._dump_writes(
config["configurable"]["thread_id"],
config["configurable"]["checkpoint_ns"],
config["c... - Loss:
MatryoshkaLosswith these parameters:{ "loss": "MultipleNegativesRankingLoss", "matryoshka_dims": [ 768, 512, 256, 128, 64 ], "matryoshka_weights": [ 1, 1, 1, 1, 1 ], "n_dims_per_step": -1 }
Training Hyperparameters
Non-Default Hyperparameters
eval_strategy: epochper_device_train_batch_size: 4per_device_eval_batch_size: 4gradient_accumulation_steps: 16learning_rate: 2e-05num_train_epochs: 2lr_scheduler_type: cosinewarmup_ratio: 0.1fp16: Trueload_best_model_at_end: Trueoptim: adamw_torchbatch_sampler: no_duplicates
All Hyperparameters
Click to expand
overwrite_output_dir: Falsedo_predict: Falseeval_strategy: epochprediction_loss_only: Trueper_device_train_batch_size: 4per_device_eval_batch_size: 4per_gpu_train_batch_size: Noneper_gpu_eval_batch_size: Nonegradient_accumulation_steps: 16eval_accumulation_steps: Nonetorch_empty_cache_steps: Nonelearning_rate: 2e-05weight_decay: 0.0adam_beta1: 0.9adam_beta2: 0.999adam_epsilon: 1e-08max_grad_norm: 1.0num_train_epochs: 2max_steps: -1lr_scheduler_type: cosinelr_scheduler_kwargs: {}warmup_ratio: 0.1warmup_steps: 0log_level: passivelog_level_replica: warninglog_on_each_node: Truelogging_nan_inf_filter: Truesave_safetensors: Truesave_on_each_node: Falsesave_only_model: Falserestore_callback_states_from_checkpoint: Falseno_cuda: Falseuse_cpu: Falseuse_mps_device: Falseseed: 42data_seed: Nonejit_mode_eval: Falsebf16: Falsefp16: Truefp16_opt_level: O1half_precision_backend: autobf16_full_eval: Falsefp16_full_eval: Falsetf32: Nonelocal_rank: 0ddp_backend: Nonetpu_num_cores: Nonetpu_metrics_debug: Falsedebug: []dataloader_drop_last: Falsedataloader_num_workers: 0dataloader_prefetch_factor: Nonepast_index: -1disable_tqdm: Falseremove_unused_columns: Truelabel_names: Noneload_best_model_at_end: Trueignore_data_skip: Falsefsdp: []fsdp_min_num_params: 0fsdp_config: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}fsdp_transformer_layer_cls_to_wrap: Noneaccelerator_config: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}parallelism_config: Nonedeepspeed: Nonelabel_smoothing_factor: 0.0optim: adamw_torchoptim_args: Noneadafactor: Falsegroup_by_length: Falselength_column_name: lengthproject: huggingfacetrackio_space_id: trackioddp_find_unused_parameters: Noneddp_bucket_cap_mb: Noneddp_broadcast_buffers: Falsedataloader_pin_memory: Truedataloader_persistent_workers: Falseskip_memory_metrics: Trueuse_legacy_prediction_loop: Falsepush_to_hub: Falseresume_from_checkpoint: Nonehub_model_id: Nonehub_strategy: every_savehub_private_repo: Nonehub_always_push: Falsehub_revision: Nonegradient_checkpointing: Falsegradient_checkpointing_kwargs: Noneinclude_inputs_for_metrics: Falseinclude_for_metrics: []eval_do_concat_batches: Truefp16_backend: autopush_to_hub_model_id: Nonepush_to_hub_organization: Nonemp_parameters:auto_find_batch_size: Falsefull_determinism: Falsetorchdynamo: Noneray_scope: lastddp_timeout: 1800torch_compile: Falsetorch_compile_backend: Nonetorch_compile_mode: Noneinclude_tokens_per_second: Falseinclude_num_input_tokens_seen: noneftune_noise_alpha: Noneoptim_target_modules: Nonebatch_eval_metrics: Falseeval_on_start: Falseuse_liger_kernel: Falseliger_kernel_config: Noneeval_use_gather_object: Falseaverage_tokens_across_devices: Trueprompts: Nonebatch_sampler: no_duplicatesmulti_dataset_batch_sampler: proportionalrouter_mapping: {}learning_rate_mapping: {}
Training Logs
| Epoch | Step | dim_768_cosine_ndcg@10 | dim_512_cosine_ndcg@10 | dim_256_cosine_ndcg@10 | dim_128_cosine_ndcg@10 | dim_64_cosine_ndcg@10 |
|---|---|---|---|---|---|---|
| 1.0 | 3 | 0.9409 | 0.9202 | 0.9431 | 0.8412 | 0.9059 |
| 2.0 | 6 | 0.9409 | 0.9409 | 0.9409 | 0.8943 | 0.9074 |
- The bold row denotes the saved checkpoint.
Framework Versions
- Python: 3.14.0
- Sentence Transformers: 5.2.2
- Transformers: 4.57.3
- PyTorch: 2.9.1
- Accelerate: 1.12.0
- Datasets: 4.5.0
- Tokenizers: 0.22.2
Citation
BibTeX
Sentence Transformers
@inproceedings{reimers-2019-sentence-bert,
title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
author = "Reimers, Nils and Gurevych, Iryna",
booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
month = "11",
year = "2019",
publisher = "Association for Computational Linguistics",
url = "https://arxiv.org/abs/1908.10084",
}
MatryoshkaLoss
@misc{kusupati2024matryoshka,
title={Matryoshka Representation Learning},
author={Aditya Kusupati and Gantavya Bhatt and Aniket Rege and Matthew Wallingford and Aditya Sinha and Vivek Ramanujan and William Howard-Snyder and Kaifeng Chen and Sham Kakade and Prateek Jain and Ali Farhadi},
year={2024},
eprint={2205.13147},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
MultipleNegativesRankingLoss
@misc{henderson2017efficient,
title={Efficient Natural Language Response Suggestion for Smart Reply},
author={Matthew Henderson and Rami Al-Rfou and Brian Strope and Yun-hsuan Sung and Laszlo Lukacs and Ruiqi Guo and Sanjiv Kumar and Balint Miklos and Ray Kurzweil},
year={2017},
eprint={1705.00652},
archivePrefix={arXiv},
primaryClass={cs.CL}
}