sming256 commited on
Commit
6983b5a
·
verified ·
1 Parent(s): 79e6f86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +346 -360
app.py CHANGED
@@ -49,6 +49,29 @@ CUSTOM_CSS = """
49
  }
50
  """
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  # ============================================================================
54
  # Utility Functions
@@ -82,233 +105,206 @@ def detect_media_type(file_path: str | None) -> str | None:
82
  return "video"
83
 
84
 
85
- # ============================================================================
86
- # Model Class
87
- # ============================================================================
 
 
 
 
88
 
 
 
 
 
89
 
90
- class Qwen3VLAutoThinkDemo:
91
- """Main model class for Qwen3-VL with adaptive inference."""
92
-
93
- def __init__(self, model_path="IVUL-KAUST/VideoAuto-R1-Qwen3-VL-8B"):
94
- """Initialize model, processor, and tokenizer."""
95
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
96
-
97
- # Load model
98
- self.model = Qwen3VLForConditionalGeneration.from_pretrained(
99
- model_path,
100
- dtype="bfloat16",
101
- attn_implementation="sdpa",
102
- ).to('cuda').eval()
103
-
104
- self.processor = AutoProcessor.from_pretrained(model_path)
105
- self.tokenizer = AutoTokenizer.from_pretrained(model_path)
106
- self.system_prompt = COT_SYSTEM_PROMPT_ANSWER_TWICE
107
-
108
- def process_image(
109
- self,
110
- image_path: str,
111
- image_min_pixels: int = 128 * 28 * 28,
112
- image_max_pixels: int = 16384 * 28 * 28,
113
- ) -> dict | None:
114
- """
115
- Process image file to base64 format.
116
-
117
- Args:
118
- image_path: Path to image file
119
- image_min_pixels: Minimum pixel count
120
- image_max_pixels: Maximum pixel count
121
-
122
- Returns:
123
- Dictionary with image data or None
124
- """
125
- if image_path is None:
126
- return None
127
-
128
- image = Image.open(image_path).convert("RGB")
129
- buffer = BytesIO()
130
- image.save(buffer, format="JPEG")
131
- base64_bytes = base64.b64encode(buffer.getvalue())
132
- base64_string = base64_bytes.decode("utf-8")
133
-
134
- return {
135
- "type": "image",
136
- "image": f"data:image/jpeg;base64,{base64_string}",
137
- "min_pixels": image_min_pixels,
138
- "max_pixels": image_max_pixels,
139
- }
140
-
141
- def process_video(
142
- self,
143
- video_path: str,
144
- video_min_pixels: int = 16 * 28 * 28,
145
- video_max_pixels: int = 768 * 28 * 28,
146
- video_total_pixels: int = 128000 * 28 * 28,
147
- min_frames: int = 4,
148
- max_frames: int = 64,
149
- fps: float = 2.0,
150
- ) -> dict | None:
151
- """
152
- Process video file configuration.
153
-
154
- Args:
155
- video_path: Path to video file
156
- video_min_pixels: Minimum pixels per frame
157
- video_max_pixels: Maximum pixels per frame
158
- video_total_pixels: Total pixels across all frames
159
- min_frames: Minimum number of frames
160
- max_frames: Maximum number of frames
161
- fps: Frames per second for sampling
162
-
163
- Returns:
164
- Dictionary with video configuration or None
165
- """
166
- if video_path is None:
167
- return None
168
-
169
- return {
170
- "type": "video",
171
- "video": video_path,
172
- "min_pixels": video_min_pixels,
173
- "max_pixels": video_max_pixels,
174
- "total_pixels": video_total_pixels,
175
- "min_frames": min_frames,
176
- "max_frames": max_frames,
177
- "fps": fps,
178
- }
179
-
180
- @spaces.GPU(duration=120)
181
- def generate(
182
- self,
183
- media_input: str | None,
184
- prompt: str,
185
- early_exit_thresh: float,
186
- temperature: float,
187
- max_new_tokens: int = 4096,
188
- ) -> dict:
189
- """
190
- Generate response with adaptive inference.
191
-
192
- Args:
193
- media_input: Path to media file
194
- prompt: Text prompt
195
- early_exit_thresh: Confidence threshold for early exit
196
- temperature: Sampling temperature
197
- max_new_tokens: Maximum tokens to generate
198
-
199
- Returns:
200
- Dictionary containing response and metadata
201
- """
202
- # if self.model.device.type != "cuda":
203
- # self.model.to("cuda")
204
-
205
- # Prepare message
206
- message = [{"role": "system", "content": self.system_prompt}]
207
- content_parts = []
208
-
209
- # Process media input
210
- if media_input is not None:
211
- media_type = detect_media_type(media_input)
212
-
213
- if media_type == "video":
214
- video_dict = self.process_video(media_input)
215
- if video_dict:
216
- content_parts.append(video_dict)
217
- elif media_type == "image":
218
- image_dict = self.process_image(media_input)
219
- if image_dict:
220
- content_parts.append(image_dict)
221
-
222
- # Add text prompt
223
- content_parts.append({"type": "text", "text": prompt})
224
- message.append({"role": "user", "content": content_parts})
225
-
226
- # Apply chat template
227
- text = self.processor.apply_chat_template([message], tokenize=False, add_generation_prompt=True)
228
-
229
- # Process vision inputs
230
- image_inputs, video_inputs, video_kwargs = process_vision_info(
231
- [message],
232
- image_patch_size=16,
233
- return_video_kwargs=True,
234
- return_video_metadata=True,
235
- )
236
 
