Babu Pallam commited on
Commit
c08c2bd
·
1 Parent(s): ca8b4c4
gallery_features/rn50_stage3_v3_gallery.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f923693bf05e01d8c47f6384e0be38924d70bf9426cdf539be7a29097d46058c
3
- size 52369628
 
 
 
 
gallery_features/rn50_zeroshot_gallery.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:df12a27314cf75e99c3c3ade1ecb36f2c673f4a882d3d340bd3435fbd543b4ed
3
- size 26540252
 
 
 
 
gallery_features/vitb16_stage3_v4_gallery.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:10f7b4aa80258ec32724d66d8942713e2d0fc2976034f175084da8b5a29d30df
3
- size 26540252
 
 
 
 
gallery_features/vitb16_zeroshot_gallery.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a8af24ee16c51086d3397930ee56efd80cd51db4ce1afbd7f8edb926b86eecd2
3
- size 13625564
 
 
 
 
inference.py DELETED
@@ -1,6 +0,0 @@
1
- def load_gallery_features(arch="vitb16", stage="stage3"):
2
- filename = f"{arch}_{stage}_gallery.pt"
3
- path = os.path.join("gallery_features", filename)
4
- if not os.path.exists(path):
5
- raise FileNotFoundError(f"Gallery file not found: {path}")
6
- return torch.load(path)
 
 
 
 
 
 
 
model_loader.py DELETED
@@ -1,49 +0,0 @@
1
- import os
2
- import torch
3
- import clip
4
- from stylefinder import CLIPModel # Make sure your model class is here
5
-
6
-
7
- def load_model(arch="vitb16", stage="stage3", device=None):
8
- """
9
- Loads the appropriate StyleFinder model.
10
-
11
- Args:
12
- arch (str): "vitb16" or "rn50"
13
- stage (str): "stage3" or "zeroshot"
14
- device (str): torch device (e.g., "cuda" or "cpu")
15
-
16
- Returns:
17
- model, preprocess
18
- """
19
- arch = arch.lower()
20
- stage = stage.lower()
21
- device = device or ("cuda" if torch.cuda.is_available() else "cpu")
22
-
23
- if stage == "zeroshot":
24
- model, preprocess = clip.load("ViT-B/16" if arch == "vitb16" else "RN50", device=device)
25
- return model, preprocess
26
-
27
- # Load fine-tuned model from local checkpoint
28
- checkpoint_paths = {
29
- "vitb16": "vitb16_stage3_v4.pth",
30
- "rn50": "rn50_stage3_v3.pth"
31
- }
32
-
33
- if arch not in checkpoint_paths:
34
- raise ValueError(f"Unsupported architecture: {arch}")
35
-
36
- ckpt_path = checkpoint_paths[arch]
37
- if not os.path.exists(ckpt_path):
38
- raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
39
-
40
- # Instantiate and load the fine-tuned model
41
- model = CLIPModel(arch=arch)
42
- state_dict = torch.load(ckpt_path, map_location=device)
43
- model.load_state_dict(state_dict)
44
- model.eval()
45
- model.to(device)
46
-
47
- from preprocess import build_preprocess
48
- preprocess = build_preprocess()
49
- return model, preprocess
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
preprocess.py DELETED
@@ -1,10 +0,0 @@
1
- # preprocess.py
2
- from torchvision import transforms
3
-
4
- def build_preprocess():
5
- return transforms.Compose([
6
- transforms.Resize((224, 224)),
7
- transforms.ToTensor(),
8
- transforms.Normalize(mean=[0.4815, 0.4578, 0.4082],
9
- std=[0.2686, 0.2613, 0.2758]),
10
- ])