| import io |
| from pathlib import Path |
|
|
| import cv2 |
| import numpy as np |
| from fastapi import APIRouter, FastAPI, File, HTTPException, Request, UploadFile |
| from fastapi.responses import FileResponse, Response |
| from fastapi.routing import APIRouter |
| from services.embeddings.embedd import EmbeddingModel |
| from services.segmentation.segmentation_model import YOLODetector |
| from starlette.datastructures import UploadFile as StarletteUploadFile |
|
|
| router = APIRouter() |
|
|
| WEIGHTS_PATH = Path(__file__).parent.parent / \ |
| "services/segmentation/yolov5s.pt" |
|
|
| try: |
| embedding_model = EmbeddingModel( |
| model_path=".embed_model_weights.pth") |
| detector = YOLODetector(weights=str(WEIGHTS_PATH), device="cpu") |
| except Exception as e: |
| raise RuntimeError(f"Failed to initialize models: {str(e)}") |
|
|
|
|
| @router.post("/image_to_embedding") |
| async def image_to_embedding(file: UploadFile, use_segmentation: bool = True): |
| """ |
| Generate embedding from input image, optionally with segmentation preprocessing. |
| |
| Args: |
| file: Uploaded image file |
| use_segmentation: Whether to segment the image before generating embedding |
| |
| Returns: |
| Image embedding vector or error message |
| |
| Raises: |
| HTTPException: If image processing fails |
| """ |
| try: |
| if not file.content_type.startswith('image/'): |
| raise HTTPException( |
| status_code=400, detail="File must be an image") |
|
|
| if use_segmentation: |
| im0, img = detector.preprocess_image(file) |
| detections = detector.infer(img) |
| image = detector.process_detections(im0, detections) |
|
|
| if image is None: |
| return {"message": "No cat detected in image"} |
| else: |
| print("No segmentation") |
| contents = await file.read() |
| nparr = np.frombuffer(contents, np.uint8) |
| image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
|
|
| if image is None: |
| raise HTTPException( |
| status_code=400, detail="Could not decode image") |
| print("image", image) |
| embedding = embedding_model(image) |
|
|
| return { |
| "embedding": embedding.tolist(), |
| "shape": embedding.shape |
| } |
|
|
| except Exception as e: |
| raise HTTPException( |
| status_code=500, detail=f"Error processing image: {str(e)}") |
|
|
|
|
| @router.post("/binary_image_to_embedding") |
| async def upload_image_binary(request: Request, use_segmentation: bool = True): |
| """ |
| Generate embedding from binary input image, optionally with segmentation preprocessing. |
| |
| Args: |
| request: Request object containing binary image data |
| use_segmentation: Whether to segment the image before generating embedding |
| |
| Returns: |
| Image embedding vector or error message |
| |
| Raises: |
| HTTPException: If image processing fails |
| """ |
| try: |
| |
| contents = await request.body() |
| segmented_image = None |
|
|
| if not contents: |
| raise HTTPException( |
| status_code=400, detail="No image data received") |
|
|
| |
| nparr = np.frombuffer(contents, np.uint8) |
| image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
|
|
| if image is None: |
| raise HTTPException( |
| status_code=400, detail="Could not decode image") |
|
|
| if use_segmentation: |
| im0, img = detector.preprocess_image_binary(image) |
| detections = detector.infer(img) |
| segmented_image = detector.process_detections(im0, detections) |
| |
| if segmented_image is None: |
| print("No segmentation") |
| embedding = embedding_model(image) |
| else: |
| print("with segmentation") |
| embedding = embedding_model(segmented_image) |
|
|
| return { |
| "embedding": embedding.tolist(), |
| "shape": embedding.shape |
| } |
|
|
| except Exception as e: |
| raise HTTPException( |
| status_code=500, detail=f"Error processing image: {str(e)}") |
|
|