Spaces:
Paused
Paused
| import itertools | |
| import uuid | |
| import pytest_asyncio | |
| from .conftest import check_local_dir_empty | |
| from ..assets import assets_path | |
| test_images = {'bsn': ['bsn_0.jpg', 'bsn_1.jpg', 'bsn_2.jpg'], | |
| 'cat': ['cat_0.jpg', 'cat_1.jpg'], | |
| 'cg': ['cg_0.jpg', 'cg_1.png']} | |
| async def img_ids(test_client, wait_for_background_task): | |
| img_ids = {} | |
| for img_cls, item_images in test_images.items(): | |
| img_ids[img_cls] = [] | |
| for image in item_images: | |
| print(f'upload image {image}...') | |
| with open(assets_path / 'test_images' / image, 'rb') as f: | |
| resp = test_client.post('/admin/upload', | |
| files={'image_file': f}, | |
| params={'local': True}) | |
| assert resp.status_code == 200 | |
| img_ids[img_cls].append(resp.json()['image_id']) | |
| print('Waiting for images to be processed...') | |
| await wait_for_background_task(sum(len(v) for v in test_images.values())) | |
| yield img_ids | |
| # cleanup | |
| for img_cls in test_images.keys(): | |
| for img_id in img_ids[img_cls]: | |
| resp = test_client.delete(f"/admin/delete/{img_id}") | |
| assert resp.status_code == 200 | |
| check_local_dir_empty() | |
| def test_search_text(test_client, img_ids): | |
| resp = test_client.get('/search/text/hatsune+miku') | |
| assert resp.status_code == 200 | |
| assert resp.json()['result'][0]['img']['id'] in img_ids['cg'] | |
| def test_search_image(test_client, img_ids): | |
| with open(assets_path / 'test_images' / test_images['cat'][0], 'rb') as f: | |
| resp = test_client.post('/search/image', | |
| files={'image': f}) | |
| assert resp.status_code == 200 | |
| assert resp.json()['result'][0]['img']['id'] in img_ids['cat'] | |
| def test_search_similar(test_client, img_ids): | |
| resp = test_client.get(f"/search/similar/{img_ids['bsn'][0]}") | |
| assert resp.status_code == 200 | |
| assert resp.json()['result'][0]['img']['id'] in img_ids['bsn'] | |
| def test_search_advanced(test_client, img_ids): | |
| resp = test_client.post("/search/advanced", | |
| json={'criteria': ['white background', 'grayscale image'], | |
| 'negative_criteria': ['cat', 'hatsune miku']}) | |
| assert resp.status_code == 200 | |
| assert resp.json()['result'][0]['img']['id'] in img_ids['bsn'] | |
| def test_search_combined(test_client, img_ids): | |
| resp = test_client.post('/search/combined', json={'criteria': ['hatsune miku'], | |
| 'negative_criteria': ['grayscale image', 'cat'], | |
| 'extra_prompt': 'hatsunemiku'}) | |
| assert resp.status_code == 200 | |
| assert resp.json()['result'][0]['img']['id'] == img_ids['cg'][1] | |
| resp = test_client.post('/search/combined?basis=ocr', | |
| json={'criteria': ['hatsunemiku'], 'extra_prompt': 'hatsune miku'}) | |
| assert resp.status_code == 200 | |
| assert resp.json()['result'][0]['img']['id'] == img_ids['cg'][1] | |
| def test_search_filters(test_client, img_ids): | |
| resp = test_client.put(f"/admin/update_opt/{img_ids['bsn'][0]}", json={'categories': ['bsn'], 'starred': True}) | |
| assert resp.status_code == 200 | |
| resp = test_client.get("/search/text/cat", params={'categories': 'bsn'}) | |
| assert resp.status_code == 200 | |
| assert resp.json()['result'][0]['img']['id'] == img_ids['bsn'][0] | |
| resp = test_client.get("/search/text/cat", params={'starred': True}) | |
| assert resp.status_code == 200 | |
| assert resp.json()['result'][0]['img']['id'] == img_ids['bsn'][0] | |
| def test_images_query_by_id(test_client, img_ids): | |
| resp = test_client.get(f"/images/id/{img_ids['bsn'][0]}") | |
| assert resp.status_code == 200 | |
| assert resp.json()['img']['id'] == img_ids['bsn'][0] | |
| def test_images_query_not_exist(test_client, img_ids): | |
| resp = test_client.get(f"/images/id/{uuid.uuid4()}") | |
| assert resp.status_code == 404 | |
| def test_images_query_scroll(test_client, img_ids): | |
| resp = test_client.get("/images/", params={'count': 50}) | |
| assert resp.status_code == 200 | |
| resp_imgs = resp.json()['images'] | |
| all_images_id = list(itertools.chain(*img_ids.values())) | |
| for item in resp_imgs: | |
| assert item['id'] in all_images_id | |
| paging_test = test_client.get(f'/images', | |
| params={'prev_offset_id': resp_imgs[len(resp_imgs) // 2]['id']}) | |
| assert paging_test.status_code == 200 | |
| assert paging_test.json()['images'][0]['id'] == resp_imgs[len(resp_imgs) // 2]['id'] | |
| no_exist_test = test_client.get(f'/images', | |
| params={'prev_offset_id': uuid.uuid4()}) | |
| assert no_exist_test.status_code == 404 | |