faceid / app /routers /embedd.py
kuznetsovnikita's picture
fix
9a44e5f
import io
from pathlib import Path
import cv2
import numpy as np
from fastapi import APIRouter, FastAPI, File, HTTPException, Request, UploadFile
from fastapi.responses import FileResponse, Response
from fastapi.routing import APIRouter
from services.embeddings.embedd import EmbeddingModel
from services.segmentation.segmentation_model import YOLODetector
from starlette.datastructures import UploadFile as StarletteUploadFile
router = APIRouter()
WEIGHTS_PATH = Path(__file__).parent.parent / \
"services/segmentation/yolov5s.pt"
try:
embedding_model = EmbeddingModel(
model_path=".embed_model_weights.pth")
detector = YOLODetector(weights=str(WEIGHTS_PATH), device="cpu")
except Exception as e:
raise RuntimeError(f"Failed to initialize models: {str(e)}")
@router.post("/image_to_embedding")
async def image_to_embedding(file: UploadFile, use_segmentation: bool = True):
"""
Generate embedding from input image, optionally with segmentation preprocessing.
Args:
file: Uploaded image file
use_segmentation: Whether to segment the image before generating embedding
Returns:
Image embedding vector or error message
Raises:
HTTPException: If image processing fails
"""
try:
if not file.content_type.startswith('image/'):
raise HTTPException(
status_code=400, detail="File must be an image")
if use_segmentation:
im0, img = detector.preprocess_image(file)
detections = detector.infer(img)
image = detector.process_detections(im0, detections)
if image is None:
return {"message": "No cat detected in image"}
else:
print("No segmentation")
contents = await file.read()
nparr = np.frombuffer(contents, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if image is None:
raise HTTPException(
status_code=400, detail="Could not decode image")
print("image", image)
embedding = embedding_model(image)
return {
"embedding": embedding.tolist(),
"shape": embedding.shape
}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error processing image: {str(e)}")
@router.post("/binary_image_to_embedding")
async def upload_image_binary(request: Request, use_segmentation: bool = True):
"""
Generate embedding from binary input image, optionally with segmentation preprocessing.
Args:
request: Request object containing binary image data
use_segmentation: Whether to segment the image before generating embedding
Returns:
Image embedding vector or error message
Raises:
HTTPException: If image processing fails
"""
try:
# Читаем бинарные данные из тела запроса
contents = await request.body()
segmented_image = None
if not contents:
raise HTTPException(
status_code=400, detail="No image data received")
# Конвертируем бинарные данные в изображение
nparr = np.frombuffer(contents, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if image is None:
raise HTTPException(
status_code=400, detail="Could not decode image")
if use_segmentation:
im0, img = detector.preprocess_image_binary(image)
detections = detector.infer(img)
segmented_image = detector.process_detections(im0, detections)
if segmented_image is None:
print("No segmentation")
embedding = embedding_model(image)
else:
print("with segmentation")
embedding = embedding_model(segmented_image)
return {
"embedding": embedding.tolist(),
"shape": embedding.shape
}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error processing image: {str(e)}")