John Ho commited on
Commit
8b3dcea
·
1 Parent(s): 1679d51

make sure DTYPE is used

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -49,14 +49,14 @@ def load_model(
49
  model = (
50
  Qwen2_5_VLForConditionalGeneration.from_pretrained(
51
  model_name,
52
- torch_dtype=torch.bfloat16,
53
  attn_implementation="flash_attention_2",
54
- device_map="auto",
55
  )
56
  if use_flash_attention
57
  else Qwen2_5_VLForConditionalGeneration.from_pretrained(
58
  model_name,
59
- torch_dtype=torch.bfloat16, # "auto",
60
  device_map=DEVICE,
61
  )
62
  )
@@ -71,7 +71,10 @@ def inference(
71
  ):
72
  # default processor
73
  processor = AutoProcessor.from_pretrained(
74
- "Qwen/Qwen2.5-VL-7B-Instruct", device_map=DEVICE, use_fast=True
 
 
 
75
  )
76
  model = load_model(use_flash_attention=use_flash_attention)
77
  fps = get_fps_ffmpeg(video_path)
 
49
  model = (
50
  Qwen2_5_VLForConditionalGeneration.from_pretrained(
51
  model_name,
52
+ torch_dtype=DTYPE, # torch.bfloat16,
53
  attn_implementation="flash_attention_2",
54
+ device_map=DEVICE, # "auto",
55
  )
56
  if use_flash_attention
57
  else Qwen2_5_VLForConditionalGeneration.from_pretrained(
58
  model_name,
59
+ torch_dtype=DTYPE, # "auto",
60
  device_map=DEVICE,
61
  )
62
  )
 
71
  ):
72
  # default processor
73
  processor = AutoProcessor.from_pretrained(
74
+ "Qwen/Qwen2.5-VL-7B-Instruct",
75
+ device_map=DEVICE,
76
+ use_fast=True,
77
+ torch_dtype=DTYPE,
78
  )
79
  model = load_model(use_flash_attention=use_flash_attention)
80
  fps = get_fps_ffmpeg(video_path)