| import os |
| import subprocess |
| import traceback |
| import cv2 |
| import io |
|
|
| import open_clip |
|
|
|
|
| from fastapi import FastAPI, UploadFile, File, HTTPException, Form, Request |
| from fastapi.responses import JSONResponse |
|
|
| import numpy as np |
| import onnxruntime as ort |
| from PIL import Image |
| from open_clip import image_transform |
| from transformers import AutoTokenizer |
| from open_clip import image_transform |
| from insightface.model_zoo import get_model |
| from insightface.app.common import Face |
| from starlette.middleware.base import BaseHTTPMiddleware |
|
|
| app = FastAPI(title="Photo AI Feature Extraction API", version="1.0.0") |
|
|
| |
| detector = None |
| recognizer = None |
| clip_vision_session = None |
| clip_text_session = None |
|
|
| |
| clip_preprocess = None |
| clip_tokenizer = None |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| API_TOKEN = os.getenv("TOKEN", "my-secret-token") |
|
|
| |
| EXCLUDED_PATHS = ["/health"] |
|
|
|
|
| class TokenAuthMiddleware(BaseHTTPMiddleware): |
| """简单的Token验证中间件""" |
| |
| async def dispatch(self, request: Request, call_next): |
| |
| if request.url.path in EXCLUDED_PATHS: |
| return await call_next(request) |
| |
| |
| auth_header = request.headers.get("Authorization") |
| |
| |
| if auth_header and auth_header.startswith("Bearer "): |
| auth_header = auth_header[7:] |
| |
| |
| if auth_header != API_TOKEN: |
| return JSONResponse( |
| status_code=401, |
| content={"detail": "未授权:无效的认证令牌"} |
| ) |
| |
| |
| response = await call_next(request) |
| return response |
|
|
|
|
| |
| app.add_middleware(TokenAuthMiddleware) |
|
|
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| global detector, recognizer, clip_vision_session, clip_text_session, clip_preprocess, clip_tokenizer |
|
|
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) |
| ANTELOPE_DIR = os.path.join(BASE_DIR, 'antelopev2') |
| CLIP_DIR = os.path.join(BASE_DIR, 'clip_model') |
|
|
| |
| sess_options = ort.SessionOptions() |
| sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL |
| sess_options.intra_op_num_threads = 4 |
|
|
| print("====== 开始加载离线模型 ======") |
| try: |
| |
| det_path = os.path.join(ANTELOPE_DIR, 'scrfd_10g_bnkps.onnx') |
| if not os.path.exists(det_path): |
| raise FileNotFoundError(f"Missing: {det_path}") |
| detector = get_model(det_path, providers=['CPUExecutionProvider']) |
| detector.prepare(ctx_id=0, input_size=(640, 640)) |
|
|
| |
| rec_path = os.path.join(ANTELOPE_DIR, 'glintr100.onnx') |
| if not os.path.exists(rec_path): |
| raise FileNotFoundError(f"Missing: {rec_path}") |
| recognizer = get_model(rec_path, providers=['CPUExecutionProvider']) |
| recognizer.prepare(ctx_id=0) |
|
|
| |
| vi_path = os.path.join(CLIP_DIR, 'vision_model.onnx') |
| tx_path = os.path.join(CLIP_DIR, 'text_model.onnx') |
| if not os.path.exists(vi_path) or not os.path.exists(tx_path): |
| raise FileNotFoundError("CLIP ONNX 模型或其 .data 配置文件缺失。") |
|
|
| clip_vision_session = ort.InferenceSession( |
| vi_path, sess_options, providers=['CPUExecutionProvider']) |
| clip_text_session = ort.InferenceSession( |
| tx_path, sess_options, providers=['CPUExecutionProvider']) |
|
|
| |
| |
|
|
| |
| |
| clip_preprocess = image_transform(224, is_train=False) |
|
|
| |
| TOKENIZER_DIR = os.path.join(CLIP_DIR, 'tokenizer') |
| if not os.path.exists(TOKENIZER_DIR): |
| raise FileNotFoundError(f"找不到离线分词器目录: {TOKENIZER_DIR}") |
|
|
| clip_tokenizer = AutoTokenizer.from_pretrained( |
| TOKENIZER_DIR, local_files_only=True) |
|
|
| print("====== 🎉 所有模型(人脸 + CLIP)加载成功,API 已就绪! ======") |
| except Exception as e: |
| print("❌ 初始化失败, 错误如下:\n", traceback.format_exc()) |
|
|
|
|
| @app.get("/health") |
| def health_check(): |
| models_ready = all( |
| [detector, recognizer, clip_vision_session, clip_text_session]) |
| return {"status": "healthy" if models_ready else "unhealthy"} |
|
|
| |
| |
| |
|
|
|
|
| @app.post("/api/v1/extract/face") |
| async def extract_face_only(file: UploadFile = File(...)): |
| if not all([detector, recognizer]): |
| raise HTTPException( |
| status_code=500, detail="Face models are not ready.") |
|
|
| try: |
| contents = await file.read() |
| nparr = np.frombuffer(contents, np.uint8) |
| img_cv = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
|
|
| if img_cv is None: |
| raise HTTPException(status_code=400, detail="Invalid image file.") |
|
|
| bboxes, kpss = detector.detect(img_cv, max_num=0, metric='default') |
|
|
| face_results = [] |
| if bboxes is not None and len(bboxes) > 0: |
| for i in range(len(bboxes)): |
| bbox = bboxes[i] |
| kp = kpss[i] if kpss is not None else None |
|
|
| face = Face(bbox=bbox[:4], kps=kp, det_score=bbox[4]) |
| recognizer.get(img_cv, face) |
|
|
| face_results.append({ |
| "bbox": face.bbox.tolist(), |
| "det_score": float(face.det_score), |
| "kps": face.kps.tolist() if face.kps is not None else None, |
| "embedding": face.embedding.tolist() if face.embedding is not None else None |
| }) |
|
|
| return { |
| "face_count": len(face_results), |
| "faces": face_results |
| } |
| except Exception as e: |
| raise HTTPException( |
| status_code=500, detail=f"Face extraction error: {str(e)}") |
|
|
| |
| |
| |
|
|
|
|
| @app.post("/api/v1/extract/clip_image") |
| async def extract_clip_image(file: UploadFile = File(...)): |
| if clip_vision_session is None: |
| raise HTTPException( |
| status_code=500, detail="CLIP Vision model is not ready.") |
|
|
| try: |
| contents = await file.read() |
| img_pil = Image.open(io.BytesIO(contents)).convert("RGB") |
|
|
| clip_tensor = clip_preprocess(img_pil).unsqueeze(0).numpy() |
| image_embedding = clip_vision_session.run( |
| None, {"image": clip_tensor})[0] |
|
|
| |
| image_embedding = image_embedding / \ |
| np.linalg.norm(image_embedding, axis=-1, keepdims=True) |
|
|
| return { |
| "image_embedding": image_embedding[0].tolist() |
| } |
| except Exception as e: |
| raise HTTPException( |
| status_code=500, detail=f"CLIP Image extraction error: {str(e)}") |
|
|
| |
| |
| |
|
|
|
|
| @app.post("/api/v1/extract/text") |
| async def extract_text_features(text: str = Form(...)): |
| if clip_text_session is None or clip_tokenizer is None: |
| raise HTTPException( |
| status_code=500, detail="Text model is not initialized.") |
|
|
| try: |
| if not text.strip(): |
| raise HTTPException( |
| status_code=400, detail="Text query cannot be empty.") |
|
|
| |
| tokenized = clip_tokenizer( |
| [text], |
| padding='max_length', |
| truncation=True, |
| max_length=77, |
| return_tensors='np' |
| ) |
| text_tokens = tokenized['input_ids'] |
|
|
| |
| text_embedding = clip_text_session.run(None, {"text": text_tokens})[0] |
| |
| text_embedding = text_embedding / \ |
| np.linalg.norm(text_embedding, axis=-1, keepdims=True) |
|
|
| return { |
| "text": text, |
| "text_embedding": text_embedding[0].tolist() |
| } |
| except Exception as e: |
| raise HTTPException( |
| status_code=500, detail=f"Text extraction error: {str(e)}") |
|
|