ML4RS-Anonymous commited on
Commit
66ee1b2
·
verified ·
1 Parent(s): 855c680

fix: import create_model_from_pretrained

Browse files
Files changed (1) hide show
  1. models/siglip_model.py +9 -9
models/siglip_model.py CHANGED
@@ -40,15 +40,15 @@ class SigLIPModel:
40
  print(f"Warning: Checkpoint not found at {self.ckpt_path}")
41
 
42
  if 'hf' in self.ckpt_path:
43
- self.ckpt_path = hf_hub_download("timm/ViT-SO400M-14-SigLIP-384", "open_clip_pytorch_model.bin")
44
- # self.tokenizer = open_clip.get_tokenizer("hf-hub:timm/ViT-SO400M-14-SigLIP-384")
45
- self.tokenizer = HFTokenizer("hf-hub:timm/ViT-SO400M-14-SigLIP-384")
46
- else:
47
- self.tokenizer = HFTokenizer(self.tokenizer_path)
48
- self.model, _, self.preprocess = open_clip.create_model_and_transforms(
49
- self.model_name,
50
- pretrained=self.ckpt_path
51
- )
52
  self.model = self.model.to(self.device)
53
  self.model.eval()
54
 
 
40
  print(f"Warning: Checkpoint not found at {self.ckpt_path}")
41
 
42
  if 'hf' in self.ckpt_path:
43
+ from open_clip import create_model_from_pretrained, get_tokenizer # works on open-clip-torch>=2.23.0, timm>=0.9.8
44
+
45
+ model, preprocess = create_model_from_pretrained('hf-hub:timm/ViT-SO400M-14-SigLIP-384')
46
+ tokenizer = get_tokenizer('hf-hub:timm/ViT-SO400M-14-SigLIP-384')
47
+
48
+ # model, preprocess = create_model_from_pretrained('hf-hub:timm/ViT-SO400M-14-SigLIP-384')
49
+ # from huggingface_hub import hf_hub_download
50
+ # tokenizer = get_tokenizer('hf-hub:timm/ViT-SO400M-14-SigLIP-384')
51
+
52
  self.model = self.model.to(self.device)
53
  self.model.eval()
54