File size: 3,263 Bytes
f95a1f1
 
40d5e03
 
 
 
 
eebe76e
40d5e03
f95a1f1
389d6f7
f95a1f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40d5e03
 
 
 
 
 
f95a1f1
 
eebe76e
8b9e569
eebe76e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f95a1f1
40d5e03
f95a1f1
 
 
 
 
 
 
 
 
 
8b9e569
f95a1f1
 
 
 
 
 
 
 
 
 
 
 
40d5e03
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import nh3

from constants import (
    MAX_COMMENT_LENGTH,
    MAX_FILE_NAME_LENGTH,
    MAX_ID_LENGTH,
    MAX_MESSAGE_LENGTH,
    MAX_RESPONSE_LENGTH,
)
from pydantic import BaseModel, Field, field_validator
from typing import Literal, Set


class IdentifierBase(BaseModel):
    user_id: str = Field(
        pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
    )
    # Participant ID could be in ProfileBase instead. It doesn't really matter.
    participant_id: str = Field(
        pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
    )
    session_id: str = Field(
        pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
    )


class ProfileBase(BaseModel):
    consent: bool
    age_group: Literal["0-18", "18-24", "25-34", "35-44", "45-54", "55-64", "65+"]
    gender: Literal["M", "F"]
    roles: Set[
        Literal["patient", "clinician", "computer-scientist", "researcher", "other"]
    ] = Field(min_length=1, max_length=5)


class ChatRequest(IdentifierBase, ProfileBase):
    conversation_id: str = Field(
        pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
    )
    model_type: Literal["champ", "openai", "google-conservative", "google-creative"]
    lang: Literal["en", "fr"]
    human_message: str = Field(min_length=1, max_length=MAX_MESSAGE_LENGTH)

    @field_validator("human_message")
    def sanitize_human_message(cls, human_message: str):
        """Remove HTML tags to prevent XSS"""
        return nh3.clean(human_message)


class FeedbackRequest(IdentifierBase, ProfileBase):
    message_index: int = Field(ge=0, le=10_000)
    rating: Literal["like", "dislike", "mixed"]
    comment: str = Field(min_length=0, max_length=MAX_COMMENT_LENGTH)
    reply_content: str = Field(min_length=1, max_length=MAX_RESPONSE_LENGTH)

    @field_validator("comment")
    def sanitize_comment(cls, comment: str):
        """Remove HTML tags to prevent XSS"""
        return nh3.clean(comment)

    @field_validator("reply_content")
    def sanitize_reply_content(cls, reply_content: str):
        """Remove HTML tags to prevent XSS"""
        return nh3.clean(reply_content)


class CommentRequest(IdentifierBase, ProfileBase):
    comment: str = Field(min_length=1, max_length=MAX_COMMENT_LENGTH)

    @field_validator("comment")
    def sanitize_comment(cls, comment: str):
        """Remove HTML tags to prevent XSS"""
        return nh3.clean(comment)


class DeleteFileRequest(IdentifierBase, ProfileBase):
    file_name: str = Field(
        # Pattern: Allows letters, numbers, -, _, spaces, and dots (but no double dots or starting dots or spaces)
        pattern=r"^[a-zA-Z0-9_()-][a-zA-Z0-9\s_()-]*(\.[a-zA-Z0-9\s_-]+)*$",
        min_length=1,
        max_length=MAX_FILE_NAME_LENGTH,
    )


class ClearConversationRequest(BaseModel):
    old_session_id: str = Field(
        pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
    )
    new_session_id: str = Field(
        pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
    )


class ChatMessage(BaseModel):
    role: Literal["user", "assistant", "system"]
    content: str