server / app.py
ImgSpace's picture
Update app.py
e6b0b9c verified
Raw
History Blame Contribute Delete
9.88 kB
import os
import subprocess
import traceback
import cv2
import io
import open_clip
from fastapi import FastAPI, UploadFile, File, HTTPException, Form, Request
from fastapi.responses import JSONResponse
import numpy as np
import onnxruntime as ort
from PIL import Image
from open_clip import image_transform
from transformers import AutoTokenizer
from open_clip import image_transform
from insightface.model_zoo import get_model
from insightface.app.common import Face
from starlette.middleware.base import BaseHTTPMiddleware
app = FastAPI(title="Photo AI Feature Extraction API", version="1.0.0")
# 全局模型实例
detector = None
recognizer = None
clip_vision_session = None
clip_text_session = None
# CLIP 专属的预处理和分词器
clip_preprocess = None
clip_tokenizer = None
# 判断模型目录是否存在
# if not os.path.exists('/code/antelopev2'):
# # 调用GIT lfs 下载模型到 /code/models
# # https://huggingface.co/ImgSpace/iLookModels
# # 获取环境变量判断是通过hf下载还是modelsp下载
# try:
# repo_url = "https://huggingface.co/ImgSpace/iLookModels"
# if os.environ.get("MODEL_SOURCE") == "modelscope":
# repo_url = "https://modelscope.cn/ImgSpace/iLookModels"
# subprocess.run(["git", "lfs", "install"], cwd="/code")
# subprocess.run(["git", "clone", repo_url, "/code/models"])
# subprocess.run(["mv", "/code/models/antelopev2", "/code/"])
# subprocess.run(["mv", "/code/models/clip_model", "/code/"])
# subprocess.run(["rm", "-rf", "/code/models"])
# except Exception as e:
# print("❌ 下载模型失败, 错误如下:\n", traceback.format_exc())
API_TOKEN = os.getenv("TOKEN", "my-secret-token")
# 不需要验证的路径
EXCLUDED_PATHS = ["/health"]
class TokenAuthMiddleware(BaseHTTPMiddleware):
"""简单的Token验证中间件"""
async def dispatch(self, request: Request, call_next):
# 跳过不需要验证的路径
if request.url.path in EXCLUDED_PATHS:
return await call_next(request)
# 从请求头获取 Authorization
auth_header = request.headers.get("Authorization")
# 兼容 Bearer 前缀
if auth_header and auth_header.startswith("Bearer "):
auth_header = auth_header[7:] # 移除 "Bearer " 前缀
# 验证token是否匹配
if auth_header != API_TOKEN:
return JSONResponse(
status_code=401,
content={"detail": "未授权:无效的认证令牌"}
)
# token验证通过,继续处理请求
response = await call_next(request)
return response
# 添加中间件
app.add_middleware(TokenAuthMiddleware)
@app.on_event("startup")
async def startup_event():
global detector, recognizer, clip_vision_session, clip_text_session, clip_preprocess, clip_tokenizer
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ANTELOPE_DIR = os.path.join(BASE_DIR, 'antelopev2')
CLIP_DIR = os.path.join(BASE_DIR, 'clip_model')
# 配置 ONNX Runtime CPU 推理线程,防止单个请求占满所有核心
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.intra_op_num_threads = 4 # 根据实际 CPU 核心数调整
print("====== 开始加载离线模型 ======")
try:
# 1. 加载 InsightFace 人脸检测
det_path = os.path.join(ANTELOPE_DIR, 'scrfd_10g_bnkps.onnx')
if not os.path.exists(det_path):
raise FileNotFoundError(f"Missing: {det_path}")
detector = get_model(det_path, providers=['CPUExecutionProvider'])
detector.prepare(ctx_id=0, input_size=(640, 640))
# 2. 加载 InsightFace 人脸识别
rec_path = os.path.join(ANTELOPE_DIR, 'glintr100.onnx')
if not os.path.exists(rec_path):
raise FileNotFoundError(f"Missing: {rec_path}")
recognizer = get_model(rec_path, providers=['CPUExecutionProvider'])
recognizer.prepare(ctx_id=0)
# 3. 加载 CLIP 视觉与文本 ONNX 引擎
vi_path = os.path.join(CLIP_DIR, 'vision_model.onnx')
tx_path = os.path.join(CLIP_DIR, 'text_model.onnx')
if not os.path.exists(vi_path) or not os.path.exists(tx_path):
raise FileNotFoundError("CLIP ONNX 模型或其 .data 配置文件缺失。")
clip_vision_session = ort.InferenceSession(
vi_path, sess_options, providers=['CPUExecutionProvider'])
clip_text_session = ort.InferenceSession(
tx_path, sess_options, providers=['CPUExecutionProvider'])
# 4. 初始化 CLIP 预处理流与分词器 (保持与训练时一致)
# 这里借用 open_clip 结构获取标准配置,不下载权重
# 4. 初始化预处理与纯离线分词器
# 视觉预处理:224 是 ViT-B 要求的标准输入尺寸,纯本地计算
clip_preprocess = image_transform(224, is_train=False)
# 文本分词器:直接指向刚才创建的本地离线文件夹,绝对不会产生网络请求
TOKENIZER_DIR = os.path.join(CLIP_DIR, 'tokenizer')
if not os.path.exists(TOKENIZER_DIR):
raise FileNotFoundError(f"找不到离线分词器目录: {TOKENIZER_DIR}")
clip_tokenizer = AutoTokenizer.from_pretrained(
TOKENIZER_DIR, local_files_only=True)
print("====== 🎉 所有模型(人脸 + CLIP)加载成功,API 已就绪! ======")
except Exception as e:
print("❌ 初始化失败, 错误如下:\n", traceback.format_exc())
@app.get("/health")
def health_check():
models_ready = all(
[detector, recognizer, clip_vision_session, clip_text_session])
return {"status": "healthy" if models_ready else "unhealthy"}
# ==========================================
# 接口 1:仅做人脸检测与识别聚类
# ==========================================
@app.post("/api/v1/extract/face")
async def extract_face_only(file: UploadFile = File(...)):
if not all([detector, recognizer]):
raise HTTPException(
status_code=500, detail="Face models are not ready.")
try:
contents = await file.read()
nparr = np.frombuffer(contents, np.uint8)
img_cv = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img_cv is None:
raise HTTPException(status_code=400, detail="Invalid image file.")
bboxes, kpss = detector.detect(img_cv, max_num=0, metric='default')
face_results = []
if bboxes is not None and len(bboxes) > 0:
for i in range(len(bboxes)):
bbox = bboxes[i]
kp = kpss[i] if kpss is not None else None
face = Face(bbox=bbox[:4], kps=kp, det_score=bbox[4])
recognizer.get(img_cv, face)
face_results.append({
"bbox": face.bbox.tolist(),
"det_score": float(face.det_score),
"kps": face.kps.tolist() if face.kps is not None else None,
"embedding": face.embedding.tolist() if face.embedding is not None else None
})
return {
"face_count": len(face_results),
"faces": face_results
}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Face extraction error: {str(e)}")
# ==========================================
# 接口 2:仅做 CLIP 整图语义特征提取
# ==========================================
@app.post("/api/v1/extract/clip_image")
async def extract_clip_image(file: UploadFile = File(...)):
if clip_vision_session is None:
raise HTTPException(
status_code=500, detail="CLIP Vision model is not ready.")
try:
contents = await file.read()
img_pil = Image.open(io.BytesIO(contents)).convert("RGB")
clip_tensor = clip_preprocess(img_pil).unsqueeze(0).numpy()
image_embedding = clip_vision_session.run(
None, {"image": clip_tensor})[0]
# L2 归一化
image_embedding = image_embedding / \
np.linalg.norm(image_embedding, axis=-1, keepdims=True)
return {
"image_embedding": image_embedding[0].tolist()
}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"CLIP Image extraction error: {str(e)}")
# ==========================================
# 接口 3:文本搜图特征提取 (保持不变)
# ==========================================
@app.post("/api/v1/extract/text")
async def extract_text_features(text: str = Form(...)):
if clip_text_session is None or clip_tokenizer is None:
raise HTTPException(
status_code=500, detail="Text model is not initialized.")
try:
if not text.strip():
raise HTTPException(
status_code=400, detail="Text query cannot be empty.")
# 使用 transformers 的本地分词器,强制对齐 77 长度,并输出 numpy 格式
tokenized = clip_tokenizer(
[text],
padding='max_length',
truncation=True,
max_length=77,
return_tensors='np'
)
text_tokens = tokenized['input_ids']
# 文本特征推理
text_embedding = clip_text_session.run(None, {"text": text_tokens})[0]
# L2 归一化
text_embedding = text_embedding / \
np.linalg.norm(text_embedding, axis=-1, keepdims=True)
return {
"text": text,
"text_embedding": text_embedding[0].tolist()
}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Text extraction error: {str(e)}")