John Ho commited on
Commit
8edc124
·
1 Parent(s): b3db9ce

added low_cpu_mem_usage and move input to device also

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -55,15 +55,17 @@ def load_model(
55
  model = (
56
  Qwen2_5_VLForConditionalGeneration.from_pretrained(
57
  model_name,
58
- torch_dtype=DTYPE, # torch.bfloat16,
59
  attn_implementation="flash_attention_2",
60
- device_map=DEVICE, # "auto",
 
61
  )
62
  if use_flash_attention
63
  else Qwen2_5_VLForConditionalGeneration.from_pretrained(
64
  model_name,
65
- torch_dtype=DTYPE, # "auto",
66
  device_map=DEVICE,
 
67
  )
68
  )
69
  # Set model to evaluation mode for inference (disables dropout, etc.)
@@ -126,7 +128,7 @@ def inference(
126
  return_tensors="pt",
127
  **video_kwargs,
128
  )
129
- # inputs = inputs.to(DEVICE)
130
 
131
  # Inference
132
  generated_ids = model.generate(**inputs, max_new_tokens=128)
 
55
  model = (
56
  Qwen2_5_VLForConditionalGeneration.from_pretrained(
57
  model_name,
58
+ torch_dtype=DTYPE,
59
  attn_implementation="flash_attention_2",
60
+ device_map=DEVICE,
61
+ low_cpu_mem_usage=True,
62
  )
63
  if use_flash_attention
64
  else Qwen2_5_VLForConditionalGeneration.from_pretrained(
65
  model_name,
66
+ torch_dtype=DTYPE,
67
  device_map=DEVICE,
68
+ low_cpu_mem_usage=True,
69
  )
70
  )
71
  # Set model to evaluation mode for inference (disables dropout, etc.)
 
128
  return_tensors="pt",
129
  **video_kwargs,
130
  )
131
+ inputs = inputs.to(DEVICE)
132
 
133
  # Inference
134
  generated_ids = model.generate(**inputs, max_new_tokens=128)