Subh775 commited on
Commit
7ae323f
·
verified ·
1 Parent(s): a513bbf

Add moondream.py for self-contained custom code

Browse files
Files changed (1) hide show
  1. moondream.py +230 -0
moondream.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from typing import List, Union, Literal, Optional
4
+ from transformers import PreTrainedModel
5
+ from PIL import Image
6
+
7
+ from .configuration_moondream import PhiConfig
8
+ from .configuration_moondream import MoondreamConfig
9
+ from .vision_encoder import VisionEncoder
10
+ from .region_model import RegionModel
11
+ from .modeling_phi import PhiForCausalLM
12
+
13
+ class Moondream(PreTrainedModel):
14
+ config_class = MoondreamConfig
15
+ _supports_flash_attn_2 = True
16
+
17
+ def __init__(self, config):
18
+ super().__init__(config)
19
+ self.vision_encoder = VisionEncoder(
20
+ use_flash_attn=config._attn_implementation == "flash_attention_2"
21
+ )
22
+ self.region_model = RegionModel()
23
+
24
+ if type(config.text_config) == dict:
25
+ phi_config = PhiConfig(
26
+ **config.text_config, attn_implementation=config._attn_implementation
27
+ )
28
+ else:
29
+ phi_config = config.text_config
30
+ self.text_model = PhiForCausalLM(phi_config)
31
+
32
+ @property
33
+ def device(self):
34
+ return self.text_model.device
35
+
36
+ def encode_image(self, image):
37
+ with torch.no_grad():
38
+ return self.vision_encoder(image)
39
+
40
+ def input_embeds(self, prompt, image_embeds, tokenizer):
41
+ def _tokenize(txt):
42
+ return tokenizer(
43
+ txt, return_tensors="pt", add_special_tokens=False
44
+ ).input_ids.to(self.device)
45
+
46
+ text_emb = self.text_model.get_input_embeddings()
47
+
48
+ # Add BOS token
49
+ embeds = []
50
+ embeds.append(
51
+ text_emb((torch.tensor([[tokenizer.bos_token_id]], device=self.device)))
52
+ )
53
+
54
+ if "<image>" not in prompt:
55
+ embeds.append(text_emb(_tokenize(prompt)))
56
+ else:
57
+ assert prompt.count("<image>") == 1
58
+ before, after = prompt.split("<image>")
59
+ if len(before) > 0:
60
+ embeds.append(text_emb(_tokenize(before)))
61
+ embeds.append(image_embeds.to(self.device))
62
+ if len(after) > 0:
63
+ embeds.append(text_emb(_tokenize(after)))
64
+
65
+ return torch.cat(embeds, dim=1)
66
+
67
+ def get_input_embeddings(self):
68
+ return self.text_model.get_input_embeddings()
69
+
70
+ def generate(
71
+ self,
72
+ image_embeds,
73
+ prompt,
74
+ tokenizer,
75
+ max_new_tokens=128,
76
+ **kwargs,
77
+ ):
78
+ generate_config = {
79
+ "eos_token_id": tokenizer.eos_token_id,
80
+ "bos_token_id": tokenizer.bos_token_id,
81
+ "pad_token_id": tokenizer.bos_token_id,
82
+ "max_new_tokens": max_new_tokens,
83
+ **kwargs,
84
+ }
85
+
86
+ with torch.no_grad():
87
+ inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
88
+ attention_mask = torch.ones((inputs_embeds.shape[0], inputs_embeds.shape[1]), device=self.device)
89
+ output_ids = self.text_model.generate(
90
+ inputs_embeds=inputs_embeds,
91
+ attention_mask=attention_mask,
92
+ **generate_config,
93
+ )
94
+
95
+ return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
96
+
97
+ # Note: Not ready for use yet, intended for September release.
98
+ def caption(
99
+ self,
100
+ images: List[Image.Image],
101
+ tokenizer,
102
+ length: Optional[Literal["short"]] = None,
103
+ **kwargs,
104
+ ):
105
+ image_embeds = self.encode_image(images)
106
+
107
+ templated_prompts = [
108
+ f"<image>\n\n{'Short caption' if length == 'short' else 'Caption'}:" for _ in images
109
+ ]
110
+ inputs_embeds = torch.stack([
111
+ self.input_embeds(prompt, image_embed.unsqueeze(0), tokenizer)[0]
112
+ for prompt, image_embed in zip(templated_prompts, image_embeds)
113
+ ])
114
+ attention_mask = torch.ones((inputs_embeds.shape[0], inputs_embeds.shape[1]), device=self.device)
115
+
116
+ generate_config = {
117
+ "eos_token_id": tokenizer.eos_token_id,
118
+ "bos_token_id": tokenizer.bos_token_id,
119
+ "pad_token_id": tokenizer.bos_token_id,
120
+ "repetition_penalty": 1.2,
121
+ "max_new_tokens": 512,
122
+ **kwargs,
123
+ }
124
+
125
+ with torch.no_grad():
126
+ output_ids = self.text_model.generate(
127
+ inputs_embeds=inputs_embeds,
128
+ attention_mask=attention_mask,
129
+ **generate_config,
130
+ )
131
+
132
+ return [
133
+ x.strip()
134
+ for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
135
+ ]
136
+
137
+ def answer_question(
138
+ self,
139
+ image_embeds,
140
+ question,
141
+ tokenizer,
142
+ chat_history="",
143
+ result_queue=None,
144
+ max_new_tokens=256,
145
+ **kwargs,
146
+ ):
147
+ prompt = f"<image>\n\n{chat_history}Question: {question}\n\nAnswer:"
148
+ answer = self.generate(
149
+ image_embeds,
150
+ prompt,
151
+ tokenizer=tokenizer,
152
+ max_new_tokens=max_new_tokens,
153
+ **kwargs,
154
+ )[0]
155
+ cleaned_answer = answer.strip()
156
+
157
+ # Use the result_queue to pass the result if it is provided
158
+ if result_queue:
159
+ result_queue.put(cleaned_answer)
160
+ else:
161
+ return cleaned_answer
162
+
163
+ def batch_answer(
164
+ self,
165
+ images,
166
+ prompts,
167
+ tokenizer,
168
+ **kwargs,
169
+ ):
170
+ image_embeds = self.encode_image(images)
171
+
172
+ templated_prompts = [
173
+ f"<image>\n\nQuestion: {prompt}\n\nAnswer:" for prompt in prompts
174
+ ]
175
+ prompt_embs = [
176
+ self.input_embeds(prompt, image_embed.unsqueeze(0), tokenizer)[0]
177
+ for prompt, image_embed in zip(templated_prompts, image_embeds)
178
+ ]
179
+
180
+ bos_emb = prompt_embs[0][0]
181
+ max_len = max([p.shape[0] for p in prompt_embs])
182
+
183
+ inputs_embeds = torch.cat(
184
+ [
185
+ torch.cat([bos_emb.repeat(max_len - p.shape[0], 1), p]).unsqueeze(0)
186
+ for p in prompt_embs
187
+ ],
188
+ dim=0,
189
+ )
190
+ attention_mask = torch.cat(
191
+ [
192
+ torch.cat(
193
+ [
194
+ torch.zeros(
195
+ 1,
196
+ max_len - p.shape[0],
197
+ device=self.device,
198
+ dtype=torch.long,
199
+ ),
200
+ torch.ones(1, p.shape[0], device=self.device, dtype=torch.long),
201
+ ],
202
+ dim=1,
203
+ )
204
+ for p in prompt_embs
205
+ ],
206
+ dim=0,
207
+ )
208
+
209
+ generate_config = {
210
+ "eos_token_id": tokenizer.eos_token_id,
211
+ "bos_token_id": tokenizer.bos_token_id,
212
+ "pad_token_id": tokenizer.bos_token_id,
213
+ "max_new_tokens": 512,
214
+ **kwargs,
215
+ }
216
+
217
+ with torch.no_grad():
218
+ output_ids = self.text_model.generate(
219
+ inputs_embeds=inputs_embeds,
220
+ attention_mask=attention_mask,
221
+ **generate_config,
222
+ )
223
+
224
+ return [
225
+ x.strip()
226
+ for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
227
+ ]
228
+
229
+ def detect(self, image: Image.Image, query: str, tokenizer):
230
+ pass