orrzohar commited on
Commit
bc239f6
·
1 Parent(s): 020eac1

maybe working

Browse files
Files changed (1) hide show
  1. legacy_eva_clip/eva_vit.py +53 -48
legacy_eva_clip/eva_vit.py CHANGED
@@ -5,61 +5,20 @@
5
  from math import pi
6
  import os
7
  from pathlib import Path
8
- from typing import Optional
9
  import importlib.util
 
10
 
11
  import torch
12
  from torch import nn
13
  from einops import rearrange, repeat
14
  import logging
15
 
16
- _hub_spec = importlib.util.find_spec("huggingface_hub")
17
- if _hub_spec is not None:
18
- from huggingface_hub import snapshot_download # type: ignore
19
- else: # pragma: no cover - optional dependency
20
- snapshot_download = None
21
 
22
- _DEFAULT_EVA_REPO = "jiuhai/eva_clip_vision_tower"
23
  XFORMERS_AVAILABLE = False # populated later once xops import resolves
24
  _XFORMERS_WARNING_EMITTED = False
25
 
26
 
27
- def _resolve_vision_checkpoint_path(vision_tower_pretrained: Optional[str]) -> str:
28
- """Determine where to load the EVA visual tower weights from."""
29
-
30
- candidate_files: list[Path] = []
31
- if vision_tower_pretrained:
32
- supplied = Path(vision_tower_pretrained)
33
- if supplied.is_file():
34
- candidate_files.append(supplied)
35
- elif supplied.is_dir():
36
- for filename in (
37
- "pytorch_model.bin",
38
- "model.safetensors",
39
- "visual.pth",
40
- "visual.bin",
41
- ):
42
- candidate = supplied / filename
43
- if candidate.exists():
44
- candidate_files.append(candidate)
45
- break
46
-
47
- if candidate_files:
48
- return str(candidate_files[0])
49
-
50
- if snapshot_download is None:
51
- raise FileNotFoundError(
52
- "EVA vision weights not found locally and huggingface_hub is unavailable. "
53
- "Provide --vision-tower-path pointing to a directory containing the checkpoint."
54
- )
55
-
56
- cache_dir = Path(snapshot_download(repo_id=_DEFAULT_EVA_REPO))
57
- default_file = cache_dir / "pytorch_model.bin"
58
- if not default_file.exists():
59
- raise FileNotFoundError(f"Default EVA checkpoint not found at {default_file}.")
60
- return str(default_file)
61
-
62
-
63
  def broadcat(tensors, dim=-1):
64
  num_tensors = len(tensors)
65
  shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
@@ -668,13 +627,54 @@ class EVAVisionTransformer(nn.Module):
668
 
669
 
670
  def load_state_dict(checkpoint_path: str, map_location: str = "cpu", model_key: str = "model|module|state_dict", is_openai: bool = False, skip_list: list = []):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
671
  if is_openai:
672
- model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
673
  state_dict = model.state_dict()
674
  for key in ["input_resolution", "context_length", "vocab_size"]:
