Spaces:
Paused
Paused
| from datetime import datetime | |
| from io import BytesIO | |
| from pathlib import PurePath | |
| from typing import Annotated | |
| from uuid import UUID | |
| from PIL import Image, UnidentifiedImageError | |
| from fastapi import APIRouter, Depends, HTTPException, params, UploadFile, File | |
| from loguru import logger | |
| from app.Models.api_models.admin_api_model import ImageOptUpdateModel, DuplicateValidationModel | |
| from app.Models.api_models.admin_query_params import UploadImageModel | |
| from app.Models.api_response.admin_api_response import ServerInfoResponse, ImageUploadResponse, \ | |
| DuplicateValidationResponse | |
| from app.Models.api_response.base import NekoProtocol | |
| from app.Models.errors import PointDuplicateError | |
| from app.Models.img_data import ImageData | |
| from app.Services.authentication import force_admin_token_verify | |
| from app.Services.provider import ServiceProvider | |
| from app.Services.vector_db_context import PointNotFoundError | |
| from app.config import config | |
| from app.util.generate_uuid import generate_uuid_from_sha1 | |
| from app.util.local_file_utility import VALID_IMAGE_EXTENSIONS | |
| admin_router = APIRouter(dependencies=[Depends(force_admin_token_verify)], tags=["Admin"]) | |
| services: ServiceProvider | None = None | |
| async def delete_image( | |
| image_id: Annotated[UUID, params.Path(description="The id of the image you want to delete.")]) -> NekoProtocol: | |
| try: | |
| point = await services.db_context.retrieve_by_id(str(image_id)) | |
| except PointNotFoundError as ex: | |
| raise HTTPException(404, "Cannot find the image with the given ID.") from ex | |
| await services.db_context.deleteItems([str(point.id)]) | |
| logger.success("Image {} deleted from database.", point.id) | |
| if config.storage.method.enabled: # local image | |
| if point.local: | |
| image_files = [itm[0] async for itm in | |
| services.storage_service.active_storage.list_files("", f"{point.id}.*")] | |
| assert len(image_files) <= 1 | |
| if not image_files: | |
| logger.warning("Image {} is a local image but not found in static folder.", point.id) | |
| else: | |
| await services.storage_service.active_storage.move(image_files[0], f"_deleted/{image_files[0].name}") | |
| logger.success("Image {} removed.", image_files[0].name) | |
| if point.thumbnail_url is not None and (point.local or point.local_thumbnail): | |
| thumbnail_file = PurePath(f"thumbnails/{point.id}.webp") | |
| if await services.storage_service.active_storage.is_exist(thumbnail_file): | |
| await services.storage_service.active_storage.delete(thumbnail_file) | |
| logger.success("Thumbnail {} removed.", thumbnail_file.name) | |
| else: | |
| logger.warning("Thumbnail {} not found.", thumbnail_file.name) | |
| return NekoProtocol(message="Image deleted.") | |
| async def update_image(image_id: Annotated[UUID, params.Path(description="The id of the image you want to delete.")], | |
| model: ImageOptUpdateModel) -> NekoProtocol: | |
| if model.empty(): | |
| raise HTTPException(422, "Nothing to update.") | |
| try: | |
| point = await services.db_context.retrieve_by_id(str(image_id)) | |
| except PointNotFoundError as ex: | |
| raise HTTPException(404, "Cannot find the image with the given ID.") from ex | |
| if model.thumbnail_url is not None: | |
| if point.local or point.local_thumbnail: | |
| raise HTTPException(422, "Cannot change the thumbnail URL of a local image.") | |
| point.thumbnail_url = model.thumbnail_url | |
| if model.url is not None: | |
| if point.local: | |
| raise HTTPException(422, "Cannot change the URL of a local image.") | |
| point.url = model.url | |
| if model.starred is not None: | |
| point.starred = model.starred | |
| if model.categories is not None: | |
| point.categories = model.categories | |
| await services.db_context.updatePayload(point) | |
| logger.success("Image {} updated.", point.id) | |
| return NekoProtocol(message="Image updated.") | |
| IMAGE_MIMES = { | |
| "image/jpeg": "jpeg", | |
| "image/png": "png", | |
| "image/webp": "webp", | |
| "image/gif": "gif", | |
| } | |
| async def upload_image(image_file: Annotated[UploadFile, File(description="The image to be uploaded.")], | |
| model: Annotated[UploadImageModel, Depends()]) -> ImageUploadResponse: | |
| # generate an ID for the image | |
| img_type = None | |
| if image_file.content_type.lower() in IMAGE_MIMES: | |
| img_type = IMAGE_MIMES[image_file.content_type.lower()] | |
| elif image_file.filename: | |
| extension = PurePath(image_file.filename).suffix.lower() | |
| if extension in VALID_IMAGE_EXTENSIONS: | |
| img_type = extension[1:] | |
| if not img_type: | |
| logger.warning("Failed to infer image format of the uploaded image. Content Type: {}, Filename: {}", | |
| image_file.content_type, image_file.filename) | |
| raise HTTPException(415, "Unsupported image format.") | |
| img_bytes = await image_file.read() | |
| try: | |
| img_id = await services.upload_service.assign_image_id(img_bytes) | |
| except PointDuplicateError as ex: | |
| raise HTTPException(409, | |
| f"The uploaded point is already contained in the database! entity id: {ex.entity_id}") \ | |
| from ex | |
| try: | |
| image = Image.open(BytesIO(img_bytes)) | |
| image.verify() | |
| image.close() | |
| except UnidentifiedImageError as ex: | |
| logger.warning("Invalid image file from upload request. id: {}", img_id) | |
| raise HTTPException(422, "Cannot open the image file.") from ex | |
| image_data = ImageData(id=img_id, | |
| url=model.url, | |
| thumbnail_url=model.thumbnail_url, | |
| local=model.local, | |
| categories=model.categories, | |
| starred=model.starred, | |
| format=img_type, | |
| index_date=datetime.now()) | |
| await services.upload_service.queue_upload_image(image_data, img_bytes, model.skip_ocr, model.local_thumbnail) | |
| return ImageUploadResponse(message="OK. Image added to upload queue.", image_id=img_id) | |
| async def server_info() -> ServerInfoResponse: | |
| return ServerInfoResponse(message="Successfully get server information!", | |
| image_count=await services.db_context.get_counts(exact=True), | |
| index_queue_length=services.upload_service.get_queue_size()) | |
| async def duplication_validate(model: DuplicateValidationModel) -> DuplicateValidationResponse: | |
| ids = [generate_uuid_from_sha1(t) for t in model.hashes] | |
| valid_ids = await services.db_context.validate_ids([str(t) for t in ids]) | |
| exists_matrix = [str(t) in valid_ids or t in services.upload_service.uploading_ids for t in ids] | |
| return DuplicateValidationResponse( | |
| exists=exists_matrix, | |
| entity_ids=[(str(t) if exists else None) for (t, exists) in zip(ids, exists_matrix)], | |
| message="Validation completed.") | |