jena-shreyas commited on
Commit
8427fe9
Β·
1 Parent(s): e6b02aa

Add support for all models

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +41 -15
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: LLaVA Video Demo
3
  emoji: πŸ€–
4
  colorFrom: blue
5
  colorTo: blue
 
1
  ---
2
+ title: Video Inference Demo
3
  emoji: πŸ€–
4
  colorFrom: blue
5
  colorTo: blue
app.py CHANGED
@@ -14,7 +14,6 @@ from models.base import BaseVideoModel
14
  # ----------------------
15
  # CONFIG
16
  # ----------------------
17
- MODEL_PATH = "Isotr0py/LLaVA-Video-7B-Qwen2-hf"
18
  DEVICE_MAP = "cuda:0"
19
 
20
  VIDEO_DIR = str(Path(__file__).parent / "videos")
@@ -27,11 +26,15 @@ TEMPERATURE = 0.01
27
  # Model loading with quantization support
28
  # ----------------------
29
  model: BaseVideoModel = None
 
30
  current_quantization = "16-bit"
31
 
32
- def load_model_with_quantization(quantization_mode: str):
 
 
 
33
  """Load or reload the model with specified quantization"""
34
- global model, current_quantization
35
 
36
  # Free GPU memory if model already exists
37
  if model is not None:
@@ -44,25 +47,30 @@ def load_model_with_quantization(quantization_mode: str):
44
  load_8bit = False
45
  load_4bit = False
46
 
47
- if quantization_mode == "8-bit":
48
  load_8bit = True
49
- elif quantization_mode == "4-bit":
50
  load_4bit = True
51
  # else: 16-bit (normal) - both flags remain False
52
 
53
- print(f"Loading LLaVa-Video-7B-Qwen2 with {quantization_mode} quantization...")
 
 
 
 
54
  model = load_model(
55
- MODEL_PATH,
56
  device_map=DEVICE_MAP,
57
  load_8bit=load_8bit,
58
  load_4bit=load_4bit,
59
  )
60
- current_quantization = quantization_mode
61
- print(f"Model loaded with {quantization_mode} quantization.")
62
- return f"βœ… Model loaded successfully with {quantization_mode} quantization"
 
63
 
64
  # Load model initially with 16-bit (normal)
65
- load_model_with_quantization("16-bit")
66
 
67
  # ----------------------
68
  # Collect video IDs
@@ -139,8 +147,8 @@ def video_qa(
139
  # ----------------------
140
  # Gradio UI
141
  # ----------------------
142
- with gr.Blocks(title="Video QA – LLaVa-Video-7B-Qwen2", theme=gr.themes.Soft()) as demo:
143
- gr.Markdown("## πŸŽ₯ Video Question Answering (LLaVa-Video-7B-Qwen2)")
144
 
145
  with gr.Row():
146
  # LEFT COLUMN
@@ -160,6 +168,21 @@ with gr.Blocks(title="Video QA – LLaVa-Video-7B-Qwen2", theme=gr.themes.Soft()
160
  autoplay=False,
161
  height=300
162
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
  gr.Markdown("### βš™οΈ Model Parameters")
165
 
@@ -173,7 +196,7 @@ with gr.Blocks(title="Video QA – LLaVa-Video-7B-Qwen2", theme=gr.themes.Soft()
173
  reload_button = gr.Button("πŸ”„ Reload Model", variant="secondary")
174
  reload_status = gr.Textbox(
175
  label="Model Status",
176
- value=f"Model loaded with {current_quantization} quantization",
177
  interactive=False,
178
  lines=1
179
  )
@@ -279,7 +302,10 @@ with gr.Blocks(title="Video QA – LLaVa-Video-7B-Qwen2", theme=gr.themes.Soft()
279
  # Reload model with new quantization
280
  reload_button.click(
281
  fn=load_model_with_quantization,
282
- inputs=quantization_radio,
 
 
 
283
  outputs=reload_status
284
  )
285
 
 
14
  # ----------------------
15
  # CONFIG
16
  # ----------------------
 
17
  DEVICE_MAP = "cuda:0"
18
 
19
  VIDEO_DIR = str(Path(__file__).parent / "videos")
 
26
  # Model loading with quantization support
27
  # ----------------------
28
  model: BaseVideoModel = None
29
+ current_model_name = "Qwen3-VL-4B-Instruct"
30
  current_quantization = "16-bit"
31
 
32
+ def load_model_with_quantization(
33
+ model_name: str,
34
+ quantization: str
35
+ ):
36
  """Load or reload the model with specified quantization"""
37
+ global model, current_model_name, current_quantization
38
 
39
  # Free GPU memory if model already exists
40
  if model is not None:
 
47
  load_8bit = False
48
  load_4bit = False
49
 
50
+ if quantization == "8-bit":
51
  load_8bit = True
52
+ elif quantization == "4-bit":
53
  load_4bit = True
54
  # else: 16-bit (normal) - both flags remain False
55
 
56
+ print(f"Loading {model_name} with {quantization} quantization...")
57
+ model_path = model_name
58
+ # Load the HF version of LLaVA-Video-7B instead of the default version, for transformers v5 compatibility
59
+ if model_name == "LLaVA-Video-7B-Qwen2":
60
+ model_path = "Isotr0py/LLaVA-Video-7B-Qwen2-hf"
61
  model = load_model(
62
+ model_path,
63
  device_map=DEVICE_MAP,
64
  load_8bit=load_8bit,
65
  load_4bit=load_4bit,
66
  )
67
+ current_model_name = model_name
68
+ current_quantization = quantization
69
+ print(f"{model_name} loaded with {quantization} quantization.")
70
+ return f"βœ… {model_name} loaded successfully with {quantization} quantization"
71
 
72
  # Load model initially with 16-bit (normal)
73
+ load_model_with_quantization(current_model_name, current_quantization)
74
 
75
  # ----------------------
76
  # Collect video IDs
 
147
  # ----------------------
148
  # Gradio UI
149
  # ----------------------
150
+ with gr.Blocks(title="Video Inference Demo", theme=gr.themes.Soft()) as demo:
151
+ gr.Markdown("## πŸŽ₯ Video Inference")
152
 
153
  with gr.Row():
154
  # LEFT COLUMN
 
168
  autoplay=False,
169
  height=300
170
  )
171
+
172
+ gr.Markdown("### πŸ€– Model Name")
173
+
174
+ model_name_radio = gr.Radio(
175
+ choices=[
176
+ "Qwen3-VL-4B-Instruct",
177
+ "Qwen3-VL-8B-Instruct",
178
+ "Qwen3-VL-2B-Thinking",
179
+ "Qwen3-VL-4B-Thinking",
180
+ "LLaVA-Video-7B-Qwen2"
181
+ ],
182
+ value="Qwen3-VL-4B-Instruct",
183
+ label="πŸ€– Model Name",
184
+ info="Select the model to use for inference"
185
+ )
186
 
187
  gr.Markdown("### βš™οΈ Model Parameters")
188
 
 
196
  reload_button = gr.Button("πŸ”„ Reload Model", variant="secondary")
197
  reload_status = gr.Textbox(
198
  label="Model Status",
199
+ value=f"{current_model_name} loaded with {current_quantization} quantization",
200
  interactive=False,
201
  lines=1
202
  )
 
302
  # Reload model with new quantization
303
  reload_button.click(
304
  fn=load_model_with_quantization,
305
+ inputs=[
306
+ model_name_radio,
307
+ quantization_radio,
308
+ ],
309
  outputs=reload_status
310
  )
311