Ju-Am commited on
Commit
5ec0dd3
·
1 Parent(s): 5c69629

Add endpoint de debug pra visualizar img segmentada.

Browse files
Files changed (2) hide show
  1. feature_extractor_single.py +33 -44
  2. main.py +31 -2
feature_extractor_single.py CHANGED
@@ -2,17 +2,42 @@ import os
2
  import torch
3
  import numpy as np
4
  from PIL import Image
5
- from torchvision import transforms
6
  from transformers import ConvNextImageProcessor, ConvNextForImageClassification
7
  from rembg import remove
8
 
9
- #Classe para extração de features (ConvNeXt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  class FeatureExtractor:
11
  def __init__(self, device=None):
12
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
13
  print(f"Usando dispositivo: {self.device}")
14
 
15
- # Modelo e pré-processador
16
  self.processor = ConvNextImageProcessor.from_pretrained(
17
  "facebook/convnext-large-224-22k-1k"
18
  )
@@ -20,61 +45,25 @@ class FeatureExtractor:
20
  "facebook/convnext-large-224-22k-1k"
21
  ).to(self.device)
22
 
23
- #Remove camada de classificação (ficam só as features)
24
  self.model.classifier = torch.nn.Identity()
25
  self.model.eval()
26
 
27
  def extract_convnext(self, image_path: str) -> np.ndarray:
28
- #1. Abre a imagem original
29
  print(f"Processando imagem: {os.path.basename(image_path)}")
30
  input_img = Image.open(image_path).convert("RGB")
31
 
32
- #2. SEGMENTAÇÃO (remoção do fundo)
33
- #O rembg devolve uma imagem RGBA (com transparência)
34
- try:
35
- img_no_bg = remove(input_img)
36
-
37
- #3. COMPOSIÇÃO EM FUNDO PRETO
38
- #Cria uma imagem totalmente preta do mesmo tamanho
39
- fundo_preto = Image.new("RGB", img_no_bg.size, (0, 0, 0))
40
-
41
- #Usa o canal Alpha (transparência) da imagem recortada como máscara
42
- #Onde for folha, cola a folha. Onde for transparente, mantém o preto.
43
- mask = img_no_bg.split()[3] #Pega o 4º canal (Alpha)
44
- fundo_preto.paste(img_no_bg, mask=mask)
45
-
46
- final_image = fundo_preto
47
- print("Fundo removido e substituído por preto com sucesso.")
48
-
49
- except Exception as e:
50
- print(f"AVISO: Falha na segmentação ({e}). Usando imagem original.")
51
- final_image = input_img
52
 
53
- #4. Passa para o ConvNeXt (que já faz o resize e normalize internamente)
54
  inputs = self.processor(final_image, return_tensors="pt").to(self.device)
55
 
56
  with torch.no_grad():
57
  features = self.model(**inputs).logits
58
 
59
  features_np = features.cpu().numpy().flatten()
60
- print(f"Vetor de características extraído com shape: {features_np.shape}")
61
-
62
  return features_np
63
 
64
- #Função principal chamada pelo main.py
65
- def process_single_image(image_path: str, output_dir: str = "processed"):
66
- """
67
- Pipeline: Segmentação (Rembg) -> Fundo preto -> ConvNeXt
68
- """
69
  extractor = FeatureExtractor()
70
- features = extractor.extract_convnext(image_path)
71
- return features
72
-
73
- #Execução direta para testes locais
74
- if __name__ == "__main__":
75
- #Teste com uma imagem local
76
- image_path = "teste_folha.jpg" #Mudar para um arquivo real se for testar
77
- if os.path.exists(image_path):
78
- process_single_image(image_path)
79
- else:
80
- print("Imagem de teste não encontrada.")
 
2
  import torch
3
  import numpy as np
4
  from PIL import Image
 
5
  from transformers import ConvNextImageProcessor, ConvNextForImageClassification
6
  from rembg import remove
7
 
