Spaces:
Running
on
Zero
Running
on
Zero
John Ho
commited on
Commit
·
035a7ef
1
Parent(s):
e13ff04
added temp and testing gemma
Browse files
app.py
CHANGED
|
@@ -3,6 +3,7 @@ import gradio as gr
|
|
| 3 |
from transformers import (
|
| 4 |
Qwen2_5_VLForConditionalGeneration,
|
| 5 |
AutoModelForImageTextToText,
|
|
|
|
| 6 |
AutoProcessor,
|
| 7 |
BitsAndBytesConfig,
|
| 8 |
)
|
|
@@ -93,6 +94,10 @@ def load_model(
|
|
| 93 |
model = AutoModelForImageTextToText.from_pretrained(
|
| 94 |
model_name, **common_args
|
| 95 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
case _:
|
| 97 |
raise ValueError(f"Unsupported model family: {model_family}")
|
| 98 |
|
|
@@ -141,6 +146,11 @@ MODEL_ZOO = {
|
|
| 141 |
use_flash_attention=False,
|
| 142 |
apply_quantization=True,
|
| 143 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
}
|
| 145 |
|
| 146 |
PROCESSORS = {
|
|
@@ -149,7 +159,8 @@ PROCESSORS = {
|
|
| 149 |
"qwen2.5-vl-3b-instruct": load_processor("Qwen/Qwen2.5-VL-3B-Instruct"),
|
| 150 |
"InternVL3-1B-hf": load_processor("OpenGVLab/InternVL3-1B-hf"),
|
| 151 |
"InternVL3-2B-hf": load_processor("OpenGVLab/InternVL3-2B-hf"),
|
| 152 |
-
|
|
|
|
| 153 |
}
|
| 154 |
logger.debug("Models and Processors Loaded!")
|
| 155 |
|
|
@@ -161,6 +172,7 @@ def inference(
|
|
| 161 |
model_name: str = "qwen2.5-vl-7b-instruct",
|
| 162 |
custom_fps: int = 8,
|
| 163 |
max_tokens: int = 256,
|
|
|
|
| 164 |
):
|
| 165 |
s_time = time.time()
|
| 166 |
# default processor
|
|
@@ -220,7 +232,9 @@ def inference(
|
|
| 220 |
inputs = inputs.to("cuda")
|
| 221 |
|
| 222 |
# Inference
|
| 223 |
-
generated_ids = model.generate(
|
|
|
|
|
|
|
| 224 |
generated_ids_trimmed = [
|
| 225 |
out_ids[len(in_ids) :]
|
| 226 |
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
|
@@ -230,7 +244,7 @@ def inference(
|
|
| 230 |
skip_special_tokens=True,
|
| 231 |
clean_up_tokenization_spaces=False,
|
| 232 |
)[0]
|
| 233 |
-
case "InternVL3":
|
| 234 |
inputs = processor.apply_chat_template(
|
| 235 |
messages,
|
| 236 |
add_generation_prompt=True,
|
|
@@ -240,7 +254,9 @@ def inference(
|
|
| 240 |
# num_frames = 8
|
| 241 |
).to("cuda", dtype=DTYPE)
|
| 242 |
|
| 243 |
-
output = model.generate(
|
|
|
|
|
|
|
| 244 |
output_text = processor.decode(
|
| 245 |
output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
|
| 246 |
)
|
|
@@ -279,6 +295,13 @@ demo = gr.Interface(
|
|
| 279 |
maximum=512,
|
| 280 |
step=32,
|
| 281 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
# gr.Checkbox(label="Use Flash Attention", value=False),
|
| 283 |
# gr.Checkbox(label="Apply Quantization", value=True),
|
| 284 |
],
|
|
|
|
| 3 |
from transformers import (
|
| 4 |
Qwen2_5_VLForConditionalGeneration,
|
| 5 |
AutoModelForImageTextToText,
|
| 6 |
+
Gemma3nForConditionalGeneration,
|
| 7 |
AutoProcessor,
|
| 8 |
BitsAndBytesConfig,
|
| 9 |
)
|
|
|
|
| 94 |
model = AutoModelForImageTextToText.from_pretrained(
|
| 95 |
model_name, **common_args
|
| 96 |
)
|
| 97 |
+
case "gemma":
|
| 98 |
+
model = Gemma3nForConditionalGeneration.from_pretrained(
|
| 99 |
+
model_name, **common_args
|
| 100 |
+
)
|
| 101 |
case _:
|
| 102 |
raise ValueError(f"Unsupported model family: {model_family}")
|
| 103 |
|
|
|
|
| 146 |
use_flash_attention=False,
|
| 147 |
apply_quantization=True,
|
| 148 |
),
|
| 149 |
+
"gemma-3n-e4b-it": load_model(
|
| 150 |
+
model_name="google/gemma-3n-e4b-it",
|
| 151 |
+
use_flash_attention=False,
|
| 152 |
+
apply_quantization=True,
|
| 153 |
+
),
|
| 154 |
}
|
| 155 |
|
| 156 |
PROCESSORS = {
|
|
|
|
| 159 |
"qwen2.5-vl-3b-instruct": load_processor("Qwen/Qwen2.5-VL-3B-Instruct"),
|
| 160 |
"InternVL3-1B-hf": load_processor("OpenGVLab/InternVL3-1B-hf"),
|
| 161 |
"InternVL3-2B-hf": load_processor("OpenGVLab/InternVL3-2B-hf"),
|
| 162 |
+
"InternVL3-8B-hf": load_processor("OpenGVLab/InternVL3-8B-hf"),
|
| 163 |
+
"gemma-3n-e4b-it": load_processor("google/gemma-3n-e4b-it"),
|
| 164 |
}
|
| 165 |
logger.debug("Models and Processors Loaded!")
|
| 166 |
|
|
|
|
| 172 |
model_name: str = "qwen2.5-vl-7b-instruct",
|
| 173 |
custom_fps: int = 8,
|
| 174 |
max_tokens: int = 256,
|
| 175 |
+
temperature: float = 0.0,
|
| 176 |
):
|
| 177 |
s_time = time.time()
|
| 178 |
# default processor
|
|
|
|
| 232 |
inputs = inputs.to("cuda")
|
| 233 |
|
| 234 |
# Inference
|
| 235 |
+
generated_ids = model.generate(
|
| 236 |
+
**inputs, max_new_tokens=max_tokens, temperature=temperature
|
| 237 |
+
)
|
| 238 |
generated_ids_trimmed = [
|
| 239 |
out_ids[len(in_ids) :]
|
| 240 |
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
|
|
|
| 244 |
skip_special_tokens=True,
|
| 245 |
clean_up_tokenization_spaces=False,
|
| 246 |
)[0]
|
| 247 |
+
case "InternVL3" | "gemma":
|
| 248 |
inputs = processor.apply_chat_template(
|
| 249 |
messages,
|
| 250 |
add_generation_prompt=True,
|
|
|
|
| 254 |
# num_frames = 8
|
| 255 |
).to("cuda", dtype=DTYPE)
|
| 256 |
|
| 257 |
+
output = model.generate(
|
| 258 |
+
**inputs, max_new_tokens=max_tokens, temperature=temperature
|
| 259 |
+
)
|
| 260 |
output_text = processor.decode(
|
| 261 |
output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
|
| 262 |
)
|
|
|
|
| 295 |
maximum=512,
|
| 296 |
step=32,
|
| 297 |
),
|
| 298 |
+
gr.Slider(
|
| 299 |
+
label="Temperature",
|
| 300 |
+
value=0.0,
|
| 301 |
+
minimum=0.0,
|
| 302 |
+
maximum=1.0,
|
| 303 |
+
step=0.1,
|
| 304 |
+
),
|
| 305 |
# gr.Checkbox(label="Use Flash Attention", value=False),
|
| 306 |
# gr.Checkbox(label="Apply Quantization", value=True),
|
| 307 |
],
|