vinithius commited on
Commit
64f585d
·
verified ·
1 Parent(s): 94cf5c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -136
app.py CHANGED
@@ -1,144 +1,44 @@
1
- # app.py
2
- import os
3
- import io
4
- import asyncio
5
- import numpy as np
6
- from fastapi import FastAPI, File, UploadFile, HTTPException
7
- from fastapi.responses import JSONResponse
8
- from PIL import Image
9
  import torch
 
10
  from transformers import AutoImageProcessor, AutoModel
11
- from typing import Optional
12
-
13
- app = FastAPI(title="DINOv2 Image Embedding API")
14
-
15
- # Configurações — altere MODEL_REPO se quiser outra variante
16
- MODEL_REPO = os.environ.get("MODEL_REPO", "facebook/dinov2-small")
17
- HF_TOKEN = os.environ.get("HF_TOKEN", None) # se repo privado
18
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
-
20
- # Limite de inferências concorrentes (ajuste via var de ambiente se quiser)
21
- MAX_CONCURRENT = int(os.environ.get("MAX_CONCURRENT", "4"))
22
- inference_semaphore = asyncio.Semaphore(MAX_CONCURRENT)
23
-
24
- # Globals
25
- model = None
26
- processor = None
27
-
28
- def load_model_and_processor():
29
- """
30
- Carrega AutoImageProcessor e AutoModel do Hugging Face Hub.
31
- Usa AutoModel (gera embeddings via pooler_output ou CLS token).
32
- """
33
- global model, processor
34
- if model is not None and processor is not None:
35
- return
36
-
37
- # Opções de auth
38
- use_auth = True if HF_TOKEN else False
39
- auth = HF_TOKEN if HF_TOKEN else None
40
-
41
- print(f"Loading processor and model from: {MODEL_REPO} (device={DEVICE})")
42
- # Carrega processor (pré-processamento oficial do repositório)
43
- processor = AutoImageProcessor.from_pretrained(MODEL_REPO, use_auth_token=auth) if use_auth else AutoImageProcessor.from_pretrained(MODEL_REPO)
44
- # Carrega o modelo base (sem cabeça de classificação explicita)
45
- model = AutoModel.from_pretrained(MODEL_REPO, use_auth_token=auth) if use_auth else AutoModel.from_pretrained(MODEL_REPO)
46
- model.to(DEVICE)
47
- model.eval()
48
- # imprimir dimensão de saída (útil para debug)
49
- try:
50
- hidden_size = model.config.hidden_size
51
- print(f"Model loaded. hidden_size = {hidden_size}")
52
- except Exception:
53
- print("Model loaded. (no hidden_size in config)")
54
-
55
- def extract_embedding_from_outputs(outputs):
56
- """
57
- Tenta extrair um embedding a partir da saída do AutoModel:
58
- - usa pooler_output se disponível
59
- - senão usa last_hidden_state[:, 0, :] (token CLS)
60
- """
61
- if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
62
- emb = outputs.pooler_output
63
- elif hasattr(outputs, "last_hidden_state"):
64
- emb = outputs.last_hidden_state[:, 0, :] # CLS token
65
- else:
66
- # fallback: pegar o primeiro tensor qualquer
67
- if isinstance(outputs, (tuple, list)):
68
- out = outputs[0]
69
- emb = out[:, 0, :]
70
- else:
71
- raise RuntimeError("Não foi possível extrair embedding das saídas do modelo.")
72
- return emb
73
 
74
- def preprocess_with_processor(pil_image: Image.Image):
75
- """
76
- Usa o AutoImageProcessor para pré-processar a PIL image em tensores PyTorch.
77
- Retorna dict com tensores enviados ao device.
78
- """
79
- # processor aceita uma lista de imagens
80
- inputs = processor(images=pil_image, return_tensors="pt")
81
- # mover tensores para device
82
- for k, v in inputs.items():
83
- inputs[k] = v.to(DEVICE)
84
- return inputs
85
 
86
- async def run_inference(pil_image: Image.Image) -> np.ndarray:
87
- """
88
- Executa inferência em thread pool (para não bloquear o loop do FastAPI).
89
- Retorna um vetor numpy 1D (embedding L2-normalizado).
90
- """
91
- loop = asyncio.get_running_loop()
92
- return await loop.run_in_executor(None, _sync_inference, pil_image)
93
 
