|
|
from collections.abc import AsyncIterable |
|
|
from typing import TYPE_CHECKING, Any |
|
|
|
|
|
from agent_framework import ( |
|
|
AgentRunResponse, |
|
|
AgentRunResponseUpdate, |
|
|
AgentThread, |
|
|
BaseAgent, |
|
|
ChatMessage, |
|
|
Role, |
|
|
) |
|
|
|
|
|
from src.orchestrator import SearchHandlerProtocol |
|
|
from src.utils.models import Citation, Evidence, SearchResult |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from src.services.embeddings import EmbeddingService |
|
|
|
|
|
|
|
|
class SearchAgent(BaseAgent): |
|
|
"""Wraps SearchHandler as an AgentProtocol for Magentic.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
search_handler: SearchHandlerProtocol, |
|
|
evidence_store: dict[str, list[Evidence]], |
|
|
embedding_service: "EmbeddingService | None" = None, |
|
|
) -> None: |
|
|
super().__init__( |
|
|
name="SearchAgent", |
|
|
description="Searches PubMed for drug repurposing evidence", |
|
|
) |
|
|
self._handler = search_handler |
|
|
self._evidence_store = evidence_store |
|
|
self._embeddings = embedding_service |
|
|
|
|
|
async def run( |
|
|
self, |
|
|
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, |
|
|
*, |
|
|
thread: AgentThread | None = None, |
|
|
**kwargs: Any, |
|
|
) -> AgentRunResponse: |
|
|
"""Execute search based on the last user message.""" |
|
|
|
|
|
query = "" |
|
|
if isinstance(messages, list): |
|
|
for msg in reversed(messages): |
|
|
if isinstance(msg, ChatMessage) and msg.role == Role.USER and msg.text: |
|
|
query = msg.text |
|
|
break |
|
|
elif isinstance(msg, str): |
|
|
query = msg |
|
|
break |
|
|
elif isinstance(messages, str): |
|
|
query = messages |
|
|
elif isinstance(messages, ChatMessage) and messages.text: |
|
|
query = messages.text |
|
|
|
|
|
if not query: |
|
|
return AgentRunResponse( |
|
|
messages=[ChatMessage(role=Role.ASSISTANT, text="No query provided")], |
|
|
response_id="search-no-query", |
|
|
) |
|
|
|
|
|
|
|
|
result: SearchResult = await self._handler.execute(query, max_results_per_tool=10) |
|
|
|
|
|
|
|
|
evidence_to_show: list[Evidence] = result.evidence |
|
|
total_new = 0 |
|
|
|
|
|
|
|
|
if self._embeddings: |
|
|
|
|
|
unique_evidence = await self._embeddings.deduplicate(result.evidence) |
|
|
|
|
|
|
|
|
related = await self._embeddings.search_similar(query, n_results=5) |
|
|
|
|
|
|
|
|
existing_urls = {e.citation.url for e in unique_evidence} |
|
|
|
|
|
|
|
|
related_evidence: list[Evidence] = [] |
|
|
for item in related: |
|
|
if item["id"] not in existing_urls: |
|
|
meta = item.get("metadata", {}) |
|
|
|
|
|
authors_str = meta.get("authors", "") |
|
|
authors = [a.strip() for a in authors_str.split(",") if a.strip()] |
|
|
|
|
|
ev = Evidence( |
|
|
content=item["content"], |
|
|
citation=Citation( |
|
|
title=meta.get("title", "Related Evidence"), |
|
|
url=item["id"], |
|
|
source="pubmed", |
|
|
date=meta.get("date", "n.d."), |
|
|
authors=authors, |
|
|
), |
|
|
|
|
|
relevance=max(0.0, 1.0 - item.get("distance", 0.5)), |
|
|
) |
|
|
related_evidence.append(ev) |
|
|
|
|
|
|
|
|
final_new_evidence = unique_evidence + related_evidence |
|
|
|
|
|
|
|
|
global_urls = {e.citation.url for e in self._evidence_store["current"]} |
|
|
really_new = [e for e in final_new_evidence if e.citation.url not in global_urls] |
|
|
self._evidence_store["current"].extend(really_new) |
|
|
|
|
|
total_new = len(really_new) |
|
|
evidence_to_show = unique_evidence + related_evidence |
|
|
|
|
|
else: |
|
|
|
|
|
existing_urls = {e.citation.url for e in self._evidence_store["current"]} |
|
|
new_unique = [e for e in result.evidence if e.citation.url not in existing_urls] |
|
|
self._evidence_store["current"].extend(new_unique) |
|
|
total_new = len(new_unique) |
|
|
evidence_to_show = result.evidence |
|
|
|
|
|
evidence_text = "\n".join( |
|
|
[ |
|
|
f"- [{e.citation.title}]({e.citation.url}): {e.content[:200]}..." |
|
|
for e in evidence_to_show[:5] |
|
|
] |
|
|
) |
|
|
|
|
|
response_text = ( |
|
|
f"Found {result.total_found} sources ({total_new} new added to context):\n\n" |
|
|
f"{evidence_text}" |
|
|
) |
|
|
|
|
|
return AgentRunResponse( |
|
|
messages=[ChatMessage(role=Role.ASSISTANT, text=response_text)], |
|
|
response_id=f"search-{result.total_found}", |
|
|
additional_properties={"evidence": [e.model_dump() for e in evidence_to_show]}, |
|
|
) |
|
|
|
|
|
async def run_stream( |
|
|
self, |
|
|
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, |
|
|
*, |
|
|
thread: AgentThread | None = None, |
|
|
**kwargs: Any, |
|
|
) -> AsyncIterable[AgentRunResponseUpdate]: |
|
|
"""Streaming wrapper for search (search itself isn't streaming).""" |
|
|
result = await self.run(messages, thread=thread, **kwargs) |
|
|
|
|
|
yield AgentRunResponseUpdate(messages=result.messages, response_id=result.response_id) |
|
|
|