Spaces:
Sleeping
Sleeping
File size: 7,751 Bytes
50fcf88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
from typing import Dict, List, Optional
import logging
from langchain_core.documents import Document
from langchain_google_genai import ChatGoogleGenerativeAI
from configuration.parameters import parameters
logger = logging.getLogger(__name__)
def estimate_tokens(text: str, chars_per_token: int = 4) -> int:
"""
Estimate token count from text length.
"""
return len(text) // chars_per_token
class ResearchAgent:
"""
ResearchAgent generates answers to user questions using Gemini LLM,
focusing on extracting factual, source-cited information from documents.
"""
def __init__(
self,
llm: Optional[ChatGoogleGenerativeAI] = None,
top_k: int = None,
max_context_chars: int = None,
max_output_tokens: int = None,
) -> None:
"""
Initialize the research agent with the Gemini model and configuration.
"""
logger.info("[RESEARCH_AGENT] Initializing...")
self.top_k = top_k or parameters.RESEARCH_TOP_K
self.max_context_chars = max_context_chars or parameters.RESEARCH_MAX_CONTEXT_CHARS
self.max_output_tokens = max_output_tokens or parameters.RESEARCH_MAX_OUTPUT_TOKENS
self.llm = llm or ChatGoogleGenerativeAI(
model=parameters.RESEARCH_AGENT_MODEL,
google_api_key=parameters.GOOGLE_API_KEY,
temperature=0.2,
max_output_tokens=self.max_output_tokens
)
logger.info(f"[RESEARCH_AGENT] ✓ Initialized (top_k={self.top_k}, model={parameters.RESEARCH_AGENT_MODEL})")
def sanitize_response(self, response_text: str) -> str:
"""
Sanitize the LLM's response by stripping unnecessary whitespace.
"""
return response_text.strip()
def generate_prompt(self, question: str, context: str, feedback: Optional[str] = None) -> str:
"""
Generate a structured prompt for the LLM to generate a precise and factual answer.
Includes special instructions for handling tables, charts, and visualizations.
"""
base_prompt = f"""You are an AI assistant designed to provide precise and factual answers based on the given context.
**Instructions:**
- Answer the following question using only the provided context.
- Be clear, concise, and factual.
- Return as much information as you can get from the context.
- Only include claims that are directly supported by the context.
**IMPORTANT - Data Consolidation:**
- If multiple charts, tables, or data sources provide similar information, CONSOLIDATE the data and provide a single, unified answer.
- DO NOT list or compare values from multiple versions of the same charts/tables separately.
- Present only the most relevant or consensus value for each data point, unless there is a clear, significant difference that must be explained.
- If there are minor discrepancies, choose the value that appears most frequently or is best supported by the context, and mention only that value.
**IMPORTANT - Chart and Page Reference:**
- When referencing data from a chart, always indicate the chart's heading or title, and also include the page title if available.
- Do NOT use phrases like "another chart" or "a different chart". Always refer to the chart by its heading/title and the page title if you need to mention the source.
**CRITICAL - Table, Chart, and Visualization Handling:**
- Pay VERY CLOSE attention to any tables in the context (formatted with | characters or markdown table format).
- Tables contain structured data - read them carefully row by row, column by column.
- Extract and cite specific numbers, percentages, scores, and ratings from tables.
- If a numbered table (Table 1, Table 4, etc.) is relevant, explicitly mention it and provide the exact values.
- **Analyze complex charts and visualizations** when present in the context:
- Look for chart descriptions, data points, trends, and patterns
- Extract specific values from line charts, bar charts, pie charts, and scatter plots
- Identify trends, correlations, and relationships shown in visualizations
- Note any zones, quadrants, or regions in complex diagrams
- Reference chart titles, axis labels, and legends when citing data
- Compare multiple visualizations if relevant to the question
"""
if feedback:
base_prompt += f"""
**IMPORTANT - Previous Answer Feedback:**
Your previous answer had issues that need to be addressed:
{feedback}
Please generate an improved answer that:
1. Addresses the unsupported claims by finding support in the context tables and charts
2. Fixes any contradictions with the source material
3. Ensures all statements are verifiable from the context
4. Look carefully at ALL tables and visualizations - the data you need may be in a numbered table or chart description
5. Read table data and chart descriptions carefully - each row/data point represents specific information
"""
base_prompt += f"""
**Question:** {question}
**Context (pay special attention to tables marked with ### Table, chart descriptions, and data visualizations):**
{context}
**Provide your answer below (cite specific table numbers, chart references, and exact values from the tables and visualizations):**
"""
return base_prompt
def generate(
self,
question: str,
documents: List[Document],
feedback: Optional[str] = None,
previous_answer: Optional[str] = None
) -> Dict:
"""
Generate an initial answer using the provided documents.
Args:
question: The user's question
documents: List of relevant documents
feedback: Optional feedback from verification agent for re-research
previous_answer: Optional previous answer that failed verification
Returns:
Dict with 'draft_answer' and 'context_used'
"""
logger.info(f"[RESEARCH_AGENT] Generating answer for: {question[:80]}...")
logger.debug(f"[RESEARCH_AGENT] Documents: {len(documents)}, Feedback: {feedback is not None}")
if not documents:
logger.warning("[RESEARCH_AGENT] No documents provided")
return {
"draft_answer": "I could not find supporting documents to answer this question.",
"context_used": ""
}
# Combine the top document contents into one string
context = "\n\n".join([doc.page_content for doc in documents[:self.top_k]])
# Truncate context if too long
if len(context) > self.max_context_chars:
logger.debug(f"[RESEARCH_AGENT] Context truncated: {len(context)} -> {self.max_context_chars} chars")
context = context[:self.max_context_chars]
# Create a prompt for the LLM (with optional feedback)
prompt = self.generate_prompt(question, context, feedback)
# Call the LLM to generate the answer
try:
response = self.llm.invoke(prompt)
content = response.content if hasattr(response, "content") else str(response)
answer = content.strip()
logger.info("[RESEARCH_AGENT] Answer generated successfully")
except Exception as e:
logger.error(f"[RESEARCH_AGENT] LLM call failed: {e}", exc_info=True)
raise RuntimeError("Failed to generate answer due to a model error.") from e
# Sanitize the response
draft_answer = self.sanitize_response(answer) if answer else "I cannot answer this question based on the provided documents."
logger.debug(f"[RESEARCH_AGENT] Answer length: {len(draft_answer)} chars")
return {
"draft_answer": draft_answer,
"context_used": context
}
|