anaghaj111's picture
Add new SentenceTransformer model
603f647 verified
metadata
language:
  - en
license: apache-2.0
tags:
  - sentence-transformers
  - sentence-similarity
  - feature-extraction
  - dense
  - generated_from_trainer
  - dataset_size:180
  - loss:MatryoshkaLoss
  - loss:MultipleNegativesRankingLoss
base_model: shubharuidas/codebert-embed-base-dense-retriever
widget:
  - source_sentence: Explain the __init__ logic
    sentences:
      - |-
        async def test_handler_with_async_execution() -> None:
            """Test handler works correctly with async tool execution."""

            @tool
            def async_add(a: int, b: int) -> int:
                """Async add two numbers."""
                return a + b

            def modifying_handler(
                request: ToolCallRequest,
                execute: Callable[[ToolCallRequest], ToolMessage | Command],
            ) -> ToolMessage | Command:
                """Handler that modifies arguments."""
                # Add 10 to both arguments using override method
                modified_call = {
                    **request.tool_call,
                    "args": {
                        **request.tool_call["args"],
                        "a": request.tool_call["args"]["a"] + 10,
                        "b": request.tool_call["args"]["b"] + 10,
                    },
                }
                modified_request = request.override(tool_call=modified_call)
                return execute(modified_request)

            tool_node = ToolNode([async_add], wrap_tool_call=modifying_handler)

            result = await tool_node.ainvoke(
                {
                    "messages": [
                        AIMessage(
                            "adding",
                            tool_calls=[
                                {
                                    "name": "async_add",
                                    "args": {"a": 1, "b": 2},
                                    "id": "call_13",
                                }
                            ],
                        )
                    ]
                },
                config=_create_config_with_runtime(),
            )

            tool_message = result["messages"][-1]
            assert isinstance(tool_message, ToolMessage)
            # Original: 1 + 2 = 3, with modifications: 11 + 12 = 23
            assert tool_message.content == "23"
      - |-
        def __init__(self) -> None:
                self.loads: set[str] = set()
                self.stores: set[str] = set()
      - |-
        class InternalServerError(APIStatusError):
            pass
  - source_sentence: Explain the async _load_checkpoint_tuple logic
    sentences:
      - >-
        def task(__func_or_none__: Callable[P, Awaitable[T]]) ->
        _TaskFunction[P, T]: ...
      - |-
        class State(BaseModel):
                query: str
                inner: InnerObject
                answer: str | None = None
                docs: Annotated[list[str], sorted_add]
      - >-
        async def _load_checkpoint_tuple(self, value: DictRow) ->
        CheckpointTuple:
                """
                Convert a database row into a CheckpointTuple object.

                Args:
                    value: A row from the database containing checkpoint data.

                Returns:
                    CheckpointTuple: A structured representation of the checkpoint,
                    including its configuration, metadata, parent checkpoint (if any),
                    and pending writes.
                """
                return CheckpointTuple(
                    {
                        "configurable": {
                            "thread_id": value["thread_id"],
                            "checkpoint_ns": value["checkpoint_ns"],
                            "checkpoint_id": value["checkpoint_id"],
                        }
                    },
                    {
                        **value["checkpoint"],
                        "channel_values": {
                            **(value["checkpoint"].get("channel_values") or {}),
                            **self._load_blobs(value["channel_values"]),
                        },
                    },
                    value["metadata"],
                    (
                        {
                            "configurable": {
                                "thread_id": value["thread_id"],
                                "checkpoint_ns": value["checkpoint_ns"],
                                "checkpoint_id": value["parent_checkpoint_id"],
                            }
                        }
                        if value["parent_checkpoint_id"]
                        else None
                    ),
                    await asyncio.to_thread(self._load_writes, value["pending_writes"]),
                )
  - source_sentence: Explain the flattened_runs logic
    sentences:
      - |-
        class ChannelWrite(RunnableCallable):
            """Implements the logic for sending writes to CONFIG_KEY_SEND.
            Can be used as a runnable or as a static method to call imperatively."""

            writes: list[ChannelWriteEntry | ChannelWriteTupleEntry | Send]
            """Sequence of write entries or Send objects to write."""

            def __init__(
                self,
                writes: Sequence[ChannelWriteEntry | ChannelWriteTupleEntry | Send],
                *,
                tags: Sequence[str] | None = None,
            ):
                super().__init__(
                    func=self._write,
                    afunc=self._awrite,
                    name=None,
                    tags=tags,
                    trace=False,
                )
                self.writes = cast(
                    list[ChannelWriteEntry | ChannelWriteTupleEntry | Send], writes
                )

            def get_name(self, suffix: str | None = None, *, name: str | None = None) -> str:
                if not name:
                    name = f"ChannelWrite<{','.join(w.channel if isinstance(w, ChannelWriteEntry) else '...' if isinstance(w, ChannelWriteTupleEntry) else w.node for w in self.writes)}>"
                return super().get_name(suffix, name=name)

            def _write(self, input: Any, config: RunnableConfig) -> None:
                writes = [
                    ChannelWriteEntry(write.channel, input, write.skip_none, write.mapper)
                    if isinstance(write, ChannelWriteEntry) and write.value is PASSTHROUGH
                    else ChannelWriteTupleEntry(write.mapper, input)
                    if isinstance(write, ChannelWriteTupleEntry) and write.value is PASSTHROUGH
                    else write
                    for write in self.writes
                ]
                self.do_write(
                    config,
                    writes,
                )
                return input

            async def _awrite(self, input: Any, config: RunnableConfig) -> None:
                writes = [
                    ChannelWriteEntry(write.channel, input, write.skip_none, write.mapper)
                    if isinstance(write, ChannelWriteEntry) and write.value is PASSTHROUGH
                    else ChannelWriteTupleEntry(write.mapper, input)
                    if isinstance(write, ChannelWriteTupleEntry) and write.value is PASSTHROUGH
                    else write
                    for write in self.writes
                ]
                self.do_write(
                    config,
                    writes,
                )
                return input

            @staticmethod
            def do_write(
                config: RunnableConfig,
                writes: Sequence[ChannelWriteEntry | ChannelWriteTupleEntry | Send],
                allow_passthrough: bool = True,
            ) -> None:
                # validate
                for w in writes:
                    if isinstance(w, ChannelWriteEntry):
                        if w.channel == TASKS:
                            raise InvalidUpdateError(
                                "Cannot write to the reserved channel TASKS"
                            )
                        if w.value is PASSTHROUGH and not allow_passthrough:
                            raise InvalidUpdateError("PASSTHROUGH value must be replaced")
                    if isinstance(w, ChannelWriteTupleEntry):
                        if w.value is PASSTHROUGH and not allow_passthrough:
                            raise InvalidUpdateError("PASSTHROUGH value must be replaced")
                # if we want to persist writes found before hitting a ParentCommand
                # can move this to a finally block
                write: TYPE_SEND = config[CONF][CONFIG_KEY_SEND]
                write(_assemble_writes(writes))

            @staticmethod
            def is_writer(runnable: Runnable) -> bool:
                """Used by PregelNode to distinguish between writers and other runnables."""
                return (
                    isinstance(runnable, ChannelWrite)
                    or getattr(runnable, "_is_channel_writer", MISSING) is not MISSING
                )

            @staticmethod
            def get_static_writes(
                runnable: Runnable,
            ) -> Sequence[tuple[str, Any, str | None]] | None:
                """Used to get conditional writes a writer declares for static analysis."""
                if isinstance(runnable, ChannelWrite):
                    return [
                        w
                        for entry in runnable.writes
                        if isinstance(entry, ChannelWriteTupleEntry) and entry.static
                        for w in entry.static
                    ] or None
                elif writes := getattr(runnable, "_is_channel_writer", MISSING):
                    if writes is not MISSING:
                        writes = cast(
                            Sequence[tuple[ChannelWriteEntry | Send, str | None]],
                            writes,
                        )
                        entries = [e for e, _ in writes]
                        labels = [la for _, la in writes]
                        return [(*t, la) for t, la in zip(_assemble_writes(entries), labels)]

            @staticmethod
            def register_writer(
                runnable: R,
                static: Sequence[tuple[ChannelWriteEntry | Send, str | None]] | None = None,
            ) -> R:
                """Used to mark a runnable as a writer, so that it can be detected by is_writer.
                Instances of ChannelWrite are automatically marked as writers.
                Optionally, a list of declared writes can be passed for static analysis."""
                # using object.__setattr__ to work around objects that override __setattr__
                # eg. pydantic models and dataclasses
                object.__setattr__(runnable, "_is_channel_writer", static)
                return runnable
      - >-
        def test_double_interrupt_subgraph(sync_checkpointer:
        BaseCheckpointSaver) -> None:
            class AgentState(TypedDict):
                input: str

            def node_1(state: AgentState):
                result = interrupt("interrupt node 1")
                return {"input": result}

            def node_2(state: AgentState):
                result = interrupt("interrupt node 2")
                return {"input": result}

            subgraph_builder = (
                StateGraph(AgentState)
                .add_node("node_1", node_1)
                .add_node("node_2", node_2)
                .add_edge(START, "node_1")
                .add_edge("node_1", "node_2")
                .add_edge("node_2", END)
            )

            # invoke the sub graph
            subgraph = subgraph_builder.compile(checkpointer=sync_checkpointer)
            thread = {"configurable": {"thread_id": str(uuid.uuid4())}}
            assert [c for c in subgraph.stream({"input": "test"}, thread)] == [
                {
                    "__interrupt__": (
                        Interrupt(
                            value="interrupt node 1",
                            id=AnyStr(),
                        ),
                    )
                },
            ]
            # resume from the first interrupt
            assert [c for c in subgraph.stream(Command(resume="123"), thread)] == [
                {
                    "node_1": {"input": "123"},
                },
                {
                    "__interrupt__": (
                        Interrupt(
                            value="interrupt node 2",
                            id=AnyStr(),
                        ),
                    )
                },
            ]
            # resume from the second interrupt
            assert [c for c in subgraph.stream(Command(resume="123"), thread)] == [
                {
                    "node_2": {"input": "123"},
                },
            ]

            subgraph = subgraph_builder.compile()

            def invoke_sub_agent(state: AgentState):
                return subgraph.invoke(state)

            thread = {"configurable": {"thread_id": str(uuid.uuid4())}}
            parent_agent = (
                StateGraph(AgentState)
                .add_node("invoke_sub_agent", invoke_sub_agent)
                .add_edge(START, "invoke_sub_agent")
                .add_edge("invoke_sub_agent", END)
                .compile(checkpointer=sync_checkpointer)
            )

            assert [c for c in parent_agent.stream({"input": "test"}, thread)] == [
                {
                    "__interrupt__": (
                        Interrupt(
                            value="interrupt node 1",
                            id=AnyStr(),
                        ),
                    )
                },
            ]

            # resume from the first interrupt
            assert [c for c in parent_agent.stream(Command(resume=True), thread)] == [
                {
                    "__interrupt__": (
                        Interrupt(
                            value="interrupt node 2",
                            id=AnyStr(),
                        ),
                    )
                }
            ]

            # resume from 2nd interrupt
            assert [c for c in parent_agent.stream(Command(resume=True), thread)] == [
                {
                    "invoke_sub_agent": {"input": True},
                },
            ]
      - |-
        def flattened_runs(self) -> list[Run]:
                q = [] + self.runs
                result = []
                while q:
                    parent = q.pop()
                    result.append(parent)
                    if parent.child_runs:
                        q.extend(parent.child_runs)
                return result
  - source_sentence: Explain the SubGraphState logic
    sentences:
      - |-
        class Cron(TypedDict):
            """Represents a scheduled task."""

            cron_id: str
            """The ID of the cron."""
            assistant_id: str
            """The ID of the assistant."""
            thread_id: str | None
            """The ID of the thread."""
            on_run_completed: OnCompletionBehavior | None
            """What to do with the thread after the run completes. Only applicable for stateless crons."""
            end_time: datetime | None
            """The end date to stop running the cron."""
            schedule: str
            """The schedule to run, cron format."""
            created_at: datetime
            """The time the cron was created."""
            updated_at: datetime
            """The last time the cron was updated."""
            payload: dict
            """The run payload to use for creating new run."""
            user_id: str | None
            """The user ID of the cron."""
            next_run_date: datetime | None
            """The next run date of the cron."""
            metadata: dict
            """The metadata of the cron."""
      - |-
        class SubGraphState(MessagesState):
                city: str
      - |-
        def task_path_str(tup: str | int | tuple) -> str:
            """Generate a string representation of the task path."""
            return (
                f"~{', '.join(task_path_str(x) for x in tup)}"
                if isinstance(tup, (tuple, list))
                else f"{tup:010d}"
                if isinstance(tup, int)
                else str(tup)
            )
  - source_sentence: Best practices for test_list_namespaces_operations
    sentences:
      - |-
        def test_doubly_nested_graph_state(
            sync_checkpointer: BaseCheckpointSaver,
        ) -> None:
            class State(TypedDict):
                my_key: str

            class ChildState(TypedDict):
                my_key: str

            class GrandChildState(TypedDict):
                my_key: str

            def grandchild_1(state: ChildState):
                return {"my_key": state["my_key"] + " here"}

            def grandchild_2(state: ChildState):
                return {
                    "my_key": state["my_key"] + " and there",
                }

            grandchild = StateGraph(GrandChildState)
            grandchild.add_node("grandchild_1", grandchild_1)
            grandchild.add_node("grandchild_2", grandchild_2)
            grandchild.add_edge("grandchild_1", "grandchild_2")
            grandchild.set_entry_point("grandchild_1")
            grandchild.set_finish_point("grandchild_2")

            child = StateGraph(ChildState)
            child.add_node(
                "child_1",
                grandchild.compile(interrupt_before=["grandchild_2"]),
            )
            child.set_entry_point("child_1")
            child.set_finish_point("child_1")

            def parent_1(state: State):
                return {"my_key": "hi " + state["my_key"]}

            def parent_2(state: State):
                return {"my_key": state["my_key"] + " and back again"}

            graph = StateGraph(State)
            graph.add_node("parent_1", parent_1)
            graph.add_node("child", child.compile())
            graph.add_node("parent_2", parent_2)
            graph.set_entry_point("parent_1")
            graph.add_edge("parent_1", "child")
            graph.add_edge("child", "parent_2")
            graph.set_finish_point("parent_2")

            app = graph.compile(checkpointer=sync_checkpointer)

            # test invoke w/ nested interrupt
            config = {"configurable": {"thread_id": "1"}}
            assert [
                c
                for c in app.stream(
                    {"my_key": "my value"}, config, subgraphs=True, durability="exit"
                )
            ] == [
                ((), {"parent_1": {"my_key": "hi my value"}}),
                (
                    (AnyStr("child:"), AnyStr("child_1:")),
                    {"grandchild_1": {"my_key": "hi my value here"}},
                ),
                ((), {"__interrupt__": ()}),
            ]
            # get state without subgraphs
            outer_state = app.get_state(config)
            assert outer_state == StateSnapshot(
                values={"my_key": "hi my value"},
                tasks=(
                    PregelTask(
                        AnyStr(),
                        "child",
                        (PULL, "child"),
                        state={
                            "configurable": {
                                "thread_id": "1",
                                "checkpoint_ns": AnyStr("child"),
                            }
                        },
                    ),
                ),
                next=("child",),
                config={
                    "configurable": {
                        "thread_id": "1",
                        "checkpoint_ns": "",
                        "checkpoint_id": AnyStr(),
                    }
                },
                metadata={
                    "parents": {},
                    "source": "loop",
                    "step": 1,
                },
                created_at=AnyStr(),
                parent_config=None,
                interrupts=(),
            )
            child_state = app.get_state(outer_state.tasks[0].state)
            assert child_state == StateSnapshot(
                values={"my_key": "hi my value"},
                tasks=(
                    PregelTask(
                        AnyStr(),
                        "child_1",
                        (PULL, "child_1"),
                        state={
                            "configurable": {
                                "thread_id": "1",
                                "checkpoint_ns": AnyStr(),
                            }
                        },
                    ),
                ),
                next=("child_1",),
                config={
                    "configurable": {
                        "thread_id": "1",
                        "checkpoint_ns": AnyStr("child:"),
                        "checkpoint_id": AnyStr(),
                        "checkpoint_map": AnyDict(
                            {
                                "": AnyStr(),
                                AnyStr("child:"): AnyStr(),
                            }
                        ),
                    }
                },
                metadata={
                    "parents": {"": AnyStr()},
                    "source": "loop",
                    "step": 0,
                },
                created_at=AnyStr(),
                parent_config=None,
                interrupts=(),
            )
            grandchild_state = app.get_state(child_state.tasks[0].state)
            assert grandchild_state == StateSnapshot(
                values={"my_key": "hi my value here"},
                tasks=(
                    PregelTask(
                        AnyStr(),
                        "grandchild_2",
                        (PULL, "grandchild_2"),
                    ),
                ),
                next=("grandchild_2",),
                config={
                    "configurable": {
                        "thread_id": "1",
                        "checkpoint_ns": AnyStr(),
                        "checkpoint_id": AnyStr(),
                        "checkpoint_map": AnyDict(
                            {
                                "": AnyStr(),
                                AnyStr("child:"): AnyStr(),
                                AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),
                            }
                        ),
                    }
                },
                metadata={
                    "parents": AnyDict(
                        {
                            "": AnyStr(),
                            AnyStr("child:"): AnyStr(),
                        }
                    ),
                    "source": "loop",
                    "step": 1,
                },
                created_at=AnyStr(),
                parent_config=None,
                interrupts=(),
            )
            # get state with subgraphs
            assert app.get_state(config, subgraphs=True) == StateSnapshot(
                values={"my_key": "hi my value"},
                tasks=(
                    PregelTask(
                        AnyStr(),
                        "child",
                        (PULL, "child"),
                        state=StateSnapshot(
                            values={"my_key": "hi my value"},
                            tasks=(
                                PregelTask(
                                    AnyStr(),
                                    "child_1",
                                    (PULL, "child_1"),
                                    state=StateSnapshot(
                                        values={"my_key": "hi my value here"},
                                        tasks=(
                                            PregelTask(
                                                AnyStr(),
                                                "grandchild_2",
                                                (PULL, "grandchild_2"),
                                            ),
                                        ),
                                        next=("grandchild_2",),
                                        config={
                                            "configurable": {
                                                "thread_id": "1",
                                                "checkpoint_ns": AnyStr(),
                                                "checkpoint_id": AnyStr(),
                                                "checkpoint_map": AnyDict(
                                                    {
                                                        "": AnyStr(),
                                                        AnyStr("child:"): AnyStr(),
                                                        AnyStr(
                                                            re.compile(r"child:.+|child1:")
                                                        ): AnyStr(),
                                                    }
                                                ),
                                            }
                                        },
                                        metadata={
                                            "parents": AnyDict(
                                                {
                                                    "": AnyStr(),
                                                    AnyStr("child:"): AnyStr(),
                                                }
                                            ),
                                            "source": "loop",
                                            "step": 1,
                                        },
                                        created_at=AnyStr(),
                                        parent_config=None,
                                        interrupts=(),
                                    ),
                                ),
                            ),
                            next=("child_1",),
                            config={
                                "configurable": {
                                    "thread_id": "1",
                                    "checkpoint_ns": AnyStr("child:"),
                                    "checkpoint_id": AnyStr(),
                                    "checkpoint_map": AnyDict(
                                        {"": AnyStr(), AnyStr("child:"): AnyStr()}
                                    ),
                                }
                            },
                            metadata={
                                "parents": {"": AnyStr()},
                                "source": "loop",
                                "step": 0,
                            },
                            created_at=AnyStr(),
                            parent_config=None,
                            interrupts=(),
                        ),
                    ),
                ),
                next=("child",),
                config={
                    "configurable": {
                        "thread_id": "1",
                        "checkpoint_ns": "",
                        "checkpoint_id": AnyStr(),
                    }
                },
                metadata={
                    "parents": {},
                    "source": "loop",
                    "step": 1,
                },
                created_at=AnyStr(),
                parent_config=None,
                interrupts=(),
            )
            # # resume
            assert [c for c in app.stream(None, config, subgraphs=True, durability="exit")] == [
                (
                    (AnyStr("child:"), AnyStr("child_1:")),
                    {"grandchild_2": {"my_key": "hi my value here and there"}},
                ),
                ((AnyStr("child:"),), {"child_1": {"my_key": "hi my value here and there"}}),
                ((), {"child": {"my_key": "hi my value here and there"}}),
                ((), {"parent_2": {"my_key": "hi my value here and there and back again"}}),
            ]
            # get state with and without subgraphs
            assert (
                app.get_state(config)
                == app.get_state(config, subgraphs=True)
                == StateSnapshot(
                    values={"my_key": "hi my value here and there and back again"},
                    tasks=(),
                    next=(),
                    config={
                        "configurable": {
                            "thread_id": "1",
                            "checkpoint_ns": "",
                            "checkpoint_id": AnyStr(),
                        }
                    },
                    metadata={
                        "parents": {},
                        "source": "loop",
                        "step": 3,
                    },
                    created_at=AnyStr(),
                    parent_config=(
                        {
                            "configurable": {
                                "thread_id": "1",
                                "checkpoint_ns": "",
                                "checkpoint_id": AnyStr(),
                            }
                        }
                    ),
                    interrupts=(),
                )
            )

            # get outer graph history
            outer_history = list(app.get_state_history(config))
            assert outer_history == [
                StateSnapshot(
                    values={"my_key": "hi my value here and there and back again"},
                    tasks=(),
                    next=(),
                    config={
                        "configurable": {
                            "thread_id": "1",
                            "checkpoint_ns": "",
                            "checkpoint_id": AnyStr(),
                        }
                    },
                    metadata={
                        "parents": {},
                        "source": "loop",
                        "step": 3,
                    },
                    created_at=AnyStr(),
                    parent_config={
                        "configurable": {
                            "thread_id": "1",
                            "checkpoint_ns": "",
                            "checkpoint_id": AnyStr(),
                        }
                    },
                    interrupts=(),
                ),
                StateSnapshot(
                    values={"my_key": "hi my value"},
                    tasks=(
                        PregelTask(
                            AnyStr(),
                            "child",
                            (PULL, "child"),
                            state={
                                "configurable": {
                                    "thread_id": "1",
                                    "checkpoint_ns": AnyStr("child"),
                                }
                            },
                            result=None,
                        ),
                    ),
                    next=("child",),
                    config={
                        "configurable": {
                            "thread_id": "1",
                            "checkpoint_ns": "",
                            "checkpoint_id": AnyStr(),
                        }
                    },
                    metadata={
                        "parents": {},
                        "source": "loop",
                        "step": 1,
                    },
                    created_at=AnyStr(),
                    parent_config=None,
                    interrupts=(),
                ),
            ]
            # get child graph history
            child_history = list(app.get_state_history(outer_history[1].tasks[0].state))
            assert child_history == [
                StateSnapshot(
                    values={"my_key": "hi my value"},
                    next=("child_1",),
                    config={
                        "configurable": {
                            "thread_id": "1",
                            "checkpoint_ns": AnyStr("child:"),
                            "checkpoint_id": AnyStr(),
                            "checkpoint_map": AnyDict(
                                {"": AnyStr(), AnyStr("child:"): AnyStr()}
                            ),
                        }
                    },
                    metadata={
                        "source": "loop",
                        "step": 0,
                        "parents": {"": AnyStr()},
                    },
                    created_at=AnyStr(),
                    parent_config=None,
                    tasks=(
                        PregelTask(
                            id=AnyStr(),
                            name="child_1",
                            path=(PULL, "child_1"),
                            state={
                                "configurable": {
                                    "thread_id": "1",
                                    "checkpoint_ns": AnyStr("child:"),
                                }
                            },
                            result=None,
                        ),
                    ),
                    interrupts=(),
                ),
            ]
            # get grandchild graph history
            grandchild_history = list(app.get_state_history(child_history[0].tasks[0].state))
            assert grandchild_history == [
                StateSnapshot(
                    values={"my_key": "hi my value here"},
                    next=("grandchild_2",),
                    config={
                        "configurable": {
                            "thread_id": "1",
                            "checkpoint_ns": AnyStr(),
                            "checkpoint_id": AnyStr(),
                            "checkpoint_map": AnyDict(
                                {
                                    "": AnyStr(),
                                    AnyStr("child:"): AnyStr(),
                                    AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),
                                }
                            ),
                        }
                    },
                    metadata={
                        "source": "loop",
                        "step": 1,
                        "parents": AnyDict(
                            {
                                "": AnyStr(),
                                AnyStr("child:"): AnyStr(),
                            }
                        ),
                    },
                    created_at=AnyStr(),
                    parent_config=None,
                    tasks=(
                        PregelTask(
                            id=AnyStr(),
                            name="grandchild_2",
                            path=(PULL, "grandchild_2"),
                            result=None,
                        ),
                    ),
                    interrupts=(),
                ),
            ]
      - |-
        def _msgpack_enc(data: Any) -> bytes:
            return ormsgpack.packb(data, default=_msgpack_default, option=_option)
      - |-
        def test_list_namespaces_operations(
            fake_embeddings: CharacterEmbeddings,
        ) -> None:
            """Test list namespaces functionality with various filters."""
            with create_vector_store(
                fake_embeddings, text_fields=["key0", "key1", "key3"]
            ) as store:
                test_pref = str(uuid.uuid4())
                test_namespaces = [
                    (test_pref, "test", "documents", "public", test_pref),
                    (test_pref, "test", "documents", "private", test_pref),
                    (test_pref, "test", "images", "public", test_pref),
                    (test_pref, "test", "images", "private", test_pref),
                    (test_pref, "prod", "documents", "public", test_pref),
                    (test_pref, "prod", "documents", "some", "nesting", "public", test_pref),
                    (test_pref, "prod", "documents", "private", test_pref),
                ]

                # Add test data
                for namespace in test_namespaces:
                    store.put(namespace, "dummy", {"content": "dummy"})

                # Test prefix filtering
                prefix_result = store.list_namespaces(prefix=(test_pref, "test"))
                assert len(prefix_result) == 4
                assert all(ns[1] == "test" for ns in prefix_result)

                # Test specific prefix
                specific_prefix_result = store.list_namespaces(
                    prefix=(test_pref, "test", "documents")
                )
                assert len(specific_prefix_result) == 2
                assert all(ns[1:3] == ("test", "documents") for ns in specific_prefix_result)

                # Test suffix filtering
                suffix_result = store.list_namespaces(suffix=("public", test_pref))
                assert len(suffix_result) == 4
                assert all(ns[-2] == "public" for ns in suffix_result)

                # Test combined prefix and suffix
                prefix_suffix_result = store.list_namespaces(
                    prefix=(test_pref, "test"), suffix=("public", test_pref)
                )
                assert len(prefix_suffix_result) == 2
                assert all(
                    ns[1] == "test" and ns[-2] == "public" for ns in prefix_suffix_result
                )

                # Test wildcard in prefix
                wildcard_prefix_result = store.list_namespaces(
                    prefix=(test_pref, "*", "documents")
                )
                assert len(wildcard_prefix_result) == 5
                assert all(ns[2] == "documents" for ns in wildcard_prefix_result)

                # Test wildcard in suffix
                wildcard_suffix_result = store.list_namespaces(
                    suffix=("*", "public", test_pref)
                )
                assert len(wildcard_suffix_result) == 4
                assert all(ns[-2] == "public" for ns in wildcard_suffix_result)

                wildcard_single = store.list_namespaces(
                    suffix=("some", "*", "public", test_pref)
                )
                assert len(wildcard_single) == 1
                assert wildcard_single[0] == (
                    test_pref,
                    "prod",
                    "documents",
                    "some",
                    "nesting",
                    "public",
                    test_pref,
                )

                # Test max depth
                max_depth_result = store.list_namespaces(max_depth=3)
                assert all(len(ns) <= 3 for ns in max_depth_result)

                max_depth_result = store.list_namespaces(
                    max_depth=4, prefix=(test_pref, "*", "documents")
                )
                assert len(set(res for res in max_depth_result)) == len(max_depth_result) == 5

                # Test pagination
                limit_result = store.list_namespaces(prefix=(test_pref,), limit=3)
                assert len(limit_result) == 3

                offset_result = store.list_namespaces(prefix=(test_pref,), offset=3)
                assert len(offset_result) == len(test_namespaces) - 3

                empty_prefix_result = store.list_namespaces(prefix=(test_pref,))
                assert len(empty_prefix_result) == len(test_namespaces)
                assert set(empty_prefix_result) == set(test_namespaces)

                # Clean up
                for namespace in test_namespaces:
                    store.delete(namespace, "dummy")
