wealthcoders commited on
Commit
68dd4db
·
verified ·
1 Parent(s): 8da1fd1

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +32 -5
handler.py CHANGED
@@ -7,14 +7,41 @@ from PIL import Image
7
  import os
8
 
9
  class EndpointHandler:
10
- def __init__(self):
11
- model_name = 'deepseek-ai/DeepSeek-OCR'
12
- self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
13
- model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
14
- self.model = model.eval().cuda().to(torch.bfloat16) # Use .cpu() if no GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def __call__(self, data: Dict[str, Any]) -> str:
17
  try:
 
18
  base64_string = inputs["base64"]
19
  # Remove data URL prefix if present
20
  if ',' in base64_string:
 
7
  import os
8
 
9
  class EndpointHandler:
10
+ def __init__(self, model_dir = 'deepseek-ai/DeepSeek-OCR'):
11
+ model_name = model_dir
12
+
13
+ self.tokenizer = AutoTokenizer.from_pretrained(
14
+ model_path,
15
+ trust_remote_code=True,
16
+ local_files_only=bool(model_dir) # Only use local files if model_dir is provided
17
+ )
18
+
19
+ # Check if CUDA is available
20
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
21
+
22
+ # Load model with appropriate settings
23
+ model_kwargs = {
24
+ 'trust_remote_code': True,
25
+ 'torch_dtype': torch.bfloat16 if self.device == 'cuda' else torch.float32
26
+ }
27
+
28
+ # Add flash attention if available and on CUDA
29
+ if self.device == 'cuda':
30
+ try:
31
+ model_kwargs['_attn_implementation'] = 'flash_attention_2'
32
+ except:
33
+ pass # Fall back to default if flash attention not available
34
+
35
+ self.model = AutoModel.from_pretrained(model_path, **model_kwargs)
36
+ self.model = self.model.eval()
37
+
38
+ # Move to appropriate device
39
+ if self.device == 'cuda':
40
+ self.model = self.model.cuda()
41
 
42
  def __call__(self, data: Dict[str, Any]) -> str:
43
  try:
44
+ inputs = data.get("inputs")
45
  base64_string = inputs["base64"]
46
  # Remove data URL prefix if present
47
  if ',' in base64_string: