Update CXR_LLAVA_HF.py
Browse files- CXR_LLAVA_HF.py +5 -1
CXR_LLAVA_HF.py
CHANGED
|
@@ -10,6 +10,7 @@ from threading import Thread
|
|
| 10 |
from dataclasses import dataclass
|
| 11 |
import numpy as np
|
| 12 |
from PIL import Image
|
|
|
|
| 13 |
# Model Constants
|
| 14 |
IGNORE_INDEX = -100
|
| 15 |
IMAGE_TOKEN_INDEX = -200
|
|
@@ -597,7 +598,7 @@ class CXRLLAVAModel(PreTrainedModel):
|
|
| 597 |
|
| 598 |
def generate_cxr_repsonse(self, chat, image, temperature=0.2, top_p=0.8):
|
| 599 |
with torch.no_grad():
|
| 600 |
-
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=
|
| 601 |
|
| 602 |
if np.array(image).max()>255:
|
| 603 |
raise Exception("16-bit image is not supported.")
|
|
@@ -636,8 +637,11 @@ class CXRLLAVAModel(PreTrainedModel):
|
|
| 636 |
))
|
| 637 |
thread.start()
|
| 638 |
generated_text = ""
|
|
|
|
| 639 |
for new_text in streamer:
|
| 640 |
generated_text += new_text
|
|
|
|
|
|
|
| 641 |
|
| 642 |
return generated_text
|
| 643 |
|
|
|
|
| 10 |
from dataclasses import dataclass
|
| 11 |
import numpy as np
|
| 12 |
from PIL import Image
|
| 13 |
+
|
| 14 |
# Model Constants
|
| 15 |
IGNORE_INDEX = -100
|
| 16 |
IMAGE_TOKEN_INDEX = -200
|
|
|
|
| 598 |
|
| 599 |
def generate_cxr_repsonse(self, chat, image, temperature=0.2, top_p=0.8):
|
| 600 |
with torch.no_grad():
|
| 601 |
+
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=180)
|
| 602 |
|
| 603 |
if np.array(image).max()>255:
|
| 604 |
raise Exception("16-bit image is not supported.")
|
|
|
|
| 637 |
))
|
| 638 |
thread.start()
|
| 639 |
generated_text = ""
|
| 640 |
+
text_len = 0
|
| 641 |
for new_text in streamer:
|
| 642 |
generated_text += new_text
|
| 643 |
+
text_len += 1
|
| 644 |
+
if text_len > 200: break
|
| 645 |
|
| 646 |
return generated_text
|
| 647 |
|