pipeline_tag: sentence-similarity
library_name: sentence-transformers
metrics:
  - cosine_accuracy@1
  - cosine_accuracy@3
  - cosine_accuracy@5
  - cosine_accuracy@10
  - cosine_precision@1
  - cosine_precision@3
  - cosine_precision@5
  - cosine_precision@10
  - cosine_recall@1
  - cosine_recall@3
  - cosine_recall@5
  - cosine_recall@10
  - cosine_ndcg@10
  - cosine_mrr@10
  - cosine_map@100
model-index:
  - name: codeBert dense retriever
    results:
      - task:
          type: information-retrieval
          name: Information Retrieval
        dataset:
          name: dim 768
          type: dim_768
        metrics:
          - type: cosine_accuracy@1
            value: 0.9
            name: Cosine Accuracy@1
          - type: cosine_accuracy@3
            value: 0.9
            name: Cosine Accuracy@3
          - type: cosine_accuracy@5
            value: 1
            name: Cosine Accuracy@5
          - type: cosine_accuracy@10
            value: 1
            name: Cosine Accuracy@10
          - type: cosine_precision@1
            value: 0.9
            name: Cosine Precision@1
          - type: cosine_precision@3
            value: 0.29999999999999993
            name: Cosine Precision@3
          - type: cosine_precision@5
            value: 0.20000000000000004
            name: Cosine Precision@5
          - type: cosine_precision@10
            value: 0.10000000000000002
            name: Cosine Precision@10
          - type: cosine_recall@1
            value: 0.9
            name: Cosine Recall@1
          - type: cosine_recall@3
            value: 0.9
            name: Cosine Recall@3
          - type: cosine_recall@5
            value: 1
            name: Cosine Recall@5
          - type: cosine_recall@10
            value: 1
            name: Cosine Recall@10
          - type: cosine_ndcg@10
            value: 0.9408764682653967
            name: Cosine Ndcg@10
          - type: cosine_mrr@10
            value: 0.9225
            name: Cosine Mrr@10
          - type: cosine_map@100
            value: 0.9225
            name: Cosine Map@100
      - task:
          type: information-retrieval
          name: Information Retrieval
        dataset:
          name: dim 512
          type: dim_512
        metrics:
          - type: cosine_accuracy@1
            value: 0.9
            name: Cosine Accuracy@1
          - type: cosine_accuracy@3
            value: 0.9
            name: Cosine Accuracy@3
          - type: cosine_accuracy@5
            value: 1
            name: Cosine Accuracy@5
          - type: cosine_accuracy@10
            value: 1
            name: Cosine Accuracy@10
          - type: cosine_precision@1
            value: 0.9
            name: Cosine Precision@1
          - type: cosine_precision@3
            value: 0.29999999999999993
            name: Cosine Precision@3
          - type: cosine_precision@5
            value: 0.20000000000000004
            name: Cosine Precision@5
          - type: cosine_precision@10
            value: 0.10000000000000002
            name: Cosine Precision@10
          - type: cosine_recall@1
            value: 0.9
            name: Cosine Recall@1
          - type: cosine_recall@3
            value: 0.9
            name: Cosine Recall@3
          - type: cosine_recall@5
            value: 1
            name: Cosine Recall@5
          - type: cosine_recall@10
            value: 1
            name: Cosine Recall@10
          - type: cosine_ndcg@10
            value: 0.9408764682653967
            name: Cosine Ndcg@10
          - type: cosine_mrr@10
            value: 0.9225
            name: Cosine Mrr@10
          - type: cosine_map@100
            value: 0.9225
            name: Cosine Map@100
      - task:
          type: information-retrieval
          name: Information Retrieval
        dataset:
          name: dim 256
          type: dim_256
        metrics:
          - type: cosine_accuracy@1
            value: 0.9
            name: Cosine Accuracy@1
          - type: cosine_accuracy@3
            value: 0.9
            name: Cosine Accuracy@3
          - type: cosine_accuracy@5
            value: 1
            name: Cosine Accuracy@5
          - type: cosine_accuracy@10
            value: 1
            name: Cosine Accuracy@10
          - type: cosine_precision@1
            value: 0.9
            name: Cosine Precision@1
          - type: cosine_precision@3
            value: 0.29999999999999993
            name: Cosine Precision@3
          - type: cosine_precision@5
            value: 0.20000000000000004
            name: Cosine Precision@5
          - type: cosine_precision@10
            value: 0.10000000000000002
            name: Cosine Precision@10
          - type: cosine_recall@1
            value: 0.9
            name: Cosine Recall@1
          - type: cosine_recall@3
            value: 0.9
            name: Cosine Recall@3
          - type: cosine_recall@5
            value: 1
            name: Cosine Recall@5
          - type: cosine_recall@10
            value: 1
            name: Cosine Recall@10
          - type: cosine_ndcg@10
            value: 0.9408764682653967
            name: Cosine Ndcg@10
          - type: cosine_mrr@10
            value: 0.9225
            name: Cosine Mrr@10
          - type: cosine_map@100
            value: 0.9225
            name: Cosine Map@100
      - task:
          type: information-retrieval
          name: Information Retrieval
        dataset:
          name: dim 128
          type: dim_128
        metrics:
          - type: cosine_accuracy@1
            value: 0.85
            name: Cosine Accuracy@1
          - type: cosine_accuracy@3
            value: 0.9
            name: Cosine Accuracy@3
          - type: cosine_accuracy@5
            value: 0.95
            name: Cosine Accuracy@5
          - type: cosine_accuracy@10
            value: 0.95
            name: Cosine Accuracy@10
          - type: cosine_precision@1
            value: 0.85
            name: Cosine Precision@1
          - type: cosine_precision@3
            value: 0.29999999999999993
            name: Cosine Precision@3
          - type: cosine_precision@5
            value: 0.19000000000000003
            name: Cosine Precision@5
          - type: cosine_precision@10
            value: 0.09500000000000001
            name: Cosine Precision@10
          - type: cosine_recall@1
            value: 0.85
            name: Cosine Recall@1
          - type: cosine_recall@3
            value: 0.9
            name: Cosine Recall@3
          - type: cosine_recall@5
            value: 0.95
            name: Cosine Recall@5
          - type: cosine_recall@10
            value: 0.95
            name: Cosine Recall@10
          - type: cosine_ndcg@10
            value: 0.894342640361727
            name: Cosine Ndcg@10
          - type: cosine_mrr@10
            value: 0.8766666666666666
            name: Cosine Mrr@10
          - type: cosine_map@100
            value: 0.8799999999999999
            name: Cosine Map@100
      - task:
          type: information-retrieval
          name: Information Retrieval
        dataset:
          name: dim 64
          type: dim_64
        metrics:
          - type: cosine_accuracy@1
            value: 0.85
            name: Cosine Accuracy@1
          - type: cosine_accuracy@3
            value: 0.9
            name: Cosine Accuracy@3
          - type: cosine_accuracy@5
            value: 0.9
            name: Cosine Accuracy@5
          - type: cosine_accuracy@10
            value: 1
            name: Cosine Accuracy@10
          - type: cosine_precision@1
            value: 0.85
            name: Cosine Precision@1
          - type: cosine_precision@3
            value: 0.29999999999999993
            name: Cosine Precision@3
          - type: cosine_precision@5
            value: 0.18000000000000005
            name: Cosine Precision@5
          - type: cosine_precision@10
            value: 0.10000000000000002
            name: Cosine Precision@10
          - type: cosine_recall@1
            value: 0.85
            name: Cosine Recall@1
          - type: cosine_recall@3
            value: 0.9
            name: Cosine Recall@3
          - type: cosine_recall@5
            value: 0.9
            name: Cosine Recall@5
          - type: cosine_recall@10
            value: 1
            name: Cosine Recall@10
          - type: cosine_ndcg@10
            value: 0.9074399105059531
            name: Cosine Ndcg@10
          - type: cosine_mrr@10
            value: 0.8800595238095237
            name: Cosine Mrr@10
          - type: cosine_map@100
            value: 0.8800595238095237
            name: Cosine Map@100

