johnmalek312 commited on
Commit
d3f5f07
·
1 Parent(s): e082327

clear gitignore cache

Browse files
.gitignore DELETED
@@ -1,14 +0,0 @@
1
- model.safetensors
2
-
3
- # python and vs code stuff
4
- __pycache__/
5
- .vscode/
6
- .env
7
- .venv/
8
- .DS_Store
9
- .env.local
10
- .env.development.local
11
- .env.test.local
12
- .env.production.local
13
- # venv
14
- venv/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
example.png DELETED

Git LFS Details

  • SHA256: a2c550ca6de1dc8589cb6eccf1e24fc455c5d46b49597066f3027a9d831a8aa5
  • Pointer size: 131 Bytes
  • Size of remote file: 747 kB
moondream2/config.json DELETED
@@ -1,13 +0,0 @@
1
- {
2
- "architectures": [
3
- "HfMoondream"
4
- ],
5
- "auto_map": {
6
- "AutoConfig": "hf_moondream.HfConfig",
7
- "AutoModelForCausalLM": "hf_moondream.HfMoondream"
8
- },
9
- "config": {},
10
- "model_type": "moondream1",
11
- "torch_dtype": "float16",
12
- "transformers_version": "4.44.0"
13
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
moondream2/config.py DELETED
@@ -1,86 +0,0 @@
1
- from dataclasses import dataclass, field
2
- from typing import Dict, List, Optional
3
-
4
-
5
- @dataclass(frozen=True)
6
- class TextConfig:
7
- dim: int = 2048
8
- ff_dim: int = 8192
9
- n_layers: int = 24
10
- vocab_size: int = 51200
11
- max_context: int = 2048
12
- n_heads: int = 32
13
- n_kv_heads: int = 32
14
- prefix_attn: int = 730
15
-
16
-
17
- @dataclass(frozen=True)
18
- class VisionConfig:
19
- enc_dim: int = 1152
20
- enc_patch_size: int = 14
21
- enc_n_layers: int = 27
22
- enc_ff_dim: int = 4304
23
- enc_n_heads: int = 16
24
- proj_out_dim: int = 2048
25
- crop_size: int = 378
26
- in_channels: int = 3
27
- max_crops: int = 12
28
- overlap_margin: int = 4
29
- proj_inner_dim: int = 8192
30
-
31
-
32
- @dataclass(frozen=True)
33
- class RegionConfig:
34
- dim: int = 2048
35
- coord_feat_dim: int = 256
36
- coord_out_dim: int = 1024
37
- size_feat_dim: int = 512
38
- size_out_dim: int = 2048
39
- inner_dim: int = 8192
40
-
41
-
42
- @dataclass(frozen=True)
43
- class TokenizerConfig:
44
- bos_id: int = 50256
45
- eos_id: int = 50256
46
- templates: Dict[str, Optional[Dict[str, List[int]]]] = field(
47
- default_factory=lambda: {
48
- "caption": {
49
- "short": [198, 198, 16438, 8305, 25],
50
- "normal": [198, 198, 24334, 1159, 25],
51
- "long": [198, 198, 14617, 8305, 25],
52
- },
53
- "query": {"prefix": [198, 198, 24361, 25], "suffix": [198, 198, 33706, 25]},
54
- "detect": {"prefix": [198, 198, 47504, 25], "suffix": [628]},
55
- "point": {"prefix": [198, 198, 12727, 25], "suffix": [628]},
56
- }
57
- )
58
-
59
-
60
- @dataclass(frozen=True)
61
- class MoondreamConfig:
62
- text: TextConfig = TextConfig()
63
- vision: VisionConfig = VisionConfig()
64
- region: RegionConfig = RegionConfig()
65
- tokenizer: TokenizerConfig = TokenizerConfig()
66
-
67
- @classmethod
68
- def from_dict(cls, config_dict: dict):
69
- text_config = TextConfig(**config_dict.get("text", {}))
70
- vision_config = VisionConfig(**config_dict.get("vision", {}))
71
- region_config = RegionConfig(**config_dict.get("region", {}))
72
- tokenizer_config = TokenizerConfig(**config_dict.get("tokenizer", {}))
73
- return cls(
74
- text=text_config,
75
- vision=vision_config,
76
- region=region_config,
77
- tokenizer=tokenizer_config,
78
- )
79
-
80
- def to_dict(self):
81
- return {
82
- "text": self.text.__dict__,
83
- "vision": self.vision.__dict__,
84
- "region": self.region.__dict__,
85
- "tokenizer": self.tokenizer.__dict__,
86
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
moondream2/hf_moondream.py DELETED
@@ -1,132 +0,0 @@
1
- from transformers import PreTrainedModel, PretrainedConfig
2
-
3
- from .config import MoondreamConfig
4
- from .moondream import MoondreamModel
5
-
6
- # Files sometimes don't get loaded without these...
7
- from .image_crops import *
8
- from .vision import *
9
- from .text import *
10
- from .region import *
11
- from .utils import *
12
-
13
-
14
- def extract_question(text):
15
- prefix = "<image>\n\nQuestion: "
16
- suffix = "\n\nAnswer:"
17
-
18
- if text.startswith(prefix) and text.endswith(suffix):
19
- return text[len(prefix) : -len(suffix)]
20
- else:
21
- return None
22
-
23
-
24
- class HfConfig(PretrainedConfig):
25
- _auto_class = "AutoConfig"
26
- model_type = "moondream1"
27
-
28
- def __init__(self, **kwargs):
29
- super().__init__(**kwargs)
30
- self.config = {}
31
-
32
-
33
- class HfMoondream(PreTrainedModel):
34
- _supports_gradient_checkpointing = True
35
- _auto_class = "AutoModelForCausalLM"
36
- config_class = HfConfig
37
-
38
- def __init__(self, config):
39
- super().__init__(config)
40
- self.model = MoondreamModel(
41
- MoondreamConfig.from_dict(config.config), setup_caches=False
42
- )
43
- self.model._setup_caches()
44
-
45
- @property
46
- def encode_image(self):
47
- return self.model.encode_image
48
-
49
- @property
50
- def query(self):
51
- return self.model.query
52
-
53
- @property
54
- def caption(self):
55
- return self.model.caption
56
-
57
- @property
58
- def detect(self):
59
- return self.model.detect
60
-
61
- @property
62
- def point(self):
63
- return self.model.point
64
-
65
- @property
66
- def detect_gaze(self):
67
- return self.model.detect_gaze
68
-
69
- def answer_question(
70
- self,
71
- image_embeds,
72
- question,
73
- tokenizer=None,
74
- chat_history="",
75
- result_queue=None,
76
- max_new_tokens=256,
77
- **kwargs
78
- ):
79
- answer = self.query(image_embeds, question)["answer"].strip()
80
-
81
- if result_queue is not None:
82
- result_queue.put(answer)
83
- return answer
84
-
85
- def batch_answer(self, images, prompts, tokenizer=None, **kwargs):
86
- answers = []
87
- for image, prompt in zip(images, prompts):
88
- answers.append(self.query(image, prompt)["answer"].strip())
89
- return answers
90
-
91
- def _unsupported_exception(self):
92
- raise NotImplementedError(
93
- "This method is not supported in the latest version of moondream. "
94
- "Consider upgrading to the updated API spec, or alternately pin "
95
- "to 'revision=2024-08-26'."
96
- )
97
-
98
- def generate(self, image_embeds, prompt, tokenizer, max_new_tokens=128, **kwargs):
99
- """
100
- Function definition remains unchanged for backwards compatibility.
101
- Be aware that tokenizer, max_new_takens, and kwargs are ignored.
102
- """
103
- prompt_extracted = extract_question(prompt)
104
- if prompt_extracted is not None:
105
- answer = self.model.query(
106
- image=image_embeds, question=prompt_extracted, stream=False
107
- )["answer"]
108
- else:
109
- image_embeds = self.encode_image(image_embeds)
110
- prompt_tokens = torch.tensor(
111
- [self.model.tokenizer.encode(prompt).ids],
112
- device=self.device,
113
- )
114
-
115
- def generator():
116
- for token in self.model._generate_text(
117
- prompt_tokens,
118
- image_embeds.kv_cache,
119
- image_embeds.pos,
120
- max_new_tokens,
121
- ):
122
- yield token
123
-
124
- answer = "".join(list(generator()))
125
-
126
- return [answer]
127
-
128
- def get_input_embeddings(self):
129
- return super().get_input_embeddings()
130
-
131
- def input_embeds(self, *args, **kwargs):
132
- self._unsupported_exception()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
moondream2/image_crops.py DELETED
@@ -1,208 +0,0 @@
1
- import math
2
- import numpy as np
3
- import torch
4
- import pyvips
5
-
6
- from typing import TypedDict
7
-
8
-
9
- def select_tiling(
10
- height: int, width: int, crop_size: int, max_crops: int
11
- ) -> tuple[int, int]:
12
- """
13
- Determine the optimal number of tiles to cover an image with overlapping crops.
14
- """
15
- if height <= crop_size or width <= crop_size:
16
- return (1, 1)
17
-
18
- # Minimum required tiles in each dimension
19
- min_h = math.ceil(height / crop_size)
20
- min_w = math.ceil(width / crop_size)
21
-
22
- # If minimum required tiles exceed max_crops, return proportional distribution
23
- if min_h * min_w > max_crops:
24
- ratio = math.sqrt(max_crops / (min_h * min_w))
25
- return (max(1, math.floor(min_h * ratio)), max(1, math.floor(min_w * ratio)))
26
-
27
- # Perfect aspect-ratio tiles that satisfy max_crops
28
- h_tiles = math.floor(math.sqrt(max_crops * height / width))
29
- w_tiles = math.floor(math.sqrt(max_crops * width / height))
30
-
31
- # Ensure we meet minimum tile requirements
32
- h_tiles = max(h_tiles, min_h)
33
- w_tiles = max(w_tiles, min_w)
34
-
35
- # If we exceeded max_crops, scale down the larger dimension
36
- if h_tiles * w_tiles > max_crops:
37
- if w_tiles > h_tiles:
38
- w_tiles = math.floor(max_crops / h_tiles)
39
- else:
40
- h_tiles = math.floor(max_crops / w_tiles)
41
-
42
- return (max(1, h_tiles), max(1, w_tiles))
43
-
44
-
45
- class OverlapCropOutput(TypedDict):
46
- crops: np.ndarray
47
- tiling: tuple[int, int]
48
-
49
-
50
- def overlap_crop_image(
51
- image: np.ndarray,
52
- overlap_margin: int,
53
- max_crops: int,
54
- base_size: tuple[int, int] = (378, 378),
55
- patch_size: int = 14,
56
- ) -> OverlapCropOutput:
57
- """
58
- Process an image using an overlap-and-resize cropping strategy with margin handling.
59
-
60
- This function takes an input image and creates multiple overlapping crops with
61
- consistent margins. It produces:
62
- 1. A single global crop resized to base_size
63
- 2. Multiple overlapping local crops that maintain high resolution details
64
- 3. A patch ordering matrix that tracks correspondence between crops
65
-
66
- The overlap strategy ensures:
67
- - Smooth transitions between adjacent crops
68
- - No loss of information at crop boundaries
69
- - Proper handling of features that cross crop boundaries
70
- - Consistent patch indexing across the full image
71
-
72
- Args:
73
- image (np.ndarray): Input image as numpy array with shape (H,W,C)
74
- base_size (tuple[int,int]): Target size for crops, default (378,378)
75
- patch_size (int): Size of patches in pixels, default 14
76
- overlap_margin (int): Margin size in patch units, default 4
77
- max_crops (int): Maximum number of crops allowed, default 12
78
-
79
- Returns:
80
- OverlapCropOutput: Dictionary containing:
81
- - crops: A numpy array containing the global crop of the full image (index 0)
82
- followed by the overlapping cropped regions (indices 1+)
83
- - tiling: Tuple of (height,width) tile counts
84
- """
85
- original_h, original_w = image.shape[:2]
86
-
87
- # Convert margin from patch units to pixels
88
- margin_pixels = patch_size * overlap_margin
89
- total_margin_pixels = margin_pixels * 2 # Both sides
90
-
91
- # Calculate crop parameters
92
- crop_patches = base_size[0] // patch_size # patches per crop dimension
93
- crop_window_patches = crop_patches - (2 * overlap_margin) # usable patches
94
- crop_window_size = crop_window_patches * patch_size # usable size in pixels
95
-
96
- # Determine tiling
97
- tiling = select_tiling(
98
- original_h - total_margin_pixels,
99
- original_w - total_margin_pixels,
100
- crop_window_size,
101
- max_crops,
102
- )
103
-
104
- # Pre-allocate crops.
105
- n_crops = tiling[0] * tiling[1] + 1 # 1 = global crop
106
- crops = np.zeros(
107
- (n_crops, base_size[0], base_size[1], image.shape[2]), dtype=np.uint8
108
- )
109
-
110
- # Resize image to fit tiling
111
- target_size = (
112
- tiling[0] * crop_window_size + total_margin_pixels,
113
- tiling[1] * crop_window_size + total_margin_pixels,
114
- )
115
-
116
- # Convert to vips for resizing
117
- vips_image = pyvips.Image.new_from_array(image)
118
- scale_x = target_size[1] / image.shape[1]
119
- scale_y = target_size[0] / image.shape[0]
120
- resized = vips_image.resize(scale_x, vscale=scale_y)
121
- image = resized.numpy()
122
-
123
- # Create global crop
124
- scale_x = base_size[1] / vips_image.width
125
- scale_y = base_size[0] / vips_image.height
126
- global_vips = vips_image.resize(scale_x, vscale=scale_y)
127
- crops[0] = global_vips.numpy()
128
-
129
- for i in range(tiling[0]):
130
- for j in range(tiling[1]):
131
- # Calculate crop coordinates
132
- y0 = i * crop_window_size
133
- x0 = j * crop_window_size
134
-
135
- # Extract crop with padding if needed
136
- y_end = min(y0 + base_size[0], image.shape[0])
137
- x_end = min(x0 + base_size[1], image.shape[1])
138
-
139
- crop_region = image[y0:y_end, x0:x_end]
140
- crops[
141
- 1 + i * tiling[1] + j, : crop_region.shape[0], : crop_region.shape[1]
142
- ] = crop_region
143
-
144
- return {"crops": crops, "tiling": tiling}
145
-
146
-
147
- def reconstruct_from_crops(
148
- crops: torch.Tensor,
149
- tiling: tuple[int, int],
150
- overlap_margin: int,
151
- patch_size: int = 14,
152
- ) -> torch.Tensor:
153
- """
154
- Reconstruct the original image from overlapping crops into a single seamless image.
155
-
156
- Takes a list of overlapping image crops along with their positional metadata and
157
- reconstructs them into a single coherent image by carefully stitching together
158
- non-overlapping regions. Handles both numpy arrays and PyTorch tensors.
159
-
160
- Args:
161
- crops: List of image crops as numpy arrays or PyTorch tensors with shape
162
- (H,W,C)
163
- tiling: Tuple of (height,width) indicating crop grid layout
164
- patch_size: Size in pixels of each patch, default 14
165
- overlap_margin: Number of overlapping patches on each edge, default 4
166
-
167
- Returns:
168
- Reconstructed image as numpy array or PyTorch tensor matching input type,
169
- with shape (H,W,C) where H,W are the original image dimensions
170
- """
171
- tiling_h, tiling_w = tiling
172
- crop_height, crop_width = crops[0].shape[:2]
173
- margin_pixels = overlap_margin * patch_size
174
-
175
- # Calculate output size (only adding margins once)
176
- output_h = (crop_height - 2 * margin_pixels) * tiling_h + 2 * margin_pixels
177
- output_w = (crop_width - 2 * margin_pixels) * tiling_w + 2 * margin_pixels
178
-
179
- reconstructed = torch.zeros(
180
- (output_h, output_w, crops[0].shape[2]),
181
- device=crops[0].device,
182
- dtype=crops[0].dtype,
183
- )
184
-
185
- for i, crop in enumerate(crops):
186
- tile_y = i // tiling_w
187
- tile_x = i % tiling_w
188
-
189
- # For each tile, determine which part to keep
190
- # Keep left margin only for first column
191
- x_start = 0 if tile_x == 0 else margin_pixels
192
- # Keep right margin only for last column
193
- x_end = crop_width if tile_x == tiling_w - 1 else crop_width - margin_pixels
194
- # Keep top margin only for first row
195
- y_start = 0 if tile_y == 0 else margin_pixels
196
- # Keep bottom margin only for last row
197
- y_end = crop_height if tile_y == tiling_h - 1 else crop_height - margin_pixels
198
-
199
- # Calculate where this piece belongs in the output
200
- out_x = tile_x * (crop_width - 2 * margin_pixels)
201
- out_y = tile_y * (crop_height - 2 * margin_pixels)
202
-
203
- # Place the piece
204
- reconstructed[
205
- out_y + y_start : out_y + y_end, out_x + x_start : out_x + x_end
206
- ] = crop[y_start:y_end, x_start:x_end]
207
-
208
- return reconstructed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
moondream2/layers.py DELETED
@@ -1,63 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Literal
3
-
4
- import torch
5
- from torch.nn import functional as F
6
-
7
-
8
- def gelu_approx(x):
9
- return F.gelu(x, approximate="tanh")
10
-
11
-
12
- @dataclass
13
- class LinearWeights:
14
- weight: torch.Tensor
15
- bias: torch.Tensor
16
-
17
-
18
- def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
19
- return F.linear(x, w.weight, w.bias)
20
-
21
-
22
- @dataclass
23
- class LayerNormWeights:
24
- weight: torch.Tensor
25
- bias: torch.Tensor
26
-
27
-
28
- def layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor:
29
- return F.layer_norm(x, w.bias.shape, w.weight, w.bias)
30
-
31
-
32
- @dataclass
33
- class MLPWeights:
34
- fc1: LinearWeights
35
- fc2: LinearWeights
36
- act: Literal["gelu_approx"] = "gelu_approx"
37
-
38
-
39
- def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor:
40
- x = w.fc1(x)
41
- x = gelu_approx(x)
42
- x = w.fc2(x)
43
- return x
44
-
45
-
46
- @dataclass
47
- class AttentionWeights:
48
- qkv: LinearWeights
49
- proj: LinearWeights
50
-
51
-
52
- def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor:
53
- bsz, q_len, d_model = x.shape
54
- head_dim = d_model // n_heads
55
-
56
- q, k, v = [
57
- t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
58
- for t in linear(x, w.qkv).chunk(3, dim=-1)
59
- ]
60
- out = F.scaled_dot_product_attention(q, k, v)
61
- out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
62
- out = linear(out, w.proj)
63
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
moondream2/moondream.py DELETED
@@ -1,476 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import random
4
-
5
- from typing import Literal, Tuple, TypedDict, Union, Dict, Any, Optional, List
6
- from PIL import Image
7
- from dataclasses import dataclass
8
- from tokenizers import Tokenizer
9
-
10
- from .config import MoondreamConfig
11
- from .image_crops import reconstruct_from_crops
12
- from .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model
13
- from .text import build_text_model, text_encoder, lm_head, text_decoder
14
- from .region import decode_coordinate, encode_coordinate, decode_size, encode_size
15
- from .utils import remove_outlier_points
16
- import os
17
-
18
- TextSamplingSettings = TypedDict(
19
- "TextSamplingSettings",
20
- {
21
- "max_tokens": int,
22
- "temperature": float,
23
- "top_p": float,
24
- },
25
- total=False,
26
- )
27
-
28
- ObjectSamplingSettings = TypedDict(
29
- "ObjectSamplingSettings",
30
- {"max_objects": int},
31
- total=False,
32
- )
33
-
34
- DEFAULT_MAX_TOKENS = 768
35
- DEFAULT_TEMPERATURE = 0.5
36
- DEFAULT_TOP_P = 0.3
37
- DEFAULT_MAX_OBJECTS = 50
38
-
39
-
40
- @dataclass(frozen=True)
41
- class EncodedImage:
42
- pos: int
43
- caches: List[Tuple[torch.Tensor, torch.Tensor]]
44
-
45
-
46
- class KVCache(nn.Module):
47
-
48
- def __init__(self, n_heads, n_kv_heads, max_context, dim, device, dtype):
49
- super().__init__()
50
- cache_shape = (1, n_kv_heads, max_context, dim // n_heads)
51
- self.register_buffer(
52
- "k_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
53
- )
54
- self.register_buffer(
55
- "v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
56
- )
57
-
58
- def update(self, pos_ids, k, v):
59
- kout, vout = self.k_cache, self.v_cache
60
- kout[:, :, pos_ids, :] = k
61
- vout[:, :, pos_ids, :] = v
62
- return kout, vout
63
-
64
-
65
- class MoondreamModel(nn.Module):
66
- def __init__(self, config: MoondreamConfig, dtype=torch.float16, setup_caches=True):
67
- super().__init__()
68
- self.config = config
69
- current_dir = os.path.dirname(os.path.abspath(__file__))
70
- self.tokenizer = Tokenizer.from_file(os.path.join(current_dir, "tokenizer.json"))
71
- self.vision = build_vision_model(config.vision, dtype)
72
- self.text = build_text_model(config.text, dtype)
73
-
74
- # Region Model
75
- self.region = nn.ModuleDict(
76
- {
77
- "coord_encoder": nn.Linear(
78
- config.region.coord_feat_dim, config.region.dim, dtype=dtype
79
- ),
80
- "coord_decoder": nn.ModuleDict(
81
- {
82
- "fc1": nn.Linear(
83
- config.region.dim, config.region.inner_dim, dtype=dtype
84
- ),
85
- "fc2": nn.Linear(
86
- config.region.inner_dim,
87
- config.region.coord_out_dim,
88
- dtype=dtype,
89
- ),
90
- }
91
- ),
92
- "size_encoder": nn.Linear(
93
- config.region.size_feat_dim, config.region.dim, dtype=dtype
94
- ),
95
- "size_decoder": nn.ModuleDict(
96
- {
97
- "fc1": nn.Linear(
98
- config.region.dim, config.region.inner_dim, dtype=dtype
99
- ),
100
- "fc2": nn.Linear(
101
- config.region.inner_dim,
102
- config.region.size_out_dim,
103
- dtype=dtype,
104
- ),
105
- }
106
- ),
107
- }
108
- )
109
- self.region.coord_features = nn.Parameter(
110
- torch.empty(config.region.coord_feat_dim // 2, 1, dtype=dtype).T
111
- )
112
- self.region.size_features = nn.Parameter(
113
- torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T
114
- )
115
-
116
- attn_mask = torch.tril(
117
- torch.ones(
118
- 1, 1, config.text.max_context, config.text.max_context, dtype=torch.bool
119
- )
120
- )
121
- patch_w = config.vision.crop_size // config.vision.enc_patch_size
122
- prefix_attn_len = 1 + patch_w**2
123
- attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
124
- self.register_buffer("attn_mask", attn_mask, persistent=False)
125
-
126
- # Initialize KV caches.
127
- if setup_caches:
128
- self._setup_caches()
129
-
130
- def _setup_caches(self):
131
- c = self.config.text
132
- for b in self.text.blocks:
133
- b.kv_cache = KVCache(
134
- c.n_heads,
135
- c.n_kv_heads,
136
- c.max_context,
137
- c.dim,
138
- device=self.device,
139
- dtype=self.vision.pos_emb.dtype,
140
- )
141
- @property
142
- def device(self):
143
- return self.vision.pos_emb.device
144
-
145
- def _vis_enc(self, x: torch.Tensor):
146
- return vision_encoder(x, self.vision, self.config.vision)
147
-
148
- def _vis_proj(self, g: torch.Tensor, r: torch.Tensor):
149
- return vision_projection(g, r, self.vision, self.config.vision)
150
-
151
- def _prefill(self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor):
152
- return text_decoder(x, self.text, attn_mask, pos_ids, self.config.text)
153
-
154
- def _decode_one_tok(
155
- self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor
156
- ):
157
- hidden = text_decoder(x, self.text, attn_mask, pos_ids, self.config.text)
158
- logits = lm_head(hidden, self.text)
159
- return logits, hidden
160
-
161
- def compile(self):
162
- # TODO: vision_projection is not being compiled
163
- self._vis_enc = torch.compile(self._vis_enc, fullgraph=True)
164
- self._prefill = torch.compile(self._prefill, fullgraph=True)
165
- self._decode_one_tok = torch.compile(
166
- self._decode_one_tok, fullgraph=True, mode="reduce-overhead"
167
- )
168
-
169
- def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
170
- all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
171
- torch._dynamo.mark_dynamic(all_crops, 0)
172
-
173
- outputs = self._vis_enc(all_crops)
174
-
175
- global_features = outputs[0]
176
- local_features = outputs[1:].view(
177
- -1,
178
- self.config.vision.enc_n_layers,
179
- self.config.vision.enc_n_layers,
180
- self.config.vision.enc_dim,
181
- )
182
-
183
- reconstructed = reconstruct_from_crops(
184
- local_features,
185
- tiling,
186
- patch_size=1,
187
- overlap_margin=self.config.vision.overlap_margin,
188
- )
189
-
190
- return self._vis_proj(global_features, reconstructed)
191
-
192
- def encode_image(self, image: Union[Image.Image, EncodedImage]) -> EncodedImage:
193
- if isinstance(image, EncodedImage):
194
- return image
195
- elif not isinstance(image, Image.Image):
196
- raise ValueError("image must be a PIL Image or EncodedImage")
197
-
198
- # Run through text model in addition to the vision encoder, to minimize
199
- # re-computation if multiple queries are performed on this image.
200
- with torch.inference_mode():
201
-
202
- bos = torch.tensor([[self.config.tokenizer.bos_id]], device=self.device)
203
-
204
- img_emb = self._run_vision_encoder(image)
205
- bos_emb = text_encoder(
206
- bos,
207
- self.text,
208
- )
209
- inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
210
- mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
211
- pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long, device=self.device)
212
- self._prefill(inputs_embeds, mask, pos_ids)
213
-
214
- return EncodedImage(
215
- pos=inputs_embeds.size(1),
216
- caches=[
217
- (
218
- b.kv_cache.k_cache[:, :, : inputs_embeds.size(1), :].clone(),
219
- b.kv_cache.v_cache[:, :, : inputs_embeds.size(1), :].clone(),
220
- )
221
- for b in self.text.blocks
222
- ],
223
- )
224
-
225
- def _apply_top_p(self, probs: torch.Tensor, top_p: float):
226
- probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
227
- probs_sum = torch.cumsum(probs_sort, dim=-1)
228
- mask = probs_sum - probs_sort > top_p
229
- probs_sort[mask] = 0.0
230
- probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
231
- next_probs = torch.zeros_like(probs)
232
- next_probs.scatter_(dim=-1, index=probs_idx, src=probs_sort)
233
- return next_probs
234
-
235
- def _prefill_prompt(
236
- self, prompt_tokens: torch.Tensor, pos: int, temperature: float, top_p: float
237
- ):
238
- with torch.inference_mode():
239
- prompt_emb = text_encoder(prompt_tokens, self.text)
240
- torch._dynamo.mark_dynamic(prompt_emb, 1)
241
- mask = self.attn_mask[:, :, pos : pos + prompt_emb.size(1), :]
242
- pos_ids = torch.arange(pos, pos + prompt_emb.size(1), dtype=torch.long, device=self.device)
243
- hidden = self._prefill(prompt_emb, mask, pos_ids)
244
- logits = lm_head(hidden, self.text)
245
-
246
- if temperature == 0:
247
- next_token = torch.argmax(logits, dim=-1).unsqueeze(1)
248
- else:
249
- probs = torch.softmax(logits / temperature, dim=-1)
250
- probs = self._apply_top_p(probs, top_p)
251
- next_token = torch.multinomial(probs, num_samples=1)
252
-
253
- pos = pos + prompt_emb.size(1)
254
- return logits, hidden, next_token, pos
255
-
256
- def _generate_text(
257
- self,
258
- prompt_tokens: torch.Tensor,
259
- pos: int,
260
- settings: Optional[TextSamplingSettings] = None,
261
- ):
262
- max_tokens = (
263
- settings.get("max_tokens", DEFAULT_MAX_TOKENS)
264
- if settings
265
- else DEFAULT_MAX_TOKENS
266
- )
267
- temperature = (
268
- settings.get("temperature", DEFAULT_TEMPERATURE)
269
- if settings
270
- else DEFAULT_TEMPERATURE
271
- )
272
- top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
273
-
274
- _, _, next_token, pos = self._prefill_prompt(
275
- prompt_tokens, pos, temperature, top_p
276
- )
277
-
278
- def generator(next_token, pos):
279
- mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool)
280
- mask[:, :, :pos] = 1
281
- pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
282
- generated_tokens = 0
283
-
284
- # For properly handling token streaming with Unicode
285
- token_cache = []
286
- print_len = 0
287
-
288
- while (
289
- next_token_id := next_token.item()
290
- ) != self.config.tokenizer.eos_id and generated_tokens < max_tokens:
291
- # Add token to our cache
292
- token_cache.append(next_token_id)
293
-
294
- # Decode all tokens collected so far
295
- text = self.tokenizer.decode(token_cache)
296
-
297
- # After a newline, we flush the cache completely
298
- if text.endswith("\n"):
299
- printable_text = text[print_len:]
300
- token_cache = []
301
- print_len = 0
302
- if printable_text:
303
- yield printable_text
304
- # If the last token is a CJK character, we can safely print it
305
- elif len(text) > 0 and _is_cjk_char(ord(text[-1])):
306
- printable_text = text[print_len:]
307
- print_len += len(printable_text)
308
- if printable_text:
309
- yield printable_text
310
- # Otherwise, only print up to the last space to avoid cutting words
311
- else:
312
- last_space_idx = text.rfind(" ", print_len)
313
- if last_space_idx >= print_len:
314
- printable_text = text[print_len : last_space_idx + 1]
315
- print_len += len(printable_text)
316
- if printable_text:
317
- yield printable_text
318
-
319
- with torch.inference_mode():
320
- next_emb = text_encoder(next_token, self.text)
321
- mask[:, :, pos], pos_ids[0] = 1, pos
322
- logits, _ = self._decode_one_tok(next_emb, mask, pos_ids)
323
- pos += 1
324
-
325
- if temperature == 0:
326
- next_token = torch.argmax(logits, dim=-1).unsqueeze(1) # (1, 1)
327
- else:
328
- probs = torch.softmax(logits / temperature, dim=-1) # (1, V)
329
- probs = self._apply_top_p(probs, top_p)
330
- next_token = torch.multinomial(probs, num_samples=1) # (1, 1)
331
-
332
- generated_tokens += 1
333
-
334
- # Flush any remaining text in the cache
335
- if token_cache:
336
- text = self.tokenizer.decode(token_cache)
337
- printable_text = text[print_len:]
338
- if printable_text:
339
- yield printable_text
340
-
341
- return generator(next_token, pos)
342
-
343
- def load_encoded_image(self, encoded_image: EncodedImage):
344
- for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
345
- b.kv_cache.k_cache[:, :, : k.size(2), :] = k
346
- b.kv_cache.v_cache[:, :, : v.size(2), :] = v
347
-
348
- def _generate_points(
349
- self,
350
- hidden: torch.Tensor,
351
- next_token: torch.Tensor,
352
- pos: int,
353
- include_size: bool = True,
354
- max_objects: int = DEFAULT_MAX_OBJECTS,
355
- ):
356
- out = []
357
- mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool)
358
- mask[:, :, :pos] = 1
359
- pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
360
-
361
- with torch.inference_mode():
362
- while (
363
- next_token.item() != self.config.tokenizer.eos_id
364
- and len(out) < max_objects
365
- ):
366
- x_logits = decode_coordinate(hidden, self.region)
367
- x_center = torch.argmax(x_logits, dim=-1) / x_logits.size(-1)
368
- next_emb = encode_coordinate(
369
- x_center.to(dtype=x_logits.dtype), self.region
370
- ).unsqueeze(0)
371
-
372
- # Decode y-coordinate
373
- mask[:, :, pos], pos_ids[0] = 1, pos
374
- _, hidden = self._decode_one_tok(next_emb, mask, pos_ids)
375
- pos += 1
376
- y_logits = decode_coordinate(hidden, self.region)
377
- y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)
378
- next_emb = encode_coordinate(
379
- y_center.to(dtype=y_logits.dtype), self.region
380
- ).unsqueeze(0)
381
-
382
- # Decode size
383
- if include_size:
384
- mask[:, :, pos], pos_ids[0] = 1, pos
385
- logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids)
386
- pos += 1
387
- size_logits = decode_size(hidden, self.region)
388
-
389
- # Get bin indices from the logits
390
- w_bin = torch.argmax(size_logits[0], dim=-1)
391
- h_bin = torch.argmax(size_logits[1], dim=-1)
392
-
393
- # Convert from bin indices to actual size values using the inverse of the log-scale mapping
394
- # Formula: size = 2^((bin / 1023.0) * 10.0 - 10.0)
395
- w = torch.pow(2.0, (w_bin.float() / 1023.0) * 10.0 - 10.0)
396
- h = torch.pow(2.0, (h_bin.float() / 1023.0) * 10.0 - 10.0)
397
-
398
- next_emb = (
399
- encode_size(
400
- torch.tensor(
401
- [w, h], device=self.device, dtype=size_logits.dtype
402
- ),
403
- self.region,
404
- )
405
- .unsqueeze(0)
406
- .unsqueeze(0)
407
- )
408
-
409
- # Add object
410
- out.append(
411
- {
412
- "x_min": x_center.item() - w.item() / 2,
413
- "y_min": y_center.item() - h.item() / 2,
414
- "x_max": x_center.item() + w.item() / 2,
415
- "y_max": y_center.item() + h.item() / 2,
416
- }
417
- )
418
- else:
419
- out.append({"x": x_center.item(), "y": y_center.item()})
420
-
421
- # Decode next token (x-coordinate, or eos)
422
- mask[:, :, pos], pos_ids[0] = 1, pos
423
- logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids)
424
- pos += 1
425
- next_token = torch.argmax(logits, dim=-1)
426
-
427
- return out
428
-
429
- def point(
430
- self,
431
- image: Union[Image.Image, EncodedImage],
432
- object: str,
433
- settings: Optional[ObjectSamplingSettings] = None,
434
- ):
435
- if self.config.tokenizer.templates["point"] is None:
436
- raise NotImplementedError("Model does not support pointing.")
437
-
438
- image = self.encode_image(image)
439
- self.load_encoded_image(image)
440
-
441
- prompt_tokens = torch.tensor(
442
- [
443
- self.config.tokenizer.templates["point"]["prefix"]
444
- + self.tokenizer.encode(" " + object).ids
445
- + self.config.tokenizer.templates["point"]["suffix"]
446
- ],
447
- device=self.device,
448
- )
449
-
450
- _, hidden, next_token, pos = self._prefill_prompt(
451
- prompt_tokens, image.pos, temperature=0, top_p=0
452
- )
453
- hidden = hidden[:, -1:, :]
454
-
455
- max_objects = (
456
- settings.get("max_objects", DEFAULT_MAX_OBJECTS)
457
- if settings
458
- else DEFAULT_MAX_OBJECTS
459
- )
460
- objects = self._generate_points(
461
- hidden, next_token, pos, include_size=False, max_objects=max_objects
462
- )
463
-
464
- return {"points": objects}
465
-
466
- def _is_cjk_char(cp):
467
- """Checks whether CP is the codepoint of a CJK character."""
468
- # This defines a "chinese character" as anything in the CJK Unicode block:
469
- # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
470
- if (
471
- (cp >= 0x4E00 and cp <= 0x9FFF)
472
- or (cp >= 0x3400 and cp <= 0x4DBF)
473
- or (cp >= 0x2F800 and cp <= 0x2FA1F)
474
- ):
475
- return True
476
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
moondream2/region.py DELETED
@@ -1,89 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import math
4
-
5
- from .layers import linear, mlp
6
-
7
-
8
- def fourier_features(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
9
- """
10
- Applies Fourier feature mapping to input tensor x using frequency matrix w. This
11
- projects inputs through sinusoidal functions to create higher dimensional features
12
- that help mitigate spectral bias - the tendency of neural networks to learn
13
- low-frequency functions more easily than high-frequency ones. By explicitly
14
- mapping inputs to higher frequencies through sin/cos transformations, we enable
15
- better learning of fine details and higher frequency patterns.
16
-
17
- Args:
18
- x: Input tensor to transform
19
- w: Matrix of frequencies for the Fourier features transformation
20
-
21
- Returns:
22
- Concatenated cosine and sine transformed features as a tensor
23
- """
24
- f = 2 * math.pi * x @ w
25
- return torch.cat([f.cos(), f.sin()], dim=-1)
26
-
27
-
28
- def encode_coordinate(coord: torch.Tensor, w: nn.Module) -> torch.Tensor:
29
- """
30
- Takes as input a tensor containing a single float coordinate value (x or y)
31
- and encodes it into hidden states for input to the text model.
32
-
33
- Args:
34
- coord: Tensor with single float coordinate value
35
-
36
- Returns:
37
- Encoded hidden states tensor for input to text model
38
- """
39
- return linear(fourier_features(coord, w.coord_features), w.coord_encoder)
40
-
41
-
42
- def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
43
- """
44
- Takes as input the last hidden state from the text model and outputs a single logit
45
- representing either an x or y coordinate prediction.
46
-
47
- Args:
48
- hidden_state: The final hidden state tensor from the text model.
49
-
50
- Returns:
51
- A single logit representing the predicted coordinate value (x or y)
52
- """
53
- return mlp(hidden_state, w.coord_decoder)
54
-
55
-
56
- def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
57
- """
58
- Takes a tensor containing width and height values and encodes them into
59
- hidden states for input to the text model.
60
-
61
- Args:
62
- size: Tensor with two floats for width and height
63
-
64
- Returns:
65
- Encoded hidden states tensor for input to text model
66
- """
67
- return linear(fourier_features(size, w.size_features), w.size_encoder)
68
-
69
-
70
- def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
71
- """
72
- Takes as input the last hidden state from the text model and outputs logits
73
- for 1024 bins representing width and height in log-scale.
74
-
75
- The bins are distributed according to the formula:
76
- bin = (log2(size) + 10.0) / 10.0 * 1023.0
77
- where size values are clamped to be at least 1/1024.
78
-
79
- To convert from bin back to size:
80
- size = 2^((bin / 1023.0) * 10.0 - 10.0)
81
-
82
- Args:
83
- hidden_state: The final hidden state tensor from the text model.
84
-
85
- Returns:
86
- A tensor containing logits for 1024 bins for width and height.
87
- Shape is (2, 1024) where the first dimension corresponds to width and height.
88
- """
89
- return mlp(hidden_state, w.size_decoder).view(2, -1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
moondream2/rope.py DELETED
@@ -1,90 +0,0 @@
1
- # Ethically sourced from https://github.com/xjdr-alt/entropix
2
-
3
- import torch
4
- import time
5
-
6
- def precompute_freqs_cis(
7
- dim: int,
8
- end: int,
9
- theta: float = 10000.0,
10
- use_scaled: bool = False,
11
- dtype: torch.dtype = torch.float32,
12
- ) -> torch.Tensor:
13
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype)[: (dim // 2)] / dim))
14
- t = torch.arange(end, dtype=dtype).unsqueeze(1)
15
- freqs = t * freqs.unsqueeze(0)
16
- freqs = torch.exp(1j * freqs)
17
- return torch.stack([freqs.real, freqs.imag], dim=-1)
18
-
19
- def func1(x):
20
- #print(x)
21
- pass
22
-
23
- def func2(x):
24
- #print(x)
25
- pass
26
-
27
- def func3(x):
28
- #print(x)
29
- pass
30
-
31
- def func4(x):
32
- #print(x)
33
- pass
34
-
35
- def func5(x):
36
- #print(x)
37
- pass
38
-
39
- def func6(x):
40
- #print(x)
41
- pass
42
-
43
- def func7(x):
44
- #print(x)
45
- pass
46
-
47
- def func8(x):
48
- #print(x)
49
- pass
50
-
51
- def func9(x):
52
- #print(x)
53
- pass
54
-
55
- def func10(x):
56
- #print(x)
57
- pass
58
-
59
- def func11(x):
60
- #print(x)
61
- pass
62
-
63
- def apply_rotary_emb(
64
- x: torch.Tensor,
65
- freqs_cis: torch.Tensor,
66
- position_ids: torch.Tensor,
67
- num_heads: int,
68
- rot_dim: int = 32,
69
- interleave: bool = False,
70
- ) -> torch.Tensor:
71
- assert rot_dim == freqs_cis.shape[-2] * 2
72
- assert num_heads == x.shape[1]
73
- x_rot, x_pass = x[..., :rot_dim], x[..., rot_dim:]
74
-
75
- d_q = x_rot.shape[-1] // 2
76
- xq_r, xq_i = x_rot[..., :d_q], x_rot[..., d_q:]
77
-
78
- # Get the cosine component from freqs_cis
79
- cos_component = freqs_cis[..., 0]
80
- # Index with position_ids
81
- cos_indexed = cos_component[position_ids, :]
82
- # Add two dimensions at the beginning
83
- freqs_cos = cos_indexed.unsqueeze(0).unsqueeze(0)
84
- freqs_sin = freqs_cis[..., 1][position_ids, :].unsqueeze(0).unsqueeze(0)
85
-
86
- # Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
87
- xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
88
- xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
89
- xq_out = torch.stack((xq_out_r, xq_out_i), dim=-1).flatten(-2)
90
- return torch.cat([xq_out.to(x.dtype), x_pass], dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
moondream2/text.py DELETED
@@ -1,129 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- from torch.nn import functional as F
5
-
6
- from .layers import layer_norm, mlp
7
- from .rope import apply_rotary_emb, precompute_freqs_cis
8
- from .config import TextConfig
9
-
10
-
11
- def text_encoder(input_ids: torch.Tensor, w: nn.Module):
12
- return F.embedding(input_ids, w.wte)
13
-
14
-
15
- def attn(
16
- x: torch.Tensor,
17
- w: nn.Module,
18
- freqs_cis: torch.Tensor,
19
- kv_cache: nn.Module,
20
- attn_mask: torch.Tensor,
21
- n_heads: int,
22
- position_ids: torch.Tensor,
23
- do_apply_rotary_emb: bool = True,
24
- ):
25
- bsz, q_len, d_model = x.shape
26
- head_dim = d_model // n_heads
27
-
28
- qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads * 3)*head_dim)
29
-
30
- qkv_reshaped = qkv_out.view(bsz, q_len, 3, n_heads, head_dim)
31
-
32
- # 2. Permute to bring heads before sequence length and QKV to the front
33
- # Current: (bsz, q_len, 3, n_heads, head_dim) -> (0, 1, 2, 3, 4)
34
- # Target: (3, bsz, n_heads, q_len, head_dim) -> (2, 0, 3, 1, 4)
35
- qkv_permuted = qkv_reshaped.permute(2, 0, 3, 1, 4)
36
-
37
- # 3. Unpack/Split along the first dimension (which now separates Q, K, V)
38
- q, k, v = qkv_permuted[0], qkv_permuted[1], qkv_permuted[2]
39
-
40
- q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
41
- k = apply_rotary_emb(k, freqs_cis, position_ids, n_heads)
42
-
43
- if kv_cache is not None:
44
- k, v = kv_cache.update(position_ids, k, v)
45
-
46
- out = F.scaled_dot_product_attention(
47
- q, k, v, attn_mask=attn_mask
48
- )
49
- out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
50
- out = w.proj(out)
51
- return out
52
-
53
-
54
- def text_decoder(
55
- x: torch.Tensor,
56
- w: nn.Module,
57
- attn_mask: torch.Tensor,
58
- position_ids: torch.Tensor,
59
- config: TextConfig,
60
- ):
61
- for i, block in enumerate(w.blocks):
62
- l_in = layer_norm(x, block.ln)
63
- l_attn = attn(
64
- l_in,
65
- block.attn,
66
- freqs_cis=w.freqs_cis,
67
- kv_cache=block.kv_cache,
68
- attn_mask=attn_mask,
69
- n_heads=config.n_heads,
70
- position_ids=position_ids,
71
- )
72
- l_mlp = mlp(l_in, block.mlp)
73
- x = x + l_attn + l_mlp
74
-
75
- return x
76
-
77
-
78
- def lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
79
- hidden_BC = hidden_BTC[:, -1, :]
80
- hidden_BC = layer_norm(hidden_BC, w.post_ln)
81
- logits = w.lm_head(hidden_BC)
82
- return logits
83
-
84
-
85
- def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
86
- qkv_dim = int(config.dim * 3)
87
-
88
- text = nn.ModuleDict(
89
- {
90
- "blocks": nn.ModuleList(
91
- [
92
- nn.ModuleDict(
93
- {
94
- "ln": nn.LayerNorm(config.dim, dtype=dtype),
95
- "attn": nn.ModuleDict(
96
- {
97
- "qkv": nn.Linear(config.dim, qkv_dim, dtype=dtype),
98
- "proj": nn.Linear(
99
- config.dim, config.dim, dtype=dtype
100
- ),
101
- }
102
- ),
103
- "mlp": nn.ModuleDict(
104
- {
105
- "fc1": nn.Linear(
106
- config.dim, config.ff_dim, dtype=dtype
107
- ),
108
- "fc2": nn.Linear(
109
- config.ff_dim, config.dim, dtype=dtype
110
- ),
111
- }
112
- ),
113
- }
114
- )
115
- for _ in range(config.n_layers)
116
- ]
117
- ),
118
- "post_ln": nn.LayerNorm(config.dim, dtype=dtype),
119
- "lm_head": nn.Linear(config.dim, config.vocab_size, dtype=dtype),
120
- }
121
- )
122
- text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype))
123
- text.register_buffer(
124
- "freqs_cis",
125
- precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context),
126
- persistent=False,
127
- )
128
-
129
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
moondream2/tokenizer.json DELETED
The diff for this file is too large to render. See raw diff
 