675
  state_dict.pop(key, None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
  else:
677
- checkpoint = torch.load(checkpoint_path, map_location=map_location)
678
  for mk in model_key.split("|"):
679
  if isinstance(checkpoint, dict) and mk in checkpoint:
680
  state_dict = checkpoint[mk]
@@ -797,9 +797,14 @@ class EVAEncoderWrapper(nn.Module):
797
  def __init__(self, vision_tower_pretrained, config):
798
  super(EVAEncoderWrapper, self).__init__()
799
  self.config = config
800
- vision_tower_path = _resolve_vision_checkpoint_path(vision_tower_pretrained)
801
- self.config["vision_tower_path"] = vision_tower_path
802
- self.vision_tower_path = vision_tower_path
 
 
 
 
 
803
  self.model = _build_vision_tower(**self.config)
804
 
805
  def forward(self, image, **kwargs):
 
5
  from math import pi
6
  import os
7
  from pathlib import Path
 
8
  import importlib.util
9
+ import json
10
 
11
  import torch
12
  from torch import nn
13
  from einops import rearrange, repeat
14
  import logging
15
 
16
+ from safetensors.torch import load_file as load_safetensors
 
 
 
 
17
 
 
18
  XFORMERS_AVAILABLE = False # populated later once xops import resolves
19
  _XFORMERS_WARNING_EMITTED = False
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def broadcat(tensors, dim=-1):
23
  num_tensors = len(tensors)
24
  shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
 
627
 
628
 
629
  def load_state_dict(checkpoint_path: str, map_location: str = "cpu", model_key: str = "model|module|state_dict", is_openai: bool = False, skip_list: list = []):
630
+ path = Path(checkpoint_path)
631
+ if path.is_dir():
632
+ index_file = path / "model.safetensors.index.json"
633
+ if index_file.exists():
634
+ path = index_file
635
+ else:
636
+ for filename in (
637
+ "model.safetensors",
638
+ "pytorch_model.bin",
639
+ "visual.pth",
640
+ "visual.bin",
641
+ ):
642
+ candidate = path / filename
643
+ if candidate.exists():
644
+ path = candidate
645
+ break
646
+ if path.is_dir():
647
+ raise FileNotFoundError(f"No EVA checkpoint files found under {checkpoint_path}")
648
  if is_openai:
649
+ model = torch.jit.load(str(path), map_location="cpu").eval()
650
  state_dict = model.state_dict()
651
  for key in ["input_resolution", "context_length", "vocab_size"]:
652
  state_dict.pop(key, None)
653
+ elif path.suffix == ".json" and path.name.endswith(".safetensors.index.json"):
654
+ if load_safetensors is None:
655
+ raise ImportError(
656
+ "safetensors is required to load EVA vision weights from sharded checkpoints. "
657
+ "Install `safetensors` or provide a .bin checkpoint."
658
+ )
659
+ with open(path, "r", encoding="utf-8") as f:
660
+ index_data = json.load(f)
661
+ weight_map = index_data.get("weight_map", {})
662
+ shard_cache: dict[str, dict[str, torch.Tensor]] = {}
663
+ state_dict = {}
664
+ for param_name, shard_name in weight_map.items():
665
+ shard_path = path.with_name(shard_name)
666
+ if shard_name not in shard_cache:
667
+ shard_cache[shard_name] = load_safetensors(str(shard_path), device=map_location)
668
+ state_dict[param_name] = shard_cache[shard_name][param_name]
669
+ elif path.suffix == ".safetensors":
670
+ if load_safetensors is None:
671
+ raise ImportError(
672
+ "safetensors is required to load EVA vision weights from safetensors checkpoints. "
673
+ "Install `safetensors` or provide a .bin checkpoint."
674
+ )
675
+ state_dict = load_safetensors(str(path), device=map_location)
676
  else:
677
+ checkpoint = torch.load(str(path), map_location=map_location)
678
  for mk in model_key.split("|"):
679
  if isinstance(checkpoint, dict) and mk in checkpoint:
680
  state_dict = checkpoint[mk]
 
797
  def __init__(self, vision_tower_pretrained, config):
798
  super(EVAEncoderWrapper, self).__init__()
799
  self.config = config
800
+ if not vision_tower_pretrained:
801
+ raise ValueError("vision_tower_pretrained must be provided.")
802
+ vision_tower_path = Path(vision_tower_pretrained)
803
+ if not vision_tower_path.exists():
804
+ raise FileNotFoundError(f"EVA vision weights not found under {vision_tower_path}")
805
+ resolved_path = str(vision_tower_path.resolve())
806
+ self.config["vision_tower_path"] = resolved_path
807
+ self.vision_tower_path = resolved_path
808
  self.model = _build_vision_tower(**self.config)
809
 
810
  def forward(self, image, **kwargs):