Update app.py
Browse files
app.py
CHANGED
|
@@ -8,11 +8,11 @@ import av
|
|
| 8 |
import gradio as gr
|
| 9 |
import spaces
|
| 10 |
import torch
|
| 11 |
-
from
|
| 12 |
from gradio.processing_utils import save_audio_to_cache
|
|
|
|
| 13 |
from transformers import AutoModelForImageTextToText, AutoProcessor
|
| 14 |
from transformers.generation.streamers import TextIteratorStreamer
|
| 15 |
-
from fastrtc import ReplyOnPause, WebRTCData, WebRTC, AdditionalOutputs, get_hf_turn_credentials
|
| 16 |
|
| 17 |
model_id = "google/gemma-3n-E4B-it"
|
| 18 |
|
|
@@ -202,12 +202,19 @@ def _generate(message: dict, history: list[dict], system_prompt: str = "", max_n
|
|
| 202 |
|
| 203 |
@spaces.GPU(time_limit=120)
|
| 204 |
def generate(data: WebRTCData, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512, image=None):
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
new_message = {"role": "assistant", "content": ""}
|
| 207 |
for output in _generate(message, history, system_prompt, max_new_tokens):
|
| 208 |
new_message["content"] += output
|
| 209 |
yield AdditionalOutputs(history + [new_message])
|
| 210 |
-
|
| 211 |
|
| 212 |
|
| 213 |
with gr.Blocks() as demo:
|
|
@@ -217,12 +224,12 @@ with gr.Blocks() as demo:
|
|
| 217 |
mode="send",
|
| 218 |
variant="textbox",
|
| 219 |
rtc_configuration=get_hf_turn_credentials,
|
| 220 |
-
server_rtc_configuration=get_hf_turn_credentials(ttl=3_600 * 24 * 30)
|
| 221 |
)
|
| 222 |
with gr.Accordion(label="Additional Inputs"):
|
| 223 |
sp = gr.Textbox(label="System Prompt", value="You are a helpful assistant.")
|
| 224 |
slider = gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700)
|
| 225 |
-
image = gr.Image()
|
| 226 |
|
| 227 |
webrtc.stream(
|
| 228 |
ReplyOnPause(generate), # type: ignore
|
|
@@ -230,9 +237,7 @@ with gr.Blocks() as demo:
|
|
| 230 |
outputs=[chatbot],
|
| 231 |
concurrency_limit=100,
|
| 232 |
)
|
| 233 |
-
webrtc.on_additional_outputs(
|
| 234 |
-
lambda old, new: new, inputs=[chatbot], outputs=[chatbot], concurrency_limit=100
|
| 235 |
-
)
|
| 236 |
|
| 237 |
if __name__ == "__main__":
|
| 238 |
demo.launch()
|
|
|
|
| 8 |
import gradio as gr
|
| 9 |
import spaces
|
| 10 |
import torch
|
| 11 |
+
from fastrtc import AdditionalOutputs, ReplyOnPause, WebRTC, WebRTCData, get_hf_turn_credentials
|
| 12 |
from gradio.processing_utils import save_audio_to_cache
|
| 13 |
+
from gradio.utils import get_upload_folder
|
| 14 |
from transformers import AutoModelForImageTextToText, AutoProcessor
|
| 15 |
from transformers.generation.streamers import TextIteratorStreamer
|
|
|
|
| 16 |
|
| 17 |
model_id = "google/gemma-3n-E4B-it"
|
| 18 |
|
|
|
|
| 202 |
|
| 203 |
@spaces.GPU(time_limit=120)
|
| 204 |
def generate(data: WebRTCData, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512, image=None):
|
| 205 |
+
files = []
|
| 206 |
+
if data.audio is not None and data.audio[1].size > 0:
|
| 207 |
+
files.append(save_audio_to_cache(data.audio[1], data.audio[0], format="mp3", cache_dir=get_upload_folder()))
|
| 208 |
+
if image is None:
|
| 209 |
+
files.append(image)
|
| 210 |
+
message = {
|
| 211 |
+
"text": data.textbox,
|
| 212 |
+
"files": [],
|
| 213 |
+
}
|
| 214 |
new_message = {"role": "assistant", "content": ""}
|
| 215 |
for output in _generate(message, history, system_prompt, max_new_tokens):
|
| 216 |
new_message["content"] += output
|
| 217 |
yield AdditionalOutputs(history + [new_message])
|
|
|
|
| 218 |
|
| 219 |
|
| 220 |
with gr.Blocks() as demo:
|
|
|
|
| 224 |
mode="send",
|
| 225 |
variant="textbox",
|
| 226 |
rtc_configuration=get_hf_turn_credentials,
|
| 227 |
+
server_rtc_configuration=get_hf_turn_credentials(ttl=3_600 * 24 * 30),
|
| 228 |
)
|
| 229 |
with gr.Accordion(label="Additional Inputs"):
|
| 230 |
sp = gr.Textbox(label="System Prompt", value="You are a helpful assistant.")
|
| 231 |
slider = gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700)
|
| 232 |
+
image = gr.Image(type="filepath")
|
| 233 |
|
| 234 |
webrtc.stream(
|
| 235 |
ReplyOnPause(generate), # type: ignore
|
|
|
|
| 237 |
outputs=[chatbot],
|
| 238 |
concurrency_limit=100,
|
| 239 |
)
|
| 240 |
+
webrtc.on_additional_outputs(lambda old, new: new, inputs=[chatbot], outputs=[chatbot], concurrency_limit=100)
|
|
|
|
|
|
|
| 241 |
|
| 242 |
if __name__ == "__main__":
|
| 243 |
demo.launch()
|