237
- if video_inputs is not None:
238
- video_inputs, video_metadatas = zip(*video_inputs)
239
- video_inputs = list(video_inputs)
240
- video_metadatas = list(video_metadatas)
241
- else:
242
- video_metadatas = None
243
-
244
- # Prepare inputs
245
- inputs = self.processor(
246
- text=text,
247
- images=image_inputs,
248
- videos=video_inputs,
249
- video_metadata=video_metadatas,
250
- do_resize=False,
251
- padding=True,
252
- return_tensors="pt",
253
- **video_kwargs,
254
- )
255
- inputs = inputs.to(self.device)
256
-
257
- # Generation configuration
258
- gen_kwargs = {
259
- "max_new_tokens": max_new_tokens,
260
- "temperature": temperature if temperature > 0 else None,
261
- "do_sample": temperature > 0,
262
- "top_p": 0.9 if temperature > 0 else None,
263
- "num_beams": 1,
264
- "use_cache": True,
265
- "return_dict_in_generate": True,
266
- "output_scores": True,
267
- }
268
-
269
- # Generate response
270
- with torch.no_grad():
271
- gen_out = self.model.generate(
272
- **inputs,
273
- eos_token_id=self.tokenizer.eos_token_id,
274
- pad_token_id=self.tokenizer.pad_token_id,
275
- **gen_kwargs,
276
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
- # Decode output
279
- generated_ids = gen_out.sequences[0][len(inputs.input_ids[0]) :]
280
- answer = self.processor.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
281
-
282
- # Compute confidence
283
- first_box_probs = compute_first_boxed_answer_probs(
284
- b=0,
285
- gen_ids=generated_ids,
286
- gen_out=gen_out,
287
- ans=answer,
288
- task="",
289
- tokenizer=self.tokenizer,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  )
291
 
292
- # Parse response
293
- first_answer = answer.split("<think>")[0]
294
- second_answer = answer.split("</think>")[-1] if "</think>" in answer else first_answer
295
- reasoning = answer.split("<think>")[-1].split("</think>")[0] if "<think>" in answer else "N/A"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
- # Determine inference mode
298
- if first_box_probs >= early_exit_thresh:
299
- need_cot = False
300
- reasoning = False
301
- else:
302
- need_cot = True
303
 
304
- return {
305
- "full_response": answer,
306
- "first_answer": first_answer,
307
- "confidence": f"{first_box_probs:.4f}",
308
- "need_cot": need_cot,
309
- "reasoning": reasoning,
310
- "second_answer": second_answer,
311
- }
312
 
313
 
314
  # ============================================================================
@@ -361,18 +357,18 @@ def chat_generate(
361
 
362
  # Initialize system prompt
363
  if len(messages_state) == 0:
364
- messages_state.append({"role": "system", "content": demo_model.system_prompt})
365
 
366
  # Prepare user message
367
  content_parts = []
368
  if media_path is not None:
369
  mtype = detect_media_type(media_path)
370
  if mtype == "video":
371
- vd = demo_model.process_video(media_path)
372
  if vd:
373
  content_parts.append(vd)
374
  elif mtype == "image":
375
- imd = demo_model.process_image(media_path)
376
  if imd:
377
  content_parts.append(imd)
378
 
@@ -380,7 +376,7 @@ def chat_generate(
380
  messages_state.append({"role": "user", "content": content_parts})
381
 
382
  # Generate response
383
- result = demo_model.generate(media_path, user_text, early_exit_thresh, temperature)
384
 
385
  # Format assistant response
386
  first_ans = (result.get("first_answer") or "").strip()
@@ -465,155 +461,145 @@ EXAMPLES = [
465
  # Gradio Interface
466
  # ============================================================================
467
 
 
468
 
469
- def create_demo():
470
- """Create and configure the Gradio interface."""
471
- with gr.Blocks(title="VideoAuto-R1 Demo") as demo:
472
- gr.Markdown("# VideoAuto-R1 (Qwen3-VL-8B) Demo")
473
 
474
- # Display system prompt
475
- with gr.Accordion("System Prompt", open=False):
476
- gr.Markdown(f"```\n{COT_SYSTEM_PROMPT_ANSWER_TWICE}\n```")
477
 
478
- # State variables
479
- messages_state = gr.State([])
480
- chatbot_state = gr.State([])
481
- last_media_state = gr.State(None)
482
 
483
- with gr.Row():
484
- # Left column: Media input and settings
485
- with gr.Column(scale=3):
486
- media_input = gr.File(
487
- label="Upload Image or Video",
488
- file_types=["image", "video"],
489
- type="filepath",
490
- )
491
- image_preview = gr.Image(label="Image Preview", visible=False)
492
- video_preview = gr.Video(label="Video Preview", visible=False)
493
-
494
- with gr.Accordion("Advanced Settings", open=True):
495
- early_exit_thresh = gr.Slider(
496
- minimum=0.0,
497
- maximum=1.0,
498
- value=0.98,
499
- step=0.01,
500
- label="Early Exit Threshold",
501
- )
502
- temperature = gr.Slider(
503
- minimum=0.0,
504
- maximum=2.0,
505
- value=0.0,
506
- step=0.1,
507
- label="Temperature",
508
- )
509
-
510
- # Right column: Chat interface
511
- with gr.Column(scale=7):
512
- chatbot = gr.Chatbot(
513
- label="Chat",
514
- elem_id="chatbot",
515
- height=600,
516
- sanitize_html=False,
517
- )
518
- textbox = gr.Textbox(
519
- show_label=False,
520
- placeholder="Enter text and press ENTER",
521
- lines=2,
522
  )
523
- with gr.Row():
524
- send_btn = gr.Button("Send", variant="primary")
525
- clear_btn = gr.Button("Clear")
526
-
527
- gr.Markdown(
528
- "Please click the **Clear** button before starting a new conversation or trying a new example."
529
  )
530
 
531
- # Event handlers
532
- media_input.change(
533
- fn=update_preview,
534
- inputs=[media_input],
535
- outputs=[image_preview, video_preview],
536
- )
537
-
538
- # Send button click: generate response and disable input controls
539
- send_btn.click(
540
- fn=chat_generate,
541
- inputs=[
542
- media_input,
543
- textbox,
544
- messages_state,
545
- chatbot_state,
546
- last_media_state,
547
- early_exit_thresh,
548
- temperature,
549
- ],
550
- outputs=[messages_state, chatbot_state, last_media_state, textbox, send_btn],
551
- ).then(
552
- fn=lambda cs: cs,
553
- inputs=[chatbot_state],
554
- outputs=[chatbot],
555
- )
556
 
557
- # Textbox submit: generate response and disable input controls
558
- textbox.submit(
559
- fn=chat_generate,
560
- inputs=[
561
- media_input,
562
- textbox,
563
- messages_state,
564
- chatbot_state,
565
- last_media_state,
566
- early_exit_thresh,
567
- temperature,
568
- ],
569
- outputs=[messages_state, chatbot_state, last_media_state, textbox, send_btn],
570
- ).then(
571
- fn=lambda cs: cs,
572
- inputs=[chatbot_state],
573
- outputs=[chatbot],
574
- )
575
 
576
- # Clear button: reset all states and re-enable input controls
577
- clear_btn.click(
578
- fn=clear_history,
579
- inputs=[],
580
- outputs=[
581
- messages_state,
582
- chatbot_state,
583
- last_media_state,
584
- media_input,
585
- image_preview,
586
- video_preview,
587
- textbox,
588
- send_btn,
589
- ],
590
- ).then(
591
- fn=lambda cs: cs,
592
- inputs=[chatbot_state],
593
- outputs=[chatbot],
594
- )
595
 
596
- gr.Examples(
597
- examples=EXAMPLES,
598
- inputs=[media_input, textbox],
599
- label="Examples",
600
- cache_examples=False,
601
- )
 
 
 
 
 
 
 
 
 
 
 
 
602
 
603
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
605
 
606
- # ============================================================================
607
- # Main Entry Point
608
- # ============================================================================
 
 
 
609
 
610
- if __name__ == "__main__":
611
- # Initialize model
612
- demo_model = Qwen3VLAutoThinkDemo()
613
 
614
- # Create and launch demo
615
- demo = create_demo()
616
- demo.launch(
617
- allowed_paths=["assets"],
618
- css=CUSTOM_CSS,
619
- )
 
 
 
 
49
  }
50
  """
51
 
52
+ MODEL_PATH = "IVUL-KAUST/VideoAuto-R1-Qwen3-VL-8B"
53
+
54
+
55
+ # ============================================================================
56
+ # Global Model Variables
57
+ # ============================================================================
58
+
59
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
+
61
+ # Load model
62
+ model = (
63
+ Qwen3VLForConditionalGeneration.from_pretrained(
64
+ MODEL_PATH,
65
+ dtype="bfloat16",
66
+ attn_implementation="sdpa",
67
+ )
68
+ .to("cuda")
69
+ .eval()
70
+ )
71
+
72
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
73
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
74
+
75
 
76
  # ============================================================================
77
  # Utility Functions
 
105
  return "video"
106
 
107
 
108
+ def process_image(
109
+ image_path: str,
110
+ image_min_pixels: int = 128 * 28 * 28,
111
+ image_max_pixels: int = 16384 * 28 * 28,
112
+ ) -> dict | None:
113
+ """
114
+ Process image file to base64 format.
115
 
116
+ Args:
117
+ image_path: Path to image file
118
+ image_min_pixels: Minimum pixel count
119
+ image_max_pixels: Maximum pixel count
120
 
121
+ Returns:
122
+ Dictionary with image data or None
123
+ """
124
+ if image_path is None:
125
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ image = Image.open(image_path).convert("RGB")
128
+ buffer = BytesIO()
129
+ image.save(buffer, format="JPEG")
130
+ base64_bytes = base64.b64encode(buffer.getvalue())
131
+ base64_string = base64_bytes.decode("utf-8")
132
+
133
+ return {
134
+ "type": "image",
135
+ "image": f"data:image/jpeg;base64,{base64_string}",
136
+ "min_pixels": image_min_pixels,
137
+ "max_pixels": image_max_pixels,
138
+ }
139
+
140
+
141
+ def process_video(
142
+ video_path: str,
143
+ video_min_pixels: int = 16 * 28 * 28,
144
+ video_max_pixels: int = 768 * 28 * 28,
145
+ video_total_pixels: int = 128000 * 28 * 28,
146
+ min_frames: int = 4,
147
+ max_frames: int = 64,
148
+ fps: float = 2.0,
149
+ ) -> dict | None:
150
+ """
151
+ Process video file configuration.
152
+
153
+ Args:
154
+ video_path: Path to video file
155
+ video_min_pixels: Minimum pixels per frame
156
+ video_max_pixels: Maximum pixels per frame
157
+ video_total_pixels: Total pixels across all frames
158
+ min_frames: Minimum number of frames
159
+ max_frames: Maximum number of frames
160
+ fps: Frames per second for sampling
161
+
162
+ Returns:
163
+ Dictionary with video configuration or None
164
+ """
165
+ if video_path is None:
166
+ return None
167
+
168
+ return {
169
+ "type": "video",
170
+ "video": video_path,
171
+ "min_pixels": video_min_pixels,
172
+ "max_pixels": video_max_pixels,
173
+ "total_pixels": video_total_pixels,
174
+ "min_frames": min_frames,
175
+ "max_frames": max_frames,
176
+ "fps": fps,
177
+ }
178
+
179
+
180
+ @spaces.GPU(duration=180)
181
+ def generate(
182
+ media_input: str | None,
183
+ prompt: str,
184
+ early_exit_thresh: float,
185
+ temperature: float,
186
+ max_new_tokens: int = 4096,
187
+ ) -> dict:
188
+ """
189
+ Generate response with adaptive inference.
190
 
191
+ Args:
192
+ media_input: Path to media file
193
+ prompt: Text prompt
194
+ early_exit_thresh: Confidence threshold for early exit
195
+ temperature: Sampling temperature
196
+ max_new_tokens: Maximum tokens to generate
197
+
198
+ Returns:
199
+ Dictionary containing response and metadata
200
+ """
201
+ # Prepare message
202
+ message = [{"role": "system", "content": COT_SYSTEM_PROMPT_ANSWER_TWICE}]
203
+ content_parts = []
204
+
205
+ # Process media input
206
+ if media_input is not None:
207
+ media_type = detect_media_type(media_input)
208
+
209
+ if media_type == "video":
210
+ video_dict = process_video(media_input)
211
+ if video_dict:
212
+ content_parts.append(video_dict)
213
+ elif media_type == "image":
214
+ image_dict = process_image(media_input)
215
+ if image_dict:
216
+ content_parts.append(image_dict)
217
+
218
+ # Add text prompt
219
+ content_parts.append({"type": "text", "text": prompt})
220
+ message.append({"role": "user", "content": content_parts})
221
+
222
+ # Apply chat template
223
+ text = processor.apply_chat_template([message], tokenize=False, add_generation_prompt=True)
224
+
225
+ # Process vision inputs
226
+ image_inputs, video_inputs, video_kwargs = process_vision_info(
227
+ [message],
228
+ image_patch_size=16,
229
+ return_video_kwargs=True,
230
+ return_video_metadata=True,
231
+ )
232
+
233
+ if video_inputs is not None:
234
+ video_inputs, video_metadatas = zip(*video_inputs)
235
+ video_inputs = list(video_inputs)
236
+ video_metadatas = list(video_metadatas)
237
+ else:
238
+ video_metadatas = None
239
+
240
+ # Prepare inputs
241
+ inputs = processor(
242
+ text=text,
243
+ images=image_inputs,
244
+ videos=video_inputs,
245
+ video_metadata=video_metadatas,
246
+ do_resize=False,
247
+ padding=True,
248
+ return_tensors="pt",
249
+ **video_kwargs,
250
+ )
251
+ inputs = inputs.to(device)
252
+
253
+ # Generation configuration
254
+ gen_kwargs = {
255
+ "max_new_tokens": max_new_tokens,
256
+ "temperature": temperature if temperature > 0 else None,
257
+ "do_sample": temperature > 0,
258
+ "top_p": 0.9 if temperature > 0 else None,
259
+ "num_beams": 1,
260
+ "use_cache": True,
261
+ "return_dict_in_generate": True,
262
+ "output_scores": True,
263
+ }
264
+
265
+ # Generate response
266
+ with torch.no_grad():
267
+ gen_out = model.generate(
268
+ **inputs,
269
+ eos_token_id=tokenizer.eos_token_id,
270
+ pad_token_id=tokenizer.pad_token_id,
271
+ **gen_kwargs,
272
  )
273
 
274
+ # Decode output
275
+ generated_ids = gen_out.sequences[0][len(inputs.input_ids[0]) :]
276
+ answer = processor.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
277
+
278
+ # Compute confidence
279
+ first_box_probs = compute_first_boxed_answer_probs(
280
+ b=0,
281
+ gen_ids=generated_ids,
282
+ gen_out=gen_out,
283
+ ans=answer,
284
+ task="",
285
+ tokenizer=tokenizer,
286
+ )
287
+
288
+ # Parse response
289
+ first_answer = answer.split("<think>")[0]
290
+ second_answer = answer.split("</think>")[-1] if "</think>" in answer else first_answer
291
+ reasoning = answer.split("<think>")[-1].split("</think>")[0] if "<think>" in answer else "N/A"
292
 
293
+ # Determine inference mode
294
+ if first_box_probs >= early_exit_thresh:
295
+ need_cot = False
296
+ reasoning = False
297
+ else:
298
+ need_cot = True
299
 
300
+ return {
301
+ "full_response": answer,
302
+ "first_answer": first_answer,
303
+ "confidence": f"{first_box_probs:.4f}",
304
+ "need_cot": need_cot,
305
+ "reasoning": reasoning,
306
+ "second_answer": second_answer,
307
+ }
308
 
309
 
310
  # ============================================================================
 
357
 
358
  # Initialize system prompt
359
  if len(messages_state) == 0:
360
+ messages_state.append({"role": "system", "content": COT_SYSTEM_PROMPT_ANSWER_TWICE})
361
 
362
  # Prepare user message
363
  content_parts = []
364
  if media_path is not None:
365
  mtype = detect_media_type(media_path)
366
  if mtype == "video":
367
+ vd = process_video(media_path)
368
  if vd:
369
  content_parts.append(vd)
370
  elif mtype == "image":
371
+ imd = process_image(media_path)
372
  if imd:
373
  content_parts.append(imd)
374
 
 
376
  messages_state.append({"role": "user", "content": content_parts})
377
 
378
  # Generate response
379
+ result = generate(media_path, user_text, early_exit_thresh, temperature)
380
 
381
  # Format assistant response
382
  first_ans = (result.get("first_answer") or "").strip()
 
461
  # Gradio Interface
462
  # ============================================================================
463
 
464
+ demo = gr.Blocks(title="VideoAuto-R1 Demo")
465
 
466
+ with demo:
467
+ gr.Markdown("# VideoAuto-R1 (Qwen3-VL-8B) Demo")
 
 
468
 
469
+ # Display system prompt
470
+ with gr.Accordion("System Prompt", open=False):
471
+ gr.Markdown(f"```\n{COT_SYSTEM_PROMPT_ANSWER_TWICE}\n```")
472
 
473
+ # State variables
474
+ messages_state = gr.State([])
475
+ chatbot_state = gr.State([])
476
+ last_media_state = gr.State(None)
477
 
478
+ with gr.Row():
479
+ # Left column: Media input and settings
480
+ with gr.Column(scale=3):
481
+ media_input = gr.File(
482
+ label="Upload Image or Video",
483
+ file_types=["image", "video"],
484
+ type="filepath",
485
+ )
486
+ image_preview = gr.Image(label="Image Preview", visible=False)
487
+ video_preview = gr.Video(label="Video Preview", visible=False)
488
+
489
+ with gr.Accordion("Advanced Settings", open=True):
490
+ early_exit_thresh = gr.Slider(
491
+ minimum=0.0,
492
+ maximum=1.0,
493
+ value=0.98,
494
+ step=0.01,
495
+ label="Early Exit Threshold",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
  )
497
+ temperature = gr.Slider(
498
+ minimum=0.0,
499
+ maximum=2.0,
500
+ value=0.0,
501
+ step=0.1,
502
+ label="Temperature",
503
  )
504
 
505
+ # Right column: Chat interface
506
+ with gr.Column(scale=7):
507
+ chatbot = gr.Chatbot(
508
+ label="Chat",
509
+ elem_id="chatbot",
510
+ height=600,
511
+ sanitize_html=False,
512
+ )
513
+ textbox = gr.Textbox(
514
+ show_label=False,
515
+ placeholder="Enter text and press ENTER",
516
+ lines=2,
517
+ )
518
+ with gr.Row():
519
+ send_btn = gr.Button("Send", variant="primary")
520
+ clear_btn = gr.Button("Clear")
 
 
 
 
 
 
 
 
 
521
 
522
+ gr.Markdown("Please click the **Clear** button before starting a new conversation or trying a new example.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
 
524
+ # Event handlers
525
+ media_input.change(
526
+ fn=update_preview,
527
+ inputs=[media_input],
528
+ outputs=[image_preview, video_preview],
529
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
530
 
531
+ # Send button click: generate response and disable input controls
532
+ send_btn.click(
533
+ fn=chat_generate,
534
+ inputs=[
535
+ media_input,
536
+ textbox,
537
+ messages_state,
538
+ chatbot_state,
539
+ last_media_state,
540
+ early_exit_thresh,
541
+ temperature,
542
+ ],
543
+ outputs=[messages_state, chatbot_state, last_media_state, textbox, send_btn],
544
+ ).then(
545
+ fn=lambda cs: cs,
546
+ inputs=[chatbot_state],
547
+ outputs=[chatbot],
548
+ )
549
 
550
+ # Textbox submit: generate response and disable input controls
551
+ textbox.submit(
552
+ fn=chat_generate,
553
+ inputs=[
554
+ media_input,
555
+ textbox,
556
+ messages_state,
557
+ chatbot_state,
558
+ last_media_state,
559
+ early_exit_thresh,
560
+ temperature,
561
+ ],
562
+ outputs=[messages_state, chatbot_state, last_media_state, textbox, send_btn],
563
+ ).then(
564
+ fn=lambda cs: cs,
565
+ inputs=[chatbot_state],
566
+ outputs=[chatbot],
567
+ )
568
 
569
+ # Clear button: reset all states and re-enable input controls
570
+ clear_btn.click(
571
+ fn=clear_history,
572
+ inputs=[],
573
+ outputs=[
574
+ messages_state,
575
+ chatbot_state,
576
+ last_media_state,
577
+ media_input,
578
+ image_preview,
579
+ video_preview,
580
+ textbox,
581
+ send_btn,
582
+ ],
583
+ ).then(
584
+ fn=lambda cs: cs,
585
+ inputs=[chatbot_state],
586
+ outputs=[chatbot],
587
+ )
588
 
589
+ gr.Examples(
590
+ examples=EXAMPLES,
591
+ inputs=[media_input, textbox],
592
+ label="Examples",
593
+ cache_examples=False,
594
+ )
595
 
 
 
 
596
 
597
+ # Launch demo
598
+ demo.launch(
599
+ share=True,
600
+ server_name="0.0.0.0",
601
+ server_port=7860,
602
+ allowed_paths=["assets"],
603
+ debug=True,
604
+ css=CUSTOM_CSS,
605
+ )