Spaces:
Paused
Paused
File size: 2,831 Bytes
97ef430 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | import pytest
from fastapi.testclient import TestClient
from fastapi import HTTPException
from typing import Union, Any
from app import ask_gpt4, app, http_exception_handler
from models.query_model import QueryModel
client = TestClient(app)
@pytest.mark.asyncio
@pytest.mark.parametrize("query_params, model, expected_output", [
(
{
"user_input": "What is the capital of France?Please answer with one word only and dont add dot at the end"},
"text-davinci-003",
"Paris"
),
(
{"user_input": "Which is the capital of UK? Please answer with one word only and dont add dot at the end"},
"text-davinci-003",
"London"
),
# Add more test cases here
])
async def test_ask_gpt4(query_params, model, expected_output):
response = client.post(
"/ask_gpt4/",
json={"user_input": query_params["user_input"], "model": model},
)
assert response.status_code == 200
json_response = response.json()
# Check if the response contains a valid answer
assert "response" in json_response or "error" in json_response
# If there's an error, check if it's a known error
if "error" in json_response:
assert json_response["error"] in [
"ChatGPT response does not contain text attribute.",
# Add other known errors here
]
else:
assert json_response["response"] == expected_output
@pytest.mark.asyncio
@pytest.mark.parametrize("status_code, detail, expected_result", [
(404, "Not Found", {"detail": "Not Found", "status_code": 404}),
(500, "Internal Server Error", {"detail": "Internal Server Error", "status_code": 500}),
(401, "Unauthorized", {"detail": "Unauthorized", "status_code": 401}),
])
async def test_http_exception_handler(status_code: int, detail: Union[str, dict], expected_result: Any) -> None:
"""
Function to test http exception handler
:param status_code: int ,status code e.g. 400, 404 etc.
:param detail: str or Dict , detail message
:param expected_result:
:return: None
"""
exc = HTTPException(status_code=status_code, detail=detail)
result = await http_exception_handler(exc)
assert result == expected_result
# @pytest.mark.parametrize(
# "user_input,expected_status_code",
# [
# ("What is the capital of France?", 200),
# ("", 400), # Invalid query
# ],
# )
# def test_ask_gpt4_route(user_input: str, expected_status_code: int):
# query = QueryModel(user_input=user_input)
# response = client.post("/ask_gpt4/", json=query.dict())
#
# assert response.status_code == expected_status_code
#
# if expected_status_code == 200:
# assert "response" in response.json()
# else:
# assert "error" in response.json() |