Fix runtime buffers after load

#36
by err805 - opened
Files changed (2) hide show
  1. hf_moondream.py +0 -7
  2. moondream.py +0 -21
hf_moondream.py CHANGED
@@ -45,13 +45,6 @@ class HfMoondream(PreTrainedModel):
45
  self._is_kv_cache_setup = False
46
  self.post_init()
47
 
48
- @classmethod
49
- def from_pretrained(cls, *args, **kwargs):
50
- output = super().from_pretrained(*args, **kwargs)
51
- model = output[0] if isinstance(output, tuple) else output
52
- model.model._refresh_runtime_buffers()
53
- return output
54
-
55
  def _setup_caches(self):
56
  if not self._is_kv_cache_setup:
57
  self.model._setup_caches()
 
45
  self._is_kv_cache_setup = False
46
  self.post_init()
47
 
 
 
 
 
 
 
 
48
  def _setup_caches(self):
49
  if not self._is_kv_cache_setup:
50
  self.model._setup_caches()
moondream.py CHANGED
@@ -22,7 +22,6 @@ from .region import (
22
  )
23
  from .layers import QuantizedLinear
24
  from .lora import load_adapter, normalize_adapter_id
25
- from .rope import precompute_freqs_cis
26
  from .utils import remove_outlier_points
27
 
28
  ImageEncodingSettings = TypedDict(
@@ -172,26 +171,6 @@ class MoondreamModel(nn.Module):
172
  )
173
  return self._point_gen_indices
174
 
175
- def _refresh_runtime_buffers(self):
176
- attn_mask = torch.tril(
177
- torch.ones(
178
- 1,
179
- 1,
180
- self.config.text.max_context,
181
- self.config.text.max_context,
182
- dtype=torch.bool,
183
- device=self.device,
184
- )
185
- )
186
- patch_w = self.config.vision.crop_size // self.config.vision.enc_patch_size
187
- prefix_attn_len = 1 + patch_w**2
188
- attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
189
- self.attn_mask = attn_mask
190
- self.text.freqs_cis = precompute_freqs_cis(
191
- self.config.text.dim // (2 * self.config.text.n_heads),
192
- self.config.text.max_context,
193
- ).to(device=self.device)
194
-
195
  def _setup_caches(self):
196
  c = self.config.text
197
  for b in self.text.blocks:
 
22
  )
23
  from .layers import QuantizedLinear
24
  from .lora import load_adapter, normalize_adapter_id
 
25
  from .utils import remove_outlier_points
26
 
27
  ImageEncodingSettings = TypedDict(
 
171
  )
172
  return self._point_gen_indices
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  def _setup_caches(self):
175
  c = self.config.text
176
  for b in self.text.blocks: