Babu Pallam commited on
Commit
a32e344
·
1 Parent(s): e7bd561

Initial push: models + gallery features + inference code

Browse files
README.md CHANGED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # StyleFinder – Fashion Visual Search with CLIP
2
+
3
+ This repository includes two fine-tuned CLIP models for image-based fashion retrieval:
4
+
5
+ | Model | Stage | Rank-1 | mAP |
6
+ |---------------|--------------|--------|-------|
7
+ | ViT-B/16 | Stage 3 v4 | 46.24% | 0.3481|
8
+ | ResNet-50 | Stage 3 v3 | 53.95% | 0.4265|
9
+
10
+ ---
11
+
12
+ ## 🧠 Model Details
13
+
14
+ - **ViT-B/16 (Transformer-based, 512-dim):** Jointly fine-tuned using SupCon + ArcFace + BNNeck.
15
+ - **RN50 (CNN-based, 1024-dim):** Fine-tuned with prompt-structured Stage 3 configuration.
16
+ - Dataset: [DeepFashion – In-shop Clothes Retrieval](https://mmlab.ie.cuhk.edu.hk/projects/DeepFashion/InShopRetrieval.html)
17
+
18
+ ---
19
+
20
+ ## 📦 How to Use
21
+
22
+ ```python
23
+ from model_loader import load_model
24
+ model = load_model("vitb16") # or "rn50"
gallery_features/rn50_stage3_v3_gallery.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f923693bf05e01d8c47f6384e0be38924d70bf9426cdf539be7a29097d46058c
3
+ size 52369628
gallery_features/rn50_zeroshot_gallery.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df12a27314cf75e99c3c3ade1ecb36f2c673f4a882d3d340bd3435fbd543b4ed
3
+ size 26540252
gallery_features/vitb16_stage3_v4_gallery.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10f7b4aa80258ec32724d66d8942713e2d0fc2976034f175084da8b5a29d30df
3
+ size 26540252
gallery_features/vitb16_zeroshot_gallery.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8af24ee16c51086d3397930ee56efd80cd51db4ce1afbd7f8edb926b86eecd2
3
+ size 13625564
inference.py CHANGED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ def load_gallery_features(arch="vitb16", stage="stage3"):
2
+ filename = f"{arch}_{stage}_gallery.pt"
3
+ path = os.path.join("gallery_features", filename)
4
+ if not os.path.exists(path):
5
+ raise FileNotFoundError(f"Gallery file not found: {path}")
6
+ return torch.load(path)
model_loader.py CHANGED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import clip
4
+ from stylefinder import CLIPModel # Make sure your model class is here
5
+
6
+
7
+ def load_model(arch="vitb16", stage="stage3", device=None):
8
+ """
9
+ Loads the appropriate StyleFinder model.
10
+
11
+ Args:
12
+ arch (str): "vitb16" or "rn50"
13
+ stage (str): "stage3" or "zeroshot"
14
+ device (str): torch device (e.g., "cuda" or "cpu")
15
+
16
+ Returns:
17
+ model, preprocess
18
+ """
19
+ arch = arch.lower()
20
+ stage = stage.lower()
21
+ device = device or ("cuda" if torch.cuda.is_available() else "cpu")
22
+
23
+ if stage == "zeroshot":
24
+ model, preprocess = clip.load("ViT-B/16" if arch == "vitb16" else "RN50", device=device)
25
+ return model, preprocess
26
+
27
+ # Load fine-tuned model from local checkpoint
28
+ checkpoint_paths = {
29
+ "vitb16": "vitb16_stage3_v4.pth",
30
+ "rn50": "rn50_stage3_v3.pth"
31
+ }
32
+
33
+ if arch not in checkpoint_paths:
34
+ raise ValueError(f"Unsupported architecture: {arch}")
35
+
36
+ ckpt_path = checkpoint_paths[arch]
37
+ if not os.path.exists(ckpt_path):
38
+ raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
39
+
40
+ # Instantiate and load the fine-tuned model
41
+ model = CLIPModel(arch=arch)
42
+ state_dict = torch.load(ckpt_path, map_location=device)
43
+ model.load_state_dict(state_dict)
44
+ model.eval()
45
+ model.to(device)
46
+
47
+ from preprocess import build_preprocess
48
+ preprocess = build_preprocess()
49
+ return model, preprocess
preprocess.py CHANGED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # preprocess.py
2
+ from torchvision import transforms
3
+
4
+ def build_preprocess():
5
+ return transforms.Compose([
6
+ transforms.Resize((224, 224)),
7
+ transforms.ToTensor(),
8
+ transforms.Normalize(mean=[0.4815, 0.4578, 0.4082],
9
+ std=[0.2686, 0.2613, 0.2758]),
10
+ ])
rn50_stage3_v3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c073c411953996380c90272d216e685d80785d8c500ae99e5bc1d28449d2d574
3
+ size 408426712
vitb16_stage3_v4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9be498bcad0c04e5605895982c441a715be768742c2c4ebfe037d8d9f61f2d77
3
+ size 598604166