bobs24 commited on
Commit
11a1df0
·
verified ·
1 Parent(s): 77b3cf0

Upload 3 files

Browse files
model/__pycache__/feature_extractor.cpython-311.pyc ADDED
Binary file (2.99 kB). View file
 
model/__pycache__/feature_extractor.cpython-312.pyc ADDED
Binary file (2.39 kB). View file
 
model/feature_extractor.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.models as models
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+
8
+ ## ResNet50
9
+ # class FeatureExtractor:
10
+ # def __init__(self):
11
+ # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ # # Load pretrained ResNet50 without the final classification layer
13
+ # resnet = models.resnet50(pretrained=True)
14
+ # # Remove the final fully connected layer (fc)
15
+ # self.model = torch.nn.Sequential(*list(resnet.children())[:-1])
16
+ # self.model.eval().to(self.device)
17
+
18
+ # # Standard ImageNet preprocessing
19
+ # self.transform = transforms.Compose([
20
+ # transforms.Resize(256),
21
+ # transforms.CenterCrop(224),
22
+ # transforms.ToTensor(),
23
+ # transforms.Normalize(
24
+ # mean=[0.485, 0.456, 0.406],
25
+ # std=[0.229, 0.224, 0.225]
26
+ # ),
27
+ # ])
28
+
29
+ # def extract(self, image: Image.Image):
30
+ # image = self.transform(image).unsqueeze(0).to(self.device)
31
+ # with torch.no_grad():
32
+ # features = self.model(image)
33
+ # features = features.squeeze().cpu().numpy()
34
+ # features = features.reshape(-1) # flatten (2048,)
35
+
36
+ # # Normalize to unit vector (important for cosine similarity)
37
+ # norm = np.linalg.norm(features)
38
+ # if norm > 0:
39
+ # features = features / norm
40
+ # return features
41
+
42
+ ## ConvNext-Tiny
43
+ class FeatureExtractor:
44
+ def __init__(self):
45
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
+
47
+ # Load pretrained ConvNeXt Tiny
48
+ weights = models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1
49
+ convnext = models.convnext_tiny(weights=weights)
50
+ # Remove classification head (last layer)
51
+ self.model = torch.nn.Sequential(*list(convnext.children())[:-1])
52
+ self.model.eval().to(self.device)
53
+
54
+ # Use official preprocessing transform for ConvNeXt Tiny
55
+ self.transform = weights.transforms()
56
+
57
+ def extract(self, image: Image.Image):
58
+ image = self.transform(image).unsqueeze(0).to(self.device)
59
+ with torch.no_grad():
60
+ features = self.model(image)
61
+ features = features.squeeze().cpu().numpy()
62
+ features = features.reshape(-1) # flatten
63
+
64
+ # Normalize to unit vector (important for cosine similarity)
65
+ norm = np.linalg.norm(features)
66
+ if norm > 0:
67
+ features = features / norm
68
+ return features