IFMedTechdemo commited on
Commit
6193aca
·
verified ·
1 Parent(s): 1ed4a63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -27,7 +27,8 @@ processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
27
  model_v = Qwen3VLForConditionalGeneration.from_pretrained(
28
  MODEL_ID_V,
29
  trust_remote_code=True,
30
- torch_dtype=torch.float16
 
31
  ).to(device).eval()
32
 
33
  # Load Nanonets-OCR2-3B
@@ -36,15 +37,16 @@ processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
36
  model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
37
  MODEL_ID_X,
38
  trust_remote_code=True,
39
- torch_dtype=torch.float16
 
40
  ).to(device).eval()
41
 
42
- # Load Dots.OCR
43
  MODEL_PATH_D = "strangervisionhf/dots.ocr-base-fix"
44
  processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
45
  model_d = AutoModelForCausalLM.from_pretrained(
46
  MODEL_PATH_D,
47
- attn_implementation="flash_attention_2",
48
  torch_dtype=torch.bfloat16,
49
  device_map="auto",
50
  trust_remote_code=True
@@ -56,15 +58,16 @@ processor_m = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", trust
56
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
57
  MODEL_ID_M,
58
  trust_remote_code=True,
59
- torch_dtype=torch.bfloat16
 
60
  ).to(device).eval()
61
 
62
- # Load DeepSeek-OCR
63
  MODEL_ID_DS = "deepseek-ai/DeepSeek-OCR"
64
  tokenizer_ds = AutoTokenizer.from_pretrained(MODEL_ID_DS, trust_remote_code=True)
65
  model_ds = AutoModel.from_pretrained(
66
  MODEL_ID_DS,
67
- attn_implementation="flash_attention_2",
68
  trust_remote_code=True,
69
  use_safetensors=True
70
  ).eval().to(device).to(torch.bfloat16)
 
27
  model_v = Qwen3VLForConditionalGeneration.from_pretrained(
28
  MODEL_ID_V,
29
  trust_remote_code=True,
30
+ torch_dtype=torch.float16,
31
+ attn_implementation="sdpa" # Use PyTorch's native scaled dot product attention
32
  ).to(device).eval()
33
 
34
  # Load Nanonets-OCR2-3B
 
37
  model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
38
  MODEL_ID_X,
39
  trust_remote_code=True,
40
+ torch_dtype=torch.float16,
41
+ attn_implementation="sdpa" # Use PyTorch's native attention
42
  ).to(device).eval()
43
 
44
+ # Load Dots.OCR - REMOVE flash_attention_2 parameter
45
  MODEL_PATH_D = "strangervisionhf/dots.ocr-base-fix"
46
  processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
47
  model_d = AutoModelForCausalLM.from_pretrained(
48
  MODEL_PATH_D,
49
+ attn_implementation="sdpa", # Changed from flash_attention_2
50
  torch_dtype=torch.bfloat16,
51
  device_map="auto",
52
  trust_remote_code=True
 
58
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
59
  MODEL_ID_M,
60
  trust_remote_code=True,
61
+ torch_dtype=torch.bfloat16,
62
+ attn_implementation="sdpa" # Use PyTorch's native attention
63
  ).to(device).eval()
64
 
65
+ # Load DeepSeek-OCR - REMOVE flash_attention_2 parameter
66
  MODEL_ID_DS = "deepseek-ai/DeepSeek-OCR"
67
  tokenizer_ds = AutoTokenizer.from_pretrained(MODEL_ID_DS, trust_remote_code=True)
68
  model_ds = AutoModel.from_pretrained(
69
  MODEL_ID_DS,
70
+ attn_implementation="sdpa", # Changed from flash_attention_2
71
  trust_remote_code=True,
72
  use_safetensors=True
73
  ).eval().to(device).to(torch.bfloat16)