Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,18 +3,20 @@
|
|
| 3 |
# This source code is based on by web_demo_mm.py, by Alibaba Cloud.
|
| 4 |
# Licensed under Apache License 2.0
|
| 5 |
|
| 6 |
-
import os
|
| 7 |
import copy
|
|
|
|
| 8 |
import re
|
| 9 |
from argparse import ArgumentParser
|
| 10 |
from threading import Thread
|
| 11 |
|
| 12 |
import gradio as gr
|
| 13 |
import torch
|
|
|
|
| 14 |
from qwen_vl_utils import process_vision_info
|
| 15 |
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
|
| 16 |
|
| 17 |
-
DEFAULT_CKPT_PATH = 'farrell236/test_model'
|
|
|
|
| 18 |
AUTH_TOKEN = os.environ.get("HF_spaces")
|
| 19 |
|
| 20 |
def _get_args():
|
|
@@ -60,13 +62,14 @@ def _load_model_processor(args):
|
|
| 60 |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 61 |
args.checkpoint_path,
|
| 62 |
use_auth_token=args.auth_token,
|
| 63 |
-
torch_dtype=
|
| 64 |
attn_implementation='flash_attention_2',
|
| 65 |
device_map=device_map)
|
| 66 |
else:
|
| 67 |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 68 |
args.checkpoint_path,
|
| 69 |
use_auth_token=args.auth_token,
|
|
|
|
| 70 |
device_map=device_map)
|
| 71 |
|
| 72 |
processor = AutoProcessor.from_pretrained('Qwen/Qwen2.5-VL-3B-Instruct')
|
|
@@ -145,6 +148,7 @@ def _transform_messages(original_messages):
|
|
| 145 |
|
| 146 |
def _launch_demo(args, model, processor):
|
| 147 |
|
|
|
|
| 148 |
def call_local_model(model, processor, messages,
|
| 149 |
max_tokens=1024, temperature=0.6,
|
| 150 |
top_p=0.9, top_k=50,
|
|
|
|
| 3 |
# This source code is based on by web_demo_mm.py, by Alibaba Cloud.
|
| 4 |
# Licensed under Apache License 2.0
|
| 5 |
|
|
|
|
| 6 |
import copy
|
| 7 |
+
import os
|
| 8 |
import re
|
| 9 |
from argparse import ArgumentParser
|
| 10 |
from threading import Thread
|
| 11 |
|
| 12 |
import gradio as gr
|
| 13 |
import torch
|
| 14 |
+
import spaces
|
| 15 |
from qwen_vl_utils import process_vision_info
|
| 16 |
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
|
| 17 |
|
| 18 |
+
# DEFAULT_CKPT_PATH = 'farrell236/test_model'
|
| 19 |
+
DEFAULT_CKPT_PATH = 'Qwen/Qwen2.5-VL-32B-Instruct'
|
| 20 |
AUTH_TOKEN = os.environ.get("HF_spaces")
|
| 21 |
|
| 22 |
def _get_args():
|
|
|
|
| 62 |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 63 |
args.checkpoint_path,
|
| 64 |
use_auth_token=args.auth_token,
|
| 65 |
+
torch_dtype=torch.bfloat16,
|
| 66 |
attn_implementation='flash_attention_2',
|
| 67 |
device_map=device_map)
|
| 68 |
else:
|
| 69 |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 70 |
args.checkpoint_path,
|
| 71 |
use_auth_token=args.auth_token,
|
| 72 |
+
torch_dtype=torch.bfloat16,
|
| 73 |
device_map=device_map)
|
| 74 |
|
| 75 |
processor = AutoProcessor.from_pretrained('Qwen/Qwen2.5-VL-3B-Instruct')
|
|
|
|
| 148 |
|
| 149 |
def _launch_demo(args, model, processor):
|
| 150 |
|
| 151 |
+
@spaces.GPU
|
| 152 |
def call_local_model(model, processor, messages,
|
| 153 |
max_tokens=1024, temperature=0.6,
|
| 154 |
top_p=0.9, top_k=50,
|