ThomasSimonini commited on
Commit
c22676d
·
verified ·
1 Parent(s): e406d5b

Upload moondream.py

Browse files
Files changed (1) hide show
  1. moondream.py +0 -188
moondream.py CHANGED
@@ -1,4 +1,3 @@
1
- """
2
  import torch
3
  from .vision_encoder import VisionEncoder
4
  from .configuration_moondream import MoondreamConfig
@@ -178,190 +177,3 @@ class Moondream(PreTrainedModel):
178
  x.strip()
179
  for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
180
  ]
181
- """
182
-
183
- import torch
184
- from .vision_encoder import VisionEncoder
185
- from .configuration_moondream import MoondreamConfig
186
- from transformers import PreTrainedModel, TextIteratorStreamer
187
-
188
- from .modeling_phi import PhiForCausalLM
189
- from .configuration_moondream import PhiConfig
190
-
191
- class Moondream(PreTrainedModel):
192
- config_class = MoondreamConfig
193
- _supports_flash_attn_2 = True
194
-
195
- def __init__(self, config):
196
- super().__init__(config)
197
- self.vision_encoder = VisionEncoder(
198
- use_flash_attn=config._attn_implementation == "flash_attention_2"
199
- )
200
-
201
- if type(config.text_config) == dict:
202
- phi_config = PhiConfig(
203
- **config.text_config, attn_implementation=config._attn_implementation
204
- )
205
- else:
206
- phi_config = config.text_config
207
- self.text_model = PhiForCausalLM(phi_config)
208
-
209
- @property
210
- def device(self):
211
- return self.text_model.device
212
-
213
- def encode_image(self, image):
214
- with torch.no_grad():
215
- return self.vision_encoder(image)
216
-
217
- def input_embeds(self, prompt, image_embeds, tokenizer):
218
- def _tokenize(txt):
219
- return tokenizer(
220
- txt, return_tensors="pt", add_special_tokens=False
221
- ).input_ids.to(self.device)
222
-
223
- text_emb = self.text_model.get_input_embeddings()
224
-
225
- # Add BOS token
226
- embeds = []
227
- embeds.append(
228
- text_emb((torch.tensor([[tokenizer.bos_token_id]], device=self.device)))
229
- )
230
-
231
- if "<image>" not in prompt:
232
- embeds.append(text_emb(_tokenize(prompt)))
233
- else:
234
- assert prompt.count("<image>") == 1
235
- before, after = prompt.split("<image>")
236
- if len(before) > 0:
237
- embeds.append(text_emb(_tokenize(before)))
238
- embeds.append(image_embeds.to(self.device))
239
- if len(after) > 0:
240
- embeds.append(text_emb(_tokenize(after)))
241
-
242
- return torch.cat(embeds, dim=1)
243
-
244
- def get_input_embeddings(self):
245
- return self.text_model.get_input_embeddings()
246
-
247
- def generate(
248
- self,
249
- image_embeds,
250
- prompt,
251
- tokenizer,
252
- max_new_tokens=128,
253
- **kwargs,
254
- ):
255
- generate_config = {
256
- "eos_token_id": tokenizer.eos_token_id,
257
- "bos_token_id": tokenizer.bos_token_id,
258
- "pad_token_id": tokenizer.bos_token_id,
259
- "max_new_tokens": max_new_tokens,
260
- **kwargs,
261
- }
262
-
263
- with torch.no_grad():
264
- inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
265
- streamer = TextIteratorStreamer(tokenizer) #, timeout=10.0, skip_prompt=True, skip_special_tokens=True
266
- output_ids = self.text_model.generate(
267
- inputs_embeds=inputs_embeds, streamer=streamer, **generate_config
268
- )
269
-
270
- model_output = ""
271
- for new_text in streamer:
272
- model_output += new_text
273
- print("NEWTEXT" + new_text)
274
- yield new_text
275
-
276
- return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
277
-
278
- def answer_question(
279
- self,
280
- image_embeds,
281
- question,
282
- tokenizer,
283
- chat_history="",
284
- result_queue=None,
285
- **kwargs,
286
- ):
287
- prompt = f"<image>\n\n{chat_history}Question: {question}\n\nAnswer:"
288
- answer = self.generate(
289
- image_embeds,
290
- prompt,
291
- tokenizer=tokenizer,
292
- max_new_tokens=512,
293
- **kwargs,
294
- )[0]
295
- cleaned_answer = answer.strip()
296
-
297
- # Use the result_queue to pass the result if it is provided
298
- if result_queue:
299
- result_queue.put(cleaned_answer)
300
- else:
301
- return cleaned_answer
302
-
303
- def batch_answer(
304
- self,
305
- images,
306
- prompts,
307
- tokenizer,
308
- **kwargs,
309
- ):
310
- image_embeds = self.encode_image(images)
311
-
312
- templated_prompts = [
313
- f"<image>\n\nQuestion: {prompt}\n\nAnswer:" for prompt in prompts
314
- ]
315
- prompt_embs = [
316
- self.input_embeds(prompt, image_embed.unsqueeze(0), tokenizer)[0]
317
- for prompt, image_embed in zip(templated_prompts, image_embeds)
318
- ]
319
-
320
- bos_emb = prompt_embs[0][0]
321
- max_len = max([p.shape[0] for p in prompt_embs])
322
-
323
- inputs_embeds = torch.cat(
324
- [
325
- torch.cat([bos_emb.repeat(max_len - p.shape[0], 1), p]).unsqueeze(0)
326
- for p in prompt_embs
327
- ],
328
- dim=0,
329
- )
330
- attention_mask = torch.cat(
331
- [
332
- torch.cat(
333
- [
334
- torch.zeros(
335
- 1,
336
- max_len - p.shape[0],
337
- device=self.device,
338
- dtype=torch.long,
339
- ),
340
- torch.ones(1, p.shape[0], device=self.device, dtype=torch.long),
341
- ],
342
- dim=1,
343
- )
344
- for p in prompt_embs
345
- ],
346
- dim=0,
347
- )
348
-
349
- generate_config = {
350
- "eos_token_id": tokenizer.eos_token_id,
351
- "bos_token_id": tokenizer.bos_token_id,
352
- "pad_token_id": tokenizer.bos_token_id,
353
- "max_new_tokens": 512,
354
- **kwargs,
355
- }
356
-
357
- with torch.no_grad():
358
- output_ids = self.text_model.generate(
359
- inputs_embeds=inputs_embeds,
360
- attention_mask=attention_mask,
361
- **generate_config,
362
- )
363
-
364
- return [
365
- x.strip()
366
- for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
367
- ]
 
 
1
  import torch
2
  from .vision_encoder import VisionEncoder
3
  from .configuration_moondream import MoondreamConfig
 
177
  x.strip()
178
  for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
179
  ]