|
|
"""Report agent for generating structured research reports.""" |
|
|
|
|
|
from collections.abc import AsyncIterable |
|
|
from typing import TYPE_CHECKING, Any |
|
|
|
|
|
from agent_framework import ( |
|
|
AgentRunResponse, |
|
|
AgentRunResponseUpdate, |
|
|
AgentThread, |
|
|
BaseAgent, |
|
|
ChatMessage, |
|
|
Role, |
|
|
) |
|
|
from pydantic_ai import Agent |
|
|
|
|
|
from src.agent_factory.judges import get_model |
|
|
from src.prompts.report import SYSTEM_PROMPT, format_report_prompt |
|
|
from src.utils.citation_validator import validate_references |
|
|
from src.utils.models import Evidence, ResearchReport |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from src.services.embeddings import EmbeddingService |
|
|
|
|
|
|
|
|
class ReportAgent(BaseAgent): |
|
|
"""Generates structured scientific reports from evidence and hypotheses.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
evidence_store: dict[str, Any], |
|
|
embedding_service: "EmbeddingService | None" = None, |
|
|
) -> None: |
|
|
super().__init__( |
|
|
name="ReportAgent", |
|
|
description="Generates structured scientific research reports with citations", |
|
|
) |
|
|
self._evidence_store = evidence_store |
|
|
self._embeddings = embedding_service |
|
|
self._agent: Agent[None, ResearchReport] | None = None |
|
|
|
|
|
def _get_agent(self) -> Agent[None, ResearchReport]: |
|
|
"""Lazy initialization of LLM agent to avoid requiring API keys at import.""" |
|
|
if self._agent is None: |
|
|
self._agent = Agent( |
|
|
model=get_model(), |
|
|
output_type=ResearchReport, |
|
|
system_prompt=SYSTEM_PROMPT, |
|
|
) |
|
|
return self._agent |
|
|
|
|
|
async def run( |
|
|
self, |
|
|
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, |
|
|
*, |
|
|
thread: AgentThread | None = None, |
|
|
**kwargs: Any, |
|
|
) -> AgentRunResponse: |
|
|
"""Generate research report.""" |
|
|
query = self._extract_query(messages) |
|
|
|
|
|
|
|
|
evidence: list[Evidence] = self._evidence_store.get("current", []) |
|
|
hypotheses = self._evidence_store.get("hypotheses", []) |
|
|
assessment = self._evidence_store.get("last_assessment", {}) |
|
|
|
|
|
if not evidence: |
|
|
return AgentRunResponse( |
|
|
messages=[ |
|
|
ChatMessage( |
|
|
role=Role.ASSISTANT, |
|
|
text="Cannot generate report: No evidence collected.", |
|
|
) |
|
|
], |
|
|
response_id="report-no-evidence", |
|
|
) |
|
|
|
|
|
|
|
|
metadata = { |
|
|
"sources": list(set(e.citation.source for e in evidence)), |
|
|
"iterations": self._evidence_store.get("iteration_count", 0), |
|
|
} |
|
|
|
|
|
|
|
|
prompt = await format_report_prompt( |
|
|
query=query, |
|
|
evidence=evidence, |
|
|
hypotheses=hypotheses, |
|
|
assessment=assessment, |
|
|
metadata=metadata, |
|
|
embeddings=self._embeddings, |
|
|
) |
|
|
|
|
|
result = await self._get_agent().run(prompt) |
|
|
report = result.output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
report = validate_references(report, evidence) |
|
|
|
|
|
|
|
|
self._evidence_store["final_report"] = report |
|
|
|
|
|
|
|
|
return AgentRunResponse( |
|
|
messages=[ChatMessage(role=Role.ASSISTANT, text=report.to_markdown())], |
|
|
response_id="report-complete", |
|
|
additional_properties={"report": report.model_dump()}, |
|
|
) |
|
|
|
|
|
def _extract_query( |
|
|
self, messages: str | ChatMessage | list[str] | list[ChatMessage] | None |
|
|
) -> str: |
|
|
"""Extract query from messages.""" |
|
|
if isinstance(messages, str): |
|
|
return messages |
|
|
elif isinstance(messages, ChatMessage): |
|
|
return messages.text or "" |
|
|
elif isinstance(messages, list): |
|
|
for msg in reversed(messages): |
|
|
if isinstance(msg, ChatMessage) and msg.role == Role.USER: |
|
|
return msg.text or "" |
|
|
elif isinstance(msg, str): |
|
|
return msg |
|
|
return "" |
|
|
|
|
|
async def run_stream( |
|
|
self, |
|
|
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, |
|
|
*, |
|
|
thread: AgentThread | None = None, |
|
|
**kwargs: Any, |
|
|
) -> AsyncIterable[AgentRunResponseUpdate]: |
|
|
"""Streaming wrapper.""" |
|
|
result = await self.run(messages, thread=thread, **kwargs) |
|
|
yield AgentRunResponseUpdate( |
|
|
messages=result.messages, |
|
|
response_id=result.response_id, |
|
|
additional_properties=result.additional_properties, |
|
|
) |
|
|
|