|
|
""" |
|
|
AI Agent for project management using LangGraph. |
|
|
""" |
|
|
from typing import TypedDict, Annotated, Sequence, List, Dict, Any |
|
|
import operator |
|
|
import os |
|
|
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage |
|
|
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint |
|
|
from langgraph.graph import StateGraph, END |
|
|
from src.rag import ProjectRAG |
|
|
|
|
|
|
|
|
class AgentState(TypedDict): |
|
|
"""State for the agent.""" |
|
|
messages: Annotated[Sequence[BaseMessage], operator.add] |
|
|
query: str |
|
|
retrieved_context: List[Dict[str, Any]] |
|
|
action_items: List[Dict[str, Any]] |
|
|
blockers: List[Dict[str, Any]] |
|
|
next_step: str |
|
|
final_answer: str |
|
|
|
|
|
|
|
|
class ProjectAgent: |
|
|
"""AI Agent for project management queries.""" |
|
|
|
|
|
def __init__(self, rag: ProjectRAG, model_name: str = "meta-llama/Llama-3.2-3B-Instruct"): |
|
|
"""Initialize the agent.""" |
|
|
self.rag = rag |
|
|
|
|
|
llm = HuggingFaceEndpoint( |
|
|
repo_id=model_name, |
|
|
temperature=0.1, |
|
|
max_new_tokens=512, |
|
|
huggingfacehub_api_token=os.getenv("HF_TOKEN") |
|
|
) |
|
|
self.llm = ChatHuggingFace(llm=llm) |
|
|
self.graph = self._build_graph() |
|
|
|
|
|
def _build_graph(self) -> StateGraph: |
|
|
"""Build the agent's state graph.""" |
|
|
workflow = StateGraph(AgentState) |
|
|
|
|
|
|
|
|
workflow.add_node("analyze_query", self.analyze_query) |
|
|
workflow.add_node("retrieve_context", self.retrieve_context) |
|
|
workflow.add_node("get_action_items", self.get_action_items) |
|
|
workflow.add_node("get_blockers", self.get_blockers) |
|
|
workflow.add_node("generate_answer", self.generate_answer) |
|
|
|
|
|
|
|
|
workflow.set_entry_point("analyze_query") |
|
|
workflow.add_edge("analyze_query", "retrieve_context") |
|
|
workflow.add_conditional_edges( |
|
|
"retrieve_context", |
|
|
self.route_after_retrieval, |
|
|
{ |
|
|
"action_items": "get_action_items", |
|
|
"blockers": "get_blockers", |
|
|
"answer": "generate_answer" |
|
|
} |
|
|
) |
|
|
workflow.add_edge("get_action_items", "generate_answer") |
|
|
workflow.add_edge("get_blockers", "generate_answer") |
|
|
workflow.add_edge("generate_answer", END) |
|
|
|
|
|
return workflow.compile() |
|
|
|
|
|
def analyze_query(self, state: AgentState) -> AgentState: |
|
|
"""Analyze the user's query to understand intent.""" |
|
|
query = state["query"] |
|
|
|
|
|
system_prompt = """You are a query analyzer for a project management assistant. |
|
|
Analyze queries and determine what information is being requested.""" |
|
|
|
|
|
analysis_prompt = f"""Analyze this query and determine: |
|
|
1. What information is being requested? |
|
|
2. Which project (if specified)? |
|
|
3. What type of query is this (action items, blockers, status, decisions, general)? |
|
|
|
|
|
Query: {query} |
|
|
|
|
|
Respond in this format: |
|
|
Type: [action_items|blockers|status|decisions|general] |
|
|
Project: [project name or "all"] |
|
|
Intent: [brief description] |
|
|
""" |
|
|
|
|
|
messages = [ |
|
|
SystemMessage(content=system_prompt), |
|
|
HumanMessage(content=analysis_prompt) |
|
|
] |
|
|
response = self.llm.invoke(messages) |
|
|
|
|
|
state["messages"] = state.get("messages", []) + [ |
|
|
HumanMessage(content=query), |
|
|
AIMessage(content=f"Analysis: {response.content}") |
|
|
] |
|
|
|
|
|
return state |
|
|
|
|
|
def retrieve_context(self, state: AgentState) -> AgentState: |
|
|
"""Retrieve relevant context from the RAG system.""" |
|
|
query = state["query"] |
|
|
|
|
|
|
|
|
project_filter = None |
|
|
projects = self.rag.get_all_projects() |
|
|
for project in projects: |
|
|
if project.lower() in query.lower(): |
|
|
project_filter = project |
|
|
break |
|
|
|
|
|
|
|
|
results = self.rag.search(query, n_results=5, project_filter=project_filter) |
|
|
state["retrieved_context"] = results |
|
|
|
|
|
return state |
|
|
|
|
|
def route_after_retrieval(self, state: AgentState) -> str: |
|
|
"""Route to appropriate node based on query type.""" |
|
|
query = state["query"].lower() |
|
|
|
|
|
if any(term in query for term in ["action item", "todo", "task", "what's next", "what should"]): |
|
|
return "action_items" |
|
|
elif any(term in query for term in ["blocker", "issue", "problem", "stuck"]): |
|
|
return "blockers" |
|
|
else: |
|
|
return "answer" |
|
|
|
|
|
def get_action_items(self, state: AgentState) -> AgentState: |
|
|
"""Get action items from the RAG system.""" |
|
|
query = state["query"].lower() |
|
|
|
|
|
|
|
|
project_filter = None |
|
|
projects = self.rag.get_all_projects() |
|
|
for project in projects: |
|
|
if project.lower() in query: |
|
|
project_filter = project |
|
|
break |
|
|
|
|
|
action_items = self.rag.get_open_action_items(project=project_filter) |
|
|
state["action_items"] = action_items |
|
|
|
|
|
return state |
|
|
|
|
|
def get_blockers(self, state: AgentState) -> AgentState: |
|
|
"""Get blockers from the RAG system.""" |
|
|
query = state["query"].lower() |
|
|
|
|
|
|
|
|
project_filter = None |
|
|
projects = self.rag.get_all_projects() |
|
|
for project in projects: |
|
|
if project.lower() in query: |
|
|
project_filter = project |
|
|
break |
|
|
|
|
|
blockers = self.rag.get_blockers(project=project_filter) |
|
|
state["blockers"] = blockers |
|
|
|
|
|
return state |
|
|
|
|
|
def generate_answer(self, state: AgentState) -> AgentState: |
|
|
"""Generate the final answer using retrieved context.""" |
|
|
query = state["query"] |
|
|
context = state.get("retrieved_context", []) |
|
|
action_items = state.get("action_items", []) |
|
|
blockers = state.get("blockers", []) |
|
|
|
|
|
|
|
|
context_parts = [] |
|
|
|
|
|
if context: |
|
|
context_parts.append("Relevant meeting context:") |
|
|
for i, result in enumerate(context[:3], 1): |
|
|
context_parts.append(f"\n[Context {i}]") |
|
|
context_parts.append(result['content']) |
|
|
if 'metadata' in result: |
|
|
meta = result['metadata'] |
|
|
context_parts.append(f"(From: {meta.get('project', 'Unknown')} - {meta.get('title', 'Unknown')})") |
|
|
|
|
|
if action_items: |
|
|
context_parts.append("\nOpen Action Items:") |
|
|
for item in action_items: |
|
|
assignee = f" ({item['assignee']})" if item.get('assignee') else "" |
|
|
deadline = f" by {item['deadline']}" if item.get('deadline') else "" |
|
|
context_parts.append(f"- {item['task']}{assignee}{deadline}") |
|
|
|
|
|
if blockers: |
|
|
context_parts.append("\nCurrent Blockers:") |
|
|
for blocker in blockers: |
|
|
context_parts.append(f"- {blocker['blocker']}") |
|
|
|
|
|
context_str = "\n".join(context_parts) |
|
|
|
|
|
|
|
|
system_prompt = """You are a helpful AI assistant that helps users manage their projects. |
|
|
Use the provided context to answer the user's question accurately and concisely. |
|
|
Format your response using bullet points for clarity. |
|
|
For action items, list the task with the assignee in parentheses at the end. |
|
|
For blockers and risks, list them directly without project names. |
|
|
Keep responses brief and to the point. Avoid lengthy explanations. |
|
|
Example format: |
|
|
## Next Actions |
|
|
- Task description (Assignee) by deadline |
|
|
- Another task (Assignee) |
|
|
|
|
|
## Blockers/Risks |
|
|
- Blocker description |
|
|
- Another blocker""" |
|
|
|
|
|
messages = [ |
|
|
SystemMessage(content=system_prompt), |
|
|
HumanMessage(content=f"Context:\n{context_str}\n\nQuestion: {query}\n\nAnswer:") |
|
|
] |
|
|
|
|
|
response = self.llm.invoke(messages) |
|
|
state["final_answer"] = response.content |
|
|
|
|
|
return state |
|
|
|
|
|
def query(self, user_query: str) -> str: |
|
|
"""Process a user query and return an answer.""" |
|
|
initial_state = { |
|
|
"messages": [], |
|
|
"query": user_query, |
|
|
"retrieved_context": [], |
|
|
"action_items": [], |
|
|
"blockers": [], |
|
|
"next_step": "", |
|
|
"final_answer": "" |
|
|
} |
|
|
|
|
|
result = self.graph.invoke(initial_state) |
|
|
return result["final_answer"] |