moondream2/utils.py DELETED
@@ -1,55 +0,0 @@
1
- import numpy as np
2
-
3
- import torch.cuda.nvtx
4
- from functools import wraps
5
-
6
- def nvtx_annotate(label=None, color='blue'):
7
- def decorator(func):
8
- @wraps(func)
9
- def wrapper(*args, **kwargs):
10
- range_name = label or func.__name__
11
- torch.cuda.nvtx.range_push(range_name)
12
- result = func(*args, **kwargs)
13
- torch.cuda.nvtx.range_pop()
14
- return result
15
- return wrapper
16
- return decorator
17
-
18
- def remove_outlier_points(points_tuples, k_nearest=2, threshold=2.0):
19
- """
20
- Robust outlier detection for list of (x,y) tuples.
21
- Only requires numpy.
22
-
23
- Args:
24
- points_tuples: list of (x,y) tuples
25
- k_nearest: number of neighbors to consider
26
- threshold: multiplier for median distance
27
-
28
- Returns:
29
- list: filtered list of (x,y) tuples with outliers removed
30
- list: list of booleans indicating which points were kept (True = kept)
31
- """
32
- points = np.array(points_tuples)
33
- n_points = len(points)
34
-
35
- # Calculate pairwise distances manually
36
- dist_matrix = np.zeros((n_points, n_points))
37
- for i in range(n_points):
38
- for j in range(i + 1, n_points):
39
- # Euclidean distance between points i and j
40
- dist = np.sqrt(np.sum((points[i] - points[j]) ** 2))
41
- dist_matrix[i, j] = dist
42
- dist_matrix[j, i] = dist
43
-
44
- # Get k nearest neighbors' distances
45
- k = min(k_nearest, n_points - 1)
46
- neighbor_distances = np.partition(dist_matrix, k, axis=1)[:, :k]
47
- avg_neighbor_dist = np.mean(neighbor_distances, axis=1)
48
-
49
- # Calculate mask using median distance
50
- median_dist = np.median(avg_neighbor_dist)
51
- mask = avg_neighbor_dist <= threshold * median_dist
52
-
53
- # Return filtered tuples and mask
54
- filtered_tuples = [t for t, m in zip(points_tuples, mask) if m]
55
- return filtered_tuples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
moondream2/vision.py DELETED
@@ -1,148 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import numpy as np
5
-
6
- from typing import Union, Tuple
7
- from PIL import Image
8
-
9
- from .layers import attn, layer_norm, linear, mlp
10
- from .image_crops import overlap_crop_image
11
- from .config import VisionConfig
12
-
13
- if torch.backends.mps.is_available():
14
- # Non-divisible input sizes are not implemented on MPS device yet.
15
- # https://github.com/pytorch/pytorch/issues/96056
16
- def adaptive_avg_pool2d(input, output_size):
17
- return F.adaptive_avg_pool2d(input.to("cpu"), output_size).to("mps")
18
-
19
- else:
20
- adaptive_avg_pool2d = F.adaptive_avg_pool2d
21
-
22
- DeviceLike = Union[str, torch.device, int]
23
-
24
- def prepare_crops(
25
- image: Image.Image, config: VisionConfig, device: DeviceLike
26
- ) -> Tuple[torch.Tensor, Tuple[int, int]]:
27
- np_image = np.array(image.convert("RGB"))
28
- overlap_crops = overlap_crop_image(
29
- np_image, max_crops=config.max_crops, overlap_margin=config.overlap_margin
30
- )
31
- all_crops = overlap_crops["crops"]
32
- all_crops = np.transpose(all_crops, (0, 3, 1, 2))
33
- all_crops = (
34
- torch.from_numpy(all_crops)
35
- .to(device=device, dtype=torch.float16)
36
- .div_(255.0)
37
- .sub_(0.5)
38
- .div_(0.5)
39
- )
40
- return all_crops, overlap_crops["tiling"]
41
-
42
-
43
- def create_patches(x, patch_size):
44
- # Original shape: [B, C, H, W]
45
- B, C, H, W = x.shape
46
- P1 = P2 = patch_size
47
-
48
- # Step 1: Split H and W dimensions into patches
49
- # [B, C, H/P1, P1, W/P2, P2]
50
- x = x.reshape(B, C, H // P1, P1, W // P2, P2)
51
-
52
- # Step 2: Rearrange dimensions to match target shape
53
- # [B, H/P1, W/P2, C, P1, P2]
54
- x = x.permute(0, 2, 4, 1, 3, 5)
55
-
56
- # Step 3: Combine dimensions to get final shape
57
- # [B, (H/P1)*(W/P2), C*P1*P2]
58
- x = x.reshape(B, (H // P1) * (W // P2), C * P1 * P2)
59
-
60
- return x
61
-
62
-
63
- def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):
64
- x = create_patches(input_BCHW, config.enc_patch_size)
65
-
66
- x = linear(x, w.patch_emb)
67
- x = x + w.pos_emb
68
- for block in w.blocks:
69
- x = x + attn(layer_norm(x, block.ln1), block.attn, n_heads=config.enc_n_heads)
70
- x = x + mlp(layer_norm(x, block.ln2), block.mlp)
71
- x = layer_norm(x, w.post_ln)
72
-
73
- return x
74
-
75
-
76
- def vision_projection(
77
- global_features: torch.Tensor,
78
- reconstructed: torch.Tensor,
79
- w: nn.Module,
80
- config: VisionConfig,
81
- ):
82
- reconstructed = reconstructed.permute(2, 0, 1)
83
- reconstructed = adaptive_avg_pool2d(
84
- reconstructed, output_size=(config.enc_n_layers, config.enc_n_layers)
85
- )
86
- reconstructed = reconstructed.permute(1, 2, 0).view(729, config.enc_dim)
87
- final_features = torch.cat([global_features, reconstructed], dim=-1)
88
- return mlp(final_features, w.proj_mlp)
89
-
90
-
91
- def build_vision_model(config: VisionConfig, dtype: torch.dtype):
92
- patch_dim = config.enc_patch_size * config.enc_patch_size * config.in_channels
93
- grid_size = config.crop_size // config.enc_patch_size
94
- num_patches = grid_size * grid_size
95
-
96
- vision = nn.ModuleDict(
97
- {
98
- "patch_emb": nn.Linear(patch_dim, config.enc_dim, dtype=dtype),
99
- "blocks": nn.ModuleList(
100
- [
101
- nn.ModuleDict(
102
- {
103
- "ln1": nn.LayerNorm(config.enc_dim, dtype=dtype),
104
- "attn": nn.ModuleDict(
105
- {
106
- "qkv": nn.Linear(
107
- config.enc_dim, 3 * config.enc_dim, dtype=dtype
108
- ),
109
- "proj": nn.Linear(
110
- config.enc_dim, config.enc_dim, dtype=dtype
111
- ),
112
- }
113
- ),
114
- "ln2": nn.LayerNorm(config.enc_dim, dtype=dtype),
115
- "mlp": nn.ModuleDict(
116
- {
117
- "fc1": nn.Linear(
118
- config.enc_dim, config.enc_ff_dim, dtype=dtype
119
- ),
120
- "fc2": nn.Linear(
121
- config.enc_ff_dim, config.enc_dim, dtype=dtype
122
- ),
123
- }
124
- ),
125
- }
126
- )
127
- for _ in range(config.enc_n_layers)
128
- ]
129
- ),
130
- "post_ln": nn.LayerNorm(config.enc_dim, dtype=dtype),
131
- "proj_mlp": nn.ModuleDict(
132
- {
133
- "fc1": nn.Linear(
134
- config.enc_dim * 2, config.proj_inner_dim, dtype=dtype
135
- ),
136
- "fc2": nn.Linear(
137
- config.proj_inner_dim, config.proj_out_dim, dtype=dtype
138
- ),
139
- }
140
- ),
141
- }
142
- )
143
- vision.pos_emb = nn.Parameter(
144
- torch.zeros(1, num_patches, config.enc_dim, dtype=dtype)
145
- )
146
- return vision
147
-
148
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
moondream2/weights.py DELETED
@@ -1,292 +0,0 @@
1
- import safetensors
2
- import torch
3
- import torch.nn as nn
4
-
5
- from contextlib import contextmanager
6
- from dataclasses import dataclass
7
- from typing import Callable, List
8
-
9
- from .layers import AttentionWeights, LayerNormWeights, LinearWeights, MLPWeights
10
-
11
-
12
- @dataclass
13
- class VisionBlock:
14
- ln1: LayerNormWeights
15
- attn: AttentionWeights
16
- ln2: LayerNormWeights
17
- mlp: MLPWeights
18
-
19
-
20
- @dataclass
21
- class VisionModel:
22
- patch_emb: LinearWeights
23
- pos_emb: torch.Tensor
24
- blocks: List[VisionBlock]
25
- post_ln: LayerNormWeights
26
- proj_mlp: MLPWeights
27
-
28
-
29
- @dataclass
30
- class TextBlock:
31
- ln: LayerNormWeights
32
- attn: AttentionWeights
33
- mlp: MLPWeights
34
-
35
-
36
- @dataclass
37
- class TextModel:
38
- wte: torch.Tensor
39
- blocks: List[TextBlock]
40
- post_ln: LayerNormWeights
41
- lm_head: LinearWeights
42
-
43
-
44
- @dataclass
45
- class RegionModel:
46
- coord_features: torch.Tensor
47
- coord_encoder: LinearWeights
48
- coord_decoder: MLPWeights
49
- size_features: torch.Tensor
50
- size_encoder: LinearWeights
51
- size_decoder: MLPWeights
52
-
53
-
54
- @dataclass
55
- class MoondreamModel:
56
- vision: VisionModel
57
- text: TextModel
58
- region: RegionModel
59
-
60
-
61
- @contextmanager
62
- def safetensors_open(safetensors_file: str):
63
- """
64
- Simplify interfacing with safetensors files. Eliminates the need to ignore
65
- type errors when using the `safe_open` function.
66
- """
67
- with safetensors.safe_open(
68
- safetensors_file, framework="pt"
69
- ) as st: # pyright: ignore
70
-
71
- def get_tensor(name: str) -> torch.Tensor:
72
- return st.get_tensor(name)
73
-
74
- def get_keys() -> List[str]:
75
- return st.keys()
76
-
77
- get_tensor.keys = get_keys
78
-
79
- yield get_tensor
80
-
81
-
82
- def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) -> None:
83
- """Internal function to load weights using a tensor getter function."""
84
- model = model.to(dtype=torch.float16)
85
-
86
- # Vision Model
87
- model.vision["patch_emb"].weight.data.copy_(
88
- get_tensor("vision_encoder.encoder.model.visual.patch_embed.linear.weight")
89
- )
90
- model.vision["patch_emb"].bias.data.copy_(
91
- get_tensor("vision_encoder.encoder.model.visual.patch_embed.linear.bias")
92
- )
93
- model.vision.pos_emb.data.copy_(
94
- get_tensor("vision_encoder.encoder.model.visual.pos_embed")
95
- )
96
-
97
- for i in range(len(model.vision["blocks"])):
98
- prefix = f"vision_encoder.encoder.model.visual.blocks.{i}"
99
-
100
- # Layer norms
101
- model.vision["blocks"][i]["ln1"].weight.data.copy_(
102
- get_tensor(f"{prefix}.norm1.weight")
103
- )
104
- model.vision["blocks"][i]["ln1"].bias.data.copy_(
105
- get_tensor(f"{prefix}.norm1.bias")
106
- )
107
- model.vision["blocks"][i]["ln2"].weight.data.copy_(
108
- get_tensor(f"{prefix}.norm2.weight")
109
- )
110
- model.vision["blocks"][i]["ln2"].bias.data.copy_(
111
- get_tensor(f"{prefix}.norm2.bias")
112
- )
113
-
114
- # Attention
115
- model.vision["blocks"][i]["attn"]["qkv"].weight.data.copy_(
116
- get_tensor(f"{prefix}.attn.qkv.weight")
117
- )
118
- model.vision["blocks"][i]["attn"]["qkv"].bias.data.copy_(
119
- get_tensor(f"{prefix}.attn.qkv.bias")
120
- )
121
- model.vision["blocks"][i]["attn"]["proj"].weight.data.copy_(
122
- get_tensor(f"{prefix}.attn.proj.weight")
123
- )
124
- model.vision["blocks"][i]["attn"]["proj"].bias.data.copy_(
125
- get_tensor(f"{prefix}.attn.proj.bias")
126
- )
127
-
128
- # MLP
129
- model.vision["blocks"][i]["mlp"]["fc1"].weight.data.copy_(
130
- get_tensor(f"{prefix}.mlp.fc1.weight")
131
- )
132
- model.vision["blocks"][i]["mlp"]["fc1"].bias.data.copy_(
133
- get_tensor(f"{prefix}.mlp.fc1.bias")
134
- )
135
- model.vision["blocks"][i]["mlp"]["fc2"].weight.data.copy_(
136
- get_tensor(f"{prefix}.mlp.fc2.weight")
137
- )
138
- model.vision["blocks"][i]["mlp"]["fc2"].bias.data.copy_(
139
- get_tensor(f"{prefix}.mlp.fc2.bias")
140
- )
141
-
142
- model.vision["post_ln"].weight.data.copy_(
143
- get_tensor("vision_encoder.encoder.model.visual.norm.weight")
144
- )
145
- model.vision["post_ln"].bias.data.copy_(
146
- get_tensor("vision_encoder.encoder.model.visual.norm.bias")
147
- )
148
-
149
- model.vision["proj_mlp"]["fc1"].weight.data.copy_(
150
- get_tensor("vision_encoder.projection.mlp.fc1.weight")
151
- )
152
- model.vision["proj_mlp"]["fc1"].bias.data.copy_(
153
- get_tensor("vision_encoder.projection.mlp.fc1.bias")
154
- )
155
- model.vision["proj_mlp"]["fc2"].weight.data.copy_(
156
- get_tensor("vision_encoder.projection.mlp.fc2.weight")
157
- )
158
- model.vision["proj_mlp"]["fc2"].bias.data.copy_(
159
- get_tensor("vision_encoder.projection.mlp.fc2.bias")
160
- )
161
-
162
- # Text Model
163
- model.text.wte.data.copy_(get_tensor("text_model.transformer.embd.wte.weight"))
164
-
165
- for i in range(len(model.text["blocks"])):
166
- prefix = f"text_model.transformer.h.{i}"
167
-
168
- # Layer norm
169
- model.text["blocks"][i]["ln"].weight.data.copy_(
170
- get_tensor(f"{prefix}.ln.weight")
171
- )
172
- model.text["blocks"][i]["ln"].bias.data.copy_(get_tensor(f"{prefix}.ln.bias"))
173
-
174
- # Attention
175
- model.text["blocks"][i]["attn"]["qkv"].weight.data.copy_(
176
- get_tensor(f"{prefix}.mixer.Wqkv.weight")
177
- )
178
- model.text["blocks"][i]["attn"]["qkv"].bias.data.copy_(
179
- get_tensor(f"{prefix}.mixer.Wqkv.bias")
180
- )
181
- model.text["blocks"][i]["attn"]["proj"].weight.data.copy_(
182
- get_tensor(f"{prefix}.mixer.out_proj.weight")
183
- )
184
- model.text["blocks"][i]["attn"]["proj"].bias.data.copy_(
185
- get_tensor(f"{prefix}.mixer.out_proj.bias")
186
- )
187
-
188
- # MLP
189
- model.text["blocks"][i]["mlp"]["fc1"].weight.data.copy_(
190
- get_tensor(f"{prefix}.mlp.fc1.weight")
191
- )
192
- model.text["blocks"][i]["mlp"]["fc1"].bias.data.copy_(
193
- get_tensor(f"{prefix}.mlp.fc1.bias")
194
- )
195
- model.text["blocks"][i]["mlp"]["fc2"].weight.data.copy_(
196
- get_tensor(f"{prefix}.mlp.fc2.weight")
197
- )
198
- model.text["blocks"][i]["mlp"]["fc2"].bias.data.copy_(
199
- get_tensor(f"{prefix}.mlp.fc2.bias")
200
- )
201
-
202
- model.text["post_ln"].weight.data.copy_(get_tensor("text_model.lm_head.ln.weight"))
203
- model.text["post_ln"].bias.data.copy_(get_tensor("text_model.lm_head.ln.bias"))
204
-
205
- model.text["lm_head"].weight.data.copy_(
206
- get_tensor("text_model.lm_head.linear.weight")
207
- )
208
- model.text["lm_head"].bias.data.copy_(get_tensor("text_model.lm_head.linear.bias"))
209
-
210
- # Region Model
211
- model.region.coord_features.data.copy_(
212
- get_tensor("region_model.coordinate_features.weight").T
213
- )
214
- model.region["coord_encoder"].weight.data.copy_(
215
- get_tensor("region_model.coordinate_encoder.weight")
216
- )
217
- model.region["coord_encoder"].bias.data.copy_(
218
- get_tensor("region_model.coordinate_encoder.bias")
219
- )
220
-
221
- model.region["coord_decoder"]["fc1"].weight.data.copy_(
222
- get_tensor("region_model.coordinate_decoder.fc1.weight")
223
- )
224
- model.region["coord_decoder"]["fc1"].bias.data.copy_(
225
- get_tensor("region_model.coordinate_decoder.fc1.bias")
226
- )
227
- model.region["coord_decoder"]["fc2"].weight.data.copy_(
228
- get_tensor("region_model.coordinate_decoder.fc2.weight")
229
- )
230
- model.region["coord_decoder"]["fc2"].bias.data.copy_(
231
- get_tensor("region_model.coordinate_decoder.fc2.bias")
232
- )
233
-
234
- model.region.size_features.data.copy_(
235
- get_tensor("region_model.size_features.weight").T
236
- )
237
- model.region["size_encoder"].weight.data.copy_(
238
- get_tensor("region_model.size_encoder.weight")
239
- )
240
- model.region["size_encoder"].bias.data.copy_(
241
- get_tensor("region_model.size_encoder.bias")
242
- )
243
-
244
- model.region["size_decoder"]["fc1"].weight.data.copy_(
245
- get_tensor("region_model.size_decoder.fc1.weight")
246
- )
247
- model.region["size_decoder"]["fc1"].bias.data.copy_(
248
- get_tensor("region_model.size_decoder.fc1.bias")
249
- )
250
- model.region["size_decoder"]["fc2"].weight.data.copy_(
251
- get_tensor("region_model.size_decoder.fc2.weight")
252
- )
253
- model.region["size_decoder"]["fc2"].bias.data.copy_(
254
- get_tensor("region_model.size_decoder.fc2.bias")
255
- )
256
-
257
-
258
- def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
259
- """Load weights from a safetensors file into a MoondreamModel instance."""
260
- with safetensors_open(weights_file) as get_tensor:
261
- # Wrap the get_tensor function to handle key normalization
262
- name_map = {k.replace("._orig_mod", ""): k for k in get_tensor.keys()}
263
- _load_weights(lambda x: get_tensor(name_map[x]).to(dtype=torch.float16), model)
264
-
265
-
266
- def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
267
- """Load weights from a PyTorch file into a MoondreamModel instance."""
268
- device = str(torch.empty(0).device)
269
- tensors = torch.load(weights_file, map_location=device, weights_only=True)
270
- tensors = {
271
- k.replace("._orig_mod", ""): v.to(dtype=torch.float16)
272
- for k, v in tensors.items()
273
- }
274
- _load_weights(lambda x: tensors[x], model)
275
-
276
-
277
- def load_weights_into_model(weights_file: str, model: nn.Module) -> None:
278
- """
279
- Load weights from either a safetensors or PyTorch file directly into a MoondreamModel instance.
280
-
281
- Args:
282
- weights_file: Path to weights file (either .safetensors or .pt)
283
- model: MoondreamModel instance to load weights into
284
- """
285
- if weights_file.endswith(".safetensors"):
286
- load_weights_from_safetensors(weights_file, model)
287
- else:
288
- load_weights_from_pt(weights_file, model)
289
-
290
- # Make all parameters contiguous
291
- for param in model.parameters():
292
- param.data = param.data.contiguous()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
notes.ipynb DELETED
@@ -1,393 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "metadata": {},
7
- "outputs": [
8
- {
9
- "data": {
10
- "text/plain": [
11
- "True"
12
- ]
13
- },
14
- "execution_count": 1,
15
- "metadata": {},
16
- "output_type": "execute_result"
17
- }
18
- ],
19
- "source": [
20
- "import torch\n",
21
- "torch.cuda.is_available()"
22
- ]
23
- },
24
- {
25
- "cell_type": "code",
26
- "execution_count": 4,
27
- "metadata": {},
28
- "outputs": [
29
- {
30
- "data": {
31
- "text/plain": [
32
- "False"
33
- ]
34
- },
35
- "execution_count": 4,
36
- "metadata": {},
37
- "output_type": "execute_result"
38
- }
39
- ],
40
- "source": [
41
- "from transformers.utils import is_flash_attn_2_available\n",
42
- "is_flash_attn_2_available()\n"
43
- ]
44
- },
45
- {
46
- "cell_type": "code",
47
- "execution_count": null,
48
- "metadata": {},
49
- "outputs": [],
50
- "source": []
51
- },
52
- {
53
- "cell_type": "code",
54
- "execution_count": 4,
55
- "metadata": {},
56
- "outputs": [],
57
- "source": [
58
- "import torch\n",
59
- "from moondream2.config import MoondreamConfig\n",
60
- "from moondream2.moondream import MoondreamModel\n",
61
- "import torch.profiler\n",
62
- "\n",
63
- "config = MoondreamConfig()\n",
64
- "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
65
- "model = MoondreamModel(config, setup_caches=False).to(device)\n",
66
- "from safetensors.torch import load_file\n",
67
- "weights_path = \"moondream2/model.safetensors\" # Path to your local weights file\n",
68
- "state_dict = load_file(weights_path, device=device)\n",
69
- "new_state_dict = {}\n",
70
- "for key, value in state_dict.items():\n",
71
- " # Remove 'model.' prefix if it exists\n",
72
- " if key.startswith('model.'):\n",
73
- " new_key = key[6:] # Skip the first 6 characters ('model.')\n",
74
- " else:\n",
75
- " new_key = key\n",
76
- " new_state_dict[new_key] = value\n",
77
- "state_dict = new_state_dict\n",
78
- "missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=True)\n",
79
- "model._setup_caches()"
80
- ]
81
- },
82
- {
83
- "cell_type": "code",
84
- "execution_count": null,
85
- "metadata": {},
86
- "outputs": [],
87
- "source": []
88
- },
89
- {
90
- "cell_type": "code",
91
- "execution_count": 8,
92
- "metadata": {},
93
- "outputs": [],
94
- "source": [
95
- "from PIL import Image\n",
96
- "image = Image.open(\"example.png\")\n",
97
- "query = \"home icon at the bottom of the screen is visible\"\n",
98
- "points = model.point(image, query)[\"points\"]\n"
99
- ]
100
- },
101
- {
102
- "cell_type": "code",
103
- "execution_count": 2,
104
- "metadata": {},
105
- "outputs": [
106
- {
107
- "data": {
108
- "text/plain": [
109
- "[{'x': 0.0849609375, 'y': 0.9453125}]"
110
- ]
111
- },
112
- "execution_count": 2,
113
- "metadata": {},
114
- "output_type": "execute_result"
115
- }
116
- ],
117
- "source": [
118
- "\n",
119
- "from PIL import Image\n",
120
- "image = Image.open(\"example.png\")\n",
121
- "query = \"home icon at the bottom\"\n",
122
- "with torch.profiler.profile(\n",
123
- " activities=[\n",
124
- " torch.profiler.ProfilerActivity.CPU,\n",
125
- " torch.profiler.ProfilerActivity.CUDA],\n",
126
- " schedule=torch.profiler.schedule(\n",
127
- " wait=1, warmup=1, active=3),\n",
128
- " on_trace_ready=torch.profiler.tensorboard_trace_handler('log'),\n",
129
- " record_shapes=True,\n",
130
- " with_stack=True\n",
131
- ") as prof:\n",
132
- " for i in range(3):\n",
133
- " points = model.point(image, query)[\"points\"]\n",
134
- " prof.step()\n",
135
- "points"
136
- ]
137
- },
138
- {
139
- "cell_type": "code",
140
- "execution_count": 7,
141
- "metadata": {},
142
- "outputs": [],
143
- "source": [
144
- "points = model.point(image, query)[\"points\"]\n"
145
- ]
146
- },
147
- {
148
- "cell_type": "code",
149
- "execution_count": null,
150
- "metadata": {},
151
- "outputs": [],
152
- "source": []
153
- },
154
- {
155
- "cell_type": "code",
156
- "execution_count": 17,
157
- "metadata": {},
158
- "outputs": [
159
- {
160
- "ename": "IndexError",
161
- "evalue": "list index out of range",
162
- "output_type": "error",
163
- "traceback": [
164
- "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
165
- "\u001b[31mIndexError\u001b[39m Traceback (most recent call last)",
166
- "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[17]\u001b[39m\u001b[32m, line 29\u001b[39m\n\u001b[32m 27\u001b[39m plt.axis(\u001b[33m'\u001b[39m\u001b[33moff\u001b[39m\u001b[33m'\u001b[39m)\n\u001b[32m 28\u001b[39m plt.show()\n\u001b[32m---> \u001b[39m\u001b[32m29\u001b[39m show_point_on_image(\u001b[43mpoints\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m]\u001b[49m[\u001b[33m'\u001b[39m\u001b[33mx\u001b[39m\u001b[33m'\u001b[39m], points[\u001b[32m0\u001b[39m][\u001b[33m'\u001b[39m\u001b[33my\u001b[39m\u001b[33m'\u001b[39m])\n",
167
- "\u001b[31mIndexError\u001b[39m: list index out of range"
168
- ]
169
- }
170
- ],
171
- "source": [
172
- "from PIL import Image\n",
173
- "import matplotlib.pyplot as plt\n",
174
- "\n",
175
- "def show_point_on_image(rel_x, rel_y, point_color='red', point_size=30):\n",
176
- " \"\"\"\n",
177
- " Display 'example.png' with a point marked at the given relative coordinates.\n",
178
- "\n",
179
- " Parameters:\n",
180
- " - rel_x: float, relative x-coordinate (0 to 1).\n",
181
- " - rel_y: float, relative y-coordinate (0 to 1).\n",
182
- " - point_color: str, color of the point (default: 'red').\n",
183
- " - point_size: int, size of the point (default: 50).\n",
184
- " \"\"\"\n",
185
- " image_path = 'example.png' # Hardcoded image path\n",
186
- "\n",
187
- " # Load image\n",
188
- " img = Image.open(image_path)\n",
189
- " width, height = img.size\n",
190
- "\n",
191
- " # Convert relative coordinates to absolute\n",
192
- " abs_x = rel_x * width\n",
193
- " abs_y = rel_y * height\n",
194
- "\n",
195
- " # Plot\n",
196
- " plt.imshow(img)\n",
197
- " plt.scatter([abs_x], [abs_y], color=point_color, s=point_size, alpha=0.7)\n",
198
- " plt.axis('off')\n",
199
- " plt.show()\n",
200
- "show_point_on_image(points[0]['x'], points[0]['y'])\n"
201
- ]
202
- },
203
- {
204
- "cell_type": "code",
205
- "execution_count": 14,
206
- "metadata": {},
207
- "outputs": [],
208
- "source": [
209
- "from PIL import Image\n",
210
- "image = Image.open(\"example.png\")"
211
- ]
212
- },
213
- {
214
- "cell_type": "code",
215
- "execution_count": 15,
216
- "metadata": {},
217
- "outputs": [],
218
- "source": [
219
- "import time"
220
- ]
221
- },
222
- {
223
- "cell_type": "code",
224
- "execution_count": 16,
225
- "metadata": {},
226
- "outputs": [
227
- {
228
- "name": "stdout",
229
- "output_type": "stream",
230
- "text": [
231
- "Time taken: 0.5298349857330322 seconds\n"
232
- ]
233
- }
234
- ],
235
- "source": [
236
- "start_time = time.time()\n",
237
- "\n",
238
- "query = \"the login button\"\n",
239
- "points = model.point(image, query)[\"points\"]\n",
240
- "end_time = time.time()\n",
241
- "print(f\"Time taken: {end_time - start_time} seconds\")\n"
242
- ]
243
- },
244
- {
245
- "cell_type": "code",
246
- "execution_count": 9,
247
- "metadata": {},
248
- "outputs": [],
249
- "source": []
250
- },
251
- {
252
- "cell_type": "code",
253
- "execution_count": 5,
254
- "metadata": {},
255
- "outputs": [
256
- {
257
- "data": {
258
- "text/plain": [
259
- "False"
260
- ]
261
- },
262
- "execution_count": 5,
263
- "metadata": {},
264
- "output_type": "execute_result"
265
- }
266
- ],
267
- "source": [
268
- "from transformers.utils import is_flash_attn_2_available\n",
269
- "is_flash_attn_2_available()\n"
270
- ]
271
- },
272
- {
273
- "cell_type": "code",
274
- "execution_count": 1,
275
- "metadata": {},
276
- "outputs": [],
277
- "source": [
278
- "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
279
- "\n",
280
- "model = AutoModelForCausalLM.from_pretrained(\n",
281
- " \"vikhyatk/moondream2\",\n",
282
- " revision=\"2025-04-14\",\n",
283
- " trust_remote_code=True,\n",
284
- " cache_dir=\"moondream\",\n",
285
- " # Uncomment to run on GPU.\n",
286
- " device_map={\"\": \"cuda\"}\n",
287
- ")"
288
- ]
289
- },
290
- {
291
- "cell_type": "code",
292
- "execution_count": 7,
293
- "metadata": {},
294
- "outputs": [
295
- {
296
- "data": {
297
- "text/plain": [
298
- "HfMoondream(\n",
299
- " (model): MoondreamModel(\n",
300
- " (vision): ModuleDict(\n",
301
- " (patch_emb): Linear(in_features=588, out_features=1152, bias=True)\n",
302
- " (blocks): ModuleList(\n",
303
- " (0-26): 27 x ModuleDict(\n",
304
- " (ln1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
305
- " (attn): ModuleDict(\n",
306
- " (qkv): Linear(in_features=1152, out_features=3456, bias=True)\n",
307
- " (proj): Linear(in_features=1152, out_features=1152, bias=True)\n",
308
- " )\n",
309
- " (ln2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
310
- " (mlp): ModuleDict(\n",
311
- " (fc1): Linear(in_features=1152, out_features=4304, bias=True)\n",
312
- " (fc2): Linear(in_features=4304, out_features=1152, bias=True)\n",
313
- " )\n",
314
- " )\n",
315
- " )\n",
316
- " (post_ln): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
317
- " (proj_mlp): ModuleDict(\n",
318
- " (fc1): Linear(in_features=2304, out_features=8192, bias=True)\n",
319
- " (fc2): Linear(in_features=8192, out_features=2048, bias=True)\n",
320
- " )\n",
321
- " )\n",
322
- " (text): ModuleDict(\n",
323
- " (blocks): ModuleList(\n",
324
- " (0-23): 24 x ModuleDict(\n",
325
- " (ln): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
326
- " (attn): ModuleDict(\n",
327
- " (qkv): Linear(in_features=2048, out_features=6144, bias=True)\n",
328
- " (proj): Linear(in_features=2048, out_features=2048, bias=True)\n",
329
- " )\n",
330
- " (mlp): ModuleDict(\n",
331
- " (fc1): Linear(in_features=2048, out_features=8192, bias=True)\n",
332
- " (fc2): Linear(in_features=8192, out_features=2048, bias=True)\n",
333
- " )\n",
334
- " )\n",
335
- " )\n",
336
- " (post_ln): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
337
- " (lm_head): Linear(in_features=2048, out_features=51200, bias=True)\n",
338
- " )\n",
339
- " (region): ModuleDict(\n",
340
- " (coord_encoder): Linear(in_features=256, out_features=2048, bias=True)\n",
341
- " (coord_decoder): ModuleDict(\n",
342
- " (fc1): Linear(in_features=2048, out_features=8192, bias=True)\n",
343
- " (fc2): Linear(in_features=8192, out_features=1024, bias=True)\n",
344
- " )\n",
345
- " (size_encoder): Linear(in_features=512, out_features=2048, bias=True)\n",
346
- " (size_decoder): ModuleDict(\n",
347
- " (fc1): Linear(in_features=2048, out_features=8192, bias=True)\n",
348
- " (fc2): Linear(in_features=8192, out_features=2048, bias=True)\n",
349
- " )\n",
350
- " )\n",
351
- " )\n",
352
- ")"
353
- ]
354
- },
355
- "execution_count": 7,
356
- "metadata": {},
357
- "output_type": "execute_result"
358
- }
359
- ],
360
- "source": [
361
- "model.point()"
362
- ]
363
- },
364
- {
365
- "cell_type": "code",
366
- "execution_count": null,
367
- "metadata": {},
368
- "outputs": [],
369
- "source": []
370
- }
371
- ],
372
- "metadata": {
373
- "kernelspec": {
374
- "display_name": "venv",
375
- "language": "python",
376
- "name": "python3"
377
- },
378
- "language_info": {
379
- "codemirror_mode": {
380
- "name": "ipython",
381
- "version": 3
382
- },
383
- "file_extension": ".py",
384
- "mimetype": "text/x-python",
385
- "name": "python",
386
- "nbconvert_exporter": "python",
387
- "pygments_lexer": "ipython3",
388
- "version": "3.12.9"
389
- }
390
- },
391
- "nbformat": 4,
392
- "nbformat_minor": 2
393
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ollama.ipynb DELETED
@@ -1,195 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "from transformers import AutoConfig\n",
10
- "\n",
11
- "config = AutoConfig.from_pretrained(\"vikhyatk/moondream2\", trust_remote_code=True)"
12
- ]
13
- },
14
- {
15
- "cell_type": "code",
16
- "execution_count": 2,
17
- "metadata": {},
18
- "outputs": [
19
- {
20
- "data": {
21
- "text/plain": [
22
- "{'return_dict': True,\n",
23
- " 'output_hidden_states': False,\n",
24
- " 'output_attentions': False,\n",
25
- " 'torchscript': False,\n",
26
- " 'torch_dtype': 'float16',\n",
27
- " 'use_bfloat16': False,\n",
28
- " 'tf_legacy_loss': False,\n",
29
- " 'pruned_heads': {},\n",
30
- " 'tie_word_embeddings': True,\n",
31
- " 'chunk_size_feed_forward': 0,\n",
32
- " 'is_encoder_decoder': False,\n",
33
- " 'is_decoder': False,\n",
34
- " 'cross_attention_hidden_size': None,\n",
35
- " 'add_cross_attention': False,\n",
36
- " 'tie_encoder_decoder': False,\n",
37
- " 'max_length': 20,\n",
38
- " 'min_length': 0,\n",
39
- " 'do_sample': False,\n",
40
- " 'early_stopping': False,\n",
41
- " 'num_beams': 1,\n",
42
- " 'num_beam_groups': 1,\n",
43
- " 'diversity_penalty': 0.0,\n",
44
- " 'temperature': 1.0,\n",
45
- " 'top_k': 50,\n",
46
- " 'top_p': 1.0,\n",
47
- " 'typical_p': 1.0,\n",
48
- " 'repetition_penalty': 1.0,\n",
49
- " 'length_penalty': 1.0,\n",
50
- " 'no_repeat_ngram_size': 0,\n",
51
- " 'encoder_no_repeat_ngram_size': 0,\n",
52
- " 'bad_words_ids': None,\n",
53
- " 'num_return_sequences': 1,\n",
54
- " 'output_scores': False,\n",
55
- " 'return_dict_in_generate': False,\n",
56
- " 'forced_bos_token_id': None,\n",
57
- " 'forced_eos_token_id': None,\n",
58
- " 'remove_invalid_values': False,\n",
59
- " 'exponential_decay_length_penalty': None,\n",
60
- " 'suppress_tokens': None,\n",
61
- " 'begin_suppress_tokens': None,\n",
62
- " 'architectures': ['HfMoondream'],\n",
63
- " 'finetuning_task': None,\n",
64
- " 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'},\n",
65
- " 'label2id': {'LABEL_0': 0, 'LABEL_1': 1},\n",
66
- " 'tokenizer_class': None,\n",
67
- " 'prefix': None,\n",
68
- " 'bos_token_id': None,\n",
69
- " 'pad_token_id': None,\n",
70
- " 'eos_token_id': None,\n",
71
- " 'sep_token_id': None,\n",
72
- " 'decoder_start_token_id': None,\n",
73
- " 'task_specific_params': None,\n",
74
- " 'problem_type': None,\n",
75
- " '_name_or_path': 'vikhyatk/moondream2',\n",
76
- " '_attn_implementation_autoset': False,\n",
77
- " 'transformers_version': '4.49.0',\n",
78
- " 'auto_map': {'AutoConfig': 'vikhyatk/moondream2--hf_moondream.HfConfig',\n",
79
- " 'AutoModelForCausalLM': 'vikhyatk/moondream2--hf_moondream.HfMoondream'},\n",
80
- " 'config': {},\n",
81
- " 'model_type': 'moondream1'}"
82
- ]
83
- },
84
- "execution_count": 2,
85
- "metadata": {},
86
- "output_type": "execute_result"
87
- }
88
- ],
89
- "source": [
90
- "config.to_dict()\n"
91
- ]
92
- },
93
- {
94
- "cell_type": "code",
95
- "execution_count": 4,
96
- "metadata": {},
97
- "outputs": [],
98
- "source": [
99
- "from moondream2.hf_moondream import HfConfig\n",
100
- "# serialize config.json to data variable\n",
101
- "import json\n",
102
- "with open(\"moondream2/config.json\", \"r\") as f:\n",
103
- " data = json.load(f)\n",
104
- "configo = HfConfig(**data)\n"
105
- ]
106
- },
107
- {
108
- "cell_type": "code",
109
- "execution_count": 5,
110
- "metadata": {},
111
- "outputs": [
112
- {
113
- "name": "stdout",
114
- "output_type": "stream",
115
- "text": [
116
- "_name_or_path : vikhyatk/moondream2 != \n",
117
- "auto_map : {'AutoConfig': 'vikhyatk/moondream2--hf_moondream.HfConfig', 'AutoModelForCausalLM': 'vikhyatk/moondream2--hf_moondream.HfMoondream'} != {'AutoConfig': 'hf_moondream.HfConfig', 'AutoModelForCausalLM': 'hf_moondream.HfMoondream'}\n"
118
- ]
119
- }
120
- ],
121
- "source": [
122
- "# compare dicts of config.to_dict() and configo.to_dict() \n",
123
- "\n",
124
- "# make it check values of the dicts\n",
125
- "\n",
126
- "for key, value in config.to_dict().items():\n",
127
- " if key not in configo.to_dict():\n",
128
- " print(key + \" : \" + str(value) + \" not in hf_config\")\n",
129
- " elif value != configo.to_dict()[key]:\n",
130
- " print(key+ \" : \"+str(value)+\" != \"+str(configo.to_dict()[key]))\n",
131
- "\n",
132
- "for key, value in configo.to_dict().items():\n",
133
- " if key not in config.to_dict():\n",
134
- " print(key + \" : \" + str(value) + \" not in from_pretrained\")\n",
135
- "\n",
136
- "hparams = config.to_dict()\n"
137
- ]
138
- },
139
- {
140
- "cell_type": "code",
141
- "execution_count": 24,
142
- "metadata": {},
143
- "outputs": [],
144
- "source": [
145
- "text_config = hparams.get(\"text_config\", {})"
146
- ]
147
- },
148
- {
149
- "cell_type": "code",
150
- "execution_count": 26,
151
- "metadata": {},
152
- "outputs": [],
153
- "source": [
154
- "text_config.get(\"architectures\")\n"
155
- ]
156
- },
157
- {
158
- "cell_type": "code",
159
- "execution_count": 27,
160
- "metadata": {},
161
- "outputs": [],
162
- "source": [
163
- "hparams.get(\"vision_config\")"
164
- ]
165
- },
166
- {
167
- "cell_type": "code",
168
- "execution_count": null,
169
- "metadata": {},
170
- "outputs": [],
171
- "source": []
172
- }
173
- ],
174
- "metadata": {
175
- "kernelspec": {
176
- "display_name": "venv",
177
- "language": "python",
178
- "name": "python3"
179
- },
180
- "language_info": {
181
- "codemirror_mode": {
182
- "name": "ipython",
183
- "version": 3
184
- },
185
- "file_extension": ".py",
186
- "mimetype": "text/x-python",
187
- "name": "python",
188
- "nbconvert_exporter": "python",
189
- "pygments_lexer": "ipython3",
190
- "version": "3.12.4"
191
- }
192
- },
193
- "nbformat": 4,
194
- "nbformat_minor": 2
195
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test.py DELETED
@@ -1,41 +0,0 @@
1
- print("app Started")
2
- import torch
3
- from moondream2.config import MoondreamConfig
4
- from moondream2.moondream import MoondreamModel
5
- import torch.profiler
6
-
7
- config = MoondreamConfig()
8
- device = "cuda"
9
- model = MoondreamModel(config, setup_caches=False).to(device)
10
- from safetensors.torch import load_file
11
- weights_path = "moondream2/model.safetensors" # Path to your local weights file
12
- state_dict = load_file(weights_path, device=device)
13
- new_state_dict = {}
14
- for key, value in state_dict.items():
15
- # Remove 'model.' prefix if it exists
16
- if key.startswith('model.'):
17
- new_key = key[6:] # Skip the first 6 characters ('model.')
18
- else:
19
- new_key = key
20
- new_state_dict[new_key] = value
21
- state_dict = new_state_dict
22
- missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=True)
23
- model._setup_caches()
24
-
25
-
26
-
27
- from PIL import Image
28
- image = Image.open("example.png")
29
- query = "home icon at the bottom"
30
- warmup_iters = 2
31
-
32
-
33
- for i in range(3):
34
- if i == warmup_iters: torch.cuda.cudart().cudaProfilerStart()
35
- if i >= warmup_iters: torch.cuda.nvtx.range_push("iteration{}".format(i))
36
- if i >= warmup_iters: torch.cuda.nvtx.range_push("forward")
37
- points = model.point(image, query)["points"]
38
- if i >= warmup_iters: torch.cuda.nvtx.range_pop()
39
- if i >= warmup_iters: torch.cuda.nvtx.range_pop()
40
-
41
- torch.cuda.cudart().cudaProfilerStop()