File size: 2,427 Bytes
2b5181c
 
 
 
 
5c69629
2b5181c
5ec0dd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b5181c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c69629
 
 
5ec0dd3
 
5c69629
5ec0dd3
5c69629
 
2b5181c
 
5c69629
2b5181c
 
 
5ec0dd3
2b5181c
5ec0dd3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import torch
import numpy as np
from PIL import Image
from transformers import ConvNextImageProcessor, ConvNextForImageClassification
from rembg import remove

#FUNÇÃO AUXILIAR DE SEGMENTAÇÃO (REUTILIZÁVEL)
def segment_image(pil_image: Image.Image) -> Image.Image:
    """

    Recebe uma imagem PIL, remove o fundo e coloca fundo preto.

    Retorna a imagem PIL tratada.

    """
    try:
        #1. Remove o fundo (Rembg)
        img_no_bg = remove(pil_image)
        
        #2. Composição em fundo preto
        #Cria uma imagem totalmente preta do mesmo tamanho
        fundo_preto = Image.new("RGB", img_no_bg.size, (0, 0, 0))
        
        #Usa o canal Alpha como máscara
        if img_no_bg.mode == 'RGBA':
            mask = img_no_bg.split()[3]
            fundo_preto.paste(img_no_bg, mask=mask)
            return fundo_preto
        else:
            return img_no_bg.convert("RGB")
            
    except Exception as e:
        print(f"AVISO: Falha na segmentação ({e}). Retornando original.")
        return pil_image.convert("RGB")


#CLASSE EXTRATORA
class FeatureExtractor:
    def __init__(self, device=None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Usando dispositivo: {self.device}")

        self.processor = ConvNextImageProcessor.from_pretrained(
            "facebook/convnext-large-224-22k-1k"
        )
        self.model = ConvNextForImageClassification.from_pretrained(
            "facebook/convnext-large-224-22k-1k"
        ).to(self.device)

        self.model.classifier = torch.nn.Identity()
        self.model.eval()

    def extract_convnext(self, image_path: str) -> np.ndarray:
        print(f"Processando imagem: {os.path.basename(image_path)}")
        input_img = Image.open(image_path).convert("RGB")
        
        #1. CHAMA A FUNÇÃO DE SEGMENTAÇÃO
        final_image = segment_image(input_img)

        #2. Passa para o ConvNeXt
        inputs = self.processor(final_image, return_tensors="pt").to(self.device)
        
        with torch.no_grad():
            features = self.model(**inputs).logits
            
        features_np = features.cpu().numpy().flatten()
        return features_np

def process_single_image(image_path: str):
    extractor = FeatureExtractor()
    return extractor.extract_convnext(image_path)