orrzohar commited on
Commit
24580dc
·
1 Parent(s): f7127d5
Files changed (4) hide show
  1. builders.py +1 -3
  2. config.json +2 -1
  3. modeling_blip3o_qwen.py +19 -3
  4. vision_tower.py +92 -130
builders.py CHANGED
@@ -7,7 +7,7 @@ from typing import Any
7
  import torch.nn as nn
8
 
9
  from .diffusion_auto import AutoDiffusionModel, DiffusionConfig
10
- from .vision_tower import AutoEvaClipVisionTower, default_device, resolve_eva_repo
11
 
12
  logger = logging.getLogger(__name__)
13
  _PKG_ROOT = Path(__file__).resolve().parent
@@ -79,9 +79,7 @@ def build_down_projector(config, delay_load: bool = False, **kwargs):
79
  def build_gen_vision_tower(config, delay_load: bool = False, **kwargs):
80
  """Instantiate the EVA-CLIP tower purely from HF Hub assets."""
81
 
82
- repo_id = resolve_eva_repo(config)
83
  tower = AutoEvaClipVisionTower(
84
- repo_id,
85
  config=config,
86
  torch_dtype=kwargs.get("torch_dtype"),
87
  device=default_device(kwargs.get("device")),
 
7
  import torch.nn as nn
8
 
9
  from .diffusion_auto import AutoDiffusionModel, DiffusionConfig
10
+ from .vision_tower import AutoEvaClipVisionTower, default_device
11
 
12
  logger = logging.getLogger(__name__)
13
  _PKG_ROOT = Path(__file__).resolve().parent
 
79
  def build_gen_vision_tower(config, delay_load: bool = False, **kwargs):
80
  """Instantiate the EVA-CLIP tower purely from HF Hub assets."""
81
 
 
82
  tower = AutoEvaClipVisionTower(
 
83
  config=config,
84
  torch_dtype=kwargs.get("torch_dtype"),
85
  device=default_device(kwargs.get("device")),
config.json CHANGED
@@ -78,8 +78,9 @@
78
  },
79
  "vision_end_token_id": 151653,
80
  "vision_start_token_id": 151652,
 
81
  "vision_token_id": 151654,
82
- "vision_tower_pretrained": null,
83
  "vocab_size": 151668,
84
  "auto_map": {
85
  "AutoConfig": "modeling_blip3o_qwen.blip3oQwenConfig",
 
78
  },
79
  "vision_end_token_id": 151653,
80
  "vision_start_token_id": 151652,
81
+ "eva_image_size": 448,
82
  "vision_token_id": 151654,
83
+ "vision_tower_pretrained": "model_zoo/EVA-CLIP-E14-Plus",
84
  "vocab_size": 151668,
85
  "auto_map": {
86
  "AutoConfig": "modeling_blip3o_qwen.blip3oQwenConfig",
modeling_blip3o_qwen.py CHANGED
@@ -22,6 +22,9 @@ from transformers import (
22
  Qwen2_5_VLModel,
23
  )
24
  from transformers.generation.utils import GenerateOutput
 
 
 
25
  from transformers.modeling_outputs import CausalLMOutputWithPast
26
 
27
  from .builders import build_dit, build_gen_vision_tower, build_down_projector
@@ -37,6 +40,12 @@ IGNORE_INDEX = -100
37
  IMAGE_TOKEN_IDX = 151667
38
 
39
 
 
 
 
 
 
 
40
  class blip3oMetaModel:
41
  def __init__(self, config):
42
  super(blip3oMetaModel, self).__init__(config)
@@ -439,12 +448,13 @@ class blip3oQwenForCausalLM(Qwen2_5_VLForConditionalGeneration, blip3oMetaForCau
439
  use_cache: Optional[bool] = None,
440
  output_attentions: Optional[bool] = None,
441
  output_hidden_states: Optional[bool] = None,
442
- gen_image: Optional[torch.FloatTensor] = None,
443
  pixel_values: Optional[torch.Tensor] = None,
444
  image_grid_thw: Optional[torch.Tensor] = None,
445
  return_dict: Optional[bool] = None,
446
  cache_position: Optional[torch.LongTensor] = None
447
  ) -> Union[Tuple, CausalLMOutputWithPast]:
 
448
 
449
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
450
  output_hidden_states = (
@@ -491,6 +501,8 @@ class blip3oQwenForCausalLM(Qwen2_5_VLForConditionalGeneration, blip3oMetaForCau
491
  logits = logits.float()
492
 
493
  total_loss = None
 
 
494
  if labels is not None:
495
  # Shift so that tokens < n predict n
496
  shift_logits = logits[..., :-1, :].contiguous()
@@ -564,14 +576,18 @@ class blip3oQwenForCausalLM(Qwen2_5_VLForConditionalGeneration, blip3oMetaForCau
564
  text_weight = getattr(self.config, "text_loss_weight", 1.0)
565
  img_weight = getattr(self.config, "img_loss_weight", 1.0)
566
  total_loss = text_weight * text_loss + img_weight * img_loss
567
- print(f"text loss {text_loss} | img loss {img_loss}")
 
 
568
 
569
- return CausalLMOutputWithPast(
570
  loss=total_loss,
571
  logits=logits,
572
  past_key_values=outputs.past_key_values,
573
  hidden_states=outputs.hidden_states,
574
  attentions=outputs.attentions,
 
 
575
  )
576
 
577
 
 
22
  Qwen2_5_VLModel,
23
  )
24
  from transformers.generation.utils import GenerateOutput
25
+ from dataclasses import dataclass
26
+ from typing import Optional
27
+
28
  from transformers.modeling_outputs import CausalLMOutputWithPast
29
 
30
  from .builders import build_dit, build_gen_vision_tower, build_down_projector
 
40
  IMAGE_TOKEN_IDX = 151667
41
 
42
 
43
+ @dataclass
44
+ class Blip3oCausalLMOutput(CausalLMOutputWithPast):
45
+ text_loss: Optional[torch.FloatTensor] = None
46
+ img_loss: Optional[torch.FloatTensor] = None
47
+
48
+
49
  class blip3oMetaModel:
50
  def __init__(self, config):
51
  super(blip3oMetaModel, self).__init__(config)
 
448
  use_cache: Optional[bool] = None,
449
  output_attentions: Optional[bool] = None,
450
  output_hidden_states: Optional[bool] = None,
451
+ gen_images: Optional[torch.FloatTensor] = None,
452
  pixel_values: Optional[torch.Tensor] = None,
453
  image_grid_thw: Optional[torch.Tensor] = None,
454
  return_dict: Optional[bool] = None,
455
  cache_position: Optional[torch.LongTensor] = None
456
  ) -> Union[Tuple, CausalLMOutputWithPast]:
457
+ gen_image=gen_images
458
 
459
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
460
  output_hidden_states = (
 
501
  logits = logits.float()
502
 
503
  total_loss = None
504
+ text_loss = None
505
+ img_loss = None
506
  if labels is not None:
507
  # Shift so that tokens < n predict n
508
  shift_logits = logits[..., :-1, :].contiguous()
 
576
  text_weight = getattr(self.config, "text_loss_weight", 1.0)
577
  img_weight = getattr(self.config, "img_loss_weight", 1.0)
578
  total_loss = text_weight * text_loss + img_weight * img_loss
579
+ # cache latest component losses for logging
580
+ self._last_text_loss = float(text_loss.detach().mean().cpu())
581
+ self._last_img_loss = float(img_loss.detach().mean().cpu())
582
 
583
+ return Blip3oCausalLMOutput(
584
  loss=total_loss,
585
  logits=logits,
586
  past_key_values=outputs.past_key_values,
587
  hidden_states=outputs.hidden_states,
588
  attentions=outputs.attentions,
589
+ text_loss=text_loss,
590
+ img_loss=img_loss,
591
  )
592
 
593
 
vision_tower.py CHANGED
@@ -1,103 +1,10 @@
1
  from __future__ import annotations
2
 
3
- import importlib
4
- import os
5
- import sys
6
- from pathlib import Path
7
- from types import SimpleNamespace
8
  from typing import Optional, Union
9
 
10
  import torch
11
  import torch.nn as nn
12
- from huggingface_hub import snapshot_download
13
-
14
- _DEFAULT_EVA_REPO = os.environ.get("BLIP3O_EVA_ID", "orrzohar/EVA-CLIP-E14-Plus")
15
- _EVA_ALIASES = {
16
- "eva-clip-e-14-plus": "orrzohar/EVA-CLIP-E14-Plus",
17
- "eva_clip_e_14_plus": "orrzohar/EVA-CLIP-E14-Plus",
18
- "orrzohar/eva-clip-e14-plus": "orrzohar/EVA-CLIP-E14-Plus",
19
- }
20
-
21
- _PKG_ROOT = Path(__file__).resolve().parent
22
- _EVA_ROOT = (_PKG_ROOT.parent / "EVA-CLIP-E14-Plus").resolve()
23
- _LEGACY_TOWER_CLASS = None
24
-
25
-
26
- def _normalize_candidate(candidate: Optional[str]) -> Optional[str]:
27
- if candidate is None:
28
- return None
29
- candidate = candidate.strip()
30
- return candidate or None
31
-
32
-
33
- def _ensure_code_on_path(path_hint: str) -> Path:
34
- """Make sure the legacy_eva_clip package is importable."""
35
-
36
- path = Path(path_hint)
37
- search_roots: list[Path] = []
38
- if path.exists():
39
- search_roots.append(path if path.is_dir() else path.parent)
40
- if _EVA_ROOT.exists():
41
- search_roots.append(_EVA_ROOT)
42
-
43
- for root in search_roots:
44
- pkg_dir = root / "legacy_eva_clip"
45
- if pkg_dir.exists():
46
- if str(root) not in sys.path:
47
- sys.path.insert(0, str(root))
48
- return root
49
-
50
- repo_id, _, revision = _DEFAULT_EVA_REPO.partition("@")
51
- download_root = Path(snapshot_download(repo_id=repo_id, revision=revision or None))
52
- if str(download_root) not in sys.path:
53
- sys.path.insert(0, str(download_root))
54
- return download_root
55
-
56
-
57
- def _get_legacy_tower_class(code_root: Path):
58
- """Import the legacy EVA tower implementation, caching the class."""
59
-
60
- global _LEGACY_TOWER_CLASS
61
- if _LEGACY_TOWER_CLASS is not None:
62
- return _LEGACY_TOWER_CLASS
63
-
64
- if str(code_root) not in sys.path:
65
- sys.path.insert(0, str(code_root))
66
-
67
- module = importlib.import_module("legacy_eva_clip.eva_clip_encoder")
68
- _LEGACY_TOWER_CLASS = module.EvaClipVisionTower
69
- return _LEGACY_TOWER_CLASS
70
-
71
-
72
- def resolve_eva_repo(config=None, fallback: Optional[str] = None) -> str:
73
- """Return a concrete path (file or directory) containing EVA weights."""
74
-
75
- candidate = _normalize_candidate(
76
- fallback
77
- or os.environ.get("BLIP3O_EVA_ID")
78
- or (getattr(config, "vision_tower_pretrained", None) if config is not None else None)
79
- or (getattr(config, "gen_vision_tower", None) if config is not None else None)
80
- or _DEFAULT_EVA_REPO
81
- )
82
-
83
- if candidate is None:
84
- raise ValueError("Unable to determine EVA checkpoint location.")
85
-
86
- candidate_path = Path(candidate)
87
- potential_paths = [
88
- candidate_path,
89
- (_PKG_ROOT / candidate) if not candidate_path.is_absolute() else None,
90
- _EVA_ROOT if _EVA_ROOT.exists() else None,
91
- ]
92
- for path in potential_paths:
93
- if path and path.exists():
94
- return str(path.resolve())
95
-
96
- alias = _EVA_ALIASES.get(candidate.lower())
97
- repo_spec = alias or candidate
98
- repo_id, sep, revision = repo_spec.partition("@")
99
- download_path = snapshot_download(repo_id=repo_id, revision=revision or None)
100
- return download_path
101
 
102
 
103
  def default_device(spec: Optional[Union[str, torch.device]] = None) -> torch.device:
@@ -109,56 +16,111 @@ def default_device(spec: Optional[Union[str, torch.device]] = None) -> torch.dev
109
 
110
 
111
  class AutoEvaClipVisionTower(nn.Module):
112
- """Wrapper that dynamically loads the EVA tower code + weights from HF."""
113
 
114
  def __init__(
115
  self,
116
- repo_id: Optional[str] = None,
117
- *,
118
- config=None,
119
  torch_dtype: torch.dtype | None = None,
120
  device: Optional[Union[str, torch.device]] = None,
121
  delay_load: bool = False,
122
  ):
123
  super().__init__()
124
 
125
- pretrained_path = resolve_eva_repo(config, repo_id)
126
- code_root = _ensure_code_on_path(pretrained_path)
127
- legacy_cls = _get_legacy_tower_class(code_root)
128
-
129
- tower_name = (
130
- getattr(config, "gen_vision_tower", None)
131
- or getattr(config, "vision_tower_pretrained", None)
132
- or "eva-clip-E-14-plus"
133
- )
134
- self.repo_id = pretrained_path
135
  self.torch_dtype = torch_dtype or torch.bfloat16
136
  self._device = default_device(device)
137
-
138
- args = SimpleNamespace(
139
- vision_tower_pretrained=pretrained_path,
140
- gen_vision_tower=self.repo_id,
141
- mm_vision_tower=self.repo_id,
142
- unfreeze_mm_vision_tower=False,
143
- mm_tunable_parts=[],
 
 
 
144
  )
145
 
146
- self.legacy_tower = legacy_cls(tower_name, args=args, delay_load=delay_load)
147
-
148
- def load_model(self, device_map=None):
149
- result = self.legacy_tower.load_model(device_map=device_map)
150
- if hasattr(self.legacy_tower, "vision_tower"):
151
- self.legacy_tower.vision_tower.to(device=self._device, dtype=self.torch_dtype)
152
- return result
153
 
154
- def forward(self, *args, **kwargs):
155
- return self.legacy_tower(*args, **kwargs)
156
-
157
- def __getattr__(self, item):
158
- if "legacy_tower" in self.__dict__ and hasattr(self.legacy_tower, item):
159
- return getattr(self.legacy_tower, item)
160
- return super().__getattr__(item)
 
161
 
 
 
162
 
163
- __all__ = ["AutoEvaClipVisionTower", "resolve_eva_repo", "default_device"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
 
1
  from __future__ import annotations
2
 
 
 
 
 
 
3
  from typing import Optional, Union
4
 
5
  import torch
6
  import torch.nn as nn
7
+ from transformers import AutoConfig, AutoImageProcessor, AutoModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  def default_device(spec: Optional[Union[str, torch.device]] = None) -> torch.device:
 
16
 
17
 
18
  class AutoEvaClipVisionTower(nn.Module):
19
+ """Plain Hugging Face vision tower wrapper (AutoModel + AutoImageProcessor)."""
20
 
21
  def __init__(
22
  self,
23
+ config,
 
 
24
  torch_dtype: torch.dtype | None = None,
25
  device: Optional[Union[str, torch.device]] = None,
26
  delay_load: bool = False,
27
  ):
28
  super().__init__()
29
 
30
+ if getattr(config, "vision_tower_pretrained", None) is None:
31
+ raise ValueError("vision_tower_pretrained must be defined in the config.")
32
+ self.repo_id = config.vision_tower_pretrained
 
 
 
 
 
 
 
33
  self.torch_dtype = torch_dtype or torch.bfloat16
34
  self._device = default_device(device)
35
+ self.is_loaded = False
36
+
37
+ self.image_processor = None
38
+ self.vision_model = None
39
+ self._hf_config = AutoConfig.from_pretrained(self.repo_id, trust_remote_code=True)
40
+ self._vision_cfg = dict(getattr(self._hf_config, "vision_cfg", {}))
41
+ self._hidden_size = (
42
+ self._vision_cfg.get("width")
43
+ or getattr(self._hf_config, "embed_dim", None)
44
+ or getattr(self._hf_config, "hidden_size", None)
45
  )
46
 
47
+ if not delay_load:
48
+ self.load_model(torch_dtype=self.torch_dtype, device=self._device)
 
 
 
 
 
49
 
50
+ def load_model(
51
+ self,
52
+ *,
53
+ torch_dtype: torch.dtype | None = None,
54
+ device: Optional[Union[str, torch.device]] = None,
55
+ ):
56
+ if self.is_loaded:
57
+ return self
58
 
59
+ dtype = torch_dtype or self.torch_dtype
60
+ target_device = default_device(device or self._device)
61
 
62
+ self.image_processor = AutoImageProcessor.from_pretrained(
63
+ self.repo_id,
64
+ trust_remote_code=True,
65
+ )
66
+ self.vision_model = AutoModel.from_pretrained(
67
+ self.repo_id,
68
+ trust_remote_code=True,
69
+ torch_dtype=dtype,
70
+ )
71
+ self.vision_model.to(target_device)
72
+ self.vision_model.eval()
73
+ self.vision_model.requires_grad_(False)
74
+
75
+ self.torch_dtype = dtype
76
+ self._device = target_device
77
+ self.is_loaded = True
78
+ return self
79
+
80
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
81
+ if not self.is_loaded:
82
+ raise RuntimeError("Vision tower used before load_model()")
83
+ inputs = pixel_values.to(self.device, dtype=self.torch_dtype)
84
+ outputs = self.vision_model(pixel_values=inputs)
85
+ hidden_states = getattr(outputs, "last_hidden_state", None)
86
+ if hidden_states is None:
87
+ raise ValueError("EVA model did not return last_hidden_state")
88
+ return hidden_states.to(pixel_values.dtype)
89
+
90
+ @property
91
+ def dtype(self) -> torch.dtype:
92
+ if self.is_loaded:
93
+ return next(self.vision_model.parameters()).dtype
94
+ return self.torch_dtype
95
+
96
+ @property
97
+ def device(self) -> torch.device:
98
+ if self.is_loaded:
99
+ return next(self.vision_model.parameters()).device
100
+ return self._device
101
+
102
+ @property
103
+ def hidden_size(self) -> int:
104
+ if self._hidden_size is not None:
105
+ return int(self._hidden_size)
106
+ if self.vision_model is not None:
107
+ return int(getattr(self.vision_model.config, "hidden_size"))
108
+ return 1024
109
+
110
+ @property
111
+ def num_patches(self) -> int:
112
+ return self.num_patches_per_side**2
113
+
114
+ @property
115
+ def num_patches_per_side(self) -> int:
116
+ size = self.image_size
117
+ patch = self._vision_cfg.get("patch_size", 14)
118
+ return size // patch
119
+
120
+ @property
121
+ def image_size(self) -> int:
122
+ return int(self._vision_cfg.get("image_size", 448))
123
+
124
+
125
+ __all__ = ["AutoEvaClipVisionTower", "default_device"]
126