Spaces:
Runtime error
Runtime error
| """Test Tracer classes.""" | |
| from __future__ import annotations | |
| import threading | |
| from datetime import datetime | |
| from typing import List, Optional, Union | |
| import pytest | |
| from freezegun import freeze_time | |
| from langchain.callbacks.tracers.base import ( | |
| BaseTracer, | |
| ChainRun, | |
| LLMRun, | |
| SharedTracer, | |
| ToolRun, | |
| Tracer, | |
| TracerException, | |
| TracerSession, | |
| ) | |
| from langchain.callbacks.tracers.schemas import TracerSessionCreate | |
| from langchain.schema import LLMResult | |
| TEST_SESSION_ID = 2023 | |
| def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]: | |
| return ChainRun( | |
| id=None, | |
| error=None, | |
| start_time=datetime.utcnow(), | |
| end_time=datetime.utcnow(), | |
| extra={}, | |
| execution_order=1, | |
| serialized={}, | |
| inputs={}, | |
| outputs={}, | |
| session_id=TEST_SESSION_ID, | |
| child_runs=[ | |
| ToolRun( | |
| id=None, | |
| start_time=datetime.utcnow(), | |
| end_time=datetime.utcnow(), | |
| extra={}, | |
| execution_order=2, | |
| serialized={}, | |
| tool_input="test", | |
| output="test", | |
| action="{}", | |
| session_id=TEST_SESSION_ID, | |
| error=None, | |
| child_runs=[ | |
| LLMRun( | |
| id=None, | |
| error=None, | |
| start_time=datetime.utcnow(), | |
| end_time=datetime.utcnow(), | |
| extra={}, | |
| execution_order=3, | |
| serialized={}, | |
| prompts=[], | |
| response=LLMResult([[]]), | |
| session_id=TEST_SESSION_ID, | |
| ) | |
| ], | |
| ), | |
| LLMRun( | |
| id=None, | |
| error=None, | |
| start_time=datetime.utcnow(), | |
| end_time=datetime.utcnow(), | |
| extra={}, | |
| execution_order=4, | |
| serialized={}, | |
| prompts=[], | |
| response=LLMResult([[]]), | |
| session_id=TEST_SESSION_ID, | |
| ), | |
| ], | |
| ) | |
| def _perform_nested_run(tracer: BaseTracer) -> None: | |
| """Perform a nested run.""" | |
| tracer.on_chain_start(serialized={}, inputs={}) | |
| tracer.on_tool_start(serialized={}, input_str="test") | |
| tracer.on_llm_start(serialized={}, prompts=[]) | |
| tracer.on_llm_end(response=LLMResult([[]])) | |
| tracer.on_tool_end("test") | |
| tracer.on_llm_start(serialized={}, prompts=[]) | |
| tracer.on_llm_end(response=LLMResult([[]])) | |
| tracer.on_chain_end(outputs={}) | |
| def _add_child_run( | |
| parent_run: Union[ChainRun, ToolRun], | |
| child_run: Union[LLMRun, ChainRun, ToolRun], | |
| ) -> None: | |
| """Add child run to a chain run or tool run.""" | |
| parent_run.child_runs.append(child_run) | |
| def _generate_id() -> Optional[Union[int, str]]: | |
| """Generate an id for a run.""" | |
| return None | |
| def load_session(session_name: str) -> TracerSession: | |
| """Load a tracing session.""" | |
| return TracerSession(id=1, name=session_name, start_time=datetime.utcnow()) | |
| def _persist_session(session: TracerSessionCreate) -> TracerSession: | |
| """Persist a tracing session.""" | |
| return TracerSession(id=TEST_SESSION_ID, **session.dict()) | |
| def load_default_session() -> TracerSession: | |
| """Load a tracing session.""" | |
| return TracerSession(id=1, name="default", start_time=datetime.utcnow()) | |
| class FakeTracer(Tracer): | |
| """Fake tracer that records LangChain execution.""" | |
| def __init__(self) -> None: | |
| """Initialize the tracer.""" | |
| super().__init__() | |
| self.runs: List[Union[LLMRun, ChainRun, ToolRun]] = [] | |
| def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: | |
| """Persist a run.""" | |
| self.runs.append(run) | |
| def _add_child_run( | |
| self, | |
| parent_run: Union[ChainRun, ToolRun], | |
| child_run: Union[LLMRun, ChainRun, ToolRun], | |
| ) -> None: | |
| """Add child run to a chain run or tool run.""" | |
| _add_child_run(parent_run, child_run) | |
| def _generate_id(self) -> Optional[Union[int, str]]: | |
| """Generate an id for a run.""" | |
| return _generate_id() | |
| def _persist_session(self, session: TracerSessionCreate) -> TracerSession: | |
| """Persist a tracing session.""" | |
| return _persist_session(session) | |
| def load_session(self, session_name: str) -> TracerSession: | |
| """Load a tracing session.""" | |
| return load_session(session_name) | |
| def load_default_session(self) -> TracerSession: | |
| """Load a tracing session.""" | |
| return load_default_session() | |
| class FakeSharedTracer(SharedTracer): | |
| """Fake shared tracer that records LangChain execution.""" | |
| runs: List[Union[LLMRun, ChainRun, ToolRun]] = [] | |
| def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: | |
| """Persist a run.""" | |
| with self._lock: | |
| self.runs.append(run) | |
| def remove_runs(self) -> None: | |
| """Remove all runs.""" | |
| with self._lock: | |
| self.runs = [] | |
| def _add_child_run( | |
| self, | |
| parent_run: Union[ChainRun, ToolRun], | |
| child_run: Union[LLMRun, ChainRun, ToolRun], | |
| ) -> None: | |
| """Add child run to a chain run or tool run.""" | |
| _add_child_run(parent_run, child_run) | |
| def _generate_id(self) -> Optional[Union[int, str]]: | |
| """Generate an id for a run.""" | |
| return _generate_id() | |
| def _persist_session(self, session: TracerSessionCreate) -> TracerSession: | |
| """Persist a tracing session.""" | |
| return _persist_session(session) | |
| def load_session(self, session_name: str) -> TracerSession: | |
| """Load a tracing session.""" | |
| return load_session(session_name) | |
| def load_default_session(self) -> TracerSession: | |
| """Load a tracing session.""" | |
| return load_default_session() | |
| def test_tracer_llm_run() -> None: | |
| """Test tracer on an LLM run.""" | |
| compare_run = LLMRun( | |
| id=None, | |
| start_time=datetime.utcnow(), | |
| end_time=datetime.utcnow(), | |
| extra={}, | |
| execution_order=1, | |
| serialized={}, | |
| prompts=[], | |
| response=LLMResult([[]]), | |
| session_id=TEST_SESSION_ID, | |
| error=None, | |
| ) | |
| tracer = FakeTracer() | |
| tracer.new_session() | |
| tracer.on_llm_start(serialized={}, prompts=[]) | |
| tracer.on_llm_end(response=LLMResult([[]])) | |
| assert tracer.runs == [compare_run] | |
| def test_tracer_llm_run_errors_no_session() -> None: | |
| """Test tracer on an LLM run without a session.""" | |
| tracer = FakeTracer() | |
| with pytest.raises(TracerException): | |
| tracer.on_llm_start(serialized={}, prompts=[]) | |
| def test_tracer_llm_run_errors_no_start() -> None: | |
| """Test tracer on an LLM run without a start.""" | |
| tracer = FakeTracer() | |
| tracer.new_session() | |
| with pytest.raises(TracerException): | |
| tracer.on_llm_end(response=LLMResult([[]])) | |
| def test_tracer_multiple_llm_runs() -> None: | |
| """Test the tracer with multiple runs.""" | |
| compare_run = LLMRun( | |
| id=None, | |
| start_time=datetime.utcnow(), | |
| end_time=datetime.utcnow(), | |
| extra={}, | |
| execution_order=1, | |
| serialized={}, | |
| prompts=[], | |
| response=LLMResult([[]]), | |
| session_id=TEST_SESSION_ID, | |
| error=None, | |
| ) | |
| tracer = FakeTracer() | |
| tracer.new_session() | |
| num_runs = 10 | |
| for _ in range(num_runs): | |
| tracer.on_llm_start(serialized={}, prompts=[]) | |
| tracer.on_llm_end(response=LLMResult([[]])) | |
| assert tracer.runs == [compare_run] * num_runs | |
| def test_tracer_chain_run() -> None: | |
| """Test tracer on a Chain run.""" | |
| compare_run = ChainRun( | |
| id=None, | |
| start_time=datetime.utcnow(), | |
| end_time=datetime.utcnow(), | |
| extra={}, | |
| execution_order=1, | |
| serialized={}, | |
| inputs={}, | |
| outputs={}, | |
| session_id=TEST_SESSION_ID, | |
| error=None, | |
| ) | |
| tracer = FakeTracer() | |
| tracer.new_session() | |
| tracer.on_chain_start(serialized={}, inputs={}) | |
| tracer.on_chain_end(outputs={}) | |
| assert tracer.runs == [compare_run] | |
| def test_tracer_tool_run() -> None: | |
| """Test tracer on a Tool run.""" | |
| compare_run = ToolRun( | |
| id=None, | |
| start_time=datetime.utcnow(), | |
| end_time=datetime.utcnow(), | |
| extra={}, | |
| execution_order=1, | |
| serialized={}, | |
| tool_input="test", | |
| output="test", | |
| action="{}", | |
| session_id=TEST_SESSION_ID, | |
| error=None, | |
| ) | |
| tracer = FakeTracer() | |
| tracer.new_session() | |
| tracer.on_tool_start(serialized={}, input_str="test") | |
| tracer.on_tool_end("test") | |
| assert tracer.runs == [compare_run] | |
| def test_tracer_nested_run() -> None: | |
| """Test tracer on a nested run.""" | |
| tracer = FakeTracer() | |
| tracer.new_session() | |
| _perform_nested_run(tracer) | |
| assert tracer.runs == [_get_compare_run()] | |
| def test_tracer_llm_run_on_error() -> None: | |
| """Test tracer on an LLM run with an error.""" | |
| exception = Exception("test") | |
| compare_run = LLMRun( | |
| id=None, | |
| start_time=datetime.utcnow(), | |
| end_time=datetime.utcnow(), | |
| extra={}, | |
| execution_order=1, | |
| serialized={}, | |
| prompts=[], | |
| response=None, | |
| session_id=TEST_SESSION_ID, | |
| error=repr(exception), | |
| ) | |
| tracer = FakeTracer() | |
| tracer.new_session() | |
| tracer.on_llm_start(serialized={}, prompts=[]) | |
| tracer.on_llm_error(exception) | |
| assert tracer.runs == [compare_run] | |
| def test_tracer_chain_run_on_error() -> None: | |
| """Test tracer on a Chain run with an error.""" | |
| exception = Exception("test") | |
| compare_run = ChainRun( | |
| id=None, | |
| start_time=datetime.utcnow(), | |
| end_time=datetime.utcnow(), | |
| extra={}, | |
| execution_order=1, | |
| serialized={}, | |
| inputs={}, | |
| outputs=None, | |
| session_id=TEST_SESSION_ID, | |
| error=repr(exception), | |
| ) | |
| tracer = FakeTracer() | |
| tracer.new_session() | |
| tracer.on_chain_start(serialized={}, inputs={}) | |
| tracer.on_chain_error(exception) | |
| assert tracer.runs == [compare_run] | |
| def test_tracer_tool_run_on_error() -> None: | |
| """Test tracer on a Tool run with an error.""" | |
| exception = Exception("test") | |
| compare_run = ToolRun( | |
| id=None, | |
| start_time=datetime.utcnow(), | |
| end_time=datetime.utcnow(), | |
| extra={}, | |
| execution_order=1, | |
| serialized={}, | |
| tool_input="test", | |
| output=None, | |
| action="{}", | |
| session_id=TEST_SESSION_ID, | |
| error=repr(exception), | |
| ) | |
| tracer = FakeTracer() | |
| tracer.new_session() | |
| tracer.on_tool_start(serialized={}, input_str="test") | |
| tracer.on_tool_error(exception) | |
| assert tracer.runs == [compare_run] | |
| def test_tracer_nested_runs_on_error() -> None: | |
| """Test tracer on a nested run with an error.""" | |
| exception = Exception("test") | |
| tracer = FakeTracer() | |
| tracer.new_session() | |
| for _ in range(3): | |
| tracer.on_chain_start(serialized={}, inputs={}) | |
| tracer.on_llm_start(serialized={}, prompts=[]) | |
| tracer.on_llm_end(response=LLMResult([[]])) | |
| tracer.on_llm_start(serialized={}, prompts=[]) | |
| tracer.on_llm_end(response=LLMResult([[]])) | |
| tracer.on_tool_start(serialized={}, input_str="test") | |
| tracer.on_llm_start(serialized={}, prompts=[]) | |
| tracer.on_llm_error(exception) | |
| tracer.on_tool_error(exception) | |
| tracer.on_chain_error(exception) | |
| compare_run = ChainRun( | |
| id=None, | |
| start_time=datetime.utcnow(), | |
| end_time=datetime.utcnow(), | |
| extra={}, | |
| execution_order=1, | |
| serialized={}, | |
| session_id=TEST_SESSION_ID, | |
| error=repr(exception), | |
| inputs={}, | |
| outputs=None, | |
| child_runs=[ | |
| LLMRun( | |
| id=None, | |
| start_time=datetime.utcnow(), | |
| end_time=datetime.utcnow(), | |
| extra={}, | |
| execution_order=2, | |
| serialized={}, | |
| session_id=TEST_SESSION_ID, | |
| error=None, | |
| prompts=[], | |
| response=LLMResult(generations=[[]], llm_output=None), | |
| ), | |
| LLMRun( | |
| id=None, | |
| start_time=datetime.utcnow(), | |
| end_time=datetime.utcnow(), | |
| extra={}, | |
| execution_order=3, | |
| serialized={}, | |
| session_id=TEST_SESSION_ID, | |
| error=None, | |
| prompts=[], | |
| response=LLMResult(generations=[[]], llm_output=None), | |
| ), | |
| ToolRun( | |
| id=None, | |
| start_time=datetime.utcnow(), | |
| end_time=datetime.utcnow(), | |
| extra={}, | |
| execution_order=4, | |
| serialized={}, | |
| session_id=TEST_SESSION_ID, | |
| error=repr(exception), | |
| tool_input="test", | |
| output=None, | |
| action="{}", | |
| child_runs=[ | |
| LLMRun( | |
| id=None, | |
| start_time=datetime.utcnow(), | |
| end_time=datetime.utcnow(), | |
| extra={}, | |
| execution_order=5, | |
| serialized={}, | |
| session_id=TEST_SESSION_ID, | |
| error=repr(exception), | |
| prompts=[], | |
| response=None, | |
| ) | |
| ], | |
| child_llm_runs=[], | |
| child_chain_runs=[], | |
| child_tool_runs=[], | |
| ), | |
| ], | |
| child_llm_runs=[], | |
| child_chain_runs=[], | |
| child_tool_runs=[], | |
| ) | |
| assert tracer.runs == [compare_run] * 3 | |
| def test_shared_tracer_nested_run() -> None: | |
| """Test shared tracer on a nested run.""" | |
| tracer = FakeSharedTracer() | |
| tracer.new_session() | |
| tracer.remove_runs() | |
| _perform_nested_run(tracer) | |
| assert tracer.runs == [_get_compare_run()] | |
| def test_shared_tracer_nested_run_multithreaded() -> None: | |
| """Test shared tracer on a nested run.""" | |
| tracer = FakeSharedTracer() | |
| tracer.remove_runs() | |
| tracer.new_session() | |
| threads = [] | |
| num_threads = 10 | |
| for _ in range(num_threads): | |
| thread = threading.Thread(target=_perform_nested_run, args=(tracer,)) | |
| thread.start() | |
| threads.append(thread) | |
| for thread in threads: | |
| thread.join() | |
| assert tracer.runs == [_get_compare_run()] * num_threads | |