94
- def _sync_inference(pil_image: Image.Image) -> np.ndarray:
95
- """
96
- Função síncrona que faz preprocess, forward e extrai embedding.
97
- """
98
- global model, processor
99
- if model is None or processor is None:
100
- load_model_and_processor()
101
 
102
- inputs = preprocess_with_processor(pil_image)
 
 
 
103
  with torch.no_grad():
104
  outputs = model(**inputs)
105
- emb_tensor = extract_embedding_from_outputs(outputs) # shape (1, dim)
106
- emb = emb_tensor.cpu().numpy().reshape(-1)
107
-
108
- # Normalizar L2
109
- norm = np.linalg.norm(emb)
110
- if norm > 0:
111
- emb = emb / norm
112
- return emb.astype(float)
113
-
114
- @app.post("/embed")
115
- async def embed_image(file: UploadFile = File(...)):
116
- # Proteção básica: tipo e tamanho máximo (ex: 6 MB)
117
- if not file.content_type.startswith("image/"):
118
- raise HTTPException(status_code=400, detail="Envie um arquivo de imagem.")
119
- content = await file.read()
120
- if len(content) > (6 * 1024 * 1024):
121
- raise HTTPException(status_code=413, detail="Arquivo muito grande (max 6MB).")
122
-
123
- try:
124
- pil_img = Image.open(io.BytesIO(content)).convert("RGB")
125
- except Exception:
126
- raise HTTPException(status_code=400, detail="Imagem inválida.")
127
-
128
- # Controle de concorrência
129
- async with inference_semaphore:
130
- try:
131
- emb = await run_inference(pil_img)
132
- except Exception as e:
133
- raise HTTPException(status_code=500, detail=f"Erro durante inferência: {e}")
134
-
135
- # Retorna embedding (lista de floats)
136
- return JSONResponse({"embedding": emb.tolist(), "dim": len(emb)})
137
-
138
- @app.get("/healthz")
139
- async def health():
140
- loaded = model is not None and processor is not None
141
- return {"status": "ok", "model_loaded": loaded, "model_repo": MODEL_REPO}
142
-
143
- # Ao iniciar em runtime, podemos preparar o model (opcional)
144
- # load_model_and_processor()
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from torch import nn
3
  from transformers import AutoImageProcessor, AutoModel
4
+ import gradio as gr
5
+ import numpy as np
6
+ from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # Nome do modelo no Hugging Face Hub
9
+ MODEL_NAME = "facebook/dinov2-small"
 
 
 
 
 
 
 
 
 
10
 
11
+ # Carregando processador e modelo
12
+ processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
13
+ model = AutoModel.from_pretrained(MODEL_NAME)
 
 
 
 
14
 
15
+ # Projeção para 512D (caso a saída seja >512, reduzimos)
16
+ projection = nn.Linear(model.config.hidden_size, 512)
 
 
 
 
 
17
 
18
+ def get_embedding(image: Image.Image):
19
+ # Preprocessamento
20
+ inputs = processor(images=image, return_tensors="pt")
21
+
22
  with torch.no_grad():
23
  outputs = model(**inputs)
24
+ # Usando o CLS token como embedding da imagem
25
+ last_hidden_state = outputs.last_hidden_state # (batch, seq_len, hidden)
26
+ embedding = last_hidden_state[:, 0] # pegando o [CLS] token
27
+
28
+ # Projeta para 512D
29
+ embedding_512 = projection(embedding)
30
+
31
+ # Converte para lista Python
32
+ return embedding_512.squeeze().tolist()
33
+
34
+ # Cria API com Gradio (sem interface visual, apenas endpoint)
35
+ iface = gr.Interface(
36
+ fn=get_embedding,
37
+ inputs=gr.Image(type="pil"),
38
+ outputs=gr.JSON(),
39
+ live=False,
40
+ api_name="embed" # endpoint em /embed
41
+ )
42
+
43
+ if __name__ == "__main__":
44
+ iface.launch()