anaspro commited on
Commit
5a13129
·
1 Parent(s): 69e6135
Files changed (1) hide show
  1. app.py +15 -44
app.py CHANGED
@@ -6,23 +6,25 @@ from threading import Thread
6
 
7
  import av
8
  import gradio as gr
9
- import spaces
10
  import torch
11
- from gradio.utils import get_upload_folder
12
  from transformers import AutoModelForImageTextToText, AutoProcessor
13
  from transformers.generation.streamers import TextIteratorStreamer
14
 
 
15
  model_id = "unsloth/gemma-3n-E4B-it"
16
-
17
  processor = AutoProcessor.from_pretrained(model_id)
18
- model = AutoModelForImageTextToText.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
 
 
 
 
19
 
 
20
  IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp")
21
  VIDEO_FILE_TYPES = (".mp4", ".mov", ".webm")
22
  AUDIO_FILE_TYPES = (".mp3", ".wav")
23
 
24
- GRADIO_TEMP_DIR = get_upload_folder()
25
-
26
  TARGET_FPS = int(os.getenv("TARGET_FPS", "3"))
27
  MAX_FRAMES = int(os.getenv("MAX_FRAMES", "30"))
28
  MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "10_000"))
@@ -118,7 +120,6 @@ def process_new_user_message(message: dict) -> list[dict]:
118
  message["files"][0],
119
  target_fps=TARGET_FPS,
120
  max_frames=MAX_FRAMES,
121
- parent_dir=GRADIO_TEMP_DIR,
122
  )
123
  paths = sorted(pathlib.Path(temp_dir).glob("*.jpg"))
124
  return [
@@ -152,7 +153,6 @@ def process_history(history: list[dict]) -> list[dict]:
152
  return messages
153
 
154
 
155
- @spaces.GPU(duration=120)
156
  @torch.inference_mode()
157
  def generate(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
158
  if not validate_media_constraints(message):
@@ -199,39 +199,14 @@ def generate(message: dict, history: list[dict], system_prompt: str = "", max_ne
199
  yield output
200
 
201
 
 
202
  examples = [
203
- [
204
- {
205
- "text": "What is the capital of France?",
206
- "files": [],
207
- }
208
- ],
209
- [
210
- {
211
- "text": "Describe this image in detail.",
212
- "files": ["assets/cat.jpeg"],
213
- }
214
- ],
215
- [
216
- {
217
- "text": "Transcribe the following speech segment in English.",
218
- "files": ["assets/speech.wav"],
219
- }
220
- ],
221
- [
222
- {
223
- "text": "Transcribe the following speech segment in English.",
224
- "files": ["assets/speech2.wav"],
225
- }
226
- ],
227
- [
228
- {
229
- "text": "Describe this video",
230
- "files": ["assets/holding_phone.mp4"],
231
- }
232
- ],
233
  ]
234
 
 
235
  demo = gr.ChatInterface(
236
  fn=generate,
237
  type="messages",
@@ -245,13 +220,9 @@ demo = gr.ChatInterface(
245
  gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
246
  gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
247
  ],
248
- stop_btn=False,
249
- title="Gemma 3n E4B it",
250
  examples=examples,
251
- run_examples_on_click=False,
252
- cache_examples=False,
253
- css_paths="style.css",
254
- delete_cache=(1800, 1800),
255
  )
256
 
257
  if __name__ == "__main__":
 
6
 
7
  import av
8
  import gradio as gr
 
9
  import torch
 
10
  from transformers import AutoModelForImageTextToText, AutoProcessor
11
  from transformers.generation.streamers import TextIteratorStreamer
12
 
13
+ # Model configuration
14
  model_id = "unsloth/gemma-3n-E4B-it"
 
15
  processor = AutoProcessor.from_pretrained(model_id)
16
+ model = AutoModelForImageTextToText.from_pretrained(
17
+ model_id,
18
+ device_map="auto",
19
+ torch_dtype=torch.bfloat16
20
+ )
21
 
22
+ # Supported file types
23
  IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp")
24
  VIDEO_FILE_TYPES = (".mp4", ".mov", ".webm")
25
  AUDIO_FILE_TYPES = (".mp3", ".wav")
26
 
27
+ # Video processing settings
 
28
  TARGET_FPS = int(os.getenv("TARGET_FPS", "3"))
29
  MAX_FRAMES = int(os.getenv("MAX_FRAMES", "30"))
30
  MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "10_000"))
 
120
  message["files"][0],
121
  target_fps=TARGET_FPS,
122
  max_frames=MAX_FRAMES,
 
123
  )
124
  paths = sorted(pathlib.Path(temp_dir).glob("*.jpg"))
125
  return [
 
153
  return messages
154
 
155
 
 
156
  @torch.inference_mode()
157
  def generate(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
158
  if not validate_media_constraints(message):
 
199
  yield output
200
 
201
 
202
+ # Simple examples for the chat interface
203
  examples = [
204
+ "What is the capital of France?",
205
+ "Explain quantum computing in simple terms",
206
+ "Write a short story about a robot learning to paint"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  ]
208
 
209
+ # Create the chat interface
210
  demo = gr.ChatInterface(
211
  fn=generate,
212
  type="messages",
 
220
  gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
221
  gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
222
  ],
223
+ title="Gemma 3n Multimodal Chat",
 
224
  examples=examples,
225
+ stop_btn=False,
 
 
 
226
  )
227
 
228
  if __name__ == "__main__":