from pydantic import BaseModel from langchain_core.prompts import ChatPromptTemplate from langchain_openai import AzureChatOpenAI from tenacity import ( retry, stop_after_attempt, wait_exponential, retry_if_exception_type ) from typing import Dict from externals.observability.langfuse import langfuse_handler, langfuse from services.llms.LLM import model_5mini, model_4omini from utils.decorator import trace_runtime from utils.logger import get_logger logger = get_logger("base generator") class MetadataObservability(BaseModel): fullname: str task_id: str agent: str class BaseAIGenerator: """ Args: name:str, prompt: ChatPromptTemplate, input_llm: Dict, metadata_observability: MetadataObservability, output_model: BaseModel, llm:AzureChatOpenAI = model_5mini | model_4omini, """ def __init__(self, task_name:str, prompt: ChatPromptTemplate, input_llm: Dict, metadata_observability: MetadataObservability, llm:AzureChatOpenAI = model_5mini | model_4omini, ): self.name = task_name self.llm = llm self.prompt = prompt self.input_llm = input_llm self.metadata_observability = metadata_observability @retry( reraise=True, stop=stop_after_attempt(2), # retry max 3 times wait=wait_exponential(multiplier=1, min=1, max=5), retry=retry_if_exception_type(Exception) # retry on any exception from LLM ) async def _asafe_invoke(self, chain, input_llm, config): """private helper for retries""" return await chain.ainvoke(input_llm, config=config) @retry( reraise=True, stop=stop_after_attempt(2), # retry max 3 times wait=wait_exponential(multiplier=1, min=1, max=5), retry=retry_if_exception_type(Exception) # retry on any exception from LLM ) async def _safe_invoke(self, chain, input_llm, config): """private helper for retries""" return chain.invoke(input_llm, config=config) # @trace_runtime # async def agenerate(self): # try: # chain = self.prompt | self.llm # config = {"callbacks": [langfuse_handler]} # with langfuse.start_as_current_observation( # as_type='generation', # name=self.name, # input=self.input_llm, # ) as trace: # trace.update_trace(user_id=self.metadata_observability.fullname, # session_id=self.metadata_observability.task_id, # metadata=self.metadata_observability.model_dump() # ) # output = await self._asafe_invoke(chain=chain, # input_llm=self.input_llm, # config=config) # trace.update_trace(output=output) # return output # except Exception as E: # logger.error(f"❌ BaseGenerator, agenerate error, {E}") # return None # @trace_runtime # async def generate(self): # try: # chain = self.prompt | self.llm # config = {"callbacks": [langfuse_handler]} # with langfuse.start_as_current_observation( # as_type='generation', # name=self.name, # input=self.input_llm, # ) as trace: # trace.update_trace(user_id=self.metadata_observability.fullname, # session_id=self.metadata_observability.task_id, # metadata=self.metadata_observability.model_dump() # ) # output = self._safe_invoke(chain=chain, # input_llm=self.input_llm, # config=config) # trace.update_trace(output=output) # return output # except Exception as E: # logger.error(f"❌ BaseGenerator, generate error, {E}") # return None @trace_runtime async def agenerate(self): trace = None try: chain = self.prompt | self.llm config = {"callbacks": [langfuse_handler]} # ✅ Create trace (no context manager, no end()) trace = langfuse.trace( name=self.name, input=self.input_llm, ) trace.update( user_id=self.metadata_observability.fullname, session_id=self.metadata_observability.task_id, metadata=self.metadata_observability.model_dump(), ) output = await self._asafe_invoke( chain=chain, input_llm=self.input_llm, config=config, ) trace.update(output=output) return output except Exception as e: logger.exception("❌ BaseGenerator agenerate error") if trace: trace.update( status="error", error=str(e), ) return None @trace_runtime async def generate(self): trace = None try: chain = self.prompt | self.llm config = {"callbacks": [langfuse_handler]} trace = langfuse.trace( name=self.name, input=self.input_llm, ) trace.update( user_id=self.metadata_observability.fullname, session_id=self.metadata_observability.task_id, metadata=self.metadata_observability.model_dump(), ) output = self._safe_invoke( chain=chain, input_llm=self.input_llm, config=config, ) trace.update(output=output) return output except Exception as e: logger.exception("❌ BaseGenerator generate error") if trace: trace.update( status="error", error=str(e), ) return None