ogd4all / generation /analyzer.py
Michael Siebenmann
add fixed LLM and embedding caching
3db4414
from abc import ABC, abstractmethod
from typing import List, Union
from pydantic import BaseModel, Field
from typing_extensions import Annotated, TypedDict
from retrieval.retriever import Metadata
from utils import get_llm_client
class CodeAct(BaseModel):
"""Thought and python code for data analysis"""
thought: str = Field(..., description="The thought process for creating the code")
code: str = Field(..., description="The Python code to run as plain text.")
class CodeAction(TypedDict):
"""Thought and python code for data analysis"""
thought: Annotated[str, ..., "The thought process for creating the code"]
code: Annotated[str, ..., "The Python code to run as plain text."]
class Analyzer(ABC):
"""
Abstract base class for data analyzers.
An analyzer is responsible for analyzing given datasets based on user queries.
"""
def __init__(self, groupOwner, metadata_docs: List[Metadata], coding_client=None):
# Initialize coding client if not provided
if coding_client is None:
coding_client = get_llm_client("gpt-4.1")
self.groupOwner = groupOwner
self.metadata_docs = metadata_docs
self.coding_client = coding_client
self.messages = []
self.total_tokens = 0
self.total_input_tokens = 0
self.total_output_tokens = 0
self.total_reasoning_tokens = 0
self.dataset_idx_offset = 0 # only used for setting dataset paths
def _is_cached(self, messages) -> bool:
"""
Return True if invoking the coding LLM with these messages would be a cache hit.
Used to decide whether to skip streaming and use invoke() instead.
"""
from langchain_core.globals import get_llm_cache
from langchain_core.load import dumps
from langchain_core.runnables import RunnableBinding, RunnableParallel, RunnableSequence
try:
cache = get_llm_cache()
if cache is None:
return False
# with_structured_output returns a RunnableSequence. We need the kwargs
# bound to the underlying LLM (tools / response_format) to reproduce the
# exact llm_string the cache uses.
#
# Two observed structures depending on include_raw:
# without include_raw: RunnableSequence([RunnableBinding(llm, **kw), parser])
# with include_raw: RunnableSequence([RunnableParallel({"raw": RunnableBinding(llm, **kw)}), ...])
bound_kwargs = {}
if isinstance(self.coding_llm, RunnableSequence):
first = self.coding_llm.first
if isinstance(first, RunnableBinding):
bound_kwargs = first.kwargs
elif isinstance(first, RunnableParallel):
for step in first.steps.values():
if isinstance(step, RunnableBinding):
bound_kwargs = step.kwargs
break
# Strip LangSmith metadata keys — they are excluded from the actual cache key
bound_kwargs = {k: v for k, v in bound_kwargs.items() if not k.startswith("ls_")}
prompt = dumps(messages)
llm_string = self.coding_client._get_llm_string(**bound_kwargs)
return cache.lookup(prompt, llm_string) is not None
except Exception:
return False
def _write_cache(self, messages, thought: str, code: str) -> None:
"""
Write the streaming result to the LangChain cache so the next identical call
gets a cache hit (stream() does not write to cache; this bridges the gap).
"""
import json
from langchain_core.globals import get_llm_cache
from langchain_core.load import dumps
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration
from langchain_core.runnables import RunnableBinding, RunnableParallel, RunnableSequence
try:
cache = get_llm_cache()
if cache is None:
return
bound_kwargs = {}
if isinstance(self.coding_llm, RunnableSequence):
first = self.coding_llm.first
if isinstance(first, RunnableBinding):
bound_kwargs = first.kwargs
elif isinstance(first, RunnableParallel):
for step in first.steps.values():
if isinstance(step, RunnableBinding):
bound_kwargs = step.kwargs
break
bound_kwargs = {k: v for k, v in bound_kwargs.items() if not k.startswith("ls_")}
prompt = dumps(messages)
llm_string = self.coding_client._get_llm_string(**bound_kwargs)
# Reconstruct the raw AIMessage content as JSON (what the model would have returned)
raw_content = json.dumps({"thought": thought, "code": code})
cache.update(prompt, llm_string, [ChatGeneration(message=AIMessage(content=raw_content))])
except Exception:
pass # cache write failure is non-fatal
@abstractmethod
def analyze(self, query: Union[str, dict]):
"""
Analyze the given datasets based on the user query.
:param query: The user query to analyze the datasets for. Can be a string or dict with text and files.
:return: The analysis result.
"""
pass
@abstractmethod
def extend_sandbox(self, file_names):
"""
Extend the sandbox with additional datasets.
:param file_names: The names of the additional datasets.
"""
pass
@abstractmethod
def finalize(self):
"""
Clean up resources
"""
pass