ZienabM commited on
Commit
02f719e
Β·
verified Β·
1 Parent(s): 8727456

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -15
app.py CHANGED
@@ -48,45 +48,121 @@ async def lifespan(app: FastAPI):
48
  app = FastAPI(title="DeepSeek-OCR-2 API", version="2.0.0", lifespan=lifespan)
49
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # ─── Core OCR inference ───────────────────────────────────────────────────────
52
  def run_ocr(pil_image: Image.Image, mode: str = "free") -> str:
53
- """Run DeepSeek-OCR-2 on a PIL image, return extracted text."""
 
 
 
54
  prompt_text = (
55
  "Convert the document to markdown."
56
  if mode == "markdown"
57
  else "Please OCR the image and return all text exactly."
58
  )
 
59
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
60
  tmp_path = tmp.name
61
  pil_image.save(tmp_path, format="PNG")
 
62
  try:
63
  if hasattr(model, "infer"):
64
  with tempfile.TemporaryDirectory() as out_dir:
65
- result = model.infer(
66
- tokenizer,
67
- prompt=f"<image>\n{prompt_text}",
68
- image_file=tmp_path,
69
- output_path=out_dir,
70
- base_size=1024,
71
- image_size=768,
72
- crop_mode=True,
73
- save_results=False,
74
- )
 
 
75
  if isinstance(result, dict):
76
  return result.get("text", str(result))
77
  return str(result) if result else ""
78
 
79
- # fallback: standard generate()
80
  messages = [{"role": "user", "content": [
81
  {"type": "image", "image": tmp_path},
82
  {"type": "text", "text": prompt_text},
83
  ]}]
84
- text_in = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
85
- inputs = tokenizer(text_in, return_tensors="pt")
 
 
86
  with torch.no_grad():
87
  out = model.generate(**inputs, max_new_tokens=1024, do_sample=False)
88
  new_ids = out[:, inputs["input_ids"].shape[1]:]
89
  return tokenizer.decode(new_ids[0], skip_special_tokens=True)
 
90
  finally:
91
  os.unlink(tmp_path)
92
 
@@ -816,4 +892,4 @@ function cp(id){
816
  function setProgress(pct){document.getElementById('prog').style.width=pct+'%';}
817
  </script>
818
  </body>
819
- </html>"""
 
48
  app = FastAPI(title="DeepSeek-OCR-2 API", version="2.0.0", lifespan=lifespan)
49
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
50
 
51
+ # ─── CPU monkey-patch context manager ────────────────────────────────────────
52
+ from contextlib import contextmanager
53
+
54
+ @contextmanager
55
+ def force_cpu():
56
+ """
57
+ DeepSeek-OCR-2's model.infer() hardcodes .cuda() even when no GPU is present.
58
+ This context manager temporarily replaces all CUDA-moving calls with no-ops
59
+ so the model runs on CPU without modification.
60
+ """
61
+ # Save originals
62
+ _tensor_cuda = torch.Tensor.cuda
63
+ _module_cuda = torch.nn.Module.cuda
64
+ _tensor_to = torch.Tensor.to
65
+ _module_to = torch.nn.Module.to
66
+
67
+ # Tensor.cuda() β†’ return self (stay on CPU)
68
+ def _noop_tensor_cuda(self, device=None, *args, **kwargs):
69
+ return self
70
+
71
+ # Module.cuda() β†’ return self
72
+ def _noop_module_cuda(self, device=None):
73
+ return self
74
+
75
+ # Tensor.to("cuda") / to(device) β†’ stay on CPU; allow dtype casts
76
+ def _safe_tensor_to(self, *args, **kwargs):
77
+ filtered = [
78
+ a for a in args
79
+ if not (isinstance(a, (str, torch.device)) and "cuda" in str(a))
80
+ ]
81
+ kwargs.pop("device", None)
82
+ if filtered or kwargs:
83
+ try:
84
+ return _tensor_to(self, *filtered, **kwargs)
85
+ except Exception:
86
+ return self
87
+ return self
88
+
89
+ # Module.to("cuda") β†’ stay on CPU; allow dtype casts
90
+ def _safe_module_to(self, *args, **kwargs):
91
+ filtered = [
92
+ a for a in args
93
+ if not (isinstance(a, (str, torch.device)) and "cuda" in str(a))
94
+ ]
95
+ kwargs.pop("device", None)
96
+ if filtered or kwargs:
97
+ try:
98
+ return _module_to(self, *filtered, **kwargs)
99
+ except Exception:
100
+ return self
101
+ return self
102
+
103
+ torch.Tensor.cuda = _noop_tensor_cuda
104
+ torch.nn.Module.cuda = _noop_module_cuda
105
+ torch.Tensor.to = _safe_tensor_to
106
+ torch.nn.Module.to = _safe_module_to
107
+
108
+ try:
109
+ yield
110
+ finally:
111
+ torch.Tensor.cuda = _tensor_cuda
112
+ torch.nn.Module.cuda = _module_cuda
113
+ torch.Tensor.to = _tensor_to
114
+ torch.nn.Module.to = _module_to
115
+
116
+
117
  # ─── Core OCR inference ───────────────────────────────────────────────────────
118
  def run_ocr(pil_image: Image.Image, mode: str = "free") -> str:
119
+ """
120
+ Run DeepSeek-OCR-2 on a PIL image and return extracted text.
121
+ Works on both CPU (HF free tier) and GPU.
122
+ """
123
  prompt_text = (
124
  "Convert the document to markdown."
125
  if mode == "markdown"
126
  else "Please OCR the image and return all text exactly."
127
  )
128
+
129
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
130
  tmp_path = tmp.name
131
  pil_image.save(tmp_path, format="PNG")
132
+
133
  try:
134
  if hasattr(model, "infer"):
135
  with tempfile.TemporaryDirectory() as out_dir:
136
+ # force_cpu() patches .cuda() β†’ no-op so model.infer() works on CPU
137
+ with force_cpu():
138
+ result = model.infer(
139
+ tokenizer,
140
+ prompt=f"<image>\n{prompt_text}",
141
+ image_file=tmp_path,
142
+ output_path=out_dir,
143
+ base_size=1024,
144
+ image_size=768,
145
+ crop_mode=True,
146
+ save_results=False,
147
+ )
148
  if isinstance(result, dict):
149
  return result.get("text", str(result))
150
  return str(result) if result else ""
151
 
152
+ # ── Fallback: standard generate() if model.infer() is not available ──
153
  messages = [{"role": "user", "content": [
154
  {"type": "image", "image": tmp_path},
155
  {"type": "text", "text": prompt_text},
156
  ]}]
157
+ text_in = tokenizer.apply_chat_template(
158
+ messages, tokenize=False, add_generation_prompt=True
159
+ )
160
+ inputs = tokenizer(text_in, return_tensors="pt")
161
  with torch.no_grad():
162
  out = model.generate(**inputs, max_new_tokens=1024, do_sample=False)
163
  new_ids = out[:, inputs["input_ids"].shape[1]:]
164
  return tokenizer.decode(new_ids[0], skip_special_tokens=True)
165
+
166
  finally:
167
  os.unlink(tmp_path)
168
 
 
892
  function setProgress(pct){document.getElementById('prog').style.width=pct+'%';}
893
  </script>
894
  </body>
895
+ </html>"""