codeBert dense retriever

This is a sentence-transformers model finetuned from shubharuidas/codebert-embed-base-dense-retriever. It maps sentences & paragraphs to a 768-dimensional dense vector space and can be used for semantic textual similarity, semantic search, paraphrase mining, text classification, clustering, and more.

Model Details

Model Description

  • Model Type: Sentence Transformer
  • Base model: shubharuidas/codebert-embed-base-dense-retriever
  • Maximum Sequence Length: 512 tokens
  • Output Dimensionality: 768 dimensions
  • Similarity Function: Cosine Similarity
  • Language: en
  • License: apache-2.0

Model Sources

Full Model Architecture

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False, 'architecture': 'RobertaModel'})
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
)

Usage

Direct Usage (Sentence Transformers)

First install the Sentence Transformers library:

pip install -U sentence-transformers

Then you can load this model and run inference.

from sentence_transformers import SentenceTransformer

# Download from the 🤗 Hub
model = SentenceTransformer("anaghaj111/codebert-base-code-embed-mrl-langchain-langgraph")
# Run inference
sentences = [
    'Best practices for test_list_namespaces_operations',
    'def test_list_namespaces_operations(\n    fake_embeddings: CharacterEmbeddings,\n) -> None:\n    """Test list namespaces functionality with various filters."""\n    with create_vector_store(\n        fake_embeddings, text_fields=["key0", "key1", "key3"]\n    ) as store:\n        test_pref = str(uuid.uuid4())\n        test_namespaces = [\n            (test_pref, "test", "documents", "public", test_pref),\n            (test_pref, "test", "documents", "private", test_pref),\n            (test_pref, "test", "images", "public", test_pref),\n            (test_pref, "test", "images", "private", test_pref),\n            (test_pref, "prod", "documents", "public", test_pref),\n            (test_pref, "prod", "documents", "some", "nesting", "public", test_pref),\n            (test_pref, "prod", "documents", "private", test_pref),\n        ]\n\n        # Add test data\n        for namespace in test_namespaces:\n            store.put(namespace, "dummy", {"content": "dummy"})\n\n        # Test prefix filtering\n        prefix_result = store.list_namespaces(prefix=(test_pref, "test"))\n        assert len(prefix_result) == 4\n        assert all(ns[1] == "test" for ns in prefix_result)\n\n        # Test specific prefix\n        specific_prefix_result = store.list_namespaces(\n            prefix=(test_pref, "test", "documents")\n        )\n        assert len(specific_prefix_result) == 2\n        assert all(ns[1:3] == ("test", "documents") for ns in specific_prefix_result)\n\n        # Test suffix filtering\n        suffix_result = store.list_namespaces(suffix=("public", test_pref))\n        assert len(suffix_result) == 4\n        assert all(ns[-2] == "public" for ns in suffix_result)\n\n        # Test combined prefix and suffix\n        prefix_suffix_result = store.list_namespaces(\n            prefix=(test_pref, "test"), suffix=("public", test_pref)\n        )\n        assert len(prefix_suffix_result) == 2\n        assert all(\n            ns[1] == "test" and ns[-2] == "public" for ns in prefix_suffix_result\n        )\n\n        # Test wildcard in prefix\n        wildcard_prefix_result = store.list_namespaces(\n            prefix=(test_pref, "*", "documents")\n        )\n        assert len(wildcard_prefix_result) == 5\n        assert all(ns[2] == "documents" for ns in wildcard_prefix_result)\n\n        # Test wildcard in suffix\n        wildcard_suffix_result = store.list_namespaces(\n            suffix=("*", "public", test_pref)\n        )\n        assert len(wildcard_suffix_result) == 4\n        assert all(ns[-2] == "public" for ns in wildcard_suffix_result)\n\n        wildcard_single = store.list_namespaces(\n            suffix=("some", "*", "public", test_pref)\n        )\n        assert len(wildcard_single) == 1\n        assert wildcard_single[0] == (\n            test_pref,\n            "prod",\n            "documents",\n            "some",\n            "nesting",\n            "public",\n            test_pref,\n        )\n\n        # Test max depth\n        max_depth_result = store.list_namespaces(max_depth=3)\n        assert all(len(ns) <= 3 for ns in max_depth_result)\n\n        max_depth_result = store.list_namespaces(\n            max_depth=4, prefix=(test_pref, "*", "documents")\n        )\n        assert len(set(res for res in max_depth_result)) == len(max_depth_result) == 5\n\n        # Test pagination\n        limit_result = store.list_namespaces(prefix=(test_pref,), limit=3)\n        assert len(limit_result) == 3\n\n        offset_result = store.list_namespaces(prefix=(test_pref,), offset=3)\n        assert len(offset_result) == len(test_namespaces) - 3\n\n        empty_prefix_result = store.list_namespaces(prefix=(test_pref,))\n        assert len(empty_prefix_result) == len(test_namespaces)\n        assert set(empty_prefix_result) == set(test_namespaces)\n\n        # Clean up\n        for namespace in test_namespaces:\n            store.delete(namespace, "dummy")',
    'def test_doubly_nested_graph_state(\n    sync_checkpointer: BaseCheckpointSaver,\n) -> None:\n    class State(TypedDict):\n        my_key: str\n\n    class ChildState(TypedDict):\n        my_key: str\n\n    class GrandChildState(TypedDict):\n        my_key: str\n\n    def grandchild_1(state: ChildState):\n        return {"my_key": state["my_key"] + " here"}\n\n    def grandchild_2(state: ChildState):\n        return {\n            "my_key": state["my_key"] + " and there",\n        }\n\n    grandchild = StateGraph(GrandChildState)\n    grandchild.add_node("grandchild_1", grandchild_1)\n    grandchild.add_node("grandchild_2", grandchild_2)\n    grandchild.add_edge("grandchild_1", "grandchild_2")\n    grandchild.set_entry_point("grandchild_1")\n    grandchild.set_finish_point("grandchild_2")\n\n    child = StateGraph(ChildState)\n    child.add_node(\n        "child_1",\n        grandchild.compile(interrupt_before=["grandchild_2"]),\n    )\n    child.set_entry_point("child_1")\n    child.set_finish_point("child_1")\n\n    def parent_1(state: State):\n        return {"my_key": "hi " + state["my_key"]}\n\n    def parent_2(state: State):\n        return {"my_key": state["my_key"] + " and back again"}\n\n    graph = StateGraph(State)\n    graph.add_node("parent_1", parent_1)\n    graph.add_node("child", child.compile())\n    graph.add_node("parent_2", parent_2)\n    graph.set_entry_point("parent_1")\n    graph.add_edge("parent_1", "child")\n    graph.add_edge("child", "parent_2")\n    graph.set_finish_point("parent_2")\n\n    app = graph.compile(checkpointer=sync_checkpointer)\n\n    # test invoke w/ nested interrupt\n    config = {"configurable": {"thread_id": "1"}}\n    assert [\n        c\n        for c in app.stream(\n            {"my_key": "my value"}, config, subgraphs=True, durability="exit"\n        )\n    ] == [\n        ((), {"parent_1": {"my_key": "hi my value"}}),\n        (\n            (AnyStr("child:"), AnyStr("child_1:")),\n            {"grandchild_1": {"my_key": "hi my value here"}},\n        ),\n        ((), {"__interrupt__": ()}),\n    ]\n    # get state without subgraphs\n    outer_state = app.get_state(config)\n    assert outer_state == StateSnapshot(\n        values={"my_key": "hi my value"},\n        tasks=(\n            PregelTask(\n                AnyStr(),\n                "child",\n                (PULL, "child"),\n                state={\n                    "configurable": {\n                        "thread_id": "1",\n                        "checkpoint_ns": AnyStr("child"),\n                    }\n                },\n            ),\n        ),\n        next=("child",),\n        config={\n            "configurable": {\n                "thread_id": "1",\n                "checkpoint_ns": "",\n                "checkpoint_id": AnyStr(),\n            }\n        },\n        metadata={\n            "parents": {},\n            "source": "loop",\n            "step": 1,\n        },\n        created_at=AnyStr(),\n        parent_config=None,\n        interrupts=(),\n    )\n    child_state = app.get_state(outer_state.tasks[0].state)\n    assert child_state == StateSnapshot(\n        values={"my_key": "hi my value"},\n        tasks=(\n            PregelTask(\n                AnyStr(),\n                "child_1",\n                (PULL, "child_1"),\n                state={\n                    "configurable": {\n                        "thread_id": "1",\n                        "checkpoint_ns": AnyStr(),\n                    }\n                },\n            ),\n        ),\n        next=("child_1",),\n        config={\n            "configurable": {\n                "thread_id": "1",\n                "checkpoint_ns": AnyStr("child:"),\n                "checkpoint_id": AnyStr(),\n                "checkpoint_map": AnyDict(\n                    {\n                        "": AnyStr(),\n                        AnyStr("child:"): AnyStr(),\n                    }\n                ),\n            }\n        },\n        metadata={\n            "parents": {"": AnyStr()},\n            "source": "loop",\n            "step": 0,\n        },\n        created_at=AnyStr(),\n        parent_config=None,\n        interrupts=(),\n    )\n    grandchild_state = app.get_state(child_state.tasks[0].state)\n    assert grandchild_state == StateSnapshot(\n        values={"my_key": "hi my value here"},\n        tasks=(\n            PregelTask(\n                AnyStr(),\n                "grandchild_2",\n                (PULL, "grandchild_2"),\n            ),\n        ),\n        next=("grandchild_2",),\n        config={\n            "configurable": {\n                "thread_id": "1",\n                "checkpoint_ns": AnyStr(),\n                "checkpoint_id": AnyStr(),\n                "checkpoint_map": AnyDict(\n                    {\n                        "": AnyStr(),\n                        AnyStr("child:"): AnyStr(),\n                        AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),\n                    }\n                ),\n            }\n        },\n        metadata={\n            "parents": AnyDict(\n                {\n                    "": AnyStr(),\n                    AnyStr("child:"): AnyStr(),\n                }\n            ),\n            "source": "loop",\n            "step": 1,\n        },\n        created_at=AnyStr(),\n        parent_config=None,\n        interrupts=(),\n    )\n    # get state with subgraphs\n    assert app.get_state(config, subgraphs=True) == StateSnapshot(\n        values={"my_key": "hi my value"},\n        tasks=(\n            PregelTask(\n                AnyStr(),\n                "child",\n                (PULL, "child"),\n                state=StateSnapshot(\n                    values={"my_key": "hi my value"},\n                    tasks=(\n                        PregelTask(\n                            AnyStr(),\n                            "child_1",\n                            (PULL, "child_1"),\n                            state=StateSnapshot(\n                                values={"my_key": "hi my value here"},\n                                tasks=(\n                                    PregelTask(\n                                        AnyStr(),\n                                        "grandchild_2",\n                                        (PULL, "grandchild_2"),\n                                    ),\n                                ),\n                                next=("grandchild_2",),\n                                config={\n                                    "configurable": {\n                                        "thread_id": "1",\n                                        "checkpoint_ns": AnyStr(),\n                                        "checkpoint_id": AnyStr(),\n                                        "checkpoint_map": AnyDict(\n                                            {\n                                                "": AnyStr(),\n                                                AnyStr("child:"): AnyStr(),\n                                                AnyStr(\n                                                    re.compile(r"child:.+|child1:")\n                                                ): AnyStr(),\n                                            }\n                                        ),\n                                    }\n                                },\n                                metadata={\n                                    "parents": AnyDict(\n                                        {\n                                            "": AnyStr(),\n                                            AnyStr("child:"): AnyStr(),\n                                        }\n                                    ),\n                                    "source": "loop",\n                                    "step": 1,\n                                },\n                                created_at=AnyStr(),\n                                parent_config=None,\n                                interrupts=(),\n                            ),\n                        ),\n                    ),\n                    next=("child_1",),\n                    config={\n                        "configurable": {\n                            "thread_id": "1",\n                            "checkpoint_ns": AnyStr("child:"),\n                            "checkpoint_id": AnyStr(),\n                            "checkpoint_map": AnyDict(\n                                {"": AnyStr(), AnyStr("child:"): AnyStr()}\n                            ),\n                        }\n                    },\n                    metadata={\n                        "parents": {"": AnyStr()},\n                        "source": "loop",\n                        "step": 0,\n                    },\n                    created_at=AnyStr(),\n                    parent_config=None,\n                    interrupts=(),\n                ),\n            ),\n        ),\n        next=("child",),\n        config={\n            "configurable": {\n                "thread_id": "1",\n                "checkpoint_ns": "",\n                "checkpoint_id": AnyStr(),\n            }\n        },\n        metadata={\n            "parents": {},\n            "source": "loop",\n            "step": 1,\n        },\n        created_at=AnyStr(),\n        parent_config=None,\n        interrupts=(),\n    )\n    # # resume\n    assert [c for c in app.stream(None, config, subgraphs=True, durability="exit")] == [\n        (\n            (AnyStr("child:"), AnyStr("child_1:")),\n            {"grandchild_2": {"my_key": "hi my value here and there"}},\n        ),\n        ((AnyStr("child:"),), {"child_1": {"my_key": "hi my value here and there"}}),\n        ((), {"child": {"my_key": "hi my value here and there"}}),\n        ((), {"parent_2": {"my_key": "hi my value here and there and back again"}}),\n    ]\n    # get state with and without subgraphs\n    assert (\n        app.get_state(config)\n        == app.get_state(config, subgraphs=True)\n        == StateSnapshot(\n            values={"my_key": "hi my value here and there and back again"},\n            tasks=(),\n            next=(),\n            config={\n                "configurable": {\n                    "thread_id": "1",\n                    "checkpoint_ns": "",\n                    "checkpoint_id": AnyStr(),\n                }\n            },\n            metadata={\n                "parents": {},\n                "source": "loop",\n                "step": 3,\n            },\n            created_at=AnyStr(),\n            parent_config=(\n                {\n                    "configurable": {\n                        "thread_id": "1",\n                        "checkpoint_ns": "",\n                        "checkpoint_id": AnyStr(),\n                    }\n                }\n            ),\n            interrupts=(),\n        )\n    )\n\n    # get outer graph history\n    outer_history = list(app.get_state_history(config))\n    assert outer_history == [\n        StateSnapshot(\n            values={"my_key": "hi my value here and there and back again"},\n            tasks=(),\n            next=(),\n            config={\n                "configurable": {\n                    "thread_id": "1",\n                    "checkpoint_ns": "",\n                    "checkpoint_id": AnyStr(),\n                }\n            },\n            metadata={\n                "parents": {},\n                "source": "loop",\n                "step": 3,\n            },\n            created_at=AnyStr(),\n            parent_config={\n                "configurable": {\n                    "thread_id": "1",\n                    "checkpoint_ns": "",\n                    "checkpoint_id": AnyStr(),\n                }\n            },\n            interrupts=(),\n        ),\n        StateSnapshot(\n            values={"my_key": "hi my value"},\n            tasks=(\n                PregelTask(\n                    AnyStr(),\n                    "child",\n                    (PULL, "child"),\n                    state={\n                        "configurable": {\n                            "thread_id": "1",\n                            "checkpoint_ns": AnyStr("child"),\n                        }\n                    },\n                    result=None,\n                ),\n            ),\n            next=("child",),\n            config={\n                "configurable": {\n                    "thread_id": "1",\n                    "checkpoint_ns": "",\n                    "checkpoint_id": AnyStr(),\n                }\n            },\n            metadata={\n                "parents": {},\n                "source": "loop",\n                "step": 1,\n            },\n            created_at=AnyStr(),\n            parent_config=None,\n            interrupts=(),\n        ),\n    ]\n    # get child graph history\n    child_history = list(app.get_state_history(outer_history[1].tasks[0].state))\n    assert child_history == [\n        StateSnapshot(\n            values={"my_key": "hi my value"},\n            next=("child_1",),\n            config={\n                "configurable": {\n                    "thread_id": "1",\n                    "checkpoint_ns": AnyStr("child:"),\n                    "checkpoint_id": AnyStr(),\n                    "checkpoint_map": AnyDict(\n                        {"": AnyStr(), AnyStr("child:"): AnyStr()}\n                    ),\n                }\n            },\n            metadata={\n                "source": "loop",\n                "step": 0,\n                "parents": {"": AnyStr()},\n            },\n            created_at=AnyStr(),\n            parent_config=None,\n            tasks=(\n                PregelTask(\n                    id=AnyStr(),\n                    name="child_1",\n                    path=(PULL, "child_1"),\n                    state={\n                        "configurable": {\n                            "thread_id": "1",\n                            "checkpoint_ns": AnyStr("child:"),\n                        }\n                    },\n                    result=None,\n                ),\n            ),\n            interrupts=(),\n        ),\n    ]\n    # get grandchild graph history\n    grandchild_history = list(app.get_state_history(child_history[0].tasks[0].state))\n    assert grandchild_history == [\n        StateSnapshot(\n            values={"my_key": "hi my value here"},\n            next=("grandchild_2",),\n            config={\n                "configurable": {\n                    "thread_id": "1",\n                    "checkpoint_ns": AnyStr(),\n                    "checkpoint_id": AnyStr(),\n                    "checkpoint_map": AnyDict(\n                        {\n                            "": AnyStr(),\n                            AnyStr("child:"): AnyStr(),\n                            AnyStr(re.compile(r"child:.+|child1:")): AnyStr(),\n                        }\n                    ),\n                }\n            },\n            metadata={\n                "source": "loop",\n                "step": 1,\n                "parents": AnyDict(\n                    {\n                        "": AnyStr(),\n                        AnyStr("child:"): AnyStr(),\n                    }\n                ),\n            },\n            created_at=AnyStr(),\n            parent_config=None,\n            tasks=(\n                PregelTask(\n                    id=AnyStr(),\n                    name="grandchild_2",\n                    path=(PULL, "grandchild_2"),\n                    result=None,\n                ),\n            ),\n            interrupts=(),\n        ),\n    ]',
]
embeddings = model.encode(sentences)
print(embeddings.shape)
# [3, 768]

