| | from pathlib import Path |
| |
|
| | from pydantic import BaseModel, Field |
| | from pydantic_ai import Agent, ModelRetry, RunContext |
| | from pydantic_ai.messages import ( |
| | ModelMessage, |
| | ModelRequest, |
| | ModelResponse, |
| | TextPart, |
| | ToolReturnPart, |
| | ) |
| | from pydantic_ai.models.google import GoogleModel, GoogleModelSettings |
| |
|
| | from app import models |
| | from app.tools import dailymed, literature |
| |
|
| |
|
| | class Context(BaseModel): |
| | thoughts: list[str] |
| | sources: dict[str, dict] |
| |
|
| |
|
| | class Statement(BaseModel): |
| | text: str |
| | sources: list[str] | None = Field( |
| | default=None, description="ID of the sources that support this statement." |
| | ) |
| |
|
| |
|
| | def get_context(messages: list[ModelMessage]) -> Context: |
| | thoughts: list[str] = [] |
| | sources: dict[str, dict] = {} |
| |
|
| | for message in messages: |
| | if isinstance(message, ModelResponse): |
| | for part in message.parts: |
| | if isinstance(part, TextPart): |
| | thoughts.append(part.content) |
| | elif isinstance(message, ModelRequest): |
| | for part in message.parts: |
| | if isinstance(part, ToolReturnPart) and part.tool_name in { |
| | "search_medical_literature", |
| | "find_drug_set_ids", |
| | }: |
| | for item in part.content: |
| | sources[item["id"]] = item |
| |
|
| | return Context(thoughts=thoughts, sources=sources) |
| |
|
| |
|
| | def create_response(ctx: RunContext, output: list[Statement]) -> models.Statements: |
| | context = get_context(ctx.messages) |
| |
|
| | statements = [] |
| | for statement in output: |
| | sources = [] |
| | for source_id in statement.sources or []: |
| | try: |
| | sources.append(context.sources[source_id]) |
| | except KeyError as err: |
| | raise ModelRetry( |
| | f"Source ID '{source_id}' not found in literature." |
| | ) from err |
| | statements.append({"text": statement.text, "sources": sources}) |
| |
|
| | return models.Statements.model_validate( |
| | { |
| | "statements": statements, |
| | "thoughts": "\n\n".join(context.thoughts), |
| | } |
| | ) |
| |
|
| |
|
| | model = GoogleModel("gemini-2.5-flash-preview-05-20") |
| | settings = GoogleModelSettings( |
| | |
| | google_thinking_config={"thinking_budget": 1024, "include_thoughts": True}, |
| | ) |
| | agent = Agent( |
| | model=model, |
| | name="elna", |
| | model_settings=settings, |
| | output_type=create_response, |
| | system_prompt=(Path(__file__).parent / "system_instruction.txt").read_text(), |
| | tools=[ |
| | dailymed.find_drug_set_ids, |
| | dailymed.find_drug_instruction, |
| | literature.search_medical_literature, |
| | ], |
| | ) |
| |
|