Spaces:
Sleeping
Sleeping
| from pydantic import BaseModel | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_openai import AzureChatOpenAI | |
| from tenacity import ( | |
| retry, | |
| stop_after_attempt, | |
| wait_exponential, | |
| retry_if_exception_type | |
| ) | |
| from typing import Dict | |
| from externals.observability.langfuse import langfuse_handler, langfuse | |
| from services.llms.LLM import model_5mini, model_4omini | |
| from utils.decorator import trace_runtime | |
| from utils.logger import get_logger | |
| logger = get_logger("base generator") | |
| class MetadataObservability(BaseModel): | |
| fullname: str | |
| task_id: str | |
| agent: str | |
| class BaseAIGenerator: | |
| """ | |
| Args: | |
| name:str, | |
| prompt: ChatPromptTemplate, | |
| input_llm: Dict, | |
| metadata_observability: MetadataObservability, | |
| output_model: BaseModel, | |
| llm:AzureChatOpenAI = model_5mini | model_4omini, | |
| """ | |
| def __init__(self, | |
| task_name:str, | |
| prompt: ChatPromptTemplate, | |
| input_llm: Dict, | |
| metadata_observability: MetadataObservability, | |
| llm:AzureChatOpenAI = model_5mini | model_4omini, | |
| ): | |
| self.name = task_name | |
| self.llm = llm | |
| self.prompt = prompt | |
| self.input_llm = input_llm | |
| self.metadata_observability = metadata_observability | |
| async def _asafe_invoke(self, chain, input_llm, config): | |
| """private helper for retries""" | |
| return await chain.ainvoke(input_llm, config=config) | |
| async def _safe_invoke(self, chain, input_llm, config): | |
| """private helper for retries""" | |
| return chain.invoke(input_llm, config=config) | |
| # @trace_runtime | |
| # async def agenerate(self): | |
| # try: | |
| # chain = self.prompt | self.llm | |
| # config = {"callbacks": [langfuse_handler]} | |
| # with langfuse.start_as_current_observation( | |
| # as_type='generation', | |
| # name=self.name, | |
| # input=self.input_llm, | |
| # ) as trace: | |
| # trace.update_trace(user_id=self.metadata_observability.fullname, | |
| # session_id=self.metadata_observability.task_id, | |
| # metadata=self.metadata_observability.model_dump() | |
| # ) | |
| # output = await self._asafe_invoke(chain=chain, | |
| # input_llm=self.input_llm, | |
| # config=config) | |
| # trace.update_trace(output=output) | |
| # return output | |
| # except Exception as E: | |
| # logger.error(f"β BaseGenerator, agenerate error, {E}") | |
| # return None | |
| # @trace_runtime | |
| # async def generate(self): | |
| # try: | |
| # chain = self.prompt | self.llm | |
| # config = {"callbacks": [langfuse_handler]} | |
| # with langfuse.start_as_current_observation( | |
| # as_type='generation', | |
| # name=self.name, | |
| # input=self.input_llm, | |
| # ) as trace: | |
| # trace.update_trace(user_id=self.metadata_observability.fullname, | |
| # session_id=self.metadata_observability.task_id, | |
| # metadata=self.metadata_observability.model_dump() | |
| # ) | |
| # output = self._safe_invoke(chain=chain, | |
| # input_llm=self.input_llm, | |
| # config=config) | |
| # trace.update_trace(output=output) | |
| # return output | |
| # except Exception as E: | |
| # logger.error(f"β BaseGenerator, generate error, {E}") | |
| # return None | |
| async def agenerate(self): | |
| trace = None | |
| try: | |
| chain = self.prompt | self.llm | |
| config = {"callbacks": [langfuse_handler]} | |
| # β Create trace (no context manager, no end()) | |
| trace = langfuse.trace( | |
| name=self.name, | |
| input=self.input_llm, | |
| ) | |
| trace.update( | |
| user_id=self.metadata_observability.fullname, | |
| session_id=self.metadata_observability.task_id, | |
| metadata=self.metadata_observability.model_dump(), | |
| ) | |
| output = await self._asafe_invoke( | |
| chain=chain, | |
| input_llm=self.input_llm, | |
| config=config, | |
| ) | |
| trace.update(output=output) | |
| return output | |
| except Exception as e: | |
| logger.exception("β BaseGenerator agenerate error") | |
| if trace: | |
| trace.update( | |
| status="error", | |
| error=str(e), | |
| ) | |
| return None | |
| async def generate(self): | |
| trace = None | |
| try: | |
| chain = self.prompt | self.llm | |
| config = {"callbacks": [langfuse_handler]} | |
| trace = langfuse.trace( | |
| name=self.name, | |
| input=self.input_llm, | |
| ) | |
| trace.update( | |
| user_id=self.metadata_observability.fullname, | |
| session_id=self.metadata_observability.task_id, | |
| metadata=self.metadata_observability.model_dump(), | |
| ) | |
| output = self._safe_invoke( | |
| chain=chain, | |
| input_llm=self.input_llm, | |
| config=config, | |
| ) | |
| trace.update(output=output) | |
| return output | |
| except Exception as e: | |
| logger.exception("β BaseGenerator generate error") | |
| if trace: | |
| trace.update( | |
| status="error", | |
| error=str(e), | |
| ) | |
| return None | |