# Get the similarity scores for the embeddings
similarities = model.similarity(embeddings, embeddings)
print(similarities)
# tensor([[1.0000, 0.7789, 0.3589],
#         [0.7789, 1.0000, 0.4748],
#         [0.3589, 0.4748, 1.0000]])

Evaluation

Metrics

Information Retrieval

Metric Value
cosine_accuracy@1 0.9
cosine_accuracy@3 0.9
cosine_accuracy@5 1.0
cosine_accuracy@10 1.0
cosine_precision@1 0.9
cosine_precision@3 0.3
cosine_precision@5 0.2
cosine_precision@10 0.1
cosine_recall@1 0.9
cosine_recall@3 0.9
cosine_recall@5 1.0
cosine_recall@10 1.0
cosine_ndcg@10 0.9409
cosine_mrr@10 0.9225
cosine_map@100 0.9225

Information Retrieval

Metric Value
cosine_accuracy@1 0.9
cosine_accuracy@3 0.9
cosine_accuracy@5 1.0
cosine_accuracy@10 1.0
cosine_precision@1 0.9
cosine_precision@3 0.3
cosine_precision@5 0.2
cosine_precision@10 0.1
cosine_recall@1 0.9
cosine_recall@3 0.9
cosine_recall@5 1.0
cosine_recall@10 1.0
cosine_ndcg@10 0.9409
cosine_mrr@10 0.9225
cosine_map@100 0.9225

