| | import nh3
|
| |
|
| | from constants import (
|
| | MAX_COMMENT_LENGTH,
|
| | MAX_FILE_NAME_LENGTH,
|
| | MAX_ID_LENGTH,
|
| | MAX_MESSAGE_LENGTH,
|
| | MAX_RESPONSE_LENGTH,
|
| | )
|
| | from pydantic import BaseModel, Field, field_validator
|
| | from typing import Literal, Set
|
| |
|
| |
|
| | class IdentifierBase(BaseModel):
|
| | user_id: str = Field(
|
| | pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
|
| | )
|
| |
|
| | participant_id: str = Field(
|
| | pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
|
| | )
|
| | session_id: str = Field(
|
| | pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
|
| | )
|
| |
|
| |
|
| | class ProfileBase(BaseModel):
|
| | consent: bool
|
| | age_group: Literal["0-18", "18-24", "25-34", "35-44", "45-54", "55-64", "65+"]
|
| | gender: Literal["M", "F"]
|
| | roles: Set[
|
| | Literal["patient", "clinician", "computer-scientist", "researcher", "other"]
|
| | ] = Field(min_length=1, max_length=5)
|
| |
|
| |
|
| | class ChatRequest(IdentifierBase, ProfileBase):
|
| | conversation_id: str = Field(
|
| | pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
|
| | )
|
| | model_type: Literal["champ", "openai", "google-conservative", "google-creative"]
|
| | lang: Literal["en", "fr"]
|
| | human_message: str = Field(min_length=1, max_length=MAX_MESSAGE_LENGTH)
|
| |
|
| | @field_validator("human_message")
|
| | def sanitize_human_message(cls, human_message: str):
|
| | """Remove HTML tags to prevent XSS"""
|
| | return nh3.clean(human_message)
|
| |
|
| |
|
| | class FeedbackRequest(IdentifierBase, ProfileBase):
|
| | message_index: int = Field(ge=0, le=10_000)
|
| | rating: Literal["like", "dislike", "mixed"]
|
| | comment: str = Field(min_length=0, max_length=MAX_COMMENT_LENGTH)
|
| | reply_content: str = Field(min_length=1, max_length=MAX_RESPONSE_LENGTH)
|
| |
|
| | @field_validator("comment")
|
| | def sanitize_comment(cls, comment: str):
|
| | """Remove HTML tags to prevent XSS"""
|
| | return nh3.clean(comment)
|
| |
|
| | @field_validator("reply_content")
|
| | def sanitize_reply_content(cls, reply_content: str):
|
| | """Remove HTML tags to prevent XSS"""
|
| | return nh3.clean(reply_content)
|
| |
|
| |
|
| | class CommentRequest(IdentifierBase, ProfileBase):
|
| | comment: str = Field(min_length=1, max_length=MAX_COMMENT_LENGTH)
|
| |
|
| | @field_validator("comment")
|
| | def sanitize_comment(cls, comment: str):
|
| | """Remove HTML tags to prevent XSS"""
|
| | return nh3.clean(comment)
|
| |
|
| |
|
| | class DeleteFileRequest(IdentifierBase, ProfileBase):
|
| | file_name: str = Field(
|
| |
|
| | pattern=r"^[a-zA-Z0-9_()-][a-zA-Z0-9\s_()-]*(\.[a-zA-Z0-9\s_-]+)*$",
|
| | min_length=1,
|
| | max_length=MAX_FILE_NAME_LENGTH,
|
| | )
|
| |
|
| |
|
| | class ClearConversationRequest(BaseModel):
|
| | old_session_id: str = Field(
|
| | pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
|
| | )
|
| | new_session_id: str = Field(
|
| | pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
|
| | )
|
| |
|
| |
|
| | class ChatMessage(BaseModel):
|
| | role: Literal["user", "assistant", "system"]
|
| | content: str
|
| |
|