Portx commited on
Commit
3106295
·
verified ·
1 Parent(s): a89e2e5

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +2 -15
handler.py CHANGED
@@ -5,25 +5,13 @@ from PIL import Image
5
  import os
6
  import base64
7
 
8
- run("pip install flash-attn --no-build-isolation", shell=True, check=True)
9
  run("pip install --upgrade pip", shell=True, check=True)
10
  run("pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu124", shell=True, check=True)
11
 
12
- #run("pip install --upgrade accelerate transformers", shell=True, check=True)
13
- #run("pip -qqq install --force-reinstall https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-0.44.2.dev0-py3-none-manylinux_2_24_x86_64.whl", shell=True, check=True)
14
 
15
 
16
 
17
-
18
- try:
19
- import flash_attn
20
- print("FlashAttention is installed")
21
- USE_FLASH_ATTENTION = True
22
- except ImportError:
23
- print("FlashAttention is not installed")
24
- USE_FLASH_ATTENTION = False
25
-
26
-
27
  from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
28
 
29
  model_id = "ibm-granite/granite-vision-3.2-2b"
@@ -82,8 +70,7 @@ class PromptSet:
82
  class EndpointHandler():
83
  def __init__(self, path=""):
84
  self.model=AutoModelForVision2Seq.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16,
85
- quantization_config=bnb_config,
86
- _attn_implementation="flash_attention_2" if USE_FLASH_ATTENTION else None,)
87
  self.processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
88
 
89
  def __call__(self, data):
 
5
  import os
6
  import base64
7
 
8
+ #run("pip install flash-attn --no-build-isolation", shell=True, check=True)
9
  run("pip install --upgrade pip", shell=True, check=True)
10
  run("pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu124", shell=True, check=True)
11
 
 
 
12
 
13
 
14
 
 
 
 
 
 
 
 
 
 
 
15
  from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
16
 
17
  model_id = "ibm-granite/granite-vision-3.2-2b"
 
70
  class EndpointHandler():
71
  def __init__(self, path=""):
72
  self.model=AutoModelForVision2Seq.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16,
73
+ quantization_config=bnb_config)
 
74
  self.processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
75
 
76
  def __call__(self, data):