from typing import Any, Literal, NotRequired from pydantic import BaseModel, Field, SerializeAsAny from typing_extensions import TypedDict from schema.models import AllModelEnum, AnthropicModelName, OpenAIModelName class AgentInfo(BaseModel): """Info about an available agent.""" key: str = Field( description="Agent key.", examples=["research-assistant"], ) description: str = Field( description="Description of the agent.", examples=["A research assistant for generating research papers."], ) prompts: list[str] = Field( description="List of suggested prompts for the agent.", default=[], ) class ServiceMetadata(BaseModel): """Metadata about the service including available agents and models.""" agents: list[AgentInfo] = Field( description="List of available agents.", ) models: list[AllModelEnum] = Field( description="List of available LLMs.", ) default_agent: str = Field( description="Default agent used when none is specified.", examples=["research-assistant"], ) default_model: AllModelEnum = Field( description="Default model used when none is specified.", ) class UserInput(BaseModel): """Basic user input for the agent.""" message: str = Field( description="User input to the agent.", examples=["What is the weather in Tokyo?"], ) model: SerializeAsAny[AllModelEnum] | None = Field( title="Model", description="LLM Model to use for the agent. Defaults to the default model set in the settings of the service.", default=None, examples=[OpenAIModelName.GPT_5_NANO, AnthropicModelName.HAIKU_45], ) thread_id: str | None = Field( description="Thread ID to persist and continue a multi-turn conversation.", default=None, examples=["847c6285-8fc9-4560-a83f-4e6285809254"], ) user_id: str | None = Field( description="User ID to persist and continue a conversation across multiple threads.", default=None, examples=["847c6285-8fc9-4560-a83f-4e6285809254"], ) agent_config: dict[str, Any] = Field( description="Additional configuration to pass through to the agent", default={}, examples=[{"spicy_level": 0.8}], ) class StreamInput(UserInput): """User input for streaming the agent's response.""" stream_tokens: bool = Field( description="Whether to stream LLM tokens to the client.", default=True, ) class ToolCall(TypedDict): """Represents a request to call a tool.""" name: str """The name of the tool to be called.""" args: dict[str, Any] """The arguments to the tool call.""" id: str | None """An identifier associated with the tool call.""" type: NotRequired[Literal["tool_call"]] class ChatMessage(BaseModel): """Message in a chat.""" type: Literal["human", "ai", "tool", "custom"] = Field( description="Role of the message.", examples=["human", "ai", "tool", "custom"], ) content: str = Field( description="Content of the message.", examples=["Hello, world!"], ) tool_calls: list[ToolCall] = Field( description="Tool calls in the message.", default=[], ) tool_call_id: str | None = Field( description="Tool call that this message is responding to.", default=None, examples=["call_Jja7J89XsjrOLA5r!MEOW!SL"], ) name: str | None = Field( description="Tool name for tool messages (type='tool'). Enables UI to show which tool produced the result.", default=None, ) run_id: str | None = Field( description="Run ID of the message.", default=None, examples=["847c6285-8fc9-4560-a83f-4e6285809254"], ) response_metadata: dict[str, Any] = Field( description="Response metadata. For example: response headers, logprobs, token counts.", default={}, ) custom_data: dict[str, Any] = Field( description="Custom message data.", default={}, ) def pretty_repr(self) -> str: """Get a pretty representation of the message.""" base_title = self.type.title() + " Message" padded = " " + base_title + " " sep_len = (80 - len(padded)) // 2 sep = "=" * sep_len second_sep = sep + "=" if len(padded) % 2 else sep title = f"{sep}{padded}{second_sep}" return f"{title}\n\n{self.content}" def pretty_print(self) -> None: print(self.pretty_repr()) # noqa: T201 class Feedback(BaseModel): # type: ignore[no-redef] """Feedback for a run, to record to LangSmith.""" run_id: str = Field( description="Run ID to record feedback for.", examples=["847c6285-8fc9-4560-a83f-4e6285809254"], ) key: str = Field( description="Feedback key.", examples=["human-feedback-stars"], ) score: float = Field( description="Feedback score.", examples=[0.8], ) kwargs: dict[str, Any] = Field( description="Additional feedback kwargs, passed to LangSmith.", default={}, examples=[{"comment": "In-line human feedback"}], ) class FeedbackResponse(BaseModel): status: Literal["success"] = "success" class ChatMessagePreview(BaseModel): """Minimal message for preview/list views (type, content snippet, id).""" type: Literal["human", "ai", "tool", "custom"] = Field( description="Role of the message.", ) content: str = Field( description="Content of the message (may be truncated for preview).", ) id: str | None = Field( default=None, description="Stable id for cursor/linking (e.g. index).", ) class ChatHistoryInput(BaseModel): """Input for retrieving chat history.""" user_id: str = Field( description="User ID to scope history to the current user.", examples=["847c6285-8fc9-4560-a83f-4e6285809254"], ) thread_id: str = Field( description="Thread ID to persist and continue a multi-turn conversation.", examples=["847c6285-8fc9-4560-a83f-4e6285809254"], ) limit: int = Field( default=50, ge=1, le=200, description="Max number of messages to return per page.", ) cursor: str | None = Field( default=None, description="Opaque cursor for pagination (older messages).", ) view: Literal["full", "preview"] = Field( default="full", description="full = all fields; preview = type, content (truncated), id only.", ) class ChatHistory(BaseModel): """Legacy response: messages only (no cursors).""" messages: list[ChatMessage] class ChatHistoryResponse(BaseModel): """Paginated chat history with cursors.""" messages: list[ChatMessage] | list[ChatMessagePreview] = Field( default_factory=list, description="Messages in this page (full or preview by view).", ) next_cursor: str | None = Field( default=None, description="Cursor for next page (older messages).", ) prev_cursor: str | None = Field( default=None, description="Cursor for previous page (newer messages).", ) class ThreadSummary(BaseModel): """Summary of a conversation thread for listing.""" thread_id: str = Field(description="Thread ID (logical id, without user prefix).") updated_at: str | None = Field( default=None, description="ISO 8601 timestamp of last update.", ) preview: str | None = Field( default=None, description="Preview of the first or last message.", ) class ThreadListInput(BaseModel): """Input for listing threads for a user.""" user_id: str = Field( description="User ID to list threads for.", examples=["847c6285-8fc9-4560-a83f-4e6285809254"], ) limit: int = Field(default=20, ge=1, le=100, description="Max number of threads to return.") offset: int = Field(default=0, ge=0, description="Number of threads to skip.") search: str | None = Field(default=None, description="Filter threads by thread_id or preview (case-insensitive).") class ThreadList(BaseModel): """List of threads for a user.""" threads: list[ThreadSummary] = Field(default_factory=list) total: int | None = Field(default=None, description="Total matching threads before pagination.")