gokaygokay commited on
Commit
c064cac
·
verified ·
1 Parent(s): 0ffe430

Update caption_models.py

Browse files
Files changed (1) hide show
  1. caption_models.py +34 -2
caption_models.py CHANGED
@@ -11,13 +11,45 @@ import torch.nn as nn
11
 
12
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, check=True, shell=True)
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
  HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN", None)
17
 
18
  # Initialize Florence model
19
- florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True).to(device).eval()
20
- florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True)
21
 
22
  # Initialize Qwen2-VL-2B model
23
  qwen_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype="auto").to(device).eval()
 
11
 
12
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, check=True, shell=True)
13
 
14
+ import sys
15
+ import importlib.util
16
+ from unittest.mock import MagicMock
17
+
18
+ # Create a fake spec object
19
+ class FakeFlashAttnSpec:
20
+ name = 'flash_attn'
21
+ loader = None
22
+ origin = None
23
+ submodule_search_locations = []
24
+
25
+ fake_spec = FakeFlashAttnSpec()
26
+
27
+ # Create mock modules with proper __spec__ attributes
28
+ flash_attn_mock = MagicMock()
29
+ flash_attn_mock.__spec__ = fake_spec
30
+ flash_attn_mock.__version__ = "0.0.0" # Force version check to fail
31
+
32
+ sys.modules['flash_attn'] = flash_attn_mock
33
+ sys.modules['flash_attn.flash_attn_interface'] = MagicMock()
34
+ sys.modules['flash_attn.bert_padding'] = MagicMock()
35
+
36
+ # Patch find_spec to return our fake spec
37
+ _original_find_spec = importlib.util.find_spec
38
+
39
+ def _patched_find_spec(name, package=None):
40
+ if name == 'flash_attn' or name.startswith('flash_attn.'):
41
+ return fake_spec
42
+ return _original_find_spec(name, package)
43
+
44
+ importlib.util.find_spec = _patched_find_spec
45
+
46
  device = "cuda" if torch.cuda.is_available() else "cpu"
47
 
48
  HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN", None)
49
 
50
  # Initialize Florence model
51
+ florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True, attn_implementation="sdpa", revision="00d2f1570b00c6dea5df998f5635db96840436bc").to(device).eval()
52
+ florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True, attn_implementation="sdpa", revision="00d2f1570b00c6dea5df998f5635db96840436bc")
53
 
54
  # Initialize Qwen2-VL-2B model
55
  qwen_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype="auto").to(device).eval()