chatx / tests /test_app.py
imseldrith's picture
Upload folder using huggingface_hub (#1)
97ef430
import pytest
from fastapi.testclient import TestClient
from fastapi import HTTPException
from typing import Union, Any
from app import ask_gpt4, app, http_exception_handler
from models.query_model import QueryModel
client = TestClient(app)
@pytest.mark.asyncio
@pytest.mark.parametrize("query_params, model, expected_output", [
(
{
"user_input": "What is the capital of France?Please answer with one word only and dont add dot at the end"},
"text-davinci-003",
"Paris"
),
(
{"user_input": "Which is the capital of UK? Please answer with one word only and dont add dot at the end"},
"text-davinci-003",
"London"
),
# Add more test cases here
])
async def test_ask_gpt4(query_params, model, expected_output):
response = client.post(
"/ask_gpt4/",
json={"user_input": query_params["user_input"], "model": model},
)
assert response.status_code == 200
json_response = response.json()
# Check if the response contains a valid answer
assert "response" in json_response or "error" in json_response
# If there's an error, check if it's a known error
if "error" in json_response:
assert json_response["error"] in [
"ChatGPT response does not contain text attribute.",
# Add other known errors here
]
else:
assert json_response["response"] == expected_output
@pytest.mark.asyncio
@pytest.mark.parametrize("status_code, detail, expected_result", [
(404, "Not Found", {"detail": "Not Found", "status_code": 404}),
(500, "Internal Server Error", {"detail": "Internal Server Error", "status_code": 500}),
(401, "Unauthorized", {"detail": "Unauthorized", "status_code": 401}),
])
async def test_http_exception_handler(status_code: int, detail: Union[str, dict], expected_result: Any) -> None:
"""
Function to test http exception handler
:param status_code: int ,status code e.g. 400, 404 etc.
:param detail: str or Dict , detail message
:param expected_result:
:return: None
"""
exc = HTTPException(status_code=status_code, detail=detail)
result = await http_exception_handler(exc)
assert result == expected_result
# @pytest.mark.parametrize(
# "user_input,expected_status_code",
# [
# ("What is the capital of France?", 200),
# ("", 400), # Invalid query
# ],
# )
# def test_ask_gpt4_route(user_input: str, expected_status_code: int):
# query = QueryModel(user_input=user_input)
# response = client.post("/ask_gpt4/", json=query.dict())
#
# assert response.status_code == expected_status_code
#
# if expected_status_code == 200:
# assert "response" in response.json()
# else:
# assert "error" in response.json()