soja-api / feature_extractor_single.py
Ju-Am's picture
Add endpoint de debug pra visualizar img segmentada.
5ec0dd3
import os
import torch
import numpy as np
from PIL import Image
from transformers import ConvNextImageProcessor, ConvNextForImageClassification
from rembg import remove
#FUNÇÃO AUXILIAR DE SEGMENTAÇÃO (REUTILIZÁVEL)
def segment_image(pil_image: Image.Image) -> Image.Image:
"""
Recebe uma imagem PIL, remove o fundo e coloca fundo preto.
Retorna a imagem PIL tratada.
"""
try:
#1. Remove o fundo (Rembg)
img_no_bg = remove(pil_image)
#2. Composição em fundo preto
#Cria uma imagem totalmente preta do mesmo tamanho
fundo_preto = Image.new("RGB", img_no_bg.size, (0, 0, 0))
#Usa o canal Alpha como máscara
if img_no_bg.mode == 'RGBA':
mask = img_no_bg.split()[3]
fundo_preto.paste(img_no_bg, mask=mask)
return fundo_preto
else:
return img_no_bg.convert("RGB")
except Exception as e:
print(f"AVISO: Falha na segmentação ({e}). Retornando original.")
return pil_image.convert("RGB")
#CLASSE EXTRATORA
class FeatureExtractor:
def __init__(self, device=None):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Usando dispositivo: {self.device}")
self.processor = ConvNextImageProcessor.from_pretrained(
"facebook/convnext-large-224-22k-1k"
)
self.model = ConvNextForImageClassification.from_pretrained(
"facebook/convnext-large-224-22k-1k"
).to(self.device)
self.model.classifier = torch.nn.Identity()
self.model.eval()
def extract_convnext(self, image_path: str) -> np.ndarray:
print(f"Processando imagem: {os.path.basename(image_path)}")
input_img = Image.open(image_path).convert("RGB")
#1. CHAMA A FUNÇÃO DE SEGMENTAÇÃO
final_image = segment_image(input_img)
#2. Passa para o ConvNeXt
inputs = self.processor(final_image, return_tensors="pt").to(self.device)
with torch.no_grad():
features = self.model(**inputs).logits
features_np = features.cpu().numpy().flatten()
return features_np
def process_single_image(image_path: str):
extractor = FeatureExtractor()
return extractor.extract_convnext(image_path)