Information Retrieval

Metric Value
cosine_accuracy@1 0.9
cosine_accuracy@3 0.9
cosine_accuracy@5 1.0
cosine_accuracy@10 1.0
cosine_precision@1 0.9
cosine_precision@3 0.3
cosine_precision@5 0.2
cosine_precision@10 0.1
cosine_recall@1 0.9
cosine_recall@3 0.9
cosine_recall@5 1.0
cosine_recall@10 1.0
cosine_ndcg@10 0.9409
cosine_mrr@10 0.9225
cosine_map@100 0.9225

Information Retrieval

Metric Value
cosine_accuracy@1 0.85
cosine_accuracy@3 0.9
cosine_accuracy@5 0.95
cosine_accuracy@10 0.95
cosine_precision@1 0.85
cosine_precision@3 0.3
cosine_precision@5 0.19
cosine_precision@10 0.095
cosine_recall@1 0.85
cosine_recall@3 0.9
cosine_recall@5 0.95
cosine_recall@10 0.95
cosine_ndcg@10 0.8943
cosine_mrr@10 0.8767
cosine_map@100 0.88

Information Retrieval

Metric Value
cosine_accuracy@1 0.85
cosine_accuracy@3 0.9
cosine_accuracy@5 0.9
cosine_accuracy@10 1.0
cosine_precision@1 0.85
cosine_precision@3 0.3
cosine_precision@5 0.18
cosine_precision@10 0.1
cosine_recall@1 0.85
cosine_recall@3 0.9
cosine_recall@5 0.9
cosine_recall@10 1.0
cosine_ndcg@10 0.9074
cosine_mrr@10 0.8801
cosine_map@100 0.8801

