|
|
""" |
|
|
Synthesis Agent: Compare findings across papers and identify patterns. |
|
|
""" |
|
|
import os |
|
|
import json |
|
|
import logging |
|
|
from typing import Dict, Any, List |
|
|
from openai import AzureOpenAI |
|
|
|
|
|
from utils.schemas import Analysis, SynthesisResult, ConsensusPoint, Contradiction, Paper |
|
|
from rag.retrieval import RAGRetriever |
|
|
from utils.langfuse_client import observe |
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class SynthesisAgent: |
|
|
"""Agent for synthesizing findings across multiple papers.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
rag_retriever: RAGRetriever, |
|
|
model=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME"), |
|
|
temperature: float = 0.0, |
|
|
timeout: int = 90 |
|
|
): |
|
|
""" |
|
|
Initialize Synthesis Agent. |
|
|
|
|
|
Args: |
|
|
rag_retriever: RAGRetriever instance |
|
|
model: Azure OpenAI model deployment name |
|
|
temperature: Temperature for generation (0 for deterministic) |
|
|
timeout: Request timeout in seconds (default: 90, longer than analyzer) |
|
|
""" |
|
|
self.rag_retriever = rag_retriever |
|
|
self.model = model |
|
|
self.temperature = temperature |
|
|
self.timeout = timeout |
|
|
|
|
|
|
|
|
self.client = AzureOpenAI( |
|
|
api_key=os.getenv("AZURE_OPENAI_API_KEY"), |
|
|
|
|
|
api_version=os.getenv("AZURE_OPENAI_API_VERSION"), |
|
|
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), |
|
|
timeout=timeout |
|
|
) |
|
|
|
|
|
def _create_synthesis_prompt( |
|
|
self, |
|
|
papers: List[Paper], |
|
|
analyses: List[Analysis], |
|
|
query: str |
|
|
) -> str: |
|
|
"""Create prompt for synthesis.""" |
|
|
|
|
|
paper_summaries = [] |
|
|
for paper, analysis in zip(papers, analyses): |
|
|
summary = f""" |
|
|
Paper ID: {paper.arxiv_id} |
|
|
Title: {paper.title} |
|
|
Authors: {", ".join(paper.authors)} |
|
|
|
|
|
Analysis: |
|
|
- Methodology: {analysis.methodology} |
|
|
- Key Findings: {", ".join(analysis.key_findings)} |
|
|
- Conclusions: {analysis.conclusions} |
|
|
- Contributions: {", ".join(analysis.main_contributions)} |
|
|
- Limitations: {", ".join(analysis.limitations)} |
|
|
""" |
|
|
paper_summaries.append(summary) |
|
|
|
|
|
prompt = f"""You are a research synthesis expert. Analyze the following papers in relation to the user's research question. |
|
|
|
|
|
Research Question: {query} |
|
|
|
|
|
Papers Analyzed: |
|
|
{"=" * 80} |
|
|
{chr(10).join(paper_summaries)} |
|
|
{"=" * 80} |
|
|
|
|
|
Synthesize these findings and provide: |
|
|
1. Consensus points - areas where papers agree |
|
|
2. Contradictions - areas where papers disagree |
|
|
3. Research gaps - what's missing or needs further investigation |
|
|
4. Executive summary addressing the research question |
|
|
|
|
|
Provide your synthesis in the following JSON format: |
|
|
{{ |
|
|
"consensus_points": [ |
|
|
{{ |
|
|
"statement": "Clear consensus statement", |
|
|
"supporting_papers": ["arxiv_id1", "arxiv_id2"], |
|
|
"citations": ["Specific evidence from papers"], |
|
|
"confidence": 0.0-1.0 |
|
|
}} |
|
|
], |
|
|
"contradictions": [ |
|
|
{{ |
|
|
"topic": "Topic of disagreement", |
|
|
"viewpoint_a": "First viewpoint", |
|
|
"papers_a": ["arxiv_id1"], |
|
|
"viewpoint_b": "Second viewpoint", |
|
|
"papers_b": ["arxiv_id2"], |
|
|
"citations": ["Evidence for both sides"], |
|
|
"confidence": 0.0-1.0 |
|
|
}} |
|
|
], |
|
|
"research_gaps": [ |
|
|
"Gap 1: What's missing", |
|
|
"Gap 2: What needs further research" |
|
|
], |
|
|
"summary": "Executive summary addressing the research question with synthesis of all findings", |
|
|
"confidence_score": 0.0-1.0 |
|
|
}} |
|
|
|
|
|
CRITICAL JSON FORMATTING RULES: |
|
|
- Ground all statements in the provided analyses |
|
|
- Be specific about which papers support which claims |
|
|
- Identify both agreements and disagreements |
|
|
- Provide confidence scores based on consistency and evidence strength |
|
|
- For ALL array fields (citations, supporting_papers, papers_a, papers_b, research_gaps): |
|
|
* MUST be flat arrays of strings ONLY: ["item1", "item2"] |
|
|
* NEVER nest arrays: [[], "text"] or [["nested"]] are INVALID |
|
|
* NEVER include null, empty strings, or non-string values |
|
|
* Each array element must be a non-empty string |
|
|
""" |
|
|
return prompt |
|
|
|
|
|
def _normalize_synthesis_response(self, data: dict) -> dict: |
|
|
""" |
|
|
Normalize synthesis LLM response to ensure all list fields contain only strings. |
|
|
|
|
|
Handles nested lists, None values, and mixed types in: |
|
|
- consensus_points[].citations |
|
|
- consensus_points[].supporting_papers |
|
|
- contradictions[].citations |
|
|
- contradictions[].papers_a |
|
|
- contradictions[].papers_b |
|
|
- research_gaps |
|
|
|
|
|
Args: |
|
|
data: Raw synthesis data dictionary from LLM |
|
|
|
|
|
Returns: |
|
|
Normalized dictionary with correct types for all fields |
|
|
""" |
|
|
def flatten_and_clean(value): |
|
|
"""Recursively flatten nested lists and clean values.""" |
|
|
if isinstance(value, str): |
|
|
return [value.strip()] if value.strip() else [] |
|
|
elif isinstance(value, list): |
|
|
cleaned = [] |
|
|
for item in value: |
|
|
if isinstance(item, str): |
|
|
if item.strip(): |
|
|
cleaned.append(item.strip()) |
|
|
elif isinstance(item, list): |
|
|
cleaned.extend(flatten_and_clean(item)) |
|
|
elif item is not None and str(item).strip(): |
|
|
cleaned.append(str(item).strip()) |
|
|
return cleaned |
|
|
elif value is not None: |
|
|
str_value = str(value).strip() |
|
|
return [str_value] if str_value else [] |
|
|
else: |
|
|
return [] |
|
|
|
|
|
|
|
|
if "research_gaps" in data: |
|
|
data["research_gaps"] = flatten_and_clean(data["research_gaps"]) |
|
|
else: |
|
|
data["research_gaps"] = [] |
|
|
|
|
|
|
|
|
if "consensus_points" in data and isinstance(data["consensus_points"], list): |
|
|
for cp in data["consensus_points"]: |
|
|
if isinstance(cp, dict): |
|
|
cp["citations"] = flatten_and_clean(cp.get("citations", [])) |
|
|
cp["supporting_papers"] = flatten_and_clean(cp.get("supporting_papers", [])) |
|
|
|
|
|
|
|
|
if "contradictions" in data and isinstance(data["contradictions"], list): |
|
|
for contr in data["contradictions"]: |
|
|
if isinstance(contr, dict): |
|
|
contr["citations"] = flatten_and_clean(contr.get("citations", [])) |
|
|
contr["papers_a"] = flatten_and_clean(contr.get("papers_a", [])) |
|
|
contr["papers_b"] = flatten_and_clean(contr.get("papers_b", [])) |
|
|
|
|
|
logger.debug("Synthesis response normalized successfully") |
|
|
return data |
|
|
|
|
|
def synthesize( |
|
|
self, |
|
|
papers: List[Paper], |
|
|
analyses: List[Analysis], |
|
|
query: str, |
|
|
state: Dict[str, Any] |
|
|
) -> SynthesisResult: |
|
|
""" |
|
|
Synthesize findings across papers. |
|
|
|
|
|
Args: |
|
|
papers: List of Paper objects |
|
|
analyses: List of Analysis objects |
|
|
query: Original research question |
|
|
state: Agent state for token tracking |
|
|
|
|
|
Returns: |
|
|
SynthesisResult object |
|
|
""" |
|
|
try: |
|
|
logger.info(f"Synthesizing {len(papers)} papers") |
|
|
|
|
|
|
|
|
prompt = self._create_synthesis_prompt(papers, analyses, query) |
|
|
|
|
|
|
|
|
response = self.client.chat.completions.create( |
|
|
model=self.model, |
|
|
messages=[ |
|
|
{"role": "system", "content": "You are a research synthesis expert. Provide accurate, grounded synthesis based only on the provided analyses."}, |
|
|
{"role": "user", "content": prompt} |
|
|
], |
|
|
temperature=self.temperature, |
|
|
max_tokens=2500, |
|
|
response_format={"type": "json_object"} |
|
|
) |
|
|
|
|
|
|
|
|
if hasattr(response, 'usage') and response.usage: |
|
|
prompt_tokens = response.usage.prompt_tokens |
|
|
completion_tokens = response.usage.completion_tokens |
|
|
state["token_usage"]["input_tokens"] += prompt_tokens |
|
|
state["token_usage"]["output_tokens"] += completion_tokens |
|
|
logger.info(f"Synthesis token usage: {prompt_tokens} input, {completion_tokens} output") |
|
|
|
|
|
|
|
|
synthesis_data = json.loads(response.choices[0].message.content) |
|
|
|
|
|
|
|
|
synthesis_data = self._normalize_synthesis_response(synthesis_data) |
|
|
|
|
|
|
|
|
consensus_points = [ |
|
|
ConsensusPoint(**cp) for cp in synthesis_data.get("consensus_points", []) |
|
|
] |
|
|
|
|
|
contradictions = [ |
|
|
Contradiction(**c) for c in synthesis_data.get("contradictions", []) |
|
|
] |
|
|
|
|
|
|
|
|
synthesis = SynthesisResult( |
|
|
consensus_points=consensus_points, |
|
|
contradictions=contradictions, |
|
|
research_gaps=synthesis_data.get("research_gaps", []), |
|
|
summary=synthesis_data.get("summary", ""), |
|
|
confidence_score=synthesis_data.get("confidence_score", 0.5), |
|
|
papers_analyzed=[p.arxiv_id for p in papers] |
|
|
) |
|
|
|
|
|
logger.info(f"Synthesis completed with confidence {synthesis.confidence_score:.2f}") |
|
|
return synthesis |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error during synthesis: {str(e)}") |
|
|
|
|
|
return SynthesisResult( |
|
|
consensus_points=[], |
|
|
contradictions=[], |
|
|
research_gaps=["Synthesis failed - unable to identify gaps"], |
|
|
summary="Synthesis failed due to an error", |
|
|
confidence_score=0.0, |
|
|
papers_analyzed=[p.arxiv_id for p in papers] |
|
|
) |
|
|
|
|
|
@observe(name="synthesis_agent_run", as_type="generation") |
|
|
def run(self, state: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Execute synthesis agent. |
|
|
|
|
|
Args: |
|
|
state: Current agent state |
|
|
|
|
|
Returns: |
|
|
Updated state with synthesis |
|
|
""" |
|
|
try: |
|
|
logger.info("=== Synthesis Agent Started ===") |
|
|
|
|
|
papers = state.get("papers", []) |
|
|
analyses = state.get("analyses", []) |
|
|
query = state.get("query", "") |
|
|
|
|
|
if not papers or not analyses: |
|
|
error_msg = "No papers or analyses available for synthesis" |
|
|
logger.error(error_msg) |
|
|
state["errors"].append(error_msg) |
|
|
return state |
|
|
|
|
|
if len(papers) != len(analyses): |
|
|
error_msg = f"Mismatch: {len(papers)} papers but {len(analyses)} analyses" |
|
|
logger.warning(error_msg) |
|
|
|
|
|
min_len = min(len(papers), len(analyses)) |
|
|
papers = papers[:min_len] |
|
|
analyses = analyses[:min_len] |
|
|
|
|
|
|
|
|
synthesis = self.synthesize(papers, analyses, query, state) |
|
|
state["synthesis"] = synthesis |
|
|
|
|
|
logger.info("=== Synthesis Agent Completed ===") |
|
|
return state |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Synthesis Agent error: {str(e)}" |
|
|
logger.error(error_msg) |
|
|
state["errors"].append(error_msg) |
|
|
return state |
|
|
|