8
+ #FUNÇÃO AUXILIAR DE SEGMENTAÇÃO (REUTILIZÁVEL)
9
+ def segment_image(pil_image: Image.Image) -> Image.Image:
10
+ """
11
+ Recebe uma imagem PIL, remove o fundo e coloca fundo preto.
12
+ Retorna a imagem PIL tratada.
13
+ """
14
+ try:
15
+ #1. Remove o fundo (Rembg)
16
+ img_no_bg = remove(pil_image)
17
+
18
+ #2. Composição em fundo preto
19
+ #Cria uma imagem totalmente preta do mesmo tamanho
20
+ fundo_preto = Image.new("RGB", img_no_bg.size, (0, 0, 0))
21
+
22
+ #Usa o canal Alpha como máscara
23
+ if img_no_bg.mode == 'RGBA':
24
+ mask = img_no_bg.split()[3]
25
+ fundo_preto.paste(img_no_bg, mask=mask)
26
+ return fundo_preto
27
+ else:
28
+ return img_no_bg.convert("RGB")
29
+
30
+ except Exception as e:
31
+ print(f"AVISO: Falha na segmentação ({e}). Retornando original.")
32
+ return pil_image.convert("RGB")
33
+
34
+
35
+ #CLASSE EXTRATORA
36
  class FeatureExtractor:
37
  def __init__(self, device=None):
38
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
39
  print(f"Usando dispositivo: {self.device}")
40
 
 
41
  self.processor = ConvNextImageProcessor.from_pretrained(
42
  "facebook/convnext-large-224-22k-1k"
43
  )
 
45
  "facebook/convnext-large-224-22k-1k"
46
  ).to(self.device)
47
 
 
48
  self.model.classifier = torch.nn.Identity()
49
  self.model.eval()
50
 
51
  def extract_convnext(self, image_path: str) -> np.ndarray:
 
52
  print(f"Processando imagem: {os.path.basename(image_path)}")
53
  input_img = Image.open(image_path).convert("RGB")
54
 
55
+ #1. CHAMA A FUNÇÃO DE SEGMENTAÇÃO
56
+ final_image = segment_image(input_img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ #2. Passa para o ConvNeXt
59
  inputs = self.processor(final_image, return_tensors="pt").to(self.device)
60
 
61
  with torch.no_grad():
62
  features = self.model(**inputs).logits
63
 
64
  features_np = features.cpu().numpy().flatten()
 
 
65
  return features_np
66
 
67
+ def process_single_image(image_path: str):
 
 
 
 
68
  extractor = FeatureExtractor()
69
+ return extractor.extract_convnext(image_path)
 
 
 
 
 
 
 
 
 
 
main.py CHANGED
@@ -5,9 +5,12 @@ import json
5
  import numpy as np
6
  from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, status
7
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
8
- from feature_extractor_single import process_single_image
9
  from datetime import datetime
10
  import unicodedata
 
 
 
 
11
 
12
  def normalize_string(s: str) -> str:
13
  """
@@ -164,4 +167,30 @@ async def extract_features(file: UploadFile = File(...), token: str = Depends(ve
164
  os.remove(temp_path)
165
 
166
  features_list = features_array.tolist()
167
- return {"features": features_list}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import numpy as np
6
  from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, status
7
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
 
8
  from datetime import datetime
9
  import unicodedata
10
+ import io
11
+ from PIL import Image
12
+ from starlette.responses import StreamingResponse
13
+ from feature_extractor_single import process_single_image, segment_image
14
 
15
  def normalize_string(s: str) -> str:
16
  """
 
167
  os.remove(temp_path)
168
 
169
  features_list = features_array.tolist()
170
+ return {"features": features_list}
171
+
172
+ @app.post("/debug/view_segmentation/")
173
+ async def view_segmentation(file: UploadFile = File(...), token: str = Depends(verify_token)):
174
+ """
175
+ Endpoint de debug.
176
+ Retorna a imagem processada (fundo preto) para verificação visual.
177
+ Útil para saber o que o modelo está "enxergando".
178
+ """
179
+ try:
180
+ #1. Lê a imagem da memória (sem salvar no disco pra ser rápido)
181
+ contents = await file.read()
182
+ pil_image = Image.open(io.BytesIO(contents)).convert("RGB")
183
+
184
+ #2. Aplica a mesma lógica de segmentação do modelo
185
+ processed_image = segment_image(pil_image)
186
+
187
+ #3. Salva a imagem processada em um buffer de memória (bytes)
188
+ img_byte_arr = io.BytesIO()
189
+ processed_image.save(img_byte_arr, format='JPEG', quality=95)
190
+ img_byte_arr.seek(0)
191
+
192
+ #4. Retorna como uma stream de imagem (O navegador/Swagger exibe isso!)
193
+ return StreamingResponse(img_byte_arr, media_type="image/jpeg")
194
+
195
+ except Exception as e:
196
+ raise HTTPException(status_code=500, detail=f"Erro ao processar imagem: {e}")