|
|
from typing import Annotated |
|
|
|
|
|
from pydantic import BaseModel, Discriminator, Field, Tag, field_serializer, field_validator |
|
|
from typing_extensions import TypedDict |
|
|
|
|
|
from .content_types import CodeContent, ErrorContent, JSONContent, MediaContent, TextContent, ToolContent |
|
|
|
|
|
|
|
|
def _get_type(d: dict | BaseModel) -> str | None: |
|
|
if isinstance(d, dict): |
|
|
return d.get("type") |
|
|
return getattr(d, "type", None) |
|
|
|
|
|
|
|
|
|
|
|
ContentType = Annotated[ |
|
|
Annotated[ToolContent, Tag("tool_use")] |
|
|
| Annotated[ErrorContent, Tag("error")] |
|
|
| Annotated[TextContent, Tag("text")] |
|
|
| Annotated[MediaContent, Tag("media")] |
|
|
| Annotated[CodeContent, Tag("code")] |
|
|
| Annotated[JSONContent, Tag("json")], |
|
|
Discriminator(_get_type), |
|
|
] |
|
|
|
|
|
|
|
|
class ContentBlock(BaseModel): |
|
|
"""A block of content that can contain different types of content.""" |
|
|
|
|
|
title: str |
|
|
contents: list[ContentType] |
|
|
allow_markdown: bool = Field(default=True) |
|
|
media_url: list[str] | None = None |
|
|
|
|
|
def __init__(self, **data) -> None: |
|
|
super().__init__(**data) |
|
|
schema_dict = self.__pydantic_core_schema__["schema"] |
|
|
if "fields" in schema_dict: |
|
|
fields = schema_dict["fields"] |
|
|
elif "schema" in schema_dict: |
|
|
fields = schema_dict["schema"]["fields"] |
|
|
fields_with_default = (f for f, d in fields.items() if "default" in d["schema"]) |
|
|
self.model_fields_set.update(fields_with_default) |
|
|
|
|
|
@field_validator("contents", mode="before") |
|
|
@classmethod |
|
|
def validate_contents(cls, v) -> list[ContentType]: |
|
|
if isinstance(v, dict): |
|
|
msg = "Contents must be a list of ContentTypes" |
|
|
raise TypeError(msg) |
|
|
return [v] if isinstance(v, BaseModel) else v |
|
|
|
|
|
@field_serializer("contents") |
|
|
def serialize_contents(self, value) -> list[dict]: |
|
|
return [v.model_dump() for v in value] |
|
|
|
|
|
|
|
|
class ContentBlockDict(TypedDict): |
|
|
title: str |
|
|
contents: list[dict] |
|
|
allow_markdown: bool |
|
|
media_url: list[str] | None |
|
|
|