prithivMLmods commited on
Commit
2b8b50a
·
verified ·
1 Parent(s): bea725a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -0
app.py CHANGED
@@ -88,6 +88,7 @@ logger.info(f"Loading model 1: {MODEL_ID_1}")
88
  processor_1 = AutoProcessor.from_pretrained(MODEL_ID_1, trust_remote_code=True)
89
  model_1 = Qwen2_5_VLForConditionalGeneration.from_pretrained(
90
  MODEL_ID_1,
 
91
  trust_remote_code=True,
92
  torch_dtype=torch.float16 if device == "cuda" else torch.float32
93
  ).to(device).eval()
@@ -99,6 +100,7 @@ logger.info(f"Loading model 2: {MODEL_ID_2}")
99
  processor_2 = AutoProcessor.from_pretrained(MODEL_ID_2, trust_remote_code=True)
100
  model_2 = Qwen2_5_VLForConditionalGeneration.from_pretrained(
101
  MODEL_ID_2,
 
102
  trust_remote_code=True,
103
  torch_dtype=torch.float16 if device == "cuda" else torch.float32
104
  ).to(device).eval()
@@ -110,6 +112,7 @@ logger.info(f"Loading model 3: {MODEL_ID_3}")
110
  processor_3 = AutoProcessor.from_pretrained(MODEL_ID_3, trust_remote_code=True)
111
  model_3 = Qwen2_5_VLForConditionalGeneration.from_pretrained(
112
  MODEL_ID_3,
 
113
  trust_remote_code=True,
114
  torch_dtype=torch.float16 if device == "cuda" else torch.float32
115
  ).to(device).eval()
 
88
  processor_1 = AutoProcessor.from_pretrained(MODEL_ID_1, trust_remote_code=True)
89
  model_1 = Qwen2_5_VLForConditionalGeneration.from_pretrained(
90
  MODEL_ID_1,
91
+ attn_implementation="flash_attention_2",
92
  trust_remote_code=True,
93
  torch_dtype=torch.float16 if device == "cuda" else torch.float32
94
  ).to(device).eval()
 
100
  processor_2 = AutoProcessor.from_pretrained(MODEL_ID_2, trust_remote_code=True)
101
  model_2 = Qwen2_5_VLForConditionalGeneration.from_pretrained(
102
  MODEL_ID_2,
103
+ attn_implementation="flash_attention_2",
104
  trust_remote_code=True,
105
  torch_dtype=torch.float16 if device == "cuda" else torch.float32
106
  ).to(device).eval()
 
112
  processor_3 = AutoProcessor.from_pretrained(MODEL_ID_3, trust_remote_code=True)
113
  model_3 = Qwen2_5_VLForConditionalGeneration.from_pretrained(
114
  MODEL_ID_3,
115
+ attn_implementation="flash_attention_2",
116
  trust_remote_code=True,
117
  torch_dtype=torch.float16 if device == "cuda" else torch.float32
118
  ).to(device).eval()