John Ho commited on
Commit
035a7ef
·
1 Parent(s): e13ff04

added temp and testing gemma

Browse files
Files changed (1) hide show
  1. app.py +27 -4
app.py CHANGED
@@ -3,6 +3,7 @@ import gradio as gr
3
  from transformers import (
4
  Qwen2_5_VLForConditionalGeneration,
5
  AutoModelForImageTextToText,
 
6
  AutoProcessor,
7
  BitsAndBytesConfig,
8
  )
@@ -93,6 +94,10 @@ def load_model(
93
  model = AutoModelForImageTextToText.from_pretrained(
94
  model_name, **common_args
95
  )
 
 
 
 
96
  case _:
97
  raise ValueError(f"Unsupported model family: {model_family}")
98
 
@@ -141,6 +146,11 @@ MODEL_ZOO = {
141
  use_flash_attention=False,
142
  apply_quantization=True,
143
  ),
 
 
 
 
 
144
  }
145
 
146
  PROCESSORS = {
@@ -149,7 +159,8 @@ PROCESSORS = {
149
  "qwen2.5-vl-3b-instruct": load_processor("Qwen/Qwen2.5-VL-3B-Instruct"),
150
  "InternVL3-1B-hf": load_processor("OpenGVLab/InternVL3-1B-hf"),
151
  "InternVL3-2B-hf": load_processor("OpenGVLab/InternVL3-2B-hf"),
152
- # "InternVL3-8B-hf": load_processor("OpenGVLab/InternVL3-8B-hf"),
 
153
  }
154
  logger.debug("Models and Processors Loaded!")
155
 
@@ -161,6 +172,7 @@ def inference(
161
  model_name: str = "qwen2.5-vl-7b-instruct",
162
  custom_fps: int = 8,
163
  max_tokens: int = 256,
 
164
  ):
165
  s_time = time.time()
166
  # default processor
@@ -220,7 +232,9 @@ def inference(
220
  inputs = inputs.to("cuda")
221
 
222
  # Inference
223
- generated_ids = model.generate(**inputs, max_new_tokens=max_tokens)
 
 
224
  generated_ids_trimmed = [
225
  out_ids[len(in_ids) :]
226
  for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
@@ -230,7 +244,7 @@ def inference(
230
  skip_special_tokens=True,
231
  clean_up_tokenization_spaces=False,
232
  )[0]
233
- case "InternVL3":
234
  inputs = processor.apply_chat_template(
235
  messages,
236
  add_generation_prompt=True,
@@ -240,7 +254,9 @@ def inference(
240
  # num_frames = 8
241
  ).to("cuda", dtype=DTYPE)
242
 
243
- output = model.generate(**inputs, max_new_tokens=max_tokens)
 
 
244
  output_text = processor.decode(
245
  output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
246
  )
@@ -279,6 +295,13 @@ demo = gr.Interface(
279
  maximum=512,
280
  step=32,
281
  ),
 
 
 
 
 
 
 
282
  # gr.Checkbox(label="Use Flash Attention", value=False),
283
  # gr.Checkbox(label="Apply Quantization", value=True),
284
  ],
 
3
  from transformers import (
4
  Qwen2_5_VLForConditionalGeneration,
5
  AutoModelForImageTextToText,
6
+ Gemma3nForConditionalGeneration,
7
  AutoProcessor,
8
  BitsAndBytesConfig,
9
  )
 
94
  model = AutoModelForImageTextToText.from_pretrained(
95
  model_name, **common_args
96
  )
97
+ case "gemma":
98
+ model = Gemma3nForConditionalGeneration.from_pretrained(
99
+ model_name, **common_args
100
+ )
101
  case _:
102
  raise ValueError(f"Unsupported model family: {model_family}")
103
 
 
146
  use_flash_attention=False,
147
  apply_quantization=True,
148
  ),
149
+ "gemma-3n-e4b-it": load_model(
150
+ model_name="google/gemma-3n-e4b-it",
151
+ use_flash_attention=False,
152
+ apply_quantization=True,
153
+ ),
154
  }
155
 
156
  PROCESSORS = {
 
159
  "qwen2.5-vl-3b-instruct": load_processor("Qwen/Qwen2.5-VL-3B-Instruct"),
160
  "InternVL3-1B-hf": load_processor("OpenGVLab/InternVL3-1B-hf"),
161
  "InternVL3-2B-hf": load_processor("OpenGVLab/InternVL3-2B-hf"),
162
+ "InternVL3-8B-hf": load_processor("OpenGVLab/InternVL3-8B-hf"),
163
+ "gemma-3n-e4b-it": load_processor("google/gemma-3n-e4b-it"),
164
  }
165
  logger.debug("Models and Processors Loaded!")
166
 
 
172
  model_name: str = "qwen2.5-vl-7b-instruct",
173
  custom_fps: int = 8,
174
  max_tokens: int = 256,
175
+ temperature: float = 0.0,
176
  ):
177
  s_time = time.time()
178
  # default processor
 
232
  inputs = inputs.to("cuda")
233
 
234
  # Inference
235
+ generated_ids = model.generate(
236
+ **inputs, max_new_tokens=max_tokens, temperature=temperature
237
+ )
238
  generated_ids_trimmed = [
239
  out_ids[len(in_ids) :]
240
  for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
 
244
  skip_special_tokens=True,
245
  clean_up_tokenization_spaces=False,
246
  )[0]
247
+ case "InternVL3" | "gemma":
248
  inputs = processor.apply_chat_template(
249
  messages,
250
  add_generation_prompt=True,
 
254
  # num_frames = 8
255
  ).to("cuda", dtype=DTYPE)
256
 
257
+ output = model.generate(
258
+ **inputs, max_new_tokens=max_tokens, temperature=temperature
259
+ )
260
  output_text = processor.decode(
261
  output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
262
  )
 
295
  maximum=512,
296
  step=32,
297
  ),
298
+ gr.Slider(
299
+ label="Temperature",
300
+ value=0.0,
301
+ minimum=0.0,
302
+ maximum=1.0,
303
+ step=0.1,
304
+ ),
305
  # gr.Checkbox(label="Use Flash Attention", value=False),
306
  # gr.Checkbox(label="Apply Quantization", value=True),
307
  ],