File size: 3,563 Bytes
c59d808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from typing import Optional, List
from backend.utils.sanitization import sanitize_user_input
from pydantic import BaseModel, Field, validator
from enum import Enum


class MessagePartType(str, Enum):
    """Enumeration of supported message part types."""
    TEXT = "text"
    STEP_START = "step-start"

class MessageRole(str, Enum):
    """Enumeration of supported message roles."""
    USER = "user"
    ASSISTANT = "assistant"
    SYSTEM = "system"   
    
class MessagePart(BaseModel):
    """Represents a single part of a message (text etc.)."""
    type: MessagePartType
    text: Optional[str] = Field(None, description="Text content for text-type parts")

    @validator('text')
    def validate_text_content(cls, v, values):
        """Ensure text parts have non-empty text content."""
        if values.get('type') == MessagePartType.TEXT and (not v or not v.strip()):
            raise ValueError("Text parts must have non-empty text content")
        return v

    class Config:
        use_enum_values = True
     

class Message(BaseModel):
    """Represents a single message in a conversation."""
    id: str = Field(..., description="Unique identifier for the message part")
    role: MessageRole
    parts: List[MessagePart] = Field(..., min_items=1, description="Message content parts")
    
    @validator('parts')
    def validate_parts_not_empty(cls, v):
        """Ensure message has at least one valid part."""
        if not v:
            raise ValueError("Message must contain at least one part")
        return v

    @validator('id')
    def validate_id_format(cls, v):
        """Ensure ID is non-empty string."""
        if not v or not v.strip():
            raise ValueError("Message ID cannot be empty")
        return v.strip()

    class Config:
        use_enum_values = True    

class ChatMessage(BaseModel):
    """Represents a complete chat conversation."""
    id: str = Field(..., description="Unique identifier for the chat")
    messages: List[Message] = Field(..., min_items=1, description="List of messages in the conversation")
    trigger: Optional[str] = Field(None, description="Optional trigger that initiated the chat")

    @validator('id')
    def validate_chat_id(cls, v):
        """Ensure chat ID is non-empty string."""
        if not v or not v.strip():
            raise ValueError("Chat ID cannot be empty")
        return v.strip()

    @validator('messages')
    def validate_messages_structure(cls, v):
        """Comprehensive validation of messages list."""
        if not v:
            raise ValueError("Chat must contain at least one message")
        
        # Check for duplicate message IDs
        message_ids = [msg.id for msg in v]
        if len(message_ids) != len(set(message_ids)):
            raise ValueError("All message IDs must be unique within a chat")
        
        # Ensure at least one user message exists
        user_messages = [msg for msg in v if msg.role == MessageRole.USER]
        if not user_messages:
            raise ValueError("Chat must contain at least one user message")
        return v

    def get_latest_message(self) -> Optional[Message]:
        """Get the most recent sanitized message in the chat."""
        latest = self.messages[-1] if self.messages else None
        if latest and latest.parts:
            for part in latest.parts:
                if part.type == MessagePartType.TEXT and part.text:
                    part.text = sanitize_user_input(part.text)
        return latest

    class Config:
        use_enum_values = True