farrell236 commited on
Commit
1d3684c
·
verified ·
1 Parent(s): 7caf75b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -3,18 +3,20 @@
3
  # This source code is based on by web_demo_mm.py, by Alibaba Cloud.
4
  # Licensed under Apache License 2.0
5
 
6
- import os
7
  import copy
 
8
  import re
9
  from argparse import ArgumentParser
10
  from threading import Thread
11
 
12
  import gradio as gr
13
  import torch
 
14
  from qwen_vl_utils import process_vision_info
15
  from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
16
 
17
- DEFAULT_CKPT_PATH = 'farrell236/test_model'
 
18
  AUTH_TOKEN = os.environ.get("HF_spaces")
19
 
20
  def _get_args():
@@ -60,13 +62,14 @@ def _load_model_processor(args):
60
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
61
  args.checkpoint_path,
62
  use_auth_token=args.auth_token,
63
- torch_dtype='auto',
64
  attn_implementation='flash_attention_2',
65
  device_map=device_map)
66
  else:
67
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
68
  args.checkpoint_path,
69
  use_auth_token=args.auth_token,
 
70
  device_map=device_map)
71
 
72
  processor = AutoProcessor.from_pretrained('Qwen/Qwen2.5-VL-3B-Instruct')
@@ -145,6 +148,7 @@ def _transform_messages(original_messages):
145
 
146
  def _launch_demo(args, model, processor):
147
 
 
148
  def call_local_model(model, processor, messages,
149
  max_tokens=1024, temperature=0.6,
150
  top_p=0.9, top_k=50,
 
3
  # This source code is based on by web_demo_mm.py, by Alibaba Cloud.
4
  # Licensed under Apache License 2.0
5
 
 
6
  import copy
7
+ import os
8
  import re
9
  from argparse import ArgumentParser
10
  from threading import Thread
11
 
12
  import gradio as gr
13
  import torch
14
+ import spaces
15
  from qwen_vl_utils import process_vision_info
16
  from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
17
 
18
+ # DEFAULT_CKPT_PATH = 'farrell236/test_model'
19
+ DEFAULT_CKPT_PATH = 'Qwen/Qwen2.5-VL-32B-Instruct'
20
  AUTH_TOKEN = os.environ.get("HF_spaces")
21
 
22
  def _get_args():
 
62
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
63
  args.checkpoint_path,
64
  use_auth_token=args.auth_token,
65
+ torch_dtype=torch.bfloat16,
66
  attn_implementation='flash_attention_2',
67
  device_map=device_map)
68
  else:
69
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
70
  args.checkpoint_path,
71
  use_auth_token=args.auth_token,
72
+ torch_dtype=torch.bfloat16,
73
  device_map=device_map)
74
 
75
  processor = AutoProcessor.from_pretrained('Qwen/Qwen2.5-VL-3B-Instruct')
 
148
 
149
  def _launch_demo(args, model, processor):
150
 
151
+ @spaces.GPU
152
  def call_local_model(model, processor, messages,
153
  max_tokens=1024, temperature=0.6,
154
  top_p=0.9, top_k=50,