ThomasSimonini commited on
Commit
5f8bba8
·
verified ·
1 Parent(s): 523dbb8

Upload moondream.py

Browse files
Files changed (1) hide show
  1. moondream.py +2 -25
moondream.py CHANGED
@@ -84,10 +84,10 @@ class Moondream(PreTrainedModel):
84
  output_ids = self.text_model.generate(
85
  inputs_embeds=inputs_embeds, streamer=streamer, **generate_config
86
  )
 
87
 
88
  return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
89
 
90
- """
91
  def answer_question(
92
  self,
93
  image_embeds,
@@ -112,29 +112,6 @@ class Moondream(PreTrainedModel):
112
  result_queue.put(cleaned_answer)
113
  else:
114
  return cleaned_answer
115
- """
116
- async def answer_question(
117
- self,
118
- image_embeds,
119
- question,
120
- tokenizer,
121
- chat_history="",
122
- result_queue=None,
123
- **kwargs,
124
- ):
125
- prompt = f"<image>\n\n{chat_history}Question: {question}\n\nAnswer:"
126
- streamer = TextStreamer(tokenizer)
127
-
128
-
129
- output_ids = self.text_model.generate(
130
- inputs_embeds=self.input_embeds(prompt, image_embeds, tokenizer),
131
- attention_mask=torch.ones(self.inputs_embeds.shape, dtype=torch.long, device=self.device), #maybe remove
132
- streamer=streamer,
133
- **kwargs,
134
- )
135
-
136
- for output_id in output_ids:
137
- yield tokenizer.decode(output_id, skip_special_tokens=True)
138
 
139
  def batch_answer(
140
  self,
@@ -200,4 +177,4 @@ class Moondream(PreTrainedModel):
200
  return [
201
  x.strip()
202
  for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
203
- ]
 
84
  output_ids = self.text_model.generate(
85
  inputs_embeds=inputs_embeds, streamer=streamer, **generate_config
86
  )
87
+ print("OUTPUTIDS" + output_ids)
88
 
89
  return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
90
 
 
91
  def answer_question(
92
  self,
93
  image_embeds,
 
112
  result_queue.put(cleaned_answer)
113
  else:
114
  return cleaned_answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  def batch_answer(
117
  self,
 
177
  return [
178
  x.strip()
179
  for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
180
+ ]