Update app.py
Browse files
app.py
CHANGED
|
@@ -20,6 +20,7 @@ from transformers import (
|
|
| 20 |
AutoModelForVision2Seq,
|
| 21 |
AutoProcessor,
|
| 22 |
TextIteratorStreamer,
|
|
|
|
| 23 |
)
|
| 24 |
from transformers.image_utils import load_image
|
| 25 |
|
|
@@ -137,6 +138,8 @@ def model_chat(prompt, image):
|
|
| 137 |
add_special_tokens=False,
|
| 138 |
return_tensors="pt"
|
| 139 |
).to(device)
|
|
|
|
|
|
|
| 140 |
outputs = model.generate(
|
| 141 |
pixel_values=pixel_values,
|
| 142 |
decoder_input_ids=prompt_inputs.input_ids,
|
|
@@ -150,7 +153,8 @@ def model_chat(prompt, image):
|
|
| 150 |
return_dict_in_generate=True,
|
| 151 |
do_sample=False,
|
| 152 |
num_beams=1,
|
| 153 |
-
repetition_penalty=1.1
|
|
|
|
| 154 |
)
|
| 155 |
sequence = processor.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0]
|
| 156 |
cleaned = sequence.replace(f"<s>{prompt} <Answer/>", "").replace("<pad>", "").replace("</s>", "").strip()
|
|
|
|
| 20 |
AutoModelForVision2Seq,
|
| 21 |
AutoProcessor,
|
| 22 |
TextIteratorStreamer,
|
| 23 |
+
EncoderDecoderCache # Added to handle the new caching mechanism
|
| 24 |
)
|
| 25 |
from transformers.image_utils import load_image
|
| 26 |
|
|
|
|
| 138 |
add_special_tokens=False,
|
| 139 |
return_tensors="pt"
|
| 140 |
).to(device)
|
| 141 |
+
|
| 142 |
+
# Explicitly set past_key_values to None to align with new caching mechanism and avoid deprecated tuple warning
|
| 143 |
outputs = model.generate(
|
| 144 |
pixel_values=pixel_values,
|
| 145 |
decoder_input_ids=prompt_inputs.input_ids,
|
|
|
|
| 153 |
return_dict_in_generate=True,
|
| 154 |
do_sample=False,
|
| 155 |
num_beams=1,
|
| 156 |
+
repetition_penalty=1.1,
|
| 157 |
+
past_key_values=None # Added to prevent deprecated tuple handling
|
| 158 |
)
|
| 159 |
sequence = processor.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0]
|
| 160 |
cleaned = sequence.replace(f"<s>{prompt} <Answer/>", "").replace("<pad>", "").replace("</s>", "").strip()
|