Update README.md
Browse files
README.md
CHANGED
|
@@ -136,14 +136,15 @@ def get_response(image_list, prompt: str, model=None, image_processor=None, in_c
|
|
| 136 |
],
|
| 137 |
return_tensors="pt",
|
| 138 |
)
|
| 139 |
-
|
| 140 |
generated_text = model.generate(
|
| 141 |
vision_x=vision_x.to(model.device),
|
| 142 |
lang_x=lang_x["input_ids"].to(model.device),
|
| 143 |
attention_mask=lang_x["attention_mask"].to(model.device),
|
| 144 |
max_new_tokens=512,
|
| 145 |
-
|
| 146 |
-
|
|
|
|
| 147 |
)
|
| 148 |
parsed_output = (
|
| 149 |
model.text_tokenizer.decode(generated_text[0])
|
|
|
|
| 136 |
],
|
| 137 |
return_tensors="pt",
|
| 138 |
)
|
| 139 |
+
bad_words_id = tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
|
| 140 |
generated_text = model.generate(
|
| 141 |
vision_x=vision_x.to(model.device),
|
| 142 |
lang_x=lang_x["input_ids"].to(model.device),
|
| 143 |
attention_mask=lang_x["attention_mask"].to(model.device),
|
| 144 |
max_new_tokens=512,
|
| 145 |
+
num_beams=3,
|
| 146 |
+
no_repeat_ngram_size=3,
|
| 147 |
+
bad_words_ids=bad_words_id,
|
| 148 |
)
|
| 149 |
parsed_output = (
|
| 150 |
model.text_tokenizer.decode(generated_text[0])
|