vinithius commited on
Commit
5b2a441
·
verified ·
1 Parent(s): 7e71744

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -0
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import io
4
+ import asyncio
5
+ import numpy as np
6
+ from fastapi import FastAPI, File, UploadFile, HTTPException
7
+ from fastapi.responses import JSONResponse
8
+ from PIL import Image
9
+ import torch
10
+ from transformers import AutoImageProcessor, AutoModel
11
+ from typing import Optional
12
+
13
+ app = FastAPI(title="DINOv2 Image Embedding API")
14
+
15
+ # Configurações — altere MODEL_REPO se quiser outra variante
16
+ MODEL_REPO = os.environ.get("MODEL_REPO", "facebook/dinov2-small")
17
+ HF_TOKEN = os.environ.get("HF_TOKEN", None) # se repo privado
18
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ # Limite de inferências concorrentes (ajuste via var de ambiente se quiser)
21
+ MAX_CONCURRENT = int(os.environ.get("MAX_CONCURRENT", "4"))
22
+ inference_semaphore = asyncio.Semaphore(MAX_CONCURRENT)
23
+
24
+ # Globals
25
+ model = None
26
+ processor = None
27
+
28
+ def load_model_and_processor():
29
+ """
30
+ Carrega AutoImageProcessor e AutoModel do Hugging Face Hub.
31
+ Usa AutoModel (gera embeddings via pooler_output ou CLS token).
32
+ """
33
+ global model, processor
34
+ if model is not None and processor is not None:
35
+ return
36
+
37
+ # Opções de auth
38
+ use_auth = True if HF_TOKEN else False
39
+ auth = HF_TOKEN if HF_TOKEN else None
40
+
41
+ print(f"Loading processor and model from: {MODEL_REPO} (device={DEVICE})")
42
+ # Carrega processor (pré-processamento oficial do repositório)
43
+ processor = AutoImageProcessor.from_pretrained(MODEL_REPO, use_auth_token=auth) if use_auth else AutoImageProcessor.from_pretrained(MODEL_REPO)
44
+ # Carrega o modelo base (sem cabeça de classificação explicita)
45
+ model = AutoModel.from_pretrained(MODEL_REPO, use_auth_token=auth) if use_auth else AutoModel.from_pretrained(MODEL_REPO)
46
+ model.to(DEVICE)
47
+ model.eval()
48
+ # imprimir dimensão de saída (útil para debug)
49
+ try:
50
+ hidden_size = model.config.hidden_size
51
+ print(f"Model loaded. hidden_size = {hidden_size}")
52
+ except Exception:
53
+ print("Model loaded. (no hidden_size in config)")
54
+
55
+ def extract_embedding_from_outputs(outputs):
56
+ """
57
+ Tenta extrair um embedding a partir da saída do AutoModel:
58
+ - usa pooler_output se disponível
59
+ - senão usa last_hidden_state[:, 0, :] (token CLS)
60
+ """
61
+ if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
62
+ emb = outputs.pooler_output
63
+ elif hasattr(outputs, "last_hidden_state"):
64
+ emb = outputs.last_hidden_state[:, 0, :] # CLS token
65
+ else:
66
+ # fallback: pegar o primeiro tensor qualquer
67
+ if isinstance(outputs, (tuple, list)):
68
+ out = outputs[0]
69
+ emb = out[:, 0, :]
70
+ else:
71
+ raise RuntimeError("Não foi possível extrair embedding das saídas do modelo.")
72
+ return emb
73
+
74
+ def preprocess_with_processor(pil_image: Image.Image):
75
+ """
76
+ Usa o AutoImageProcessor para pré-processar a PIL image em tensores PyTorch.
77
+ Retorna dict com tensores enviados ao device.
78
+ """
79
+ # processor aceita uma lista de imagens
80
+ inputs = processor(images=pil_image, return_tensors="pt")
81
+ # mover tensores para device
82
+ for k, v in inputs.items():
83
+ inputs[k] = v.to(DEVICE)
84
+ return inputs
85
+
86
+ async def run_inference(pil_image: Image.Image) -> np.ndarray:
87
+ """
88
+ Executa inferência em thread pool (para não bloquear o loop do FastAPI).
89
+ Retorna um vetor numpy 1D (embedding L2-normalizado).
90
+ """
91
+ loop = asyncio.get_running_loop()
92
+ return await loop.run_in_executor(None, _sync_inference, pil_image)
93
+
94
+ def _sync_inference(pil_image: Image.Image) -> np.ndarray:
95
+ """
96
+ Função síncrona que faz preprocess, forward e extrai embedding.
97
+ """
98
+ global model, processor
99
+ if model is None or processor is None:
100
+ load_model_and_processor()
101
+
102
+ inputs = preprocess_with_processor(pil_image)
103
+ with torch.no_grad():
104
+ outputs = model(**inputs)
105
+ emb_tensor = extract_embedding_from_outputs(outputs) # shape (1, dim)
106
+ emb = emb_tensor.cpu().numpy().reshape(-1)
107
+
108
+ # Normalizar L2
109
+ norm = np.linalg.norm(emb)
110
+ if norm > 0:
111
+ emb = emb / norm
112
+ return emb.astype(float)
113
+
114
+ @app.post("/embed")
115
+ async def embed_image(file: UploadFile = File(...)):
116
+ # Proteção básica: tipo e tamanho máximo (ex: 6 MB)
117
+ if not file.content_type.startswith("image/"):
118
+ raise HTTPException(status_code=400, detail="Envie um arquivo de imagem.")
119
+ content = await file.read()
120
+ if len(content) > (6 * 1024 * 1024):
121
+ raise HTTPException(status_code=413, detail="Arquivo muito grande (max 6MB).")
122
+
123
+ try:
124
+ pil_img = Image.open(io.BytesIO(content)).convert("RGB")
125
+ except Exception:
126
+ raise HTTPException(status_code=400, detail="Imagem inválida.")
127
+
128
+ # Controle de concorrência
129
+ async with inference_semaphore:
130
+ try:
131
+ emb = await run_inference(pil_img)
132
+ except Exception as e:
133
+ raise HTTPException(status_code=500, detail=f"Erro durante inferência: {e}")
134
+
135
+ # Retorna embedding (lista de floats)
136
+ return JSONResponse({"embedding": emb.tolist(), "dim": len(emb)})
137
+
138
+ @app.get("/healthz")
139
+ async def health():
140
+ loaded = model is not None and processor is not None
141
+ return {"status": "ok", "model_loaded": loaded, "model_repo": MODEL_REPO}
142
+
143
+ # Ao iniciar em runtime, podemos preparar o model (opcional)
144
+ # load_model_and_processor()