CandidateExplorer / services /base /BaseGenerator.py
ishaq101's picture
[KM-383] [CEX] [AI] Deployment AI Engine / BE
f3bdba1
import os
from config.constant import LangfuseConstants
from langfuse.langchain import CallbackHandler
from pydantic import BaseModel
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import AzureChatOpenAI
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type
)
from typing import Dict
# ❌ REMOVED: from externals.observability.langfuse import langfuse_handler, langfuse
from services.llms.LLM import model_5mini, model_4omini
from utils.decorator import trace_runtime
from utils.logger import get_logger
from langfuse import get_client, Langfuse
logger = get_logger("base generator")
class MetadataObservability(BaseModel):
fullname: str
task_id: str
agent: str
user_id: str
class BaseAIGenerator:
def __init__(self,
task_name: str,
prompt: ChatPromptTemplate,
input_llm: Dict,
metadata_observability: MetadataObservability,
llm: AzureChatOpenAI = model_5mini | model_4omini,
):
self.metadata_observability = metadata_observability
self.llm = llm
self.prompt = prompt
self.input_llm = input_llm
self.name = task_name
def _get_langfuse_client(self):
try:
os.environ["LANGFUSE_PUBLIC_KEY"] = LangfuseConstants.PUBLIC_KEY
os.environ["LANGFUSE_SECRET_KEY"] = LangfuseConstants.SECRET_KEY
os.environ["LANGFUSE_HOST"] = LangfuseConstants.HOST or "https://us.cloud.langfuse.com"
langfuse = Langfuse()
return langfuse
except Exception as e:
logger.warning(f"⚠️ Langfuse unavailable, skipping observability: {e}")
return None
def _get_langfuse_config(self):
try:
os.environ["LANGFUSE_PUBLIC_KEY"] = LangfuseConstants.PUBLIC_KEY
os.environ["LANGFUSE_SECRET_KEY"] = LangfuseConstants.SECRET_KEY
os.environ["LANGFUSE_HOST"] = LangfuseConstants.HOST or "https://us.cloud.langfuse.com"
handler = CallbackHandler(update_trace=True)
return {
"callbacks": [handler],
"metadata": {
"langfuse_session_id": self.metadata_observability.task_id,
"langfuse_user_id": self.metadata_observability.fullname,
"langfuse_tags": [self.metadata_observability.agent],
"langfuse_trace_name": self.name,
},
}
except Exception as e:
logger.warning(f"⚠️ Langfuse unavailable, skipping observability: {e}")
return {}
@retry(
reraise=True,
stop=stop_after_attempt(2),
wait=wait_exponential(multiplier=1, min=1, max=5),
retry=retry_if_exception_type(Exception)
)
async def _asafe_invoke(self, chain, input_llm, config):
return await chain.ainvoke(input_llm, config=config)
@retry(
reraise=True,
stop=stop_after_attempt(2),
wait=wait_exponential(multiplier=1, min=1, max=5),
retry=retry_if_exception_type(Exception)
)
async def _safe_invoke(self, chain, input_llm, config):
return chain.invoke(input_llm, config=config)
@trace_runtime
async def agenerate(self):
try:
config = self._get_langfuse_config()
chain = self.prompt | self.llm
langfuse_client = self._get_langfuse_client()
trace_id = Langfuse.create_trace_id(seed=self.metadata_observability.task_id)
with langfuse_client.start_as_current_observation(
as_type='generation',
name=self.name,
metadata=self.metadata_observability,
input=self.input_llm,
trace_context={"trace_id": trace_id},
) as span:
span.update_trace(
name=self.name,
user_id=self.metadata_observability.user_id)
output = await self._asafe_invoke(
chain=chain,
input_llm=self.input_llm,
config=config,
)
span.update_trace(output=output)
return output
except Exception:
logger.exception("❌ BaseGenerator agenerate error")
return None
@trace_runtime
async def generate(self):
try:
config = self._get_langfuse_config()
chain = self.prompt | self.llm
langfuse_client = self._get_langfuse_client()
trace_id = Langfuse.create_trace_id(seed=self.metadata_observability.task_id)
with langfuse_client.start_as_current_observation(
as_type='generation',
name=self.name,
metadata=self.metadata_observability,
input=self.input_llm,
trace_context={"trace_id": trace_id},
) as span:
span.update_trace(
name=self.name,
user_id=self.metadata_observability.user_id)
output = self._safe_invoke(
chain=chain,
input_llm=self.input_llm,
config=config,
)
span.update_trace(output=output)
return output
except Exception:
logger.exception("❌ BaseGenerator generate error")
return None