| import os
|
| import torch
|
| import numpy as np
|
| from PIL import Image
|
| from transformers import ConvNextImageProcessor, ConvNextForImageClassification
|
| from rembg import remove
|
|
|
|
|
| def segment_image(pil_image: Image.Image) -> Image.Image:
|
| """
|
| Recebe uma imagem PIL, remove o fundo e coloca fundo preto.
|
| Retorna a imagem PIL tratada.
|
| """
|
| try:
|
|
|
| img_no_bg = remove(pil_image)
|
|
|
|
|
|
|
| fundo_preto = Image.new("RGB", img_no_bg.size, (0, 0, 0))
|
|
|
|
|
| if img_no_bg.mode == 'RGBA':
|
| mask = img_no_bg.split()[3]
|
| fundo_preto.paste(img_no_bg, mask=mask)
|
| return fundo_preto
|
| else:
|
| return img_no_bg.convert("RGB")
|
|
|
| except Exception as e:
|
| print(f"AVISO: Falha na segmentação ({e}). Retornando original.")
|
| return pil_image.convert("RGB")
|
|
|
|
|
|
|
| class FeatureExtractor:
|
| def __init__(self, device=None):
|
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| print(f"Usando dispositivo: {self.device}")
|
|
|
| self.processor = ConvNextImageProcessor.from_pretrained(
|
| "facebook/convnext-large-224-22k-1k"
|
| )
|
| self.model = ConvNextForImageClassification.from_pretrained(
|
| "facebook/convnext-large-224-22k-1k"
|
| ).to(self.device)
|
|
|
| self.model.classifier = torch.nn.Identity()
|
| self.model.eval()
|
|
|
| def extract_convnext(self, image_path: str) -> np.ndarray:
|
| print(f"Processando imagem: {os.path.basename(image_path)}")
|
| input_img = Image.open(image_path).convert("RGB")
|
|
|
|
|
| final_image = segment_image(input_img)
|
|
|
|
|
| inputs = self.processor(final_image, return_tensors="pt").to(self.device)
|
|
|
| with torch.no_grad():
|
| features = self.model(**inputs).logits
|
|
|
| features_np = features.cpu().numpy().flatten()
|
| return features_np
|
|
|
| def process_single_image(image_path: str):
|
| extractor = FeatureExtractor()
|
| return extractor.extract_convnext(image_path) |