Spaces:
Running
on
Zero
Running
on
Zero
Update raw.py
Browse files
raw.py
CHANGED
|
@@ -105,34 +105,27 @@ def caption(input_image: Image.Image, prompt: str, temperature: float, top_p: fl
|
|
| 105 |
# WARNING: HF's handling of chat's on Llava models is very fragile. This specific combination of processor.apply_chat_template(), and processor() works
|
| 106 |
# but if using other combinations always inspect the final input_ids to ensure they are correct. Often times you will end up with multiple <bos> tokens
|
| 107 |
# if not careful, which can make the model perform poorly.
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
| 130 |
-
t.start()
|
| 131 |
-
|
| 132 |
-
outputs = []
|
| 133 |
-
for text in streamer:
|
| 134 |
-
outputs.append(text)
|
| 135 |
-
yield "".join(outputs)
|
| 136 |
|
| 137 |
@spaces.GPU()
|
| 138 |
@torch.no_grad()
|
|
|
|
| 105 |
# WARNING: HF's handling of chat's on Llava models is very fragile. This specific combination of processor.apply_chat_template(), and processor() works
|
| 106 |
# but if using other combinations always inspect the final input_ids to ensure they are correct. Often times you will end up with multiple <bos> tokens
|
| 107 |
# if not careful, which can make the model perform poorly.
|
| 108 |
+
convo_string = cap_processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
|
| 109 |
+
assert isinstance(convo_string, str)
|
| 110 |
+
inputs = cap_processor(text=[convo_string], images=[input_image], return_tensors="pt").to('cuda')
|
| 111 |
+
inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
|
| 112 |
+
streamer = TextIteratorStreamer(cap_processor.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
| 113 |
+
generate_kwargs = dict(
|
| 114 |
+
**inputs,
|
| 115 |
+
max_new_tokens=max_new_tokens,
|
| 116 |
+
do_sample=True if temperature > 0 else False,
|
| 117 |
+
suppress_tokens=None,
|
| 118 |
+
use_cache=True,
|
| 119 |
+
temperature=temperature if temperature > 0 else None,
|
| 120 |
+
top_k=None,
|
| 121 |
+
top_p=top_p if temperature > 0 else None,
|
| 122 |
+
streamer=streamer,
|
| 123 |
+
)
|
| 124 |
+
_ = cap_model.generate(**generate_kwargs)
|
| 125 |
+
outputs = []
|
| 126 |
+
for text in streamer:
|
| 127 |
+
outputs.append(text)
|
| 128 |
+
yield "".join(outputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
@spaces.GPU()
|
| 131 |
@torch.no_grad()
|