import nh3 from constants import ( MAX_COMMENT_LENGTH, MAX_FILE_NAME_LENGTH, MAX_ID_LENGTH, MAX_MESSAGE_LENGTH, MAX_RESPONSE_LENGTH, ) from pydantic import BaseModel, Field, field_validator from typing import Literal, Set class IdentifierBase(BaseModel): user_id: str = Field( pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH ) # Participant ID could be in ProfileBase instead. It doesn't really matter. participant_id: str = Field( pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH ) session_id: str = Field( pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH ) class ProfileBase(BaseModel): consent: bool age_group: Literal["0-18", "18-24", "25-34", "35-44", "45-54", "55-64", "65+"] gender: Literal["M", "F"] roles: Set[ Literal["patient", "clinician", "computer-scientist", "researcher", "other"] ] = Field(min_length=1, max_length=5) class ChatRequest(IdentifierBase, ProfileBase): conversation_id: str = Field( pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH ) model_type: Literal["champ", "openai", "google-conservative", "google-creative"] lang: Literal["en", "fr"] human_message: str = Field(min_length=1, max_length=MAX_MESSAGE_LENGTH) @field_validator("human_message") def sanitize_human_message(cls, human_message: str): """Remove HTML tags to prevent XSS""" return nh3.clean(human_message) class FeedbackRequest(IdentifierBase, ProfileBase): message_index: int = Field(ge=0, le=10_000) rating: Literal["like", "dislike", "mixed"] comment: str = Field(min_length=0, max_length=MAX_COMMENT_LENGTH) reply_content: str = Field(min_length=1, max_length=MAX_RESPONSE_LENGTH) @field_validator("comment") def sanitize_comment(cls, comment: str): """Remove HTML tags to prevent XSS""" return nh3.clean(comment) @field_validator("reply_content") def sanitize_reply_content(cls, reply_content: str): """Remove HTML tags to prevent XSS""" return nh3.clean(reply_content) class CommentRequest(IdentifierBase, ProfileBase): comment: str = Field(min_length=1, max_length=MAX_COMMENT_LENGTH) @field_validator("comment") def sanitize_comment(cls, comment: str): """Remove HTML tags to prevent XSS""" return nh3.clean(comment) class DeleteFileRequest(IdentifierBase, ProfileBase): file_name: str = Field( # Pattern: Allows letters, numbers, -, _, spaces, and dots (but no double dots or starting dots or spaces) pattern=r"^[a-zA-Z0-9_()-][a-zA-Z0-9\s_()-]*(\.[a-zA-Z0-9\s_-]+)*$", min_length=1, max_length=MAX_FILE_NAME_LENGTH, ) class ClearConversationRequest(BaseModel): old_session_id: str = Field( pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH ) new_session_id: str = Field( pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH ) class ChatMessage(BaseModel): role: Literal["user", "assistant", "system"] content: str