hysts HF Staff commited on
Commit
8aaee8a
·
1 Parent(s): 2c95709

Enable ChatInterface stop button

Browse files
Files changed (1) hide show
  1. app.py +26 -9
app.py CHANGED
@@ -5,7 +5,7 @@ from threading import Thread
5
  import gradio as gr
6
  import spaces
7
  import torch
8
- from transformers import AutoModelForMultimodalLM, AutoProcessor, BatchFeature
9
  from transformers.generation.streamers import TextIteratorStreamer
10
 
11
  MODEL_ID = "google/gemma-4-e4b-it"
@@ -91,6 +91,14 @@ def process_history(history: list[dict]) -> list[dict]:
91
  return messages
92
 
93
 
 
 
 
 
 
 
 
 
94
  @spaces.GPU(duration=120)
95
  @torch.inference_mode()
96
  def _generate_on_gpu(inputs: BatchFeature, max_new_tokens: int, thinking: bool) -> Iterator[str]:
@@ -102,9 +110,11 @@ def _generate_on_gpu(inputs: BatchFeature, max_new_tokens: int, thinking: bool)
102
  skip_prompt=True,
103
  skip_special_tokens=not thinking,
104
  )
 
105
  generate_kwargs = {
106
  **inputs,
107
  "streamer": streamer,
 
108
  "max_new_tokens": max_new_tokens,
109
  "disable_compile": True,
110
  }
@@ -121,13 +131,20 @@ def _generate_on_gpu(inputs: BatchFeature, max_new_tokens: int, thinking: bool)
121
  thread.start()
122
 
123
  chunks: list[str] = []
124
- for text in streamer:
125
- chunks.append(text)
126
- accumulated = "".join(chunks)
127
- if thinking:
128
- yield _strip_special_tokens(accumulated)
129
- else:
130
- yield accumulated
 
 
 
 
 
 
 
131
 
132
  thread.join()
133
  if exception_holder:
@@ -292,6 +309,7 @@ demo = gr.ChatInterface(
292
  file_types=[*IMAGE_FILE_TYPES, *AUDIO_FILE_TYPES, *VIDEO_FILE_TYPES],
293
  file_count="multiple",
294
  autofocus=True,
 
295
  ),
296
  multimodal=True,
297
  additional_inputs=[
@@ -306,7 +324,6 @@ demo = gr.ChatInterface(
306
  gr.Textbox(label="System Prompt", value=""),
307
  ],
308
  additional_inputs_accordion=gr.Accordion("Settings", open=True),
309
- stop_btn=False,
310
  title="Gemma 4 E4B It",
311
  examples=examples,
312
  run_examples_on_click=False,
 
5
  import gradio as gr
6
  import spaces
7
  import torch
8
+ from transformers import AutoModelForMultimodalLM, AutoProcessor, BatchFeature, StoppingCriteria
9
  from transformers.generation.streamers import TextIteratorStreamer
10
 
11
  MODEL_ID = "google/gemma-4-e4b-it"
 
91
  return messages
92
 
93
 
94
+ class StopOnSignal(StoppingCriteria):
95
+ def __init__(self) -> None:
96
+ self.stopped = False
97
+
98
+ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, **kwargs: object) -> bool: # noqa: ARG002
99
+ return self.stopped
100
+
101
+
102
  @spaces.GPU(duration=120)
103
  @torch.inference_mode()
104
  def _generate_on_gpu(inputs: BatchFeature, max_new_tokens: int, thinking: bool) -> Iterator[str]:
 
110
  skip_prompt=True,
111
  skip_special_tokens=not thinking,
112
  )
113
+ stop_criteria = StopOnSignal()
114
  generate_kwargs = {
115
  **inputs,
116
  "streamer": streamer,
117
+ "stopping_criteria": [stop_criteria],
118
  "max_new_tokens": max_new_tokens,
119
  "disable_compile": True,
120
  }
 
131
  thread.start()
132
 
133
  chunks: list[str] = []
134
+ try:
135
+ for text in streamer:
136
+ chunks.append(text)
137
+ accumulated = "".join(chunks)
138
+ if thinking:
139
+ yield _strip_special_tokens(accumulated)
140
+ else:
141
+ yield accumulated
142
+ except GeneratorExit:
143
+ stop_criteria.stopped = True
144
+ for _ in streamer:
145
+ pass
146
+ thread.join()
147
+ raise
148
 
149
  thread.join()
150
  if exception_holder:
 
309
  file_types=[*IMAGE_FILE_TYPES, *AUDIO_FILE_TYPES, *VIDEO_FILE_TYPES],
310
  file_count="multiple",
311
  autofocus=True,
312
+ stop_btn=True,
313
  ),
314
  multimodal=True,
315
  additional_inputs=[
 
324
  gr.Textbox(label="System Prompt", value=""),
325
  ],
326
  additional_inputs_accordion=gr.Accordion("Settings", open=True),
 
327
  title="Gemma 4 E4B It",
328
  examples=examples,
329
  run_examples_on_click=False,