Spaces:
Paused
Paused
| from io import BytesIO | |
| from typing import Annotated, List | |
| from uuid import uuid4, UUID | |
| from PIL import Image | |
| from fastapi import APIRouter, HTTPException | |
| from fastapi.params import File, Query, Path, Depends | |
| from loguru import logger | |
| from app.Models.api_models.search_api_model import AdvancedSearchModel, CombinedSearchModel, SearchBasisEnum | |
| from app.Models.api_response.search_api_response import SearchApiResponse | |
| from app.Models.query_params import SearchPagingParams, FilterParams | |
| from app.Models.search_result import SearchResult | |
| from app.Services.authentication import force_access_token_verify | |
| from app.Services.provider import ServiceProvider | |
| from app.config import config | |
| from app.util.calculate_vectors_cosine import calculate_vectors_cosine | |
| search_router = APIRouter(dependencies=([Depends(force_access_token_verify)] if config.access_protected else None), | |
| tags=["Search"]) | |
| services: ServiceProvider | None = None # The service provider will be injected in the webapp initialize | |
| class SearchBasisParams: | |
| def __init__(self, | |
| basis: Annotated[SearchBasisEnum, Query( | |
| description="The basis used to search the image.")] = SearchBasisEnum.vision): | |
| if basis == SearchBasisEnum.ocr and not config.ocr_search.enable: | |
| raise HTTPException(400, "OCR search is not enabled.") | |
| self.basis = basis | |
| async def result_postprocessing(resp: SearchApiResponse) -> SearchApiResponse: | |
| if not config.storage.method.enabled: | |
| return resp | |
| for item in resp.result: | |
| if item.img.local: | |
| img_extension = item.img.format or item.img.url.split('.')[-1] | |
| img_remote_filename = f"{item.img.id}.{img_extension}" | |
| item.img.url = await services.storage_service.active_storage.presign_url(img_remote_filename) | |
| if item.img.thumbnail_url is not None and (item.img.local or item.img.local_thumbnail): | |
| thumbnail_remote_filename = f"thumbnails/{item.img.id}.webp" | |
| item.img.thumbnail_url = await services.storage_service.active_storage.presign_url( | |
| thumbnail_remote_filename) | |
| return resp | |
| async def textSearch( | |
| prompt: Annotated[ | |
| str, Path(max_length=100, description="The image prompt text you want to search.")], | |
| basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)], | |
| filter_param: Annotated[FilterParams, Depends(FilterParams)], | |
| paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)], | |
| exact: Annotated[bool, Query( | |
| description="If using OCR search, this option will require the ocr text contains **exactly** the " | |
| "criteria you have given. This won't take any effect in vision search.")] = False | |
| ) -> SearchApiResponse: | |
| logger.info("Text search request received, prompt: {}", prompt) | |
| text_vector = services.transformers_service.get_text_vector(prompt) if basis.basis == SearchBasisEnum.vision \ | |
| else services.transformers_service.get_bert_vector(prompt) | |
| if basis.basis == SearchBasisEnum.ocr and exact: | |
| filter_param.ocr_text = prompt | |
| results = await services.db_context.querySearch(text_vector, | |
| query_vector_name=services.db_context.vector_name_for_basis( | |
| basis.basis), | |
| filter_param=filter_param, | |
| top_k=paging.count, | |
| skip=paging.skip) | |
| return await result_postprocessing( | |
| SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4())) | |
| async def imageSearch( | |
| image: Annotated[bytes, File(max_length=10 * 1024 * 1024, media_type="image/*", | |
| description="The image you want to search.")], | |
| filter_param: Annotated[FilterParams, Depends(FilterParams)], | |
| paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)] | |
| ) -> SearchApiResponse: | |
| fakefile = BytesIO(image) | |
| img = Image.open(fakefile) | |
| logger.info("Image search request received") | |
| image_vector = services.transformers_service.get_image_vector(img) | |
| results = await services.db_context.querySearch(image_vector, | |
| top_k=paging.count, | |
| skip=paging.skip, | |
| filter_param=filter_param) | |
| return await result_postprocessing( | |
| SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4())) | |
| async def similarWith( | |
| image_id: Annotated[UUID, Path(description="The id of the image you want to search.")], | |
| basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)], | |
| filter_param: Annotated[FilterParams, Depends(FilterParams)], | |
| paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)] | |
| ) -> SearchApiResponse: | |
| logger.info("Similar search request received, id: {}", image_id) | |
| results = await services.db_context.querySimilar(search_id=str(image_id), | |
| top_k=paging.count, | |
| skip=paging.skip, | |
| filter_param=filter_param, | |
| query_vector_name=services.db_context.vector_name_for_basis( | |
| basis.basis)) | |
| return await result_postprocessing( | |
| SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4())) | |
| async def advancedSearch( | |
| model: AdvancedSearchModel, | |
| basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)], | |
| filter_param: Annotated[FilterParams, Depends(FilterParams)], | |
| paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse: | |
| logger.info("Advanced search request received: {}", model) | |
| result = await process_advanced_and_combined_search_query(model, basis, filter_param, paging) | |
| return await result_postprocessing( | |
| SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4())) | |
| async def combinedSearch( | |
| model: CombinedSearchModel, | |
| basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)], | |
| filter_param: Annotated[FilterParams, Depends(FilterParams)], | |
| paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse: | |
| if not config.ocr_search.enable: | |
| raise HTTPException(400, "You used combined search, but it needs OCR search which is not " | |
| "enabled.") | |
| logger.info("Combined search request received: {}", model) | |
| result = await process_advanced_and_combined_search_query(model, basis, filter_param, paging, True) | |
| calculate_and_sort_by_combined_scores(model, basis, result) | |
| result = result[:paging.count] if len(result) > paging.count else result | |
| return await result_postprocessing( | |
| SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4())) | |
| async def randomPick( | |
| filter_param: Annotated[FilterParams, Depends(FilterParams)], | |
| paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)], | |
| seed: Annotated[int | None, Query( | |
| description="The seed for random pick. This is helpful for generating a reproducible random pick.")] = None, | |
| ) -> SearchApiResponse: | |
| logger.info("Random pick request received") | |
| random_vector = services.transformers_service.get_random_vector(seed) | |
| result = await services.db_context.querySearch(random_vector, top_k=paging.count, skip=paging.skip, | |
| filter_param=filter_param) | |
| return await result_postprocessing( | |
| SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4())) | |
| # @search_router.get("/recall/{query_id}", description="Recall the query with given queryId") | |
| # async def recallQuery(query_id: str): | |
| # raise NotImplementedError() | |
| async def process_advanced_and_combined_search_query(model: AdvancedSearchModel, | |
| basis: SearchBasisParams, | |
| filter_param: FilterParams, | |
| paging: SearchPagingParams, | |
| is_combined_search=False) -> List[SearchResult]: | |
| match basis.basis: | |
| case SearchBasisEnum.ocr: | |
| positive_vectors = [services.transformers_service.get_bert_vector(t) for t in model.criteria] | |
| negative_vectors = [services.transformers_service.get_bert_vector(t) for t in model.negative_criteria] | |
| case SearchBasisEnum.vision: | |
| positive_vectors = [services.transformers_service.get_text_vector(t) for t in model.criteria] | |
| negative_vectors = [services.transformers_service.get_text_vector(t) for t in model.negative_criteria] | |
| case _: # pragma: no cover | |
| raise NotImplementedError() | |
| # In order to ensure the query effect of the combined query, modify the actual top_k | |
| _query_top_k = min(max(30, paging.count * 3), 100) if is_combined_search else paging.count | |
| result = await services.db_context.querySimilar( | |
| query_vector_name=services.db_context.vector_name_for_basis(basis.basis), | |
| positive_vectors=positive_vectors, | |
| negative_vectors=negative_vectors, | |
| mode=model.mode, | |
| filter_param=filter_param, | |
| with_vectors=is_combined_search, | |
| top_k=_query_top_k, | |
| skip=paging.skip) | |
| return result | |
| def calculate_and_sort_by_combined_scores(model: CombinedSearchModel, | |
| basis: SearchBasisParams, | |
| result: List[SearchResult]) -> None: | |
| # Use a different method to calculate the extra prompt vector based on the basis | |
| match basis.basis: | |
| case SearchBasisEnum.ocr: | |
| extra_prompt_vector = services.transformers_service.get_text_vector(model.extra_prompt) | |
| case SearchBasisEnum.vision: | |
| extra_prompt_vector = services.transformers_service.get_bert_vector(model.extra_prompt) | |
| case _: # pragma: no cover | |
| raise NotImplementedError() | |
| # Calculate combined_similar_score (original score * similar_score) and write to SearchResult.score | |
| for itm in result: | |
| match basis.basis: | |
| case SearchBasisEnum.ocr: | |
| extra_vector = itm.img.image_vector | |
| case SearchBasisEnum.vision: | |
| extra_vector = itm.img.text_contain_vector | |
| case _: # pragma: no cover | |
| raise NotImplementedError() | |
| if extra_vector is not None: | |
| similar_score = calculate_vectors_cosine(extra_vector, extra_prompt_vector) | |
| itm.score = (1 + similar_score) * itm.score | |
| # Finally, sort the result by combined_similar_score | |
| result.sort(key=lambda i: i.score, reverse=True) | |