Spaces:
Sleeping
Sleeping
| from openai import OpenAI | |
| from typing import List, Dict, Any, Optional, Tuple | |
| import sys | |
| import os | |
| import traceback | |
| # Add the project root to the path to ensure imports work | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) | |
| # Import configuration | |
| from src.utils.config import CHAT_MODEL, OPENAI_API_KEY | |
| # Import other modules needed for the agents | |
| from src.models.retriever import Retriever | |
| class QueryAnalyzer: | |
| """ | |
| Agent responsible for analyzing and refining the user's query. | |
| """ | |
| def __init__(self, model: str = CHAT_MODEL): | |
| """Initialize the query analyzer.""" | |
| self.model = model | |
| self.client = OpenAI(api_key=OPENAI_API_KEY) | |
| def analyze_query(self, query: str) -> Dict[str, Any]: | |
| """ | |
| Analyze the user's query to extract key information and refine it if needed. | |
| Args: | |
| query: The user's query | |
| Returns: | |
| Dictionary containing analysis results | |
| """ | |
| # Create a system prompt for the query analyzer | |
| system_prompt = ( | |
| "You are a legal query analyzer. Your task is to analyze the user's query to understand:" | |
| "\n1. The legal domain and specific legal concepts involved" | |
| "\n2. What type of legal advice or information they are seeking" | |
| "\n3. Key entities and relationships relevant to their question" | |
| "\n4. Any ambiguities that might need clarification" | |
| "\n\nProvide your analysis in a structured format that our legal research system can use to retrieve relevant information." | |
| ) | |
| # Get analysis from the LLM | |
| try: | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": query} | |
| ], | |
| temperature=0.3 | |
| ) | |
| analysis = response.choices[0].message.content.strip() | |
| # Create a structured analysis | |
| struct_analysis = self._extract_structured_analysis(analysis, query) | |
| return { | |
| "original_query": query, | |
| "analysis": analysis, | |
| "structured_analysis": struct_analysis | |
| } | |
| except Exception as e: | |
| print(f"Error analyzing query: {e}") | |
| return { | |
| "original_query": query, | |
| "analysis": f"Error: {str(e)}", | |
| "structured_analysis": "" | |
| } | |
| def _extract_structured_analysis(self, analysis: str, query: str) -> str: | |
| """Extract a structured analysis from the raw analysis text.""" | |
| # This would normally do more sophisticated extraction | |
| # For demo purposes, we'll just format it with some headers | |
| formatted = "## Query Analysis\n\n" | |
| formatted += "- **Domain**: Legal defense\n" | |
| formatted += f"- **Original Query**: {query}\n" | |
| formatted += "- **Key Concepts**: Legal defense, legal arguments\n" | |
| return formatted | |
| class ContextAggregator: | |
| """ | |
| Agent responsible for aggregating and organizing retrieved document chunks. | |
| """ | |
| def __init__(self, model: str = CHAT_MODEL): | |
| """Initialize the context aggregator.""" | |
| self.model = model | |
| self.client = OpenAI(api_key=OPENAI_API_KEY) | |
| def aggregate_context(self, query: str, retrieved_chunks: List[Dict[str, Any]]) -> str: | |
| """ | |
| Aggregate retrieved chunks into a coherent context. | |
| Args: | |
| query: The user's query | |
| retrieved_chunks: List of retrieved document chunks | |
| Returns: | |
| String containing the organized context | |
| """ | |
| # If small number of chunks, use a simpler approach | |
| if len(retrieved_chunks) <= 10: | |
| # For small number of chunks, just organize them | |
| chunk_contents = [ | |
| { | |
| 'source': chunk.get('source', 'unknown'), | |
| 'content': chunk.get('text', chunk.get('chunk', "No content available")), | |
| 'is_summary': False | |
| } | |
| for chunk in retrieved_chunks | |
| ] | |
| return self._organize_content(query, chunk_contents) | |
| else: | |
| # Group chunks by source | |
| sources = {} | |
| for chunk in retrieved_chunks: | |
| source = chunk.get('source', 'unknown') | |
| if source not in sources: | |
| sources[source] = [] | |
| sources[source].append(chunk) | |
| # Create summaries for each source | |
| summaries = [] | |
| for source, chunks in sources.items(): | |
| summary = self._summarize_chunks(source, chunks, query) | |
| summaries.append(summary) | |
| # Aggregate the summaries and individual chunks | |
| aggregated_context = self._organize_content(query, summaries) | |
| return aggregated_context | |
| def _summarize_chunks(self, source: str, chunks: List[Dict[str, Any]], query: str) -> Dict[str, Any]: | |
| """Summarize a group of chunks from the same source.""" | |
| # Combine chunks into a single text, handling different chunk formats | |
| try: | |
| chunks_text = "\n\n".join([chunk.get('text', chunk.get('chunk', "No content available")) for chunk in chunks]) | |
| except Exception as e: | |
| print(f"Error combining chunks: {e}") | |
| # Fallback to a safer method | |
| chunks_text = "" | |
| for chunk in chunks: | |
| try: | |
| if isinstance(chunk, dict): | |
| chunk_content = chunk.get('text', chunk.get('chunk', "No content available")) | |
| chunks_text += chunk_content + "\n\n" | |
| else: | |
| chunks_text += str(chunk) + "\n\n" | |
| except Exception as chunk_e: | |
| print(f"Error processing individual chunk: {chunk_e}") | |
| continue | |
| # Create a prompt for summarization | |
| system_prompt = ( | |
| "You are a legal document summarizer. Your task is to summarize the provided legal document excerpts " | |
| "in a way that addresses the user's query. Focus on extracting key information, legal principles, " | |
| "and arguments relevant to the query while maintaining factual accuracy." | |
| ) | |
| # Get summary from the LLM | |
| try: | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": f"Query: {query}\n\nDocument excerpts from {source}:\n\n{chunks_text}"} | |
| ], | |
| temperature=0.3 | |
| ) | |
| summary = response.choices[0].message.content.strip() | |
| return { | |
| 'source': source, | |
| 'content': summary, | |
| 'is_summary': True, | |
| 'num_chunks': len(chunks) | |
| } | |
| except Exception as e: | |
| print(f"Error summarizing chunks from {source}: {e}") | |
| return { | |
| 'source': source, | |
| 'content': f"Error summarizing content: {str(e)}", | |
| 'is_summary': True, | |
| 'num_chunks': len(chunks) | |
| } | |
| def _organize_content(self, query: str, contents: List[Dict[str, Any]]) -> str: | |
| """Organize content items into a coherent structure.""" | |
| # Simple organization - separate summaries and regular chunks | |
| organized_text = f"# Relevant Legal Context for: {query}\n\n" | |
| # Add summaries first | |
| summaries = [item for item in contents if item.get('is_summary', False)] | |
| if summaries: | |
| organized_text += "## Summaries of Key Sources\n\n" | |
| for summary in summaries: | |
| organized_text += f"### {summary['source']}\n" | |
| organized_text += f"{summary['content']}\n\n" | |
| # Add individual chunks | |
| individual_chunks = [item for item in contents if not item.get('is_summary', False)] | |
| if individual_chunks: | |
| organized_text += "## Additional Relevant Details\n\n" | |
| for chunk in individual_chunks: | |
| organized_text += f"### From {chunk['source']}\n" | |
| organized_text += f"{chunk['content']}\n\n" | |
| return organized_text | |
| class AnswerGenerator: | |
| """ | |
| Agent responsible for generating comprehensive answers based on the context. | |
| """ | |
| def __init__(self, model: str = CHAT_MODEL): | |
| """Initialize the answer generator.""" | |
| self.model = model | |
| self.client = OpenAI(api_key=OPENAI_API_KEY) | |
| def generate_answer(self, query: str, context: str) -> str: | |
| """ | |
| Generate a comprehensive answer to the user's query using the provided context. | |
| Args: | |
| query: The user's query | |
| context: The organized context | |
| Returns: | |
| The generated answer | |
| """ | |
| # Create a system prompt for the answer generator | |
| system_prompt = ( | |
| "You are a legal expert specialized in providing accurate, comprehensive legal analyses based on provided sources. " | |
| "When answering questions, follow these guidelines:\n" | |
| "1. Base your answers exclusively on the information provided in the context, without adding external knowledge\n" | |
| "2. If the context doesn't contain sufficient information to answer confidently, acknowledge the limitations\n" | |
| "3. Be precise about legal concepts, principles, and precedents mentioned in the sources\n" | |
| "4. Structure your answer clearly with appropriate headings and sections\n" | |
| "5. Maintain objectivity and present multiple perspectives when appropriate\n" | |
| "6. Cite specific sources when referring to key information or arguments" | |
| ) | |
| # Get answer from the LLM | |
| try: | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": f"Question: {query}\n\nContext:\n\n{context}"} | |
| ], | |
| temperature=0.3 | |
| ) | |
| answer = response.choices[0].message.content.strip() | |
| return answer | |
| except Exception as e: | |
| print(f"Error generating answer: {e}") | |
| return f"Error generating answer: {str(e)}" | |
| class AgentDirector: | |
| """ | |
| Director that coordinates the various specialized agents to process a query. | |
| """ | |
| def __init__(self, model: str = None, top_k: int = 200, debug: bool = False): | |
| """ | |
| Initialize the agent director. | |
| Args: | |
| model: The OpenAI chat model to use | |
| top_k: Number of chunks to retrieve | |
| debug: Whether to show detailed reasoning steps | |
| """ | |
| # Ensure model is not None, default to CHAT_MODEL if not provided | |
| self.model = model if model is not None else CHAT_MODEL | |
| self.top_k = top_k | |
| self.debug = debug | |
| self.retriever = Retriever(top_k=top_k) | |
| self.query_analyzer = QueryAnalyzer(model=self.model) | |
| self.context_aggregator = ContextAggregator(model=self.model) | |
| self.answer_generator = AnswerGenerator(model=self.model) | |
| # LegalAgent will be imported on demand | |
| def _debug_print(self, message): | |
| """Print debug message if debug mode is enabled.""" | |
| if self.debug: | |
| print(f"\n🧠 AGENT THINKING: {message}") | |
| def process_query(self, query: str) -> Dict[str, Any]: | |
| """ | |
| Process a user query through the agent pipeline. | |
| Args: | |
| query: The user's query | |
| Returns: | |
| Dictionary containing the results and intermediate steps | |
| """ | |
| results = { | |
| "original_query": query, | |
| "model_used": self.model, | |
| "reasoning_steps": [] if self.debug else None | |
| } | |
| try: | |
| # Step 1: Analyze the query | |
| self._debug_print("Analyzing query to understand intent and extract key entities...") | |
| print("Analyzing query...") | |
| query_analysis = self.query_analyzer.analyze_query(query) | |
| if self.debug: | |
| # Extract key findings from analysis | |
| analysis_text = query_analysis.get("analysis", "") | |
| structured_analysis = query_analysis.get("structured_analysis", "") | |
| reasoning = f"Query analysis complete. I identified these key elements:\n" | |
| # Add a simplified version of the analysis | |
| if structured_analysis: | |
| reasoning += f"{structured_analysis}\n" | |
| else: | |
| reasoning += f"{analysis_text[:300]}...\n" | |
| results["reasoning_steps"].append({ | |
| "stage": "Query Analysis", | |
| "reasoning": reasoning | |
| }) | |
| self._debug_print(reasoning) | |
| results["query_analysis"] = query_analysis | |
| # Step 2: Retrieve relevant chunks | |
| self._debug_print("Searching for relevant document chunks in the knowledge base...") | |
| print("Retrieving documents...") | |
| try: | |
| retrieved_chunks = self.retriever.retrieve(query, self.top_k) | |
| results["num_chunks_retrieved"] = len(retrieved_chunks) | |
| except Exception as e: | |
| print(f"Error during retrieval: {e}") | |
| # Try a simpler approach with fewer chunks | |
| print("Trying with reduced parameters...") | |
| try: | |
| retrieved_chunks = self.retriever.retrieve(query, min(5, self.top_k)) | |
| results["num_chunks_retrieved"] = len(retrieved_chunks) | |
| results["retrieval_fallback_used"] = True | |
| except Exception as inner_e: | |
| print(f"Retrieval completely failed: {inner_e}") | |
| raise | |
| if self.debug: | |
| # Analyze the retrieved chunks | |
| num_chunks = len(retrieved_chunks) | |
| source_summary = {} | |
| # Count chunks per source | |
| for chunk in retrieved_chunks: | |
| source = chunk.get('source', 'unknown') | |
| if source in source_summary: | |
| source_summary[source] += 1 | |
| else: | |
| source_summary[source] = 1 | |
| # Build the reasoning text | |
| sources_text = ", ".join([f"{src} ({count})" for src, count in source_summary.items()]) | |
| reasoning = f"Retrieved {num_chunks} relevant chunks from sources: {sources_text}\n" | |
| if num_chunks > 0: | |
| # Add preview of top chunks | |
| reasoning += f"\nTop results preview:\n" | |
| for i, chunk in enumerate(retrieved_chunks[:3]): | |
| chunk_text = chunk.get('text', chunk.get('chunk', 'No content available')) | |
| preview = chunk_text[:100] + "..." if len(chunk_text) > 100 else chunk_text | |
| reasoning += f"{i+1}. {preview}\n" | |
| results["reasoning_steps"].append({ | |
| "stage": "Document Retrieval", | |
| "reasoning": reasoning | |
| }) | |
| self._debug_print(reasoning) | |
| # Step 3: Aggregate and organize context | |
| self._debug_print("Organizing and structuring retrieved information...") | |
| print("Aggregating context...") | |
| try: | |
| aggregated_context = self.context_aggregator.aggregate_context(query, retrieved_chunks) | |
| results["context_length"] = len(aggregated_context) | |
| except Exception as e: | |
| print(f"Error during context aggregation: {e}") | |
| # Use a simple fallback context | |
| print("Using simple context aggregation as fallback...") | |
| aggregated_context = self.retriever.get_formatted_context(retrieved_chunks) | |
| results["context_length"] = len(aggregated_context) | |
| results["context_fallback_used"] = True | |
| if self.debug: | |
| # Analyze the context | |
| context_preview = aggregated_context[:200] + "..." if len(aggregated_context) > 200 else aggregated_context | |
| word_count = len(aggregated_context.split()) | |
| reasoning = f"Organized {word_count} words of context information for answer generation.\n" | |
| reasoning += f"Context preview: {context_preview}\n" | |
| results["reasoning_steps"].append({ | |
| "stage": "Context Organization", | |
| "reasoning": reasoning | |
| }) | |
| self._debug_print(reasoning) | |
| # Step 4: Generate the answer | |
| self._debug_print("Formulating a comprehensive answer based on organized evidence...") | |
| print("Generating answer...") | |
| try: | |
| answer = self.answer_generator.generate_answer(query, aggregated_context) | |
| except Exception as e: | |
| print(f"Error during answer generation: {e}") | |
| # Use legal agent as fallback | |
| print("Using legal agent for answer generation as fallback...") | |
| # Import legal agent here to avoid circular dependencies | |
| try: | |
| from src.agents.legal_agent import LegalAgent | |
| # Create an instance of LegalAgent | |
| legal_agent = LegalAgent(model=self.model) | |
| legal_answer = legal_agent.answer_query(query, 5) # Use just 5 chunks for fallback | |
| answer = legal_answer.get("answer", "Failed to generate an answer.") | |
| results["answer_fallback_used"] = True | |
| except Exception as legal_error: | |
| print(f"Error using legal agent fallback: {legal_error}") | |
| answer = "Unable to generate an answer due to technical difficulties." | |
| results["answer_fallback_used"] = False | |
| results["answer"] = answer | |
| # Add sources to results | |
| try: | |
| results["sources"] = [chunk.get('source', 'unknown') for chunk in retrieved_chunks[:5]] | |
| except Exception as e: | |
| print(f"Error extracting sources: {e}") | |
| results["sources"] = ["Source information unavailable"] | |
| if self.debug: | |
| # Analyze the answer generation | |
| answer_preview = answer[:150] + "..." if answer else "No answer generated" | |
| reasoning = "Answer generated based on the organized context.\n" | |
| reasoning += f"Preview: {answer_preview}\n" | |
| results["reasoning_steps"].append({ | |
| "stage": "Answer Generation", | |
| "reasoning": reasoning | |
| }) | |
| self._debug_print("Answer generation complete.") | |
| return results | |
| except Exception as e: | |
| error_details = traceback.format_exc() | |
| print(f"Error in agent pipeline, falling back to standard legal agent: {e}") | |
| print(f"Detailed error: {error_details}") | |
| if self.debug: | |
| reasoning = f"Encountered an error: {str(e)}\n" | |
| reasoning += "Falling back to standard legal agent." | |
| results["reasoning_steps"].append({ | |
| "stage": "Error Recovery", | |
| "reasoning": reasoning | |
| }) | |
| self._debug_print(reasoning) | |
| # Fall back to the standard legal agent | |
| try: | |
| print("Using fallback legal agent...") | |
| # Import legal agent here to avoid circular dependencies | |
| try: | |
| from src.agents.legal_agent import LegalAgent | |
| # Create an instance of LegalAgent | |
| legal_agent = LegalAgent(model=self.model) | |
| legal_agent_result = legal_agent.answer_query(query, self.top_k) | |
| results["error"] = str(e) | |
| results["answer"] = legal_agent_result.get("answer", "No answer available from the fallback agent.") | |
| if "sources" in legal_agent_result: | |
| results["sources"] = legal_agent_result["sources"] | |
| return results | |
| except Exception as import_error: | |
| print(f"Error importing legal agent: {import_error}") | |
| raise | |
| except Exception as fallback_error: | |
| # Even the fallback failed, return a simple response | |
| error_msg = f"Main error: {e}\nFallback error: {fallback_error}" | |
| print(f"Fallback agent also failed: {fallback_error}") | |
| results["error"] = error_msg | |
| results["answer"] = "I apologize, but I'm having technical difficulties processing your query. Please try again later." | |
| return results |