CandidateExplorer / services /base /BaseGenerator.py
ishaq101's picture
clean init
478dec6
raw
history blame
6.46 kB
from pydantic import BaseModel
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import AzureChatOpenAI
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type
)
from typing import Dict
from externals.observability.langfuse import langfuse_handler, langfuse
from services.llms.LLM import model_5mini, model_4omini
from utils.decorator import trace_runtime
from utils.logger import get_logger
logger = get_logger("base generator")
class MetadataObservability(BaseModel):
fullname: str
task_id: str
agent: str
class BaseAIGenerator:
"""
Args:
name:str,
prompt: ChatPromptTemplate,
input_llm: Dict,
metadata_observability: MetadataObservability,
output_model: BaseModel,
llm:AzureChatOpenAI = model_5mini | model_4omini,
"""
def __init__(self,
task_name:str,
prompt: ChatPromptTemplate,
input_llm: Dict,
metadata_observability: MetadataObservability,
llm:AzureChatOpenAI = model_5mini | model_4omini,
):
self.name = task_name
self.llm = llm
self.prompt = prompt
self.input_llm = input_llm
self.metadata_observability = metadata_observability
@retry(
reraise=True,
stop=stop_after_attempt(2), # retry max 3 times
wait=wait_exponential(multiplier=1, min=1, max=5),
retry=retry_if_exception_type(Exception) # retry on any exception from LLM
)
async def _asafe_invoke(self, chain, input_llm, config):
"""private helper for retries"""
return await chain.ainvoke(input_llm, config=config)
@retry(
reraise=True,
stop=stop_after_attempt(2), # retry max 3 times
wait=wait_exponential(multiplier=1, min=1, max=5),
retry=retry_if_exception_type(Exception) # retry on any exception from LLM
)
async def _safe_invoke(self, chain, input_llm, config):
"""private helper for retries"""
return chain.invoke(input_llm, config=config)
# @trace_runtime
# async def agenerate(self):
# try:
# chain = self.prompt | self.llm
# config = {"callbacks": [langfuse_handler]}
# with langfuse.start_as_current_observation(
# as_type='generation',
# name=self.name,
# input=self.input_llm,
# ) as trace:
# trace.update_trace(user_id=self.metadata_observability.fullname,
# session_id=self.metadata_observability.task_id,
# metadata=self.metadata_observability.model_dump()
# )
# output = await self._asafe_invoke(chain=chain,
# input_llm=self.input_llm,
# config=config)
# trace.update_trace(output=output)
# return output
# except Exception as E:
# logger.error(f"❌ BaseGenerator, agenerate error, {E}")
# return None
# @trace_runtime
# async def generate(self):
# try:
# chain = self.prompt | self.llm
# config = {"callbacks": [langfuse_handler]}
# with langfuse.start_as_current_observation(
# as_type='generation',
# name=self.name,
# input=self.input_llm,
# ) as trace:
# trace.update_trace(user_id=self.metadata_observability.fullname,
# session_id=self.metadata_observability.task_id,
# metadata=self.metadata_observability.model_dump()
# )
# output = self._safe_invoke(chain=chain,
# input_llm=self.input_llm,
# config=config)
# trace.update_trace(output=output)
# return output
# except Exception as E:
# logger.error(f"❌ BaseGenerator, generate error, {E}")
# return None
@trace_runtime
async def agenerate(self):
trace = None
try:
chain = self.prompt | self.llm
config = {"callbacks": [langfuse_handler]}
# βœ… Create trace (no context manager, no end())
trace = langfuse.trace(
name=self.name,
input=self.input_llm,
)
trace.update(
user_id=self.metadata_observability.fullname,
session_id=self.metadata_observability.task_id,
metadata=self.metadata_observability.model_dump(),
)
output = await self._asafe_invoke(
chain=chain,
input_llm=self.input_llm,
config=config,
)
trace.update(output=output)
return output
except Exception as e:
logger.exception("❌ BaseGenerator agenerate error")
if trace:
trace.update(
status="error",
error=str(e),
)
return None
@trace_runtime
async def generate(self):
trace = None
try:
chain = self.prompt | self.llm
config = {"callbacks": [langfuse_handler]}
trace = langfuse.trace(
name=self.name,
input=self.input_llm,
)
trace.update(
user_id=self.metadata_observability.fullname,
session_id=self.metadata_observability.task_id,
metadata=self.metadata_observability.model_dump(),
)
output = self._safe_invoke(
chain=chain,
input_llm=self.input_llm,
config=config,
)
trace.update(output=output)
return output
except Exception as e:
logger.exception("❌ BaseGenerator generate error")
if trace:
trace.update(
status="error",
error=str(e),
)
return None