initial gitlab files
Browse files- .dockerignore +17 -0
- .gitignore +4 -0
- Dockerfile +11 -0
- README.md +53 -10
- conftest.py +0 -0
- docker-compose.yml +23 -0
- dynamodb_helper.py +139 -0
- main.py +318 -0
- rag_data/netg_baaibge_chunks_v1.pkl +3 -0
- requirements.txt +25 -0
- static/app.js +162 -0
- static/style.css +227 -0
- templates/index.html +94 -0
- tests/test_dynamodb.py +48 -0
- tests/test_main.py +37 -0
.dockerignore
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ignore the tests folder
|
| 2 |
+
tests/
|
| 3 |
+
|
| 4 |
+
# Ignore development and version control files
|
| 5 |
+
.git
|
| 6 |
+
.gitignore
|
| 7 |
+
.dockerignore
|
| 8 |
+
.vscode
|
| 9 |
+
.idea
|
| 10 |
+
|
| 11 |
+
# Ignore Python specific files and directories
|
| 12 |
+
__pycache__
|
| 13 |
+
*.pyc
|
| 14 |
+
.venv
|
| 15 |
+
venv
|
| 16 |
+
*.log
|
| 17 |
+
.DS_Store
|
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.DS_Store
|
| 2 |
+
__pycache__/
|
| 3 |
+
.venv/
|
| 4 |
+
venv/
|
Dockerfile
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
COPY requirements.txt .
|
| 6 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 7 |
+
|
| 8 |
+
COPY . .
|
| 9 |
+
|
| 10 |
+
ENV PORT=8000
|
| 11 |
+
CMD uvicorn main:app --host 0.0.0.0 --port $PORT
|
README.md
CHANGED
|
@@ -1,10 +1,53 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MARVIN WebUI Demo
|
| 2 |
+
|
| 3 |
+
A lightweight chat interface powered by the MARVIN model, designed for easy deployment and testing. This project can run locally with Docker Compose or be deployed to platforms like Hugging Face Spaces.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
- Simple and fast chat interface
|
| 8 |
+
- Python backend with async endpoints
|
| 9 |
+
- Frontend served through a minimal web app
|
| 10 |
+
- Containerized so can start everything with a single command
|
| 11 |
+
|
| 12 |
+
## Requirements
|
| 13 |
+
|
| 14 |
+
- Docker
|
| 15 |
+
- Docker Compose
|
| 16 |
+
|
| 17 |
+
## Local Development
|
| 18 |
+
|
| 19 |
+
### Start the project
|
| 20 |
+
|
| 21 |
+
From the project root:
|
| 22 |
+
|
| 23 |
+
```
|
| 24 |
+
docker compose up --build
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
This starts:
|
| 28 |
+
|
| 29 |
+
1. Backend service
|
| 30 |
+
2. Frontend service
|
| 31 |
+
3. Database service
|
| 32 |
+
|
| 33 |
+
Once everything is ready, open:
|
| 34 |
+
|
| 35 |
+
```
|
| 36 |
+
http://localhost:8000
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
### Stopping the project
|
| 40 |
+
|
| 41 |
+
Use:
|
| 42 |
+
|
| 43 |
+
```
|
| 44 |
+
docker compose down
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
### Rebuilding after code changes
|
| 48 |
+
|
| 49 |
+
Use:
|
| 50 |
+
|
| 51 |
+
```
|
| 52 |
+
docker compose up --build
|
| 53 |
+
```
|
conftest.py
ADDED
|
File without changes
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
services:
|
| 2 |
+
api:
|
| 3 |
+
build: .
|
| 4 |
+
container_name: chat_api
|
| 5 |
+
ports:
|
| 6 |
+
- '8000:8000'
|
| 7 |
+
env_file:
|
| 8 |
+
- .env
|
| 9 |
+
# depends_on:
|
| 10 |
+
# - dynamodb-local
|
| 11 |
+
volumes:
|
| 12 |
+
- ./:/app
|
| 13 |
+
command: uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
| 14 |
+
# dynamodb-local:
|
| 15 |
+
# image: amazon/dynamodb-local
|
| 16 |
+
# ports:
|
| 17 |
+
# - '9000:8000'
|
| 18 |
+
# command: '-jar DynamoDBLocal.jar -inMemory -sharedDb'
|
| 19 |
+
# volumes:
|
| 20 |
+
# - dynamodb_data:/data
|
| 21 |
+
|
| 22 |
+
# volumes:
|
| 23 |
+
# dynamodb_data:
|
dynamodb_helper.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import boto3
|
| 4 |
+
from boto3.dynamodb.types import TypeDeserializer, TypeSerializer
|
| 5 |
+
from botocore.exceptions import ClientError
|
| 6 |
+
from datetime import datetime, timezone
|
| 7 |
+
from uuid import uuid4
|
| 8 |
+
from decimal import Decimal
|
| 9 |
+
|
| 10 |
+
AWS_REGION = os.getenv("AWS_REGION", "us-east-1")
|
| 11 |
+
AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY", None)
|
| 12 |
+
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY", None)
|
| 13 |
+
DYNAMODB_ENDPOINT = os.getenv("DYNAMODB_ENDPOINT", None)
|
| 14 |
+
DDB_TABLE = os.getenv("DDB_TABLE", "chatbot-conversations")
|
| 15 |
+
USE_LOCAL_DDB = os.getenv("USE_LOCAL_DDB", "false").lower() == "true"
|
| 16 |
+
|
| 17 |
+
def get_dynamodb_client():
|
| 18 |
+
if USE_LOCAL_DDB: # only for local testing with DynamoDB Local
|
| 19 |
+
return boto3.resource(
|
| 20 |
+
"dynamodb",
|
| 21 |
+
endpoint_url=DYNAMODB_ENDPOINT,
|
| 22 |
+
region_name=AWS_REGION,
|
| 23 |
+
aws_access_key_id="fake",
|
| 24 |
+
aws_secret_access_key="fake"
|
| 25 |
+
)
|
| 26 |
+
else: # production AWS DynamoDB
|
| 27 |
+
return boto3.resource("dynamodb",
|
| 28 |
+
region_name=AWS_REGION,
|
| 29 |
+
aws_access_key_id=AWS_ACCESS_KEY,
|
| 30 |
+
aws_secret_access_key=AWS_SECRET_ACCESS_KEY
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
dynamodb = get_dynamodb_client()
|
| 34 |
+
table = None
|
| 35 |
+
|
| 36 |
+
def create_table_if_not_exists(dynamodb):
|
| 37 |
+
global table
|
| 38 |
+
client = dynamodb.meta.client
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
existing_tables = client.list_tables()["TableNames"]
|
| 42 |
+
except Exception as e:
|
| 43 |
+
print("Cannot list tables:", e)
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
if DDB_TABLE in existing_tables:
|
| 47 |
+
print(f"Table {DDB_TABLE} already exists.")
|
| 48 |
+
table = dynamodb.Table(DDB_TABLE)
|
| 49 |
+
return table
|
| 50 |
+
|
| 51 |
+
print(f"Creating DynamoDB table {DDB_TABLE}...")
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
table = dynamodb.create_table(
|
| 55 |
+
TableName=DDB_TABLE,
|
| 56 |
+
KeySchema=[
|
| 57 |
+
{"AttributeName": "PK", "KeyType": "HASH"},
|
| 58 |
+
{"AttributeName": "SK", "KeyType": "RANGE"}
|
| 59 |
+
],
|
| 60 |
+
AttributeDefinitions=[
|
| 61 |
+
{"AttributeName": "PK", "AttributeType": "S"},
|
| 62 |
+
{"AttributeName": "SK", "AttributeType": "S"},
|
| 63 |
+
{"AttributeName": "GSI1_PK", "AttributeType": "S"},
|
| 64 |
+
{"AttributeName": "GSI1_SK", "AttributeType": "S"}
|
| 65 |
+
],
|
| 66 |
+
GlobalSecondaryIndexes=[
|
| 67 |
+
{
|
| 68 |
+
"IndexName": "GSI1",
|
| 69 |
+
"KeySchema": [
|
| 70 |
+
{"AttributeName": "GSI1_PK", "KeyType": "HASH"},
|
| 71 |
+
{"AttributeName": "GSI1_SK", "KeyType": "RANGE"}
|
| 72 |
+
],
|
| 73 |
+
"Projection": {"ProjectionType": "ALL"},
|
| 74 |
+
"ProvisionedThroughput": {
|
| 75 |
+
"ReadCapacityUnits": 5,
|
| 76 |
+
"WriteCapacityUnits": 5
|
| 77 |
+
},
|
| 78 |
+
}
|
| 79 |
+
],
|
| 80 |
+
BillingMode='PAY_PER_REQUEST'
|
| 81 |
+
# ProvisionedThroughput={
|
| 82 |
+
# "ReadCapacityUnits": 5,
|
| 83 |
+
# "WriteCapacityUnits": 5
|
| 84 |
+
# }
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
table.wait_until_exists()
|
| 88 |
+
print(f"Table {DDB_TABLE} created.")
|
| 89 |
+
return table
|
| 90 |
+
|
| 91 |
+
except ClientError as e:
|
| 92 |
+
print("Error creating table:", e.response["Error"]["Message"])
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def iso_ts():
|
| 97 |
+
# Return the current timestamp in ISO 8601 format
|
| 98 |
+
return datetime.now(timezone.utc).isoformat()
|
| 99 |
+
|
| 100 |
+
table = create_table_if_not_exists(dynamodb)
|
| 101 |
+
|
| 102 |
+
def convert_floats(obj):
|
| 103 |
+
if isinstance(obj, float):
|
| 104 |
+
return Decimal(str(obj))
|
| 105 |
+
elif isinstance(obj, dict):
|
| 106 |
+
return {k: convert_floats(v) for k, v in obj.items()}
|
| 107 |
+
elif isinstance(obj, list):
|
| 108 |
+
return [convert_floats(v) for v in obj]
|
| 109 |
+
else:
|
| 110 |
+
return obj
|
| 111 |
+
|
| 112 |
+
def log_event(user_id, session_id, data):
|
| 113 |
+
"""
|
| 114 |
+
Log conversation data to DynamoDB table.
|
| 115 |
+
:param user_id: ID of the user
|
| 116 |
+
:param session_id: ID of the session
|
| 117 |
+
:param data: Dictionary containing conversation data
|
| 118 |
+
"""
|
| 119 |
+
global table
|
| 120 |
+
if table is None:
|
| 121 |
+
print("Table not initialized. Skipping log.")
|
| 122 |
+
return
|
| 123 |
+
|
| 124 |
+
ts = iso_ts()
|
| 125 |
+
item = {
|
| 126 |
+
"PK": f"SESSION#{session_id}",
|
| 127 |
+
"SK": f"TS#{ts}#{uuid4().hex}",
|
| 128 |
+
'user_id': user_id,
|
| 129 |
+
"GSI1_PK": f"USER#{user_id}",
|
| 130 |
+
"GSI1_SK": f"TS#{ts}",
|
| 131 |
+
'session_id': session_id,
|
| 132 |
+
'timestamp': ts,
|
| 133 |
+
'data': convert_floats(data)
|
| 134 |
+
}
|
| 135 |
+
print(f"Logging conversation: {item}")
|
| 136 |
+
try:
|
| 137 |
+
table.put_item(Item=item)
|
| 138 |
+
except ClientError as e:
|
| 139 |
+
print(f"Error logging conversation: {e.response['Error']['Message']}")
|
main.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import faiss
|
| 4 |
+
import asyncio
|
| 5 |
+
from contextlib import asynccontextmanager
|
| 6 |
+
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
from typing import List, Literal, Optional
|
| 10 |
+
from datetime import datetime, timezone
|
| 11 |
+
|
| 12 |
+
from dotenv import load_dotenv
|
| 13 |
+
load_dotenv()
|
| 14 |
+
from fastapi import FastAPI, Request
|
| 15 |
+
from fastapi.responses import HTMLResponse, JSONResponse
|
| 16 |
+
from fastapi.staticfiles import StaticFiles
|
| 17 |
+
from fastapi.templating import Jinja2Templates
|
| 18 |
+
|
| 19 |
+
from pydantic import BaseModel
|
| 20 |
+
from dynamodb_helper import log_event
|
| 21 |
+
from fastapi import BackgroundTasks
|
| 22 |
+
|
| 23 |
+
from huggingface_hub import InferenceClient
|
| 24 |
+
|
| 25 |
+
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
|
| 26 |
+
from langchain.agents import create_agent
|
| 27 |
+
from langchain.agents.middleware import dynamic_prompt, ModelRequest
|
| 28 |
+
from langchain_community.docstore.in_memory import InMemoryDocstore
|
| 29 |
+
from langchain_community.vectorstores import FAISS
|
| 30 |
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
| 31 |
+
|
| 32 |
+
# -------------------- Config --------------------
|
| 33 |
+
BASE_DIR = Path(__file__).resolve().parent
|
| 34 |
+
|
| 35 |
+
MODEL_MAP = {
|
| 36 |
+
"champ": "champ-model/placeholder",
|
| 37 |
+
"openai": "openai/gpt-oss-20b",
|
| 38 |
+
"google": "google/gemma-2-2b-it"
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HF_API_TOKEN")
|
| 42 |
+
if HF_TOKEN is None:
|
| 43 |
+
raise RuntimeError(
|
| 44 |
+
"HF_TOKEN or HF_API_TOKEN is not set. "
|
| 45 |
+
"Go to Space → Settings → Variables & secrets and add one."
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
hf_client = InferenceClient(token=HF_TOKEN)
|
| 49 |
+
|
| 50 |
+
# Max history messages to keep for context
|
| 51 |
+
MAX_HISTORY = 20
|
| 52 |
+
|
| 53 |
+
# -------------------- Prompts --------------------
|
| 54 |
+
DEFAULT_SYSTEM_PROMPT = (
|
| 55 |
+
"Answer clearly and concisely. You are a helpful assistant. If you do not know the answer, just say you don't know. "
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
CHAMP_SYSTEM_PROMPT = (
|
| 59 |
+
"""
|
| 60 |
+
# CONTEXT #
|
| 61 |
+
You are *CHAMP*, a knowledgeable and compassionate pediatrician chatting online with adolescent patients, their families, or their caregivers. Children and adolescents commonly experience infectious illnesses (for example: fever, cough, vomiting, diarrhea). Timely access to credible information can support safe self-management at home and may reduce unnecessary non-emergency ED visits, helping to lower overcrowding and improve the care experience at home.
|
| 62 |
+
|
| 63 |
+
#########
|
| 64 |
+
|
| 65 |
+
# OBJECTIVE #
|
| 66 |
+
Your task is to answer questions about common pediatric infectious diseases asked by the adolescent patient, their family, or their caregiver. Base your answers only on the background material provided. If the relevant information is not clearly present in that material, reply with: "I don't know." Do not invent or guess information.
|
| 67 |
+
|
| 68 |
+
#########
|
| 69 |
+
|
| 70 |
+
# STYLE #
|
| 71 |
+
Provide concise, accurate, and actionable information to help them manage these conditions at home when it is safe to do so. Focus on clear next steps and practical advice that help them make informed decisions. Do not exceed four sentences per response.
|
| 72 |
+
|
| 73 |
+
#########
|
| 74 |
+
|
| 75 |
+
# TONE #
|
| 76 |
+
Maintain a positive, empathetic, and supportive tone throughout, to reduce the questioners worry and help them feel heard. Your responses should feel warm and reassuring, while still reflecting professionalism and seriousness.
|
| 77 |
+
|
| 78 |
+
# AUDIENCE #
|
| 79 |
+
Your audience is adolescent patients, their families, or their caregivers. They are seeking practical advice and concrete actions they can take for disease self-management. Write at approximately a sixth-grade reading level, avoiding medical jargon or explaining it briefly when needed.
|
| 80 |
+
|
| 81 |
+
#########
|
| 82 |
+
|
| 83 |
+
# RESPONSE FORMAT #
|
| 84 |
+
Respond in three to four sentences, as if chatting in a Facebook Messenger conversation. Do not include references, citations, or mention specific document locations in your answer.
|
| 85 |
+
|
| 86 |
+
#############
|
| 87 |
+
|
| 88 |
+
# START ANALYSIS #
|
| 89 |
+
|
| 90 |
+
Here is the user question: {last_query}
|
| 91 |
+
|
| 92 |
+
Here are the materials you must rely on for your answers: {context}
|
| 93 |
+
|
| 94 |
+
Now, step by step, you can start answering the user’s question.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
)
|
| 98 |
+
###TODO: And here is the conversation history so far : {history}
|
| 99 |
+
|
| 100 |
+
class ChatMessage(BaseModel):
|
| 101 |
+
role: Literal["user", "assistant", "system"]
|
| 102 |
+
content: str
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class ChatRequest(BaseModel):
|
| 106 |
+
user_id: str
|
| 107 |
+
session_id: str
|
| 108 |
+
messages: List[ChatMessage]
|
| 109 |
+
temperature: float = 0.7
|
| 110 |
+
model_type: str
|
| 111 |
+
# max_new_tokens: int = 256
|
| 112 |
+
consent: bool = False
|
| 113 |
+
system_prompt: Optional[str] = None
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# -------------------- Helpers --------------------
|
| 117 |
+
def convert_messages(messages: List[ChatMessage]):
|
| 118 |
+
"""
|
| 119 |
+
Convert our internal message format into OpenAI-style messages.
|
| 120 |
+
"""
|
| 121 |
+
sys = DEFAULT_SYSTEM_PROMPT
|
| 122 |
+
out = [{"role": "system", "content": sys}]
|
| 123 |
+
|
| 124 |
+
for m in messages:
|
| 125 |
+
out.append({"role": m.role, "content": m.content})
|
| 126 |
+
return out
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def convert_messages_langchain(messages: List[ChatMessage]):
|
| 130 |
+
"""
|
| 131 |
+
Convert our internal message format into Langchain-style messages.
|
| 132 |
+
"""
|
| 133 |
+
sys = CHAMP_SYSTEM_PROMPT
|
| 134 |
+
list_chatmessages = [SystemMessage(content = sys)]
|
| 135 |
+
|
| 136 |
+
for m in messages[-MAX_HISTORY:]:
|
| 137 |
+
if m.role == "user":
|
| 138 |
+
list_chatmessages.append(HumanMessage(content=m.content))
|
| 139 |
+
elif m.role == "assistant":
|
| 140 |
+
list_chatmessages.append(AIMessage(content=m.content))
|
| 141 |
+
elif m.role == "system":
|
| 142 |
+
list_chatmessages.append(SystemMessage(content=m.content))
|
| 143 |
+
return list_chatmessages
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def call_llm(req: ChatRequest) -> str:
|
| 147 |
+
if req.model_type == "champ":
|
| 148 |
+
return call_champ(req)
|
| 149 |
+
|
| 150 |
+
MODEL_ID = MODEL_MAP.get(req.model_type, MODEL_MAP["champ"])
|
| 151 |
+
msgs = convert_messages(req.messages)
|
| 152 |
+
|
| 153 |
+
try:
|
| 154 |
+
# Call HuggingFace inference API
|
| 155 |
+
resp = hf_client.chat.completions.create(
|
| 156 |
+
model=MODEL_ID,
|
| 157 |
+
messages=msgs,
|
| 158 |
+
# max_tokens=256,
|
| 159 |
+
temperature=req.temperature,
|
| 160 |
+
)
|
| 161 |
+
# Extract chat reply
|
| 162 |
+
return resp.choices[0].message["content"].strip()
|
| 163 |
+
except Exception as e:
|
| 164 |
+
raise RuntimeError(f"Inference API error: {e}")
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def call_champ(req: ChatRequest) -> str:
|
| 168 |
+
msgs = convert_messages_langchain(req.messages)
|
| 169 |
+
# config = {"configurable": {"thread_id": req.user_id}}
|
| 170 |
+
try:
|
| 171 |
+
result = agent_retrievalbased.invoke(
|
| 172 |
+
{"messages": msgs},
|
| 173 |
+
# config=config,
|
| 174 |
+
)
|
| 175 |
+
return result["messages"][-1].text.strip()
|
| 176 |
+
except Exception as e:
|
| 177 |
+
raise RuntimeError(f"CHAMP model error: {e}")
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# def log_event(user_id: str, session_id: str, data: dict):
|
| 181 |
+
# record = {
|
| 182 |
+
# "user_id": user_id,
|
| 183 |
+
# "session_id": session_id,
|
| 184 |
+
# "data": data,
|
| 185 |
+
# "timestamp": datetime.now(timezone.utc)
|
| 186 |
+
# }
|
| 187 |
+
# conversations_collection.insert_one(record)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# -------------------- CHAMP setup --------------------
|
| 191 |
+
# RAG setup
|
| 192 |
+
|
| 193 |
+
def build_vector_store():
|
| 194 |
+
rag_path = BASE_DIR / "rag_data" / "netg_baaibge_chunks_v1.pkl"
|
| 195 |
+
with open(rag_path, 'rb') as f:
|
| 196 |
+
loaded_documents = pickle.load(f)
|
| 197 |
+
print("Chunks loaded successfully.")
|
| 198 |
+
|
| 199 |
+
device = "cpu" # to be update if need GPU
|
| 200 |
+
|
| 201 |
+
model_embedding_name = "BAAI/bge-large-en-v1.5"
|
| 202 |
+
model_embedding_kwargs = {'device': device, "use_auth_token": HF_TOKEN}
|
| 203 |
+
encode_kwargs = {'normalize_embeddings': True}
|
| 204 |
+
|
| 205 |
+
embeddings = HuggingFaceEmbeddings(
|
| 206 |
+
model_name=model_embedding_name,
|
| 207 |
+
model_kwargs=model_embedding_kwargs,
|
| 208 |
+
encode_kwargs=encode_kwargs,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
embedding_dim = len(embeddings.embed_query("hello world"))
|
| 212 |
+
index = faiss.IndexFlatL2(embedding_dim)
|
| 213 |
+
|
| 214 |
+
vector_store = FAISS(
|
| 215 |
+
embedding_function=embeddings,
|
| 216 |
+
index=index,
|
| 217 |
+
docstore=InMemoryDocstore(),
|
| 218 |
+
index_to_docstore_id={},
|
| 219 |
+
)
|
| 220 |
+
vector_store.add_documents(documents=loaded_documents)
|
| 221 |
+
return vector_store
|
| 222 |
+
|
| 223 |
+
def make_prompt_with_context(vector_store: FAISS):
|
| 224 |
+
@dynamic_prompt
|
| 225 |
+
def prompt_with_context(request: ModelRequest) -> str:
|
| 226 |
+
last_query = request.state["messages"][-1].text
|
| 227 |
+
retrieved_docs = vector_store.similarity_search(last_query, k = 3)
|
| 228 |
+
|
| 229 |
+
docs_content = "\n\n".join(doc.page_content for doc in retrieved_docs) if retrieved_docs else ""
|
| 230 |
+
|
| 231 |
+
system_message = CHAMP_SYSTEM_PROMPT.format(
|
| 232 |
+
last_query = last_query,
|
| 233 |
+
context = docs_content
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
return system_message
|
| 237 |
+
|
| 238 |
+
return prompt_with_context
|
| 239 |
+
|
| 240 |
+
def build_champ_agent(vector_store: FAISS):
|
| 241 |
+
hf_llm_champ = HuggingFaceEndpoint(
|
| 242 |
+
repo_id = "openai/gpt-oss-20b",
|
| 243 |
+
task = "text-generation",
|
| 244 |
+
max_new_tokens = 1024,
|
| 245 |
+
# temperature = 0.7,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
model_chat = ChatHuggingFace(llm=hf_llm_champ)
|
| 249 |
+
prompt_middleware = make_prompt_with_context(vector_store)
|
| 250 |
+
agent = create_agent(model_chat, tools=[], middleware=[prompt_middleware]) #checkpointer = InMemorySaver()
|
| 251 |
+
|
| 252 |
+
return agent
|
| 253 |
+
# -------------------- FastAPI setup --------------------
|
| 254 |
+
vector_store: Optional[FAISS] = None
|
| 255 |
+
agent_retrievalbased = None # 给 call_champ 用
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
@asynccontextmanager
|
| 259 |
+
async def lifespan(app: FastAPI):
|
| 260 |
+
global vector_store, agent_retrievalbased
|
| 261 |
+
|
| 262 |
+
loop = asyncio.get_event_loop()
|
| 263 |
+
# 在后台线程执行同步的 build_vector_store
|
| 264 |
+
vector_store = await loop.run_in_executor(
|
| 265 |
+
None, build_vector_store
|
| 266 |
+
)
|
| 267 |
+
agent_retrievalbased = build_champ_agent(vector_store)
|
| 268 |
+
|
| 269 |
+
print("CHAMP RAG + agent initialized.")
|
| 270 |
+
yield
|
| 271 |
+
|
| 272 |
+
app = FastAPI(lifespan=lifespan)
|
| 273 |
+
|
| 274 |
+
app.mount("/static", StaticFiles(directory="static"), name="static")
|
| 275 |
+
templates = Jinja2Templates(directory="templates")
|
| 276 |
+
|
| 277 |
+
# -------------------- Routes --------------------
|
| 278 |
+
|
| 279 |
+
@app.get("/", response_class=HTMLResponse)
|
| 280 |
+
async def home(request: Request):
|
| 281 |
+
return templates.TemplateResponse("index.html", {"request": request})
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
@app.post("/chat")
|
| 285 |
+
async def chat_endpoint(payload: ChatRequest, background_tasks: BackgroundTasks):
|
| 286 |
+
print(f"Received chat request: {payload}")
|
| 287 |
+
if not payload.messages:
|
| 288 |
+
return JSONResponse({"error": "No messages provided"}, status_code=400)
|
| 289 |
+
|
| 290 |
+
try:
|
| 291 |
+
reply = call_llm(payload)
|
| 292 |
+
except Exception as e:
|
| 293 |
+
background_tasks.add_task(
|
| 294 |
+
log_event,
|
| 295 |
+
user_id=payload.user_id,
|
| 296 |
+
session_id=payload.session_id,
|
| 297 |
+
data={
|
| 298 |
+
"error": str(e),
|
| 299 |
+
"model_type": payload.model_type,
|
| 300 |
+
"consent": payload.consent,
|
| 301 |
+
"temperature": payload.temperature,
|
| 302 |
+
"messages": payload.messages[-1].dict() if payload.messages else {},
|
| 303 |
+
}
|
| 304 |
+
)
|
| 305 |
+
return JSONResponse({"error": str(e)}, status_code=500)
|
| 306 |
+
background_tasks.add_task(
|
| 307 |
+
log_event,
|
| 308 |
+
user_id=payload.user_id,
|
| 309 |
+
session_id=payload.session_id,
|
| 310 |
+
data={
|
| 311 |
+
"model_type": payload.model_type,
|
| 312 |
+
"consent": payload.consent,
|
| 313 |
+
"temperature": payload.temperature,
|
| 314 |
+
"messages": payload.messages[-1].dict(),
|
| 315 |
+
"reply": reply,
|
| 316 |
+
}
|
| 317 |
+
)
|
| 318 |
+
return {"reply": reply}
|
rag_data/netg_baaibge_chunks_v1.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7fffd4dbb98d49ee9dc97d1c530c3c3278a624bf19a79854fde0b03af9977c71
|
| 3 |
+
size 205801
|
requirements.txt
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn[standard]
|
| 3 |
+
|
| 4 |
+
jinja2
|
| 5 |
+
python-multipart
|
| 6 |
+
|
| 7 |
+
requests
|
| 8 |
+
|
| 9 |
+
python-dotenv
|
| 10 |
+
|
| 11 |
+
huggingface_hub
|
| 12 |
+
sentence-transformers
|
| 13 |
+
|
| 14 |
+
pydantic
|
| 15 |
+
pymongo
|
| 16 |
+
|
| 17 |
+
faiss-cpu
|
| 18 |
+
|
| 19 |
+
langchain
|
| 20 |
+
langchain-core
|
| 21 |
+
langchain-community
|
| 22 |
+
langchain-huggingface
|
| 23 |
+
|
| 24 |
+
boto3
|
| 25 |
+
botocore
|
static/app.js
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
const chatWindow = document.getElementById('chatWindow');
|
| 2 |
+
const userInput = document.getElementById('userInput');
|
| 3 |
+
const sendBtn = document.getElementById('sendBtn');
|
| 4 |
+
const statusEl = document.getElementById('status');
|
| 5 |
+
|
| 6 |
+
const systemPresetSelect = document.getElementById('systemPreset');
|
| 7 |
+
const tempSlider = document.getElementById('tempSlider');
|
| 8 |
+
const tempValue = document.getElementById('tempValue');
|
| 9 |
+
// const maxTokensSlider = document.getElementById("maxTokensSlider");
|
| 10 |
+
// const maxTokensValue = document.getElementById("maxTokensValue");
|
| 11 |
+
const clearBtn = document.getElementById('clearBtn');
|
| 12 |
+
|
| 13 |
+
const consentOverlay = document.getElementById('consentOverlay');
|
| 14 |
+
const consentCheckbox = document.getElementById('consentCheckbox');
|
| 15 |
+
const consentBtn = document.getElementById('consentBtn');
|
| 16 |
+
|
| 17 |
+
// Local in-browser chat history
|
| 18 |
+
let messages = [];
|
| 19 |
+
let consentGranted = false;
|
| 20 |
+
let sessionId = 'session-' + crypto.randomUUID(); // Unique session ID, generated once per page load
|
| 21 |
+
document.body.classList.add('no-scroll');
|
| 22 |
+
|
| 23 |
+
function renderMessages() {
|
| 24 |
+
chatWindow.innerHTML = '';
|
| 25 |
+
messages.forEach((m) => {
|
| 26 |
+
const bubble = document.createElement('div');
|
| 27 |
+
bubble.classList.add(
|
| 28 |
+
'msg-bubble',
|
| 29 |
+
m.role === 'user' ? 'user' : 'assistant'
|
| 30 |
+
);
|
| 31 |
+
// convert markdown to HTML safely
|
| 32 |
+
bubble.innerHTML = DOMPurify.sanitize(marked.parse(m.content));
|
| 33 |
+
chatWindow.appendChild(bubble);
|
| 34 |
+
});
|
| 35 |
+
chatWindow.scrollTop = chatWindow.scrollHeight;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
function updateSlidersUI() {
|
| 39 |
+
tempValue.textContent = tempSlider.value;
|
| 40 |
+
// maxTokensValue.textContent = maxTokensSlider.value;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
function getMachineId() {
|
| 44 |
+
let machineId = localStorage.getItem('MachineId');
|
| 45 |
+
|
| 46 |
+
if (!machineId) {
|
| 47 |
+
machineId = 'dev-' + crypto.randomUUID();
|
| 48 |
+
localStorage.setItem('MachineId', machineId);
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
return machineId;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
// ----- Chat -----
|
| 55 |
+
|
| 56 |
+
async function sendMessage() {
|
| 57 |
+
const text = userInput.value.trim();
|
| 58 |
+
if (!text) return;
|
| 59 |
+
|
| 60 |
+
// Add user message locally
|
| 61 |
+
messages.push({ role: 'user', content: text });
|
| 62 |
+
renderMessages();
|
| 63 |
+
userInput.value = '';
|
| 64 |
+
|
| 65 |
+
statusEl.textContent = 'Thinking...';
|
| 66 |
+
statusEl.className = 'status status-info';
|
| 67 |
+
|
| 68 |
+
const temperature = parseFloat(tempSlider.value);
|
| 69 |
+
// const maxTokens = parseInt(maxTokensSlider.value, 10);
|
| 70 |
+
// const systemPrompt = systemPresetSelect.value;
|
| 71 |
+
const modelType = systemPresetSelect.value;
|
| 72 |
+
|
| 73 |
+
const payload = {
|
| 74 |
+
user_id: getMachineId(),
|
| 75 |
+
session_id: sessionId,
|
| 76 |
+
messages: messages.map((m) => ({ role: m.role, content: m.content })),
|
| 77 |
+
temperature,
|
| 78 |
+
// max_new_tokens: maxTokens,
|
| 79 |
+
model_type: modelType,
|
| 80 |
+
consent: consentGranted,
|
| 81 |
+
};
|
| 82 |
+
|
| 83 |
+
try {
|
| 84 |
+
const res = await fetch('/chat', {
|
| 85 |
+
method: 'POST',
|
| 86 |
+
headers: { 'Content-Type': 'application/json' },
|
| 87 |
+
body: JSON.stringify(payload),
|
| 88 |
+
});
|
| 89 |
+
|
| 90 |
+
const data = await res.json();
|
| 91 |
+
|
| 92 |
+
if (!res.ok) {
|
| 93 |
+
statusEl.textContent = data.error || 'Error from server.';
|
| 94 |
+
statusEl.className = 'status status-error';
|
| 95 |
+
return;
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
const reply = data.reply || '(No reply)';
|
| 99 |
+
messages.push({ role: 'assistant', content: reply });
|
| 100 |
+
renderMessages();
|
| 101 |
+
|
| 102 |
+
statusEl.textContent = 'Ready';
|
| 103 |
+
statusEl.className = 'status status-ok';
|
| 104 |
+
} catch (err) {
|
| 105 |
+
console.error(err);
|
| 106 |
+
statusEl.textContent = 'Network error.';
|
| 107 |
+
statusEl.className = 'status status-error';
|
| 108 |
+
}
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
function resetSession() {
|
| 112 |
+
sessionId = 'session-' + crypto.randomUUID();
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
function clearConversation() {
|
| 116 |
+
resetSession();
|
| 117 |
+
messages = [];
|
| 118 |
+
renderMessages();
|
| 119 |
+
statusEl.textContent = 'Conversation cleared. Start a new chat!';
|
| 120 |
+
statusEl.className = 'status status-ok';
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
// ----- Event wiring -----
|
| 124 |
+
|
| 125 |
+
// Consent logic
|
| 126 |
+
|
| 127 |
+
// When the checkbox is toggled, enable or disable the button
|
| 128 |
+
consentCheckbox.addEventListener('change', () => {
|
| 129 |
+
consentBtn.disabled = !consentCheckbox.checked;
|
| 130 |
+
});
|
| 131 |
+
|
| 132 |
+
// Handle the consent acceptance
|
| 133 |
+
consentBtn.addEventListener('click', () => {
|
| 134 |
+
consentOverlay.style.display = 'none'; // Hide overlay
|
| 135 |
+
document.body.classList.remove('no-scroll'); // NEW: re-enable scrolling
|
| 136 |
+
consentGranted = true; // Mark consent as granted
|
| 137 |
+
});
|
| 138 |
+
|
| 139 |
+
sendBtn.addEventListener('click', sendMessage);
|
| 140 |
+
|
| 141 |
+
// Enter to send, Shift+Enter = newline
|
| 142 |
+
userInput.addEventListener('keydown', (e) => {
|
| 143 |
+
if (e.key === 'Enter' && !e.shiftKey) {
|
| 144 |
+
e.preventDefault();
|
| 145 |
+
sendMessage();
|
| 146 |
+
}
|
| 147 |
+
});
|
| 148 |
+
|
| 149 |
+
tempSlider.addEventListener('input', updateSlidersUI);
|
| 150 |
+
// maxTokensSlider.addEventListener("input", updateSlidersUI);
|
| 151 |
+
clearBtn.addEventListener('click', clearConversation);
|
| 152 |
+
|
| 153 |
+
systemPresetSelect.addEventListener('change', () => {
|
| 154 |
+
clearConversation();
|
| 155 |
+
statusEl.textContent = 'Model changed. History cleared.';
|
| 156 |
+
statusEl.className = 'status status-ok';
|
| 157 |
+
});
|
| 158 |
+
|
| 159 |
+
// initial UI state
|
| 160 |
+
updateSlidersUI();
|
| 161 |
+
statusEl.textContent = 'Ready';
|
| 162 |
+
statusEl.className = 'status status-ok';
|
static/style.css
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* Dark theme page background */
|
| 2 |
+
body {
|
| 3 |
+
margin: 0;
|
| 4 |
+
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI',
|
| 5 |
+
sans-serif;
|
| 6 |
+
background: #0b1020;
|
| 7 |
+
color: #f5f5f5;
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
/* NEW: prevent scrolling while consent overlay is active */
|
| 11 |
+
body.no-scroll {
|
| 12 |
+
overflow: hidden;
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
.chat-container {
|
| 16 |
+
max-width: 900px;
|
| 17 |
+
margin: 40px auto;
|
| 18 |
+
background: #141b2f;
|
| 19 |
+
border-radius: 16px;
|
| 20 |
+
box-shadow: 0 10px 30px rgba(0, 0, 0, 0.45);
|
| 21 |
+
display: flex;
|
| 22 |
+
flex-direction: column;
|
| 23 |
+
height: 80vh;
|
| 24 |
+
padding: 16px;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
.chat-header {
|
| 28 |
+
padding: 8px 4px 12px;
|
| 29 |
+
border-bottom: 1px solid #2c3554;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
.chat-header h1 {
|
| 33 |
+
margin: 0;
|
| 34 |
+
font-size: 1.8rem;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
.chat-header .subtitle {
|
| 38 |
+
margin: 4px 0 0;
|
| 39 |
+
color: #c0c6e0;
|
| 40 |
+
font-size: 0.95rem;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
/* Controls bar */
|
| 44 |
+
.controls-bar {
|
| 45 |
+
display: flex;
|
| 46 |
+
flex-wrap: wrap;
|
| 47 |
+
gap: 12px;
|
| 48 |
+
margin-top: 10px;
|
| 49 |
+
padding: 8px 4px;
|
| 50 |
+
border-bottom: 1px solid #2c3554;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
.control-group {
|
| 54 |
+
display: flex;
|
| 55 |
+
flex-direction: column;
|
| 56 |
+
gap: 4px;
|
| 57 |
+
font-size: 0.85rem;
|
| 58 |
+
color: #d3dbff;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
.control-group select,
|
| 62 |
+
.control-group input[type='range'] {
|
| 63 |
+
background: #0d1324;
|
| 64 |
+
border-radius: 8px;
|
| 65 |
+
border: 1px solid #2c3554;
|
| 66 |
+
color: #f5f5f5;
|
| 67 |
+
padding: 4px 8px;
|
| 68 |
+
font-size: 0.85rem;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
.secondary-button {
|
| 72 |
+
align-self: flex-end;
|
| 73 |
+
padding: 6px 12px;
|
| 74 |
+
border-radius: 8px;
|
| 75 |
+
border: 1px solid #2c3554;
|
| 76 |
+
background: #1f2840;
|
| 77 |
+
color: #f5f5f5;
|
| 78 |
+
font-size: 0.85rem;
|
| 79 |
+
cursor: pointer;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
.secondary-button:hover {
|
| 83 |
+
background: #273256;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
/* Chat window */
|
| 87 |
+
.chat-window {
|
| 88 |
+
flex: 1;
|
| 89 |
+
margin-top: 10px;
|
| 90 |
+
padding: 10px;
|
| 91 |
+
overflow-y: auto;
|
| 92 |
+
background: #0d1324;
|
| 93 |
+
border-radius: 12px;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
/* Message bubbles */
|
| 97 |
+
.msg-bubble {
|
| 98 |
+
max-width: 75%;
|
| 99 |
+
padding: 8px 12px;
|
| 100 |
+
margin-bottom: 8px;
|
| 101 |
+
border-radius: 12px;
|
| 102 |
+
font-size: 0.95rem;
|
| 103 |
+
line-height: 1.4;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
.msg-bubble.user {
|
| 107 |
+
margin-left: auto;
|
| 108 |
+
background: #4c6fff;
|
| 109 |
+
color: #ffffff;
|
| 110 |
+
border-bottom-right-radius: 4px;
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
.msg-bubble.assistant {
|
| 114 |
+
margin-right: auto;
|
| 115 |
+
background: #1f2840;
|
| 116 |
+
color: #f5f5f5;
|
| 117 |
+
border-bottom-left-radius: 4px;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
/* Input area */
|
| 121 |
+
.chat-input-area {
|
| 122 |
+
display: flex;
|
| 123 |
+
gap: 8px;
|
| 124 |
+
margin-top: 12px;
|
| 125 |
+
border-top: 1px solid #2c3554;
|
| 126 |
+
padding-top: 8px;
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
.chat-input-area textarea {
|
| 130 |
+
flex: 1;
|
| 131 |
+
border-radius: 10px;
|
| 132 |
+
border: 1px solid #2c3554;
|
| 133 |
+
background: #0d1324;
|
| 134 |
+
color: #f5f5f5;
|
| 135 |
+
padding: 8px;
|
| 136 |
+
resize: none;
|
| 137 |
+
font-size: 0.95rem;
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
.chat-input-area button {
|
| 141 |
+
padding: 8px 18px;
|
| 142 |
+
border-radius: 10px;
|
| 143 |
+
border: none;
|
| 144 |
+
background: #4c6fff;
|
| 145 |
+
color: white;
|
| 146 |
+
font-weight: 600;
|
| 147 |
+
cursor: pointer;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
.chat-input-area button:hover {
|
| 151 |
+
background: #3453e6;
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
/* Status text */
|
| 155 |
+
.status {
|
| 156 |
+
margin-top: 6px;
|
| 157 |
+
font-size: 0.85rem;
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
.status-info {
|
| 161 |
+
color: #ffce56;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
.status-ok {
|
| 165 |
+
color: #8be48b;
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
.status-error {
|
| 169 |
+
color: #ff8080;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
/* RESPONSIVE DESIGN */
|
| 173 |
+
@media (max-width: 768px) {
|
| 174 |
+
.chat-container {
|
| 175 |
+
margin: 0;
|
| 176 |
+
border-radius: 0;
|
| 177 |
+
height: 100vh;
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
.msg-bubble {
|
| 181 |
+
max-width: 90%;
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
.controls-bar {
|
| 185 |
+
flex-direction: column;
|
| 186 |
+
align-items: flex-start;
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
.secondary-button {
|
| 190 |
+
align-self: flex-start;
|
| 191 |
+
}
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
/* CONSENT OVERLAY FIXED VERSION */
|
| 195 |
+
.consent-overlay {
|
| 196 |
+
position: fixed;
|
| 197 |
+
top: 0;
|
| 198 |
+
left: 0;
|
| 199 |
+
width: 100%;
|
| 200 |
+
height: 100%;
|
| 201 |
+
|
| 202 |
+
background: rgba(0, 0, 0, 0.55); /* CHANGED: darker for visibility */
|
| 203 |
+
backdrop-filter: blur(4px);
|
| 204 |
+
|
| 205 |
+
display: flex;
|
| 206 |
+
align-items: center;
|
| 207 |
+
justify-content: center;
|
| 208 |
+
|
| 209 |
+
z-index: 9999; /* NEW: ensure nothing covers this */
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
/* Dark theme overlay box */
|
| 213 |
+
.consent-box {
|
| 214 |
+
background: #141b2f; /* CHANGED: match theme */
|
| 215 |
+
color: #f5f5f5; /* NEW: readable on dark bg */
|
| 216 |
+
padding: 24px;
|
| 217 |
+
width: 420px;
|
| 218 |
+
border-radius: 12px;
|
| 219 |
+
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.4);
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
.consent-check {
|
| 223 |
+
display: flex;
|
| 224 |
+
align-items: center;
|
| 225 |
+
margin: 16px 0;
|
| 226 |
+
gap: 10px;
|
| 227 |
+
}
|
templates/index.html
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8" />
|
| 5 |
+
<title>CHAMP Chatbot Demo</title>
|
| 6 |
+
<link rel="stylesheet" href="/static/style.css" />
|
| 7 |
+
<!-- Include marked.js for Markdown rendering -->
|
| 8 |
+
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
|
| 9 |
+
<!-- Include DOMPurify to sanitize HTML -->
|
| 10 |
+
<script src="https://cdn.jsdelivr.net/npm/dompurify@2.4.2/dist/purify.min.js"></script>
|
| 11 |
+
</head>
|
| 12 |
+
<body class="no-scroll">
|
| 13 |
+
<div class="chat-container">
|
| 14 |
+
<!-- Header -->
|
| 15 |
+
<header class="chat-header">
|
| 16 |
+
<h1>CHAMP Chatbot Demo: Model Comparison</h1>
|
| 17 |
+
<p class="subtitle">
|
| 18 |
+
Talk to and compare chatbots powered by different models. Please
|
| 19 |
+
remember to avoid sharing any sensitive or private details during the
|
| 20 |
+
conversation.
|
| 21 |
+
</p>
|
| 22 |
+
</header>
|
| 23 |
+
|
| 24 |
+
<!-- Controls bar -->
|
| 25 |
+
<div class="controls-bar">
|
| 26 |
+
<div class="control-group">
|
| 27 |
+
<label for="systemPreset">Model Selection</label>
|
| 28 |
+
<select id="systemPreset">
|
| 29 |
+
<option value="champ" selected>CHAMP</option>
|
| 30 |
+
<!-- champ is our model -->
|
| 31 |
+
<option value="openai">ChatGPT</option>
|
| 32 |
+
<option value="google">Gemma</option>
|
| 33 |
+
</select>
|
| 34 |
+
</div>
|
| 35 |
+
|
| 36 |
+
<div class="control-group">
|
| 37 |
+
<label for="tempSlider">
|
| 38 |
+
Temperature:
|
| 39 |
+
<span id="tempValue">0.7</span>
|
| 40 |
+
</label>
|
| 41 |
+
<input
|
| 42 |
+
type="range"
|
| 43 |
+
id="tempSlider"
|
| 44 |
+
min="0.1"
|
| 45 |
+
max="1.2"
|
| 46 |
+
step="0.1"
|
| 47 |
+
value="0.7"
|
| 48 |
+
/>
|
| 49 |
+
</div>
|
| 50 |
+
|
| 51 |
+
<button id="clearBtn" class="secondary-button">Clear</button>
|
| 52 |
+
</div>
|
| 53 |
+
|
| 54 |
+
<div id="consentOverlay" class="consent-overlay">
|
| 55 |
+
<div class="consent-box">
|
| 56 |
+
<h2>Before you continue</h2>
|
| 57 |
+
<p>
|
| 58 |
+
By using this demo you agree that your messages will be shared with
|
| 59 |
+
us for processing. Do not provide sensitive or private details.
|
| 60 |
+
</p>
|
| 61 |
+
|
| 62 |
+
<label class="consent-check">
|
| 63 |
+
<input type="checkbox" id="consentCheckbox" />
|
| 64 |
+
I understand and agree
|
| 65 |
+
</label>
|
| 66 |
+
|
| 67 |
+
<button id="consentBtn" class="primary-button" disabled>
|
| 68 |
+
Agree and Continue
|
| 69 |
+
</button>
|
| 70 |
+
</div>
|
| 71 |
+
</div>
|
| 72 |
+
|
| 73 |
+
<!-- Chat window -->
|
| 74 |
+
<main id="chatWindow" class="chat-window">
|
| 75 |
+
<!-- Messages get rendered here by app.js -->
|
| 76 |
+
</main>
|
| 77 |
+
|
| 78 |
+
<!-- Input area -->
|
| 79 |
+
<footer class="chat-input-area">
|
| 80 |
+
<textarea
|
| 81 |
+
id="userInput"
|
| 82 |
+
rows="2"
|
| 83 |
+
placeholder="Type your message and press Enter or click Send..."
|
| 84 |
+
></textarea>
|
| 85 |
+
<button id="sendBtn">Send</button>
|
| 86 |
+
</footer>
|
| 87 |
+
|
| 88 |
+
<!-- Status line -->
|
| 89 |
+
<div id="status" class="status"></div>
|
| 90 |
+
</div>
|
| 91 |
+
|
| 92 |
+
<script src="/static/app.js"></script>
|
| 93 |
+
</body>
|
| 94 |
+
</html>
|
tests/test_dynamodb.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from moto import mock_aws
|
| 2 |
+
import boto3
|
| 3 |
+
import pytest
|
| 4 |
+
from dynamodb_helper import create_table_if_not_exists, DDB_TABLE
|
| 5 |
+
from botocore.exceptions import ClientError
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
class TestDynamoDBHelper:
|
| 9 |
+
dynamodb = boto3.resource("dynamodb", region_name="ca-central-1")
|
| 10 |
+
@mock_aws
|
| 11 |
+
def test_create_table_if_not_exists(self):
|
| 12 |
+
# Ensure the table does not exist initially
|
| 13 |
+
client = self.dynamodb.meta.client
|
| 14 |
+
existing_tables = client.list_tables()["TableNames"]
|
| 15 |
+
assert DDB_TABLE not in existing_tables
|
| 16 |
+
|
| 17 |
+
# Create the table
|
| 18 |
+
table = create_table_if_not_exists(self.dynamodb)
|
| 19 |
+
assert table is not None
|
| 20 |
+
|
| 21 |
+
# Verify the table now exists
|
| 22 |
+
existing_tables = client.list_tables()["TableNames"]
|
| 23 |
+
assert DDB_TABLE in existing_tables
|
| 24 |
+
|
| 25 |
+
# Attempt to create the table again, should not raise an error
|
| 26 |
+
table = create_table_if_not_exists(self.dynamodb)
|
| 27 |
+
assert table is not None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@mock_aws
|
| 31 |
+
def test_log_event(self):
|
| 32 |
+
table = create_table_if_not_exists(self.dynamodb)
|
| 33 |
+
|
| 34 |
+
table_resource = self.dynamodb.Table(DDB_TABLE)
|
| 35 |
+
user_id = "user123"
|
| 36 |
+
session_id = "test-session-456"
|
| 37 |
+
data = {"event": "test_event", "value": 26, "float_value": 3.14}
|
| 38 |
+
from dynamodb_helper import log_event
|
| 39 |
+
log_event(user_id, session_id, data)
|
| 40 |
+
response = table_resource.scan()
|
| 41 |
+
assert response["Count"] == 1
|
| 42 |
+
item = response["Items"][0]
|
| 43 |
+
assert item["PK"] == f"SESSION#{session_id}"
|
| 44 |
+
assert item["data"]["event"] == "test_event"
|
| 45 |
+
from decimal import Decimal
|
| 46 |
+
assert item["data"]["value"] == Decimal(26)
|
| 47 |
+
assert item["data"]["float_value"] == Decimal("3.14")
|
| 48 |
+
|
tests/test_main.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from main import app
|
| 2 |
+
from fastapi.testclient import TestClient
|
| 3 |
+
from moto import mock_aws
|
| 4 |
+
from dynamodb_helper import create_table_if_not_exists, DDB_TABLE
|
| 5 |
+
from botocore.exceptions import ClientError
|
| 6 |
+
|
| 7 |
+
class TestMain:
|
| 8 |
+
def test_convert_messages(self):
|
| 9 |
+
from main import convert_messages
|
| 10 |
+
from main import DEFAULT_SYSTEM_PROMPT
|
| 11 |
+
from main import ChatMessage
|
| 12 |
+
messages = [
|
| 13 |
+
ChatMessage(role="user", content="Hello"),
|
| 14 |
+
ChatMessage(role="assistant", content="Hi there!"),
|
| 15 |
+
]
|
| 16 |
+
converted = convert_messages(messages)
|
| 17 |
+
assert converted == [
|
| 18 |
+
{"role": "system", "content": DEFAULT_SYSTEM_PROMPT},
|
| 19 |
+
{"role": "user", "content": "Hello"},
|
| 20 |
+
{"role": "assistant", "content": "Hi there!"},
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
def test_convert_messages_langchain(self):
|
| 24 |
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
| 25 |
+
from main import convert_messages_langchain, ChatMessage, CHAMP_SYSTEM_PROMPT
|
| 26 |
+
messages = [
|
| 27 |
+
ChatMessage(role="user", content="Hello"),
|
| 28 |
+
ChatMessage(role="assistant", content="Hi there!"),
|
| 29 |
+
]
|
| 30 |
+
converted = convert_messages_langchain(messages)
|
| 31 |
+
# Check types and content only
|
| 32 |
+
expected_types = [SystemMessage, HumanMessage, AIMessage]
|
| 33 |
+
expected_contents = [CHAMP_SYSTEM_PROMPT, "Hello", "Hi there!"]
|
| 34 |
+
|
| 35 |
+
for msg, expected_type, expected_content in zip(converted, expected_types, expected_contents):
|
| 36 |
+
assert isinstance(msg, expected_type)
|
| 37 |
+
assert msg.content == expected_content
|