Spaces:
Paused
Paused
| import nh3 | |
| from constants import ( | |
| MAX_COMMENT_LENGTH, | |
| MAX_FILE_NAME_LENGTH, | |
| MAX_ID_LENGTH, | |
| MAX_MESSAGE_LENGTH, | |
| MAX_RESPONSE_LENGTH, | |
| ) | |
| from pydantic import BaseModel, Field, field_validator | |
| from typing import Literal, Set | |
| class IdentifierBase(BaseModel): | |
| user_id: str = Field( | |
| pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH | |
| ) | |
| # Participant ID could be in ProfileBase instead. It doesn't really matter. | |
| participant_id: str = Field( | |
| pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH | |
| ) | |
| session_id: str = Field( | |
| pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH | |
| ) | |
| class ProfileBase(BaseModel): | |
| consent: bool | |
| age_group: Literal["0-18", "18-24", "25-34", "35-44", "45-54", "55-64", "65+"] | |
| gender: Literal["M", "F"] | |
| roles: Set[ | |
| Literal["patient", "clinician", "computer-scientist", "researcher", "other"] | |
| ] = Field(min_length=1, max_length=5) | |
| class ChatRequest(IdentifierBase, ProfileBase): | |
| conversation_id: str = Field( | |
| pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH | |
| ) | |
| model_type: Literal["champ", "openai", "google-conservative", "google-creative"] | |
| lang: Literal["en", "fr"] | |
| human_message: str = Field(min_length=1, max_length=MAX_MESSAGE_LENGTH) | |
| def sanitize_human_message(cls, human_message: str): | |
| """Remove HTML tags to prevent XSS""" | |
| return nh3.clean(human_message) | |
| class FeedbackRequest(IdentifierBase, ProfileBase): | |
| message_index: int = Field(ge=0, le=10_000) | |
| rating: Literal["like", "dislike", "mixed"] | |
| comment: str = Field(min_length=0, max_length=MAX_COMMENT_LENGTH) | |
| reply_content: str = Field(min_length=1, max_length=MAX_RESPONSE_LENGTH) | |
| def sanitize_comment(cls, comment: str): | |
| """Remove HTML tags to prevent XSS""" | |
| return nh3.clean(comment) | |
| def sanitize_reply_content(cls, reply_content: str): | |
| """Remove HTML tags to prevent XSS""" | |
| return nh3.clean(reply_content) | |
| class CommentRequest(IdentifierBase, ProfileBase): | |
| comment: str = Field(min_length=1, max_length=MAX_COMMENT_LENGTH) | |
| def sanitize_comment(cls, comment: str): | |
| """Remove HTML tags to prevent XSS""" | |
| return nh3.clean(comment) | |
| class DeleteFileRequest(IdentifierBase, ProfileBase): | |
| file_name: str = Field( | |
| # Pattern: Allows letters, numbers, -, _, spaces, and dots (but no double dots or starting dots or spaces) | |
| pattern=r"^[a-zA-Z0-9_()-][a-zA-Z0-9\s_()-]*(\.[a-zA-Z0-9\s_-]+)*$", | |
| min_length=1, | |
| max_length=MAX_FILE_NAME_LENGTH, | |
| ) | |
| class ClearConversationRequest(BaseModel): | |
| old_session_id: str = Field( | |
| pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH | |
| ) | |
| new_session_id: str = Field( | |
| pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH | |
| ) | |
| class ChatMessage(BaseModel): | |
| role: Literal["user", "assistant", "system"] | |
| content: str | |