Geraldine commited on
Commit
d5d699b
·
verified ·
1 Parent(s): 9704588

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -10
app.py CHANGED
@@ -43,6 +43,35 @@ if torch.cuda.is_available():
43
  print("Using device:", device)
44
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def patch_dots_ocr_configuration(repo_path: str) -> None:
47
  config_path = Path(repo_path) / "configuration_dots.py"
48
  if not config_path.exists():
@@ -92,9 +121,9 @@ def resolve_dots_ocr_model_path(repo_id: str) -> str:
92
 
93
  MODEL_ID_V = "nanonets/Nanonets-OCR2-3B"
94
  processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
95
- model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
 
96
  MODEL_ID_V,
97
- attn_implementation="kernels-community/flash-attn2",
98
  trust_remote_code=True,
99
  torch_dtype=torch.float16
100
  ).to(device).eval()
@@ -102,36 +131,36 @@ model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
102
  MODEL_ID_Y = "rednote-hilab/dots.ocr"
103
  MODEL_PATH_Y = resolve_dots_ocr_model_path(MODEL_ID_Y)
104
  processor_y = AutoProcessor.from_pretrained(MODEL_PATH_Y, trust_remote_code=True)
105
- model_y = AutoModelForCausalLM.from_pretrained(
 
106
  MODEL_PATH_Y,
107
- attn_implementation="kernels-community/flash-attn2",
108
  trust_remote_code=True,
109
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
110
  ).to(device).eval()
111
 
112
  MODEL_ID_X = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
113
  processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
114
- model_x = Qwen2VLForConditionalGeneration.from_pretrained(
 
115
  MODEL_ID_X,
116
- attn_implementation="kernels-community/flash-attn2",
117
  trust_remote_code=True,
118
  torch_dtype=torch.float16
119
  ).to(device).eval()
120
 
121
  MODEL_ID_W = "allenai/olmOCR-7B-0725"
122
  processor_w = AutoProcessor.from_pretrained(MODEL_ID_W, trust_remote_code=True)
123
- model_w = Qwen2_5_VLForConditionalGeneration.from_pretrained(
 
124
  MODEL_ID_W,
125
- attn_implementation="kernels-community/flash-attn2",
126
  trust_remote_code=True,
127
  torch_dtype=torch.float16
128
  ).to(device).eval()
129
 
130
  MODEL_ID_M = "reducto/RolmOCR"
131
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
132
- model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
 
133
  MODEL_ID_M,
134
- attn_implementation="kernels-community/flash-attn2",
135
  trust_remote_code=True,
136
  torch_dtype=torch.float16
137
  ).to(device).eval()
 
43
  print("Using device:", device)
44
 
45
 
46
+ def get_attention_fallbacks() -> list[str | None]:
47
+ fallbacks = []
48
+ if torch.cuda.is_available() and os.getenv("USE_FLASH_ATTN", "0") == "1":
49
+ fallbacks.append("kernels-community/flash-attn2")
50
+ if torch.cuda.is_available():
51
+ fallbacks.append("sdpa")
52
+ fallbacks.append("eager")
53
+ fallbacks.append(None)
54
+ return fallbacks
55
+
56
+
57
+ def load_model_with_attention_fallback(model_cls, model_id, **kwargs):
58
+ last_error = None
59
+ for attn_impl in get_attention_fallbacks():
60
+ load_kwargs = dict(kwargs)
61
+ label = attn_impl or "default"
62
+ if attn_impl is None:
63
+ load_kwargs.pop("attn_implementation", None)
64
+ else:
65
+ load_kwargs["attn_implementation"] = attn_impl
66
+ try:
67
+ print(f"Loading {model_id} with attention backend: {label}")
68
+ return model_cls.from_pretrained(model_id, **load_kwargs)
69
+ except Exception as exc:
70
+ last_error = exc
71
+ print(f"Failed loading {model_id} with attention backend {label}: {exc}")
72
+ raise last_error
73
+
74
+
75
  def patch_dots_ocr_configuration(repo_path: str) -> None:
76
  config_path = Path(repo_path) / "configuration_dots.py"
77
  if not config_path.exists():
 
121
 
122
  MODEL_ID_V = "nanonets/Nanonets-OCR2-3B"
123
  processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
124
+ model_v = load_model_with_attention_fallback(
125
+ Qwen2_5_VLForConditionalGeneration,
126
  MODEL_ID_V,
 
127
  trust_remote_code=True,
128
  torch_dtype=torch.float16
129
  ).to(device).eval()
 
131
  MODEL_ID_Y = "rednote-hilab/dots.ocr"
132
  MODEL_PATH_Y = resolve_dots_ocr_model_path(MODEL_ID_Y)
133
  processor_y = AutoProcessor.from_pretrained(MODEL_PATH_Y, trust_remote_code=True)
134
+ model_y = load_model_with_attention_fallback(
135
+ AutoModelForCausalLM,
136
  MODEL_PATH_Y,
 
137
  trust_remote_code=True,
138
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
139
  ).to(device).eval()
140
 
141
  MODEL_ID_X = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
142
  processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
143
+ model_x = load_model_with_attention_fallback(
144
+ Qwen2VLForConditionalGeneration,
145
  MODEL_ID_X,
 
146
  trust_remote_code=True,
147
  torch_dtype=torch.float16
148
  ).to(device).eval()
149
 
150
  MODEL_ID_W = "allenai/olmOCR-7B-0725"
151
  processor_w = AutoProcessor.from_pretrained(MODEL_ID_W, trust_remote_code=True)
152
+ model_w = load_model_with_attention_fallback(
153
+ Qwen2_5_VLForConditionalGeneration,
154
  MODEL_ID_W,
 
155
  trust_remote_code=True,
156
  torch_dtype=torch.float16
157
  ).to(device).eval()
158
 
159
  MODEL_ID_M = "reducto/RolmOCR"
160
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
161
+ model_m = load_model_with_attention_fallback(
162
+ Qwen2_5_VLForConditionalGeneration,
163
  MODEL_ID_M,
 
164
  trust_remote_code=True,
165
  torch_dtype=torch.float16
166
  ).to(device).eval()