Spaces:
Sleeping
Sleeping
| from abc import ABC, abstractmethod | |
| from typing import List, Union | |
| from pydantic import BaseModel, Field | |
| from typing_extensions import Annotated, TypedDict | |
| from retrieval.retriever import Metadata | |
| from utils import get_llm_client | |
| class CodeAct(BaseModel): | |
| """Thought and python code for data analysis""" | |
| thought: str = Field(..., description="The thought process for creating the code") | |
| code: str = Field(..., description="The Python code to run as plain text.") | |
| class CodeAction(TypedDict): | |
| """Thought and python code for data analysis""" | |
| thought: Annotated[str, ..., "The thought process for creating the code"] | |
| code: Annotated[str, ..., "The Python code to run as plain text."] | |
| class Analyzer(ABC): | |
| """ | |
| Abstract base class for data analyzers. | |
| An analyzer is responsible for analyzing given datasets based on user queries. | |
| """ | |
| def __init__(self, groupOwner, metadata_docs: List[Metadata], coding_client=None): | |
| # Initialize coding client if not provided | |
| if coding_client is None: | |
| coding_client = get_llm_client("gpt-4.1") | |
| self.groupOwner = groupOwner | |
| self.metadata_docs = metadata_docs | |
| self.coding_client = coding_client | |
| self.messages = [] | |
| self.total_tokens = 0 | |
| self.total_input_tokens = 0 | |
| self.total_output_tokens = 0 | |
| self.total_reasoning_tokens = 0 | |
| self.dataset_idx_offset = 0 # only used for setting dataset paths | |
| def _is_cached(self, messages) -> bool: | |
| """ | |
| Return True if invoking the coding LLM with these messages would be a cache hit. | |
| Used to decide whether to skip streaming and use invoke() instead. | |
| """ | |
| from langchain_core.globals import get_llm_cache | |
| from langchain_core.load import dumps | |
| from langchain_core.runnables import RunnableBinding, RunnableParallel, RunnableSequence | |
| try: | |
| cache = get_llm_cache() | |
| if cache is None: | |
| return False | |
| # with_structured_output returns a RunnableSequence. We need the kwargs | |
| # bound to the underlying LLM (tools / response_format) to reproduce the | |
| # exact llm_string the cache uses. | |
| # | |
| # Two observed structures depending on include_raw: | |
| # without include_raw: RunnableSequence([RunnableBinding(llm, **kw), parser]) | |
| # with include_raw: RunnableSequence([RunnableParallel({"raw": RunnableBinding(llm, **kw)}), ...]) | |
| bound_kwargs = {} | |
| if isinstance(self.coding_llm, RunnableSequence): | |
| first = self.coding_llm.first | |
| if isinstance(first, RunnableBinding): | |
| bound_kwargs = first.kwargs | |
| elif isinstance(first, RunnableParallel): | |
| for step in first.steps.values(): | |
| if isinstance(step, RunnableBinding): | |
| bound_kwargs = step.kwargs | |
| break | |
| # Strip LangSmith metadata keys — they are excluded from the actual cache key | |
| bound_kwargs = {k: v for k, v in bound_kwargs.items() if not k.startswith("ls_")} | |
| prompt = dumps(messages) | |
| llm_string = self.coding_client._get_llm_string(**bound_kwargs) | |
| return cache.lookup(prompt, llm_string) is not None | |
| except Exception: | |
| return False | |
| def _write_cache(self, messages, thought: str, code: str) -> None: | |
| """ | |
| Write the streaming result to the LangChain cache so the next identical call | |
| gets a cache hit (stream() does not write to cache; this bridges the gap). | |
| """ | |
| import json | |
| from langchain_core.globals import get_llm_cache | |
| from langchain_core.load import dumps | |
| from langchain_core.messages import AIMessage | |
| from langchain_core.outputs import ChatGeneration | |
| from langchain_core.runnables import RunnableBinding, RunnableParallel, RunnableSequence | |
| try: | |
| cache = get_llm_cache() | |
| if cache is None: | |
| return | |
| bound_kwargs = {} | |
| if isinstance(self.coding_llm, RunnableSequence): | |
| first = self.coding_llm.first | |
| if isinstance(first, RunnableBinding): | |
| bound_kwargs = first.kwargs | |
| elif isinstance(first, RunnableParallel): | |
| for step in first.steps.values(): | |
| if isinstance(step, RunnableBinding): | |
| bound_kwargs = step.kwargs | |
| break | |
| bound_kwargs = {k: v for k, v in bound_kwargs.items() if not k.startswith("ls_")} | |
| prompt = dumps(messages) | |
| llm_string = self.coding_client._get_llm_string(**bound_kwargs) | |
| # Reconstruct the raw AIMessage content as JSON (what the model would have returned) | |
| raw_content = json.dumps({"thought": thought, "code": code}) | |
| cache.update(prompt, llm_string, [ChatGeneration(message=AIMessage(content=raw_content))]) | |
| except Exception: | |
| pass # cache write failure is non-fatal | |
| def analyze(self, query: Union[str, dict]): | |
| """ | |
| Analyze the given datasets based on the user query. | |
| :param query: The user query to analyze the datasets for. Can be a string or dict with text and files. | |
| :return: The analysis result. | |
| """ | |
| pass | |
| def extend_sandbox(self, file_names): | |
| """ | |
| Extend the sandbox with additional datasets. | |
| :param file_names: The names of the additional datasets. | |
| """ | |
| pass | |
| def finalize(self): | |
| """ | |
| Clean up resources | |
| """ | |
| pass | |