farrell236 commited on
Commit
e231ba6
·
verified ·
1 Parent(s): c5a53a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -0
app.py CHANGED
@@ -15,6 +15,7 @@ from qwen_vl_utils import process_vision_info
15
  from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
16
 
17
  DEFAULT_CKPT_PATH = 'farrell236/OpthModel32B_b'
 
18
  # DEFAULT_CKPT_PATH = '/scratch/llm-weights/Qwen/Qwen2.5-VL-7B-Instruct'
19
  AUTH_TOKEN = os.environ.get("HF_spaces")
20
 
@@ -25,6 +26,11 @@ def _get_args():
25
  type=str,
26
  default=DEFAULT_CKPT_PATH,
27
  help='Checkpoint name or path, default to %(default)r')
 
 
 
 
 
28
  parser.add_argument('-t',
29
  '--auth-token',
30
  type=str,
@@ -60,6 +66,7 @@ def _load_model_processor(args):
60
  if args.flash_attn2:
61
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
62
  args.checkpoint_path,
 
63
  use_auth_token=args.auth_token,
64
  torch_dtype=torch.bfloat16,
65
  attn_implementation='flash_attention_2',
@@ -67,6 +74,7 @@ def _load_model_processor(args):
67
  else:
68
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
69
  args.checkpoint_path,
 
70
  use_auth_token=args.auth_token,
71
  torch_dtype=torch.bfloat16,
72
  device_map=device_map)
 
15
  from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
16
 
17
  DEFAULT_CKPT_PATH = 'farrell236/OpthModel32B_b'
18
+ DEFAULT_REVISION = '7d10823654d2d3ba42d74d8177cd54368bc7df96'
19
  # DEFAULT_CKPT_PATH = '/scratch/llm-weights/Qwen/Qwen2.5-VL-7B-Instruct'
20
  AUTH_TOKEN = os.environ.get("HF_spaces")
21
 
 
26
  type=str,
27
  default=DEFAULT_CKPT_PATH,
28
  help='Checkpoint name or path, default to %(default)r')
29
+ parser.add_argument('-r',
30
+ '--revision',
31
+ type=str,
32
+ default=DEFAULT_REVISION,
33
+ help='Commit tag of checkpoint, default to %(default)r')
34
  parser.add_argument('-t',
35
  '--auth-token',
36
  type=str,
 
66
  if args.flash_attn2:
67
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
68
  args.checkpoint_path,
69
+ revision=args.revision,
70
  use_auth_token=args.auth_token,
71
  torch_dtype=torch.bfloat16,
72
  attn_implementation='flash_attention_2',
 
74
  else:
75
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
76
  args.checkpoint_path,
77
+ revision=args.revision,
78
  use_auth_token=args.auth_token,
79
  torch_dtype=torch.bfloat16,
80
  device_map=device_map)