Spaces:
Paused
Paused
File size: 3,263 Bytes
f95a1f1 40d5e03 eebe76e 40d5e03 f95a1f1 389d6f7 f95a1f1 40d5e03 f95a1f1 eebe76e 8b9e569 eebe76e f95a1f1 40d5e03 f95a1f1 8b9e569 f95a1f1 40d5e03 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | import nh3
from constants import (
MAX_COMMENT_LENGTH,
MAX_FILE_NAME_LENGTH,
MAX_ID_LENGTH,
MAX_MESSAGE_LENGTH,
MAX_RESPONSE_LENGTH,
)
from pydantic import BaseModel, Field, field_validator
from typing import Literal, Set
class IdentifierBase(BaseModel):
user_id: str = Field(
pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
)
# Participant ID could be in ProfileBase instead. It doesn't really matter.
participant_id: str = Field(
pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
)
session_id: str = Field(
pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
)
class ProfileBase(BaseModel):
consent: bool
age_group: Literal["0-18", "18-24", "25-34", "35-44", "45-54", "55-64", "65+"]
gender: Literal["M", "F"]
roles: Set[
Literal["patient", "clinician", "computer-scientist", "researcher", "other"]
] = Field(min_length=1, max_length=5)
class ChatRequest(IdentifierBase, ProfileBase):
conversation_id: str = Field(
pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
)
model_type: Literal["champ", "openai", "google-conservative", "google-creative"]
lang: Literal["en", "fr"]
human_message: str = Field(min_length=1, max_length=MAX_MESSAGE_LENGTH)
@field_validator("human_message")
def sanitize_human_message(cls, human_message: str):
"""Remove HTML tags to prevent XSS"""
return nh3.clean(human_message)
class FeedbackRequest(IdentifierBase, ProfileBase):
message_index: int = Field(ge=0, le=10_000)
rating: Literal["like", "dislike", "mixed"]
comment: str = Field(min_length=0, max_length=MAX_COMMENT_LENGTH)
reply_content: str = Field(min_length=1, max_length=MAX_RESPONSE_LENGTH)
@field_validator("comment")
def sanitize_comment(cls, comment: str):
"""Remove HTML tags to prevent XSS"""
return nh3.clean(comment)
@field_validator("reply_content")
def sanitize_reply_content(cls, reply_content: str):
"""Remove HTML tags to prevent XSS"""
return nh3.clean(reply_content)
class CommentRequest(IdentifierBase, ProfileBase):
comment: str = Field(min_length=1, max_length=MAX_COMMENT_LENGTH)
@field_validator("comment")
def sanitize_comment(cls, comment: str):
"""Remove HTML tags to prevent XSS"""
return nh3.clean(comment)
class DeleteFileRequest(IdentifierBase, ProfileBase):
file_name: str = Field(
# Pattern: Allows letters, numbers, -, _, spaces, and dots (but no double dots or starting dots or spaces)
pattern=r"^[a-zA-Z0-9_()-][a-zA-Z0-9\s_()-]*(\.[a-zA-Z0-9\s_-]+)*$",
min_length=1,
max_length=MAX_FILE_NAME_LENGTH,
)
class ClearConversationRequest(BaseModel):
old_session_id: str = Field(
pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
)
new_session_id: str = Field(
pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
)
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system"]
content: str
|