Spaces:
Sleeping
Sleeping
| """ | |
| AI Agent for project management using LangGraph. | |
| """ | |
| from typing import TypedDict, Annotated, Sequence, List, Dict, Any | |
| import operator | |
| import os | |
| from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langgraph.graph import StateGraph, END | |
| from src.rag import ProjectRAG | |
| class AgentState(TypedDict): | |
| """State for the agent.""" | |
| messages: Annotated[Sequence[BaseMessage], operator.add] | |
| query: str | |
| retrieved_context: List[Dict[str, Any]] | |
| action_items: List[Dict[str, Any]] | |
| blockers: List[Dict[str, Any]] | |
| next_step: str | |
| final_answer: str | |
| class ProjectAgent: | |
| """AI Agent for project management queries.""" | |
| def __init__(self, rag: ProjectRAG, provider: str = "huggingface", model_name: str = None): | |
| """Initialize the agent. | |
| Args: | |
| rag: ProjectRAG instance for retrieval | |
| provider: "huggingface" (free) or "google" (paid) | |
| model_name: Optional model name override | |
| """ | |
| self.rag = rag | |
| self.provider = provider | |
| if provider == "google": | |
| # Use Google Gemini API (paid) | |
| google_api_key = os.getenv("GOOGLE_API_KEY") | |
| if not google_api_key: | |
| raise ValueError("GOOGLE_API_KEY environment variable not set") | |
| self.llm = ChatGoogleGenerativeAI( | |
| model=model_name or "gemini-2.5-flash-lite", | |
| temperature=0.1, | |
| google_api_key=google_api_key, | |
| timeout=60, # 60 second timeout | |
| convert_system_message_to_human=True # Better compatibility | |
| ) | |
| else: | |
| # Use HF Inference API (free tier) | |
| hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") | |
| if not hf_token: | |
| raise ValueError("HF_TOKEN environment variable not set") | |
| llm = HuggingFaceEndpoint( | |
| repo_id=model_name or "meta-llama/Llama-3.2-3B-Instruct", | |
| temperature=0.1, | |
| max_new_tokens=512, | |
| huggingfacehub_api_token=hf_token, | |
| timeout=60 # 60 second timeout to prevent hanging | |
| ) | |
| self.llm = ChatHuggingFace(llm=llm) | |
| self.graph = self._build_graph() | |
| def _build_graph(self) -> StateGraph: | |
| """Build the agent's state graph.""" | |
| workflow = StateGraph(AgentState) | |
| # Add nodes | |
| workflow.add_node("analyze_query", self.analyze_query) | |
| workflow.add_node("retrieve_context", self.retrieve_context) | |
| workflow.add_node("get_action_items", self.get_action_items) | |
| workflow.add_node("get_blockers", self.get_blockers) | |
| workflow.add_node("generate_answer", self.generate_answer) | |
| # Add edges | |
| workflow.set_entry_point("analyze_query") | |
| workflow.add_edge("analyze_query", "retrieve_context") | |
| workflow.add_conditional_edges( | |
| "retrieve_context", | |
| self.route_after_retrieval, | |
| { | |
| "action_items": "get_action_items", | |
| "blockers": "get_blockers", | |
| "answer": "generate_answer" | |
| } | |
| ) | |
| workflow.add_edge("get_action_items", "generate_answer") | |
| workflow.add_edge("get_blockers", "generate_answer") | |
| workflow.add_edge("generate_answer", END) | |
| return workflow.compile() | |
| def analyze_query(self, state: AgentState) -> AgentState: | |
| """Analyze the user's query to understand intent.""" | |
| query = state["query"] | |
| system_prompt = """You are a query analyzer for a project management assistant. | |
| Analyze queries and determine what information is being requested.""" | |
| analysis_prompt = f"""Analyze this query and determine: | |
| 1. What information is being requested? | |
| 2. Which project (if specified)? | |
| 3. What type of query is this (action items, blockers, status, decisions, general)? | |
| Query: {query} | |
| Respond in this format: | |
| Type: [action_items|blockers|status|decisions|general] | |
| Project: [project name or "all"] | |
| Intent: [brief description] | |
| """ | |
| messages = [ | |
| SystemMessage(content=system_prompt), | |
| HumanMessage(content=analysis_prompt) | |
| ] | |
| response = self.llm.invoke(messages) | |
| state["messages"] = state.get("messages", []) + [ | |
| HumanMessage(content=query), | |
| AIMessage(content=f"Analysis: {response.content}") | |
| ] | |
| return state | |
| def retrieve_context(self, state: AgentState) -> AgentState: | |
| """Retrieve relevant context from the RAG system.""" | |
| query = state["query"] | |
| # Extract project name if mentioned | |
| project_filter = None | |
| projects = self.rag.get_all_projects() | |
| for project in projects: | |
| if project.lower() in query.lower(): | |
| project_filter = project | |
| break | |
| # Search for relevant context | |
| results = self.rag.search(query, n_results=5, project_filter=project_filter) | |
| state["retrieved_context"] = results | |
| return state | |
| def route_after_retrieval(self, state: AgentState) -> str: | |
| """Route to appropriate node based on query type.""" | |
| query = state["query"].lower() | |
| if any(term in query for term in ["action item", "todo", "task", "what's next", "what should"]): | |
| return "action_items" | |
| elif any(term in query for term in ["blocker", "issue", "problem", "stuck"]): | |
| return "blockers" | |
| else: | |
| return "answer" | |
| def get_action_items(self, state: AgentState) -> AgentState: | |
| """Get action items from the RAG system.""" | |
| query = state["query"].lower() | |
| # Extract project name if mentioned | |
| project_filter = None | |
| projects = self.rag.get_all_projects() | |
| for project in projects: | |
| if project.lower() in query: | |
| project_filter = project | |
| break | |
| action_items = self.rag.get_open_action_items(project=project_filter) | |
| state["action_items"] = action_items | |
| return state | |
| def get_blockers(self, state: AgentState) -> AgentState: | |
| """Get blockers from the RAG system.""" | |
| query = state["query"].lower() | |
| # Extract project name if mentioned | |
| project_filter = None | |
| projects = self.rag.get_all_projects() | |
| for project in projects: | |
| if project.lower() in query: | |
| project_filter = project | |
| break | |
| blockers = self.rag.get_blockers(project=project_filter) | |
| state["blockers"] = blockers | |
| return state | |
| def generate_answer(self, state: AgentState) -> AgentState: | |
| """Generate the final answer using retrieved context.""" | |
| query = state["query"] | |
| context = state.get("retrieved_context", []) | |
| action_items = state.get("action_items", []) | |
| blockers = state.get("blockers", []) | |
| # Build context string | |
| context_parts = [] | |
| if context: | |
| context_parts.append("Relevant meeting context:") | |
| for i, result in enumerate(context[:3], 1): | |
| context_parts.append(f"\n[Context {i}]") | |
| context_parts.append(result['content']) | |
| if 'metadata' in result: | |
| meta = result['metadata'] | |
| context_parts.append(f"(From: {meta.get('project', 'Unknown')} - {meta.get('title', 'Unknown')})") | |
| if action_items: | |
| context_parts.append("\nOpen Action Items:") | |
| for item in action_items: | |
| assignee = f" ({item['assignee']})" if item.get('assignee') else "" | |
| deadline = f" by {item['deadline']}" if item.get('deadline') else "" | |
| context_parts.append(f"- {item['task']}{assignee}{deadline}") | |
| if blockers: | |
| context_parts.append("\nCurrent Blockers:") | |
| for blocker in blockers: | |
| context_parts.append(f"- {blocker['blocker']}") | |
| context_str = "\n".join(context_parts) | |
| # Generate answer | |
| system_prompt = """You are a helpful AI assistant that helps users manage their projects. | |
| Use the provided context to answer the user's question accurately and concisely. | |
| Format your response using bullet points for clarity. | |
| For action items, list the task with the assignee in parentheses at the end. | |
| For blockers and risks, list them directly without project names. | |
| Keep responses brief and to the point. Avoid lengthy explanations. | |
| Example format: | |
| ## Next Actions | |
| - Task description (Assignee) by deadline | |
| - Another task (Assignee) | |
| ## Blockers/Risks | |
| - Blocker description | |
| - Another blocker""" | |
| messages = [ | |
| SystemMessage(content=system_prompt), | |
| HumanMessage(content=f"Context:\n{context_str}\n\nQuestion: {query}\n\nAnswer:") | |
| ] | |
| response = self.llm.invoke(messages) | |
| state["final_answer"] = response.content | |
| return state | |
| def query(self, user_query: str) -> str: | |
| """Process a user query and return an answer.""" | |
| initial_state = { | |
| "messages": [], | |
| "query": user_query, | |
| "retrieved_context": [], | |
| "action_items": [], | |
| "blockers": [], | |
| "next_step": "", | |
| "final_answer": "" | |
| } | |
| result = self.graph.invoke(initial_state) | |
| return result["final_answer"] | |
| def stream_query(self, user_query: str): | |
| """Process a user query and stream the answer token by token.""" | |
| # First run analysis and retrieval (non-streaming) | |
| initial_state = { | |
| "messages": [], | |
| "query": user_query, | |
| "retrieved_context": [], | |
| "action_items": [], | |
| "blockers": [], | |
| "next_step": "", | |
| "final_answer": "" | |
| } | |
| # Run through analysis and retrieval nodes | |
| state = self.analyze_query(initial_state) | |
| state = self.retrieve_context(state) | |
| # Determine route and get additional data | |
| route = self.route_after_retrieval(state) | |
| if route == "action_items": | |
| state = self.get_action_items(state) | |
| elif route == "blockers": | |
| state = self.get_blockers(state) | |
| # Now stream the final answer generation | |
| query = state["query"] | |
| context = state.get("retrieved_context", []) | |
| action_items = state.get("action_items", []) | |
| blockers = state.get("blockers", []) | |
| # Build context string | |
| context_parts = [] | |
| if context: | |
| context_parts.append("Relevant meeting context:") | |
| for i, result in enumerate(context[:3], 1): | |
| context_parts.append(f"\n[Context {i}]") | |
| context_parts.append(result['content']) | |
| if 'metadata' in result: | |
| meta = result['metadata'] | |
| context_parts.append(f"(From: {meta.get('project', 'Unknown')} - {meta.get('title', 'Unknown')})") | |
| if action_items: | |
| context_parts.append("\nOpen Action Items:") | |
| for item in action_items: | |
| assignee = f" ({item['assignee']})" if item.get('assignee') else "" | |
| deadline = f" by {item['deadline']}" if item.get('deadline') else "" | |
| context_parts.append(f"- {item['task']}{assignee}{deadline}") | |
| if blockers: | |
| context_parts.append("\nCurrent Blockers:") | |
| for blocker in blockers: | |
| context_parts.append(f"- {blocker['blocker']}") | |
| context_str = "\n".join(context_parts) | |
| # Generate streaming answer | |
| system_prompt = """You are a helpful AI assistant that helps users manage their projects. | |
| Use the provided context to answer the user's question accurately and concisely. | |
| Format your response using bullet points for clarity. | |
| For action items, list the task with the assignee in parentheses at the end. | |
| For blockers and risks, list them directly without project names. | |
| Keep responses brief and to the point. Avoid lengthy explanations. | |
| Example format: | |
| ## Next Actions | |
| - Task description (Assignee) by deadline | |
| - Another task (Assignee) | |
| ## Blockers/Risks | |
| - Blocker description | |
| - Another blocker""" | |
| messages = [ | |
| SystemMessage(content=system_prompt), | |
| HumanMessage(content=f"Context:\n{context_str}\n\nQuestion: {query}\n\nAnswer:") | |
| ] | |
| # Stream tokens | |
| full_response = "" | |
| try: | |
| for chunk in self.llm.stream(messages): | |
| if hasattr(chunk, 'content') and chunk.content: | |
| full_response += chunk.content | |
| yield full_response | |
| except Exception: | |
| # Fallback to non-streaming if streaming not supported | |
| response = self.llm.invoke(messages) | |
| yield response.content |