|
|
from typing import Optional, List |
|
|
from backend.utils.sanitization import sanitize_user_input |
|
|
from pydantic import BaseModel, Field, validator |
|
|
from enum import Enum |
|
|
|
|
|
|
|
|
class MessagePartType(str, Enum): |
|
|
"""Enumeration of supported message part types.""" |
|
|
TEXT = "text" |
|
|
STEP_START = "step-start" |
|
|
|
|
|
class MessageRole(str, Enum): |
|
|
"""Enumeration of supported message roles.""" |
|
|
USER = "user" |
|
|
ASSISTANT = "assistant" |
|
|
SYSTEM = "system" |
|
|
|
|
|
class MessagePart(BaseModel): |
|
|
"""Represents a single part of a message (text etc.).""" |
|
|
type: MessagePartType |
|
|
text: Optional[str] = Field(None, description="Text content for text-type parts") |
|
|
|
|
|
@validator('text') |
|
|
def validate_text_content(cls, v, values): |
|
|
"""Ensure text parts have non-empty text content.""" |
|
|
if values.get('type') == MessagePartType.TEXT and (not v or not v.strip()): |
|
|
raise ValueError("Text parts must have non-empty text content") |
|
|
return v |
|
|
|
|
|
class Config: |
|
|
use_enum_values = True |
|
|
|
|
|
|
|
|
class Message(BaseModel): |
|
|
"""Represents a single message in a conversation.""" |
|
|
id: str = Field(..., description="Unique identifier for the message part") |
|
|
role: MessageRole |
|
|
parts: List[MessagePart] = Field(..., min_items=1, description="Message content parts") |
|
|
|
|
|
@validator('parts') |
|
|
def validate_parts_not_empty(cls, v): |
|
|
"""Ensure message has at least one valid part.""" |
|
|
if not v: |
|
|
raise ValueError("Message must contain at least one part") |
|
|
return v |
|
|
|
|
|
@validator('id') |
|
|
def validate_id_format(cls, v): |
|
|
"""Ensure ID is non-empty string.""" |
|
|
if not v or not v.strip(): |
|
|
raise ValueError("Message ID cannot be empty") |
|
|
return v.strip() |
|
|
|
|
|
class Config: |
|
|
use_enum_values = True |
|
|
|
|
|
class ChatMessage(BaseModel): |
|
|
"""Represents a complete chat conversation.""" |
|
|
id: str = Field(..., description="Unique identifier for the chat") |
|
|
messages: List[Message] = Field(..., min_items=1, description="List of messages in the conversation") |
|
|
trigger: Optional[str] = Field(None, description="Optional trigger that initiated the chat") |
|
|
|
|
|
@validator('id') |
|
|
def validate_chat_id(cls, v): |
|
|
"""Ensure chat ID is non-empty string.""" |
|
|
if not v or not v.strip(): |
|
|
raise ValueError("Chat ID cannot be empty") |
|
|
return v.strip() |
|
|
|
|
|
@validator('messages') |
|
|
def validate_messages_structure(cls, v): |
|
|
"""Comprehensive validation of messages list.""" |
|
|
if not v: |
|
|
raise ValueError("Chat must contain at least one message") |
|
|
|
|
|
|
|
|
message_ids = [msg.id for msg in v] |
|
|
if len(message_ids) != len(set(message_ids)): |
|
|
raise ValueError("All message IDs must be unique within a chat") |
|
|
|
|
|
|
|
|
user_messages = [msg for msg in v if msg.role == MessageRole.USER] |
|
|
if not user_messages: |
|
|
raise ValueError("Chat must contain at least one user message") |
|
|
return v |
|
|
|
|
|
def get_latest_message(self) -> Optional[Message]: |
|
|
"""Get the most recent sanitized message in the chat.""" |
|
|
latest = self.messages[-1] if self.messages else None |
|
|
if latest and latest.parts: |
|
|
for part in latest.parts: |
|
|
if part.type == MessagePartType.TEXT and part.text: |
|
|
part.text = sanitize_user_input(part.text) |
|
|
return latest |
|
|
|
|
|
class Config: |
|
|
use_enum_values = True |