Training Details

Training Dataset

Unnamed Dataset

  • Size: 180 training samples
  • Columns: anchor and positive
  • Approximate statistics based on the first 180 samples:
    anchor positive
    type string string
    details
    • min: 6 tokens
    • mean: 12.34 tokens
    • max: 117 tokens
    • min: 14 tokens
    • mean: 273.18 tokens
    • max: 512 tokens
  • Samples:
    anchor positive
    How to implement State? class State(TypedDict):
    messages: Annotated[list[str], operator.add]
    Best practices for test_sql_injection_vulnerability def test_sql_injection_vulnerability(store: SqliteStore) -> None:
    """Test that SQL injection via malicious filter keys is prevented."""
    # Add public and private documents
    store.put(("docs",), "public", {"access": "public", "data": "public info"})
    store.put(
    ("docs",), "private", {"access": "private", "data": "secret", "password": "123"}
    )

    # Normal query - returns 1 public document
    normal = store.search(("docs",), filter={"access": "public"})
    assert len(normal) == 1
    assert normal[0].value["access"] == "public"

    # SQL injection attempt via malicious key should raise ValueError
    malicious_key = "access') = 'public' OR '1'='1' OR json_extract(value, '$."

    with pytest.raises(ValueError, match="Invalid filter key"):
    store.search(("docs",), filter={malicious_key: "dummy"})
    Example usage of put_writes def put_writes(
    self,
    config: RunnableConfig,
    writes: Sequence[tuple[str, Any]],
    task_id: str,
    task_path: str = "",
    ) -> None:
    """Store intermediate writes linked to a checkpoint.

    This method saves intermediate writes associated with a checkpoint to the Postgres database.

    Args:
    config: Configuration of the related checkpoint.
    writes: List of writes to store.
    task_id: Identifier for the task creating the writes.
    """
    query = (
    self.UPSERT_CHECKPOINT_WRITES_SQL
    if all(w[0] in WRITES_IDX_MAP for w in writes)
    else self.INSERT_CHECKPOINT_WRITES_SQL
    )
    with self._cursor(pipeline=True) as cur:
    cur.executemany(
    query,
    self._dump_writes(
    config["configurable"]["thread_id"],
    config["configurable"]["checkpoint_ns"],
    config["c...
  • Loss: MatryoshkaLoss with these parameters:
    {
        "loss": "MultipleNegativesRankingLoss",
        "matryoshka_dims": [
            768,
            512,
            256,
            128,
            64
        ],
        "matryoshka_weights": [
            1,
            1,
            1,
            1,
            1
        ],
        "n_dims_per_step": -1
    }
    

Training Hyperparameters

Non-Default Hyperparameters

  • eval_strategy: epoch
  • per_device_train_batch_size: 4
  • per_device_eval_batch_size: 4
  • gradient_accumulation_steps: 16
  • learning_rate: 2e-05
  • num_train_epochs: 2
  • lr_scheduler_type: cosine
  • warmup_ratio: 0.1
  • fp16: True
  • load_best_model_at_end: True
  • optim: adamw_torch
  • batch_sampler: no_duplicates

All Hyperparameters

Click to expand
  • overwrite_output_dir: False
  • do_predict: False
  • eval_strategy: epoch
  • prediction_loss_only: True
  • per_device_train_batch_size: 4
  • per_device_eval_batch_size: 4
  • per_gpu_train_batch_size: None
  • per_gpu_eval_batch_size: None
  • gradient_accumulation_steps: 16
  • eval_accumulation_steps: None
  • torch_empty_cache_steps: None
  • learning_rate: 2e-05
  • weight_decay: 0.0
  • adam_beta1: 0.9
  • adam_beta2: 0.999
  • adam_epsilon: 1e-08
  • max_grad_norm: 1.0
  • num_train_epochs: 2
  • max_steps: -1
  • lr_scheduler_type: cosine
  • lr_scheduler_kwargs: {}
  • warmup_ratio: 0.1
  • warmup_steps: 0
  • log_level: passive
  • log_level_replica: warning
  • log_on_each_node: True
  • logging_nan_inf_filter: True
  • save_safetensors: True
  • save_on_each_node: False
  • save_only_model: False
  • restore_callback_states_from_checkpoint: False
  • no_cuda: False
  • use_cpu: False
  • use_mps_device: False
  • seed: 42
  • data_seed: None
  • jit_mode_eval: False
  • bf16: False
  • fp16: True
  • fp16_opt_level: O1
  • half_precision_backend: auto
  • bf16_full_eval: False
  • fp16_full_eval: False
  • tf32: None
  • local_rank: 0
  • ddp_backend: None
  • tpu_num_cores: None
  • tpu_metrics_debug: False
  • debug: []
  • dataloader_drop_last: False
  • dataloader_num_workers: 0
  • dataloader_prefetch_factor: None
  • past_index: -1
  • disable_tqdm: False
  • remove_unused_columns: True
  • label_names: None
  • load_best_model_at_end: True
  • ignore_data_skip: False
  • fsdp: []
  • fsdp_min_num_params: 0
  • fsdp_config: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}
  • fsdp_transformer_layer_cls_to_wrap: None
  • accelerator_config: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}
  • parallelism_config: None
  • deepspeed: None
  • label_smoothing_factor: 0.0
  • optim: adamw_torch
  • optim_args: None
  • adafactor: False
  • group_by_length: False
  • length_column_name: length
  • project: huggingface
  • trackio_space_id: trackio
  • ddp_find_unused_parameters: None
  • ddp_bucket_cap_mb: None
  • ddp_broadcast_buffers: False
  • dataloader_pin_memory: True
  • dataloader_persistent_workers: False
  • skip_memory_metrics: True
  • use_legacy_prediction_loop: False
  • push_to_hub: False
  • resume_from_checkpoint: None
  • hub_model_id: None
  • hub_strategy: every_save
  • hub_private_repo: None
  • hub_always_push: False
  • hub_revision: None
  • gradient_checkpointing: False
  • gradient_checkpointing_kwargs: None
  • include_inputs_for_metrics: False
  • include_for_metrics: []
  • eval_do_concat_batches: True
  • fp16_backend: auto
  • push_to_hub_model_id: None
  • push_to_hub_organization: None
  • mp_parameters:
  • auto_find_batch_size: False
  • full_determinism: False
  • torchdynamo: None
  • ray_scope: last
  • ddp_timeout: 1800
  • torch_compile: False
  • torch_compile_backend: None
  • torch_compile_mode: None
  • include_tokens_per_second: False
  • include_num_input_tokens_seen: no
  • neftune_noise_alpha: None
  • optim_target_modules: None
  • batch_eval_metrics: False
  • eval_on_start: False
  • use_liger_kernel: False
  • liger_kernel_config: None
  • eval_use_gather_object: False
  • average_tokens_across_devices: True
  • prompts: None
  • batch_sampler: no_duplicates
  • multi_dataset_batch_sampler: proportional
  • router_mapping: {}
  • learning_rate_mapping: {}

Training Logs

Epoch Step dim_768_cosine_ndcg@10 dim_512_cosine_ndcg@10 dim_256_cosine_ndcg@10 dim_128_cosine_ndcg@10 dim_64_cosine_ndcg@10
1.0 3 0.9409 0.9202 0.9431 0.8412 0.9059
2.0 6 0.9409 0.9409 0.9409 0.8943 0.9074
  • The bold row denotes the saved checkpoint.

Framework Versions

  • Python: 3.14.0
  • Sentence Transformers: 5.2.2
  • Transformers: 4.57.3
  • PyTorch: 2.9.1
  • Accelerate: 1.12.0
  • Datasets: 4.5.0
  • Tokenizers: 0.22.2

Citation

BibTeX

Sentence Transformers

@inproceedings{reimers-2019-sentence-bert,
    title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
    author = "Reimers, Nils and Gurevych, Iryna",
    booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
    month = "11",
    year = "2019",
    publisher = "Association for Computational Linguistics",
    url = "https://arxiv.org/abs/1908.10084",
}

MatryoshkaLoss

@misc{kusupati2024matryoshka,
    title={Matryoshka Representation Learning},
    author={Aditya Kusupati and Gantavya Bhatt and Aniket Rege and Matthew Wallingford and Aditya Sinha and Vivek Ramanujan and William Howard-Snyder and Kaifeng Chen and Sham Kakade and Prateek Jain and Ali Farhadi},
    year={2024},
    eprint={2205.13147},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

MultipleNegativesRankingLoss

@misc{henderson2017efficient,
    title={Efficient Natural Language Response Suggestion for Smart Reply},
    author={Matthew Henderson and Rami Al-Rfou and Brian Strope and Yun-hsuan Sung and Laszlo Lukacs and Ruiqi Guo and Sanjiv Kumar and Balint Miklos and Ray Kurzweil},
    year={2017},
    eprint={1705.00652},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}