File size: 10,702 Bytes
407e466 fbfec74 407e466 a1e2111 407e466 6e9fb70 407e466 a1e2111 407e466 a1e2111 407e466 92925a0 407e466 92925a0 407e466 fbfec74 407e466 fbfec74 407e466 fbfec74 f30003b 407e466 92925a0 407e466 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 | import os
import logging
import warnings
import re
import time
# Suppress TensorFlow/Keras warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
logging.getLogger('tensorflow').setLevel(logging.ERROR)
warnings.filterwarnings('ignore', module='tensorflow')
warnings.filterwarnings('ignore', module='tf_keras')
from typing import TypedDict, Optional, List, Annotated
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.graph import MessagesState, StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import tools_condition
from langgraph.prebuilt import ToolNode
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from custom_tools import get_custom_tools_list
from system_prompt import SYSTEM_PROMPT
from utils import cleanup_answer, extract_text_from_content
import config
from langfuse_tracking import track_agent_execution, track_llm_call
# Suppress BeautifulSoup GuessedAtParserWarning
try:
from bs4 import GuessedAtParserWarning
warnings.filterwarnings('ignore', category=GuessedAtParserWarning)
except ImportError:
pass
class AgentState(TypedDict):
question: str
messages: Annotated[list , add_messages] # for LangGraph
answer: str
step_count: int # Track number of iterations to prevent infinite loops
file_name: str # Optional file name for questions that reference files
class LangGraphAgent:
def __init__(self):
# Validate API keys
if not os.getenv("GOOGLE_API_KEY"):
print("WARNING: GOOGLE_API_KEY not found - analyze_youtube_video will fail")
self.tools = get_custom_tools_list()
self.llm_client_with_tools = self._create_llm_client()
self.graph = self._build_graph()
def _create_llm_client(self, model_provider: str = "google"):
"""Create and return the LLM client with tools bound based on the model provider."""
if model_provider == "google":
apikey = os.getenv("GOOGLE_API_KEY")
return ChatGoogleGenerativeAI(
model=config.ACTIVE_AGENT_LLM_MODEL,
temperature=0,
api_key=apikey,
timeout=60 # Add timeout to prevent hanging
).bind_tools(self.tools)
elif model_provider == "huggingface":
LLM_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
apikey = os.getenv("HUGGINGFACEHUB_API_TOKEN")
llmObject = HuggingFaceEndpoint(
repo_id=LLM_MODEL,
task="text-generation",
max_new_tokens=512,
temperature=0.7,
do_sample=False,
repetition_penalty=1.03,
huggingfacehub_api_token=apikey
)
return ChatHuggingFace(llm=llmObject).bind_tools(self.tools)
# Nodes
def _init_questions(self, state: AgentState):
"""Initialize the messages in the state with system prompt and user question."""
# Build the question message, including file name if available
question_content = state["question"]
if state.get("file_name"):
question_content += f'\n\nNote: This question references a file: {state["file_name"]}'
return {
"messages": [
SystemMessage(content=SYSTEM_PROMPT),
HumanMessage(content=question_content)
],
"step_count": 0 # Initialize step counter
}
@track_llm_call(config.ACTIVE_AGENT_LLM_MODEL)
def _assistant(self, state: AgentState):
"""Assistant node which calls the LLM with tools"""
# Track and log current step
current_step = state.get("step_count", 0) + 1
print(f"[STEP {current_step}] Calling assistant with {len(state['messages'])} messages")
# Invoke LLM with tools enabled, with retry logic for 504 errors
max_retries = config.MAX_RETRIES
delay = config.INITIAL_RETRY_DELAY
for attempt in range(max_retries + 1):
try:
response = self.llm_client_with_tools.invoke(state["messages"])
# Success - break out of retry loop
break
except Exception as e:
error_msg = str(e)
# Check if this is a 504 DEADLINE_EXCEEDED error
if "504" in error_msg and "DEADLINE_EXCEEDED" in error_msg:
if attempt < max_retries:
print(f"[RETRY] Attempt {attempt + 1}/{max_retries} failed with 504 DEADLINE_EXCEEDED")
print(f"[RETRY] Retrying in {delay:.1f} seconds...")
time.sleep(delay)
delay *= config.RETRY_BACKOFF_FACTOR
continue
else:
print(f"[RETRY] All {max_retries} retries exhausted for 504 error")
print(f"[ERROR] LLM invocation failed after retries: {e}")
return {
"messages": [],
"answer": f"Error: LLM failed after {max_retries} retries - {str(e)[:100]}",
"step_count": current_step
}
else:
# Not a 504 error - fail immediately without retry
print(f"[ERROR] LLM invocation failed: {e}")
return {
"messages": [],
"answer": f"Error: LLM failed - {str(e)[:100]}",
"step_count": current_step
}
# If no tool calls, set the final answer
if not response.tool_calls:
content = response.content
print(f"[FINAL ANSWER] Agent produced answer (no tool calls)")
# Handle case where content is a list (e.g. mixed content from Gemini)
if isinstance(content, list):
# Extract text from list of content parts
text_parts = []
for item in content:
if isinstance(item, dict) and 'text' in item:
text_parts.append(item['text'])
elif hasattr(item, 'text'):
text_parts.append(item.text)
else:
text_parts.append(str(item))
content = " ".join(text_parts)
elif isinstance(content, dict) and 'text' in content:
# Handle single dict with 'text' field
content = content['text']
elif hasattr(content, 'text'):
# Handle object with text attribute
content = content.text
else:
# Fallback to string conversion
content = str(content)
# Clean up any remaining noise
content = content.strip()
print(f"[EXTRACTED TEXT] {content[:100]}{'...' if len(content) > 100 else ''}")
return {
"messages": [response],
"answer": content,
"step_count": current_step
}
# Has tool calls, log them
print(f"[TOOL CALLS] Agent requesting {len(response.tool_calls)} tool(s):")
for tc in response.tool_calls:
print(f" - {tc['name']}")
return {
"messages": [response],
"step_count": current_step
}
def _should_continue(self, state: AgentState):
"""Check if we should continue or stop based on step count and other conditions."""
step_count = state.get("step_count", 0)
# Stop if we've exceeded maximum steps
if step_count >= 40: # Increased from 25 to handle complex multi-step reasoning
print(f"[WARNING] Max steps (40) reached, forcing termination")
# Force a final answer if we don't have one
if not state.get("answer"):
state["answer"] = "Error: Maximum iteration limit reached"
return END
# Otherwise use the default tools_condition
return tools_condition(state)
def _build_graph(self):
"""Build and return the Compiled Graph for the agent."""
graph = StateGraph(AgentState)
# Build graph
graph.add_node("init", self._init_questions)
graph.add_node("assistant", self._assistant)
graph.add_node("tools", ToolNode(self.tools))
graph.add_edge(START, "init")
graph.add_edge("init", "assistant")
graph.add_conditional_edges(
"assistant",
# Use custom should_continue instead of tools_condition
self._should_continue,
)
graph.add_edge("tools", "assistant")
# Compile graph
return graph.compile()
@track_agent_execution("LangGraph")
def __call__(self, question: str, file_name: str = None) -> str:
"""Invoke the agent graph with the given question and return the final answer.
Args:
question: The question to answer
file_name: Optional file name if the question references a file
"""
print(f"\n{'='*60}")
print(f"[LANGGRAPH AGENT START] Question: {question}")
if file_name:
print(f"[FILE] {file_name}")
print(f"{'='*60}")
start_time = time.time()
try:
response = self.graph.invoke(
{"question": question, "messages": [], "answer": None, "step_count": 0, "file_name": file_name or ""},
config={"recursion_limit": 80} # Must be >= 2x step limit (40 * 2 = 80)
)
elapsed_time = time.time() - start_time
print(f"[LANGGRAPH AGENT COMPLETE] Time: {elapsed_time:.2f}s")
print(f"{'='*60}\n")
answer = response.get("answer")
if not answer or answer is None:
print("[WARNING] Agent completed but returned None as answer")
return "Error: No answer generated"
# Use utility function to extract text from various content formats
answer = extract_text_from_content(answer)
# Clean up the answer using utility function (includes stripping)
answer = cleanup_answer(answer)
print(f"[FINAL ANSWER] {answer}")
return answer
except Exception as e:
elapsed_time = time.time() - start_time
print(f"[LANGGRAPH AGENT ERROR] Failed after {elapsed_time:.2f}s: {e}")
print(f"{'='*60}\n")
return f"Error: {str(e)[:100]}"
|