|
|
import os
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
from PIL import Image
|
|
|
from transformers import ConvNextImageProcessor, ConvNextForImageClassification
|
|
|
from rembg import remove
|
|
|
|
|
|
|
|
|
def segment_image(pil_image: Image.Image) -> Image.Image:
|
|
|
"""
|
|
|
Recebe uma imagem PIL, remove o fundo e coloca fundo preto.
|
|
|
Retorna a imagem PIL tratada.
|
|
|
"""
|
|
|
try:
|
|
|
|
|
|
img_no_bg = remove(pil_image)
|
|
|
|
|
|
|
|
|
|
|
|
fundo_preto = Image.new("RGB", img_no_bg.size, (0, 0, 0))
|
|
|
|
|
|
|
|
|
if img_no_bg.mode == 'RGBA':
|
|
|
mask = img_no_bg.split()[3]
|
|
|
fundo_preto.paste(img_no_bg, mask=mask)
|
|
|
return fundo_preto
|
|
|
else:
|
|
|
return img_no_bg.convert("RGB")
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"AVISO: Falha na segmentação ({e}). Retornando original.")
|
|
|
return pil_image.convert("RGB")
|
|
|
|
|
|
|
|
|
|
|
|
class FeatureExtractor:
|
|
|
def __init__(self, device=None):
|
|
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
print(f"Usando dispositivo: {self.device}")
|
|
|
|
|
|
self.processor = ConvNextImageProcessor.from_pretrained(
|
|
|
"facebook/convnext-large-224-22k-1k"
|
|
|
)
|
|
|
self.model = ConvNextForImageClassification.from_pretrained(
|
|
|
"facebook/convnext-large-224-22k-1k"
|
|
|
).to(self.device)
|
|
|
|
|
|
self.model.classifier = torch.nn.Identity()
|
|
|
self.model.eval()
|
|
|
|
|
|
def extract_convnext(self, image_path: str) -> np.ndarray:
|
|
|
print(f"Processando imagem: {os.path.basename(image_path)}")
|
|
|
input_img = Image.open(image_path).convert("RGB")
|
|
|
|
|
|
|
|
|
final_image = segment_image(input_img)
|
|
|
|
|
|
|
|
|
inputs = self.processor(final_image, return_tensors="pt").to(self.device)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
features = self.model(**inputs).logits
|
|
|
|
|
|
features_np = features.cpu().numpy().flatten()
|
|
|
return features_np
|
|
|
|
|
|
def process_single_image(image_path: str):
|
|
|
extractor = FeatureExtractor()
|
|
|
return extractor.extract_convnext(image_path) |