Spaces:
Running on Zero
Running on Zero
Enable ChatInterface stop button
Browse files
app.py
CHANGED
|
@@ -5,7 +5,7 @@ from threading import Thread
|
|
| 5 |
import gradio as gr
|
| 6 |
import spaces
|
| 7 |
import torch
|
| 8 |
-
from transformers import AutoModelForMultimodalLM, AutoProcessor, BatchFeature
|
| 9 |
from transformers.generation.streamers import TextIteratorStreamer
|
| 10 |
|
| 11 |
MODEL_ID = "google/gemma-4-e4b-it"
|
|
@@ -91,6 +91,14 @@ def process_history(history: list[dict]) -> list[dict]:
|
|
| 91 |
return messages
|
| 92 |
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
@spaces.GPU(duration=120)
|
| 95 |
@torch.inference_mode()
|
| 96 |
def _generate_on_gpu(inputs: BatchFeature, max_new_tokens: int, thinking: bool) -> Iterator[str]:
|
|
@@ -102,9 +110,11 @@ def _generate_on_gpu(inputs: BatchFeature, max_new_tokens: int, thinking: bool)
|
|
| 102 |
skip_prompt=True,
|
| 103 |
skip_special_tokens=not thinking,
|
| 104 |
)
|
|
|
|
| 105 |
generate_kwargs = {
|
| 106 |
**inputs,
|
| 107 |
"streamer": streamer,
|
|
|
|
| 108 |
"max_new_tokens": max_new_tokens,
|
| 109 |
"disable_compile": True,
|
| 110 |
}
|
|
@@ -121,13 +131,20 @@ def _generate_on_gpu(inputs: BatchFeature, max_new_tokens: int, thinking: bool)
|
|
| 121 |
thread.start()
|
| 122 |
|
| 123 |
chunks: list[str] = []
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
thread.join()
|
| 133 |
if exception_holder:
|
|
@@ -292,6 +309,7 @@ demo = gr.ChatInterface(
|
|
| 292 |
file_types=[*IMAGE_FILE_TYPES, *AUDIO_FILE_TYPES, *VIDEO_FILE_TYPES],
|
| 293 |
file_count="multiple",
|
| 294 |
autofocus=True,
|
|
|
|
| 295 |
),
|
| 296 |
multimodal=True,
|
| 297 |
additional_inputs=[
|
|
@@ -306,7 +324,6 @@ demo = gr.ChatInterface(
|
|
| 306 |
gr.Textbox(label="System Prompt", value=""),
|
| 307 |
],
|
| 308 |
additional_inputs_accordion=gr.Accordion("Settings", open=True),
|
| 309 |
-
stop_btn=False,
|
| 310 |
title="Gemma 4 E4B It",
|
| 311 |
examples=examples,
|
| 312 |
run_examples_on_click=False,
|
|
|
|
| 5 |
import gradio as gr
|
| 6 |
import spaces
|
| 7 |
import torch
|
| 8 |
+
from transformers import AutoModelForMultimodalLM, AutoProcessor, BatchFeature, StoppingCriteria
|
| 9 |
from transformers.generation.streamers import TextIteratorStreamer
|
| 10 |
|
| 11 |
MODEL_ID = "google/gemma-4-e4b-it"
|
|
|
|
| 91 |
return messages
|
| 92 |
|
| 93 |
|
| 94 |
+
class StopOnSignal(StoppingCriteria):
|
| 95 |
+
def __init__(self) -> None:
|
| 96 |
+
self.stopped = False
|
| 97 |
+
|
| 98 |
+
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, **kwargs: object) -> bool: # noqa: ARG002
|
| 99 |
+
return self.stopped
|
| 100 |
+
|
| 101 |
+
|
| 102 |
@spaces.GPU(duration=120)
|
| 103 |
@torch.inference_mode()
|
| 104 |
def _generate_on_gpu(inputs: BatchFeature, max_new_tokens: int, thinking: bool) -> Iterator[str]:
|
|
|
|
| 110 |
skip_prompt=True,
|
| 111 |
skip_special_tokens=not thinking,
|
| 112 |
)
|
| 113 |
+
stop_criteria = StopOnSignal()
|
| 114 |
generate_kwargs = {
|
| 115 |
**inputs,
|
| 116 |
"streamer": streamer,
|
| 117 |
+
"stopping_criteria": [stop_criteria],
|
| 118 |
"max_new_tokens": max_new_tokens,
|
| 119 |
"disable_compile": True,
|
| 120 |
}
|
|
|
|
| 131 |
thread.start()
|
| 132 |
|
| 133 |
chunks: list[str] = []
|
| 134 |
+
try:
|
| 135 |
+
for text in streamer:
|
| 136 |
+
chunks.append(text)
|
| 137 |
+
accumulated = "".join(chunks)
|
| 138 |
+
if thinking:
|
| 139 |
+
yield _strip_special_tokens(accumulated)
|
| 140 |
+
else:
|
| 141 |
+
yield accumulated
|
| 142 |
+
except GeneratorExit:
|
| 143 |
+
stop_criteria.stopped = True
|
| 144 |
+
for _ in streamer:
|
| 145 |
+
pass
|
| 146 |
+
thread.join()
|
| 147 |
+
raise
|
| 148 |
|
| 149 |
thread.join()
|
| 150 |
if exception_holder:
|
|
|
|
| 309 |
file_types=[*IMAGE_FILE_TYPES, *AUDIO_FILE_TYPES, *VIDEO_FILE_TYPES],
|
| 310 |
file_count="multiple",
|
| 311 |
autofocus=True,
|
| 312 |
+
stop_btn=True,
|
| 313 |
),
|
| 314 |
multimodal=True,
|
| 315 |
additional_inputs=[
|
|
|
|
| 324 |
gr.Textbox(label="System Prompt", value=""),
|
| 325 |
],
|
| 326 |
additional_inputs_accordion=gr.Accordion("Settings", open=True),
|
|
|
|
| 327 |
title="Gemma 4 E4B It",
|
| 328 |
examples=examples,
|
| 329 |
run_examples_on_click=False,
|