File size: 2,427 Bytes
2b5181c 5c69629 2b5181c 5ec0dd3 2b5181c 5c69629 5ec0dd3 5c69629 5ec0dd3 5c69629 2b5181c 5c69629 2b5181c 5ec0dd3 2b5181c 5ec0dd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import os
import torch
import numpy as np
from PIL import Image
from transformers import ConvNextImageProcessor, ConvNextForImageClassification
from rembg import remove
#FUNÇÃO AUXILIAR DE SEGMENTAÇÃO (REUTILIZÁVEL)
def segment_image(pil_image: Image.Image) -> Image.Image:
"""
Recebe uma imagem PIL, remove o fundo e coloca fundo preto.
Retorna a imagem PIL tratada.
"""
try:
#1. Remove o fundo (Rembg)
img_no_bg = remove(pil_image)
#2. Composição em fundo preto
#Cria uma imagem totalmente preta do mesmo tamanho
fundo_preto = Image.new("RGB", img_no_bg.size, (0, 0, 0))
#Usa o canal Alpha como máscara
if img_no_bg.mode == 'RGBA':
mask = img_no_bg.split()[3]
fundo_preto.paste(img_no_bg, mask=mask)
return fundo_preto
else:
return img_no_bg.convert("RGB")
except Exception as e:
print(f"AVISO: Falha na segmentação ({e}). Retornando original.")
return pil_image.convert("RGB")
#CLASSE EXTRATORA
class FeatureExtractor:
def __init__(self, device=None):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Usando dispositivo: {self.device}")
self.processor = ConvNextImageProcessor.from_pretrained(
"facebook/convnext-large-224-22k-1k"
)
self.model = ConvNextForImageClassification.from_pretrained(
"facebook/convnext-large-224-22k-1k"
).to(self.device)
self.model.classifier = torch.nn.Identity()
self.model.eval()
def extract_convnext(self, image_path: str) -> np.ndarray:
print(f"Processando imagem: {os.path.basename(image_path)}")
input_img = Image.open(image_path).convert("RGB")
#1. CHAMA A FUNÇÃO DE SEGMENTAÇÃO
final_image = segment_image(input_img)
#2. Passa para o ConvNeXt
inputs = self.processor(final_image, return_tensors="pt").to(self.device)
with torch.no_grad():
features = self.model(**inputs).logits
features_np = features.cpu().numpy().flatten()
return features_np
def process_single_image(image_path: str):
extractor = FeatureExtractor()
return extractor.extract_convnext(image_path) |