Update CXR_LLAVA_HF.py
Browse files- CXR_LLAVA_HF.py +2 -8
CXR_LLAVA_HF.py
CHANGED
|
@@ -615,11 +615,8 @@ class CXRLLAVAModel(PreTrainedModel):
|
|
| 615 |
images = self.vision_tower.image_processor(image, return_tensors='pt')['pixel_values']
|
| 616 |
images = images.to(self.device)
|
| 617 |
input_ids = self.tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
print('using cuda')
|
| 621 |
-
else:
|
| 622 |
-
print(f'using device {self.device}')
|
| 623 |
stopping_criteria = KeywordsStoppingCriteria(["</s>"], self.tokenizer, input_ids)
|
| 624 |
|
| 625 |
image_args = {"images": images}
|
|
@@ -642,11 +639,8 @@ class CXRLLAVAModel(PreTrainedModel):
|
|
| 642 |
))
|
| 643 |
thread.start()
|
| 644 |
generated_text = ""
|
| 645 |
-
text_len = 0
|
| 646 |
for new_text in streamer:
|
| 647 |
generated_text += new_text
|
| 648 |
-
text_len += 1
|
| 649 |
-
if text_len > 200: break
|
| 650 |
|
| 651 |
return generated_text
|
| 652 |
|
|
|
|
| 615 |
images = self.vision_tower.image_processor(image, return_tensors='pt')['pixel_values']
|
| 616 |
images = images.to(self.device)
|
| 617 |
input_ids = self.tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)
|
| 618 |
+
input_ids = input_ids.to(self.device)
|
| 619 |
+
# print(f'using device {self.device}')
|
|
|
|
|
|
|
|
|
|
| 620 |
stopping_criteria = KeywordsStoppingCriteria(["</s>"], self.tokenizer, input_ids)
|
| 621 |
|
| 622 |
image_args = {"images": images}
|
|
|
|
| 639 |
))
|
| 640 |
thread.start()
|
| 641 |
generated_text = ""
|
|
|
|
| 642 |
for new_text in streamer:
|
| 643 |
generated_text += new_text
|
|
|
|
|
|
|
| 644 |
|
| 645 |
return generated_text
|
| 646 |
|