ricklon Claude Sonnet 4.6 commited on
Commit
b9d5e1c
·
1 Parent(s): 25ba1bf

Fix flash_attention_2 startup crash on ZeroGPU and LaTeX delimiter rendering

Browse files

- Use sdpa attention impl when CUDA is unavailable at load time (ZeroGPU
has no GPU until inside @spaces.GPU); fall back to flash_attention_2
locally where CUDA is present
- Pre-convert model's \[...\] and \(...\) delimiters to $$...$$ and $...$
in to_math_html() before passing to markdown; markdown strips backslashes
before arithmatex can protect them, causing equations to render as bare
brackets instead of math
- Document delimiter pre-conversion in TECHNICAL.md

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. TECHNICAL.md +13 -0
  2. app.py +11 -1
TECHNICAL.md CHANGED
@@ -304,6 +304,19 @@ After arithmatex + markdown:
304
 
305
  The `_` inside the math is never touched by the markdown processor.
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  ### MathJax configuration
308
 
309
  MathJax is loaded once in the page `<head>` and configured to process `\(...\)` for inline math and `\[...\]` for display math — matching the output format of arithmatex:
 
304
 
305
  The `_` inside the math is never touched by the markdown processor.
306
 
307
+ ### Delimiter pre-conversion
308
+
309
+ The model outputs `\[...\]` for display math and `\(...\)` for inline math. But `pymdownx.arithmatex` only recognises `$...$` and `$$...$$` by default. Worse, if `\[...\]` is passed directly to the markdown processor, the backslashes are stripped first — before arithmatex can intercept them — leaving bare `[...]` brackets in the output.
310
+
311
+ `to_math_html()` therefore pre-converts the model's native delimiters before calling `markdown()`:
312
+
313
+ ```python
314
+ text = re.sub(r'\\\[(.+?)\\\]', r'$$\1$$', text, flags=re.DOTALL)
315
+ text = re.sub(r'\\\((.+?)\\\)', r'$\1$', text)
316
+ ```
317
+
318
+ After this step, arithmatex sees `$$...$$` and `$...$`, protects the content from markdown, and wraps it in `\[...\]` and `\(...\)` for MathJax to render.
319
+
320
  ### MathJax configuration
321
 
322
  MathJax is loaded once in the page `<head>` and configured to process `\(...\)` for inline math and `\[...\]` for display math — matching the output format of arithmatex:
app.py CHANGED
@@ -28,7 +28,12 @@ MODEL_NAME = 'deepseek-ai/DeepSeek-OCR-2'
28
  # MODEL_NAME = 'mzbac/DeepSeek-OCR-2-8bit'
29
 
30
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
31
- model = AutoModel.from_pretrained(MODEL_NAME, _attn_implementation='flash_attention_2', torch_dtype=torch.bfloat16, trust_remote_code=True, use_safetensors=True).eval()
 
 
 
 
 
32
  # .cuda() is NOT called here — on ZeroGPU, GPU is only available inside @spaces.GPU
33
  # functions. Locally, model.cuda() is called inside process_image on first run.
34
 
@@ -161,6 +166,11 @@ window.MathJax = {
161
  def to_math_html(text):
162
  if not text:
163
  return ""
 
 
 
 
 
164
  html = md_lib.markdown(text, extensions=[
165
  'pymdownx.arithmatex',
166
  'tables',
 
28
  # MODEL_NAME = 'mzbac/DeepSeek-OCR-2-8bit'
29
 
30
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
31
+ # flash_attention_2 requires a CUDA device at init time — not available on ZeroGPU at
32
+ # module load. Use sdpa (PyTorch scaled dot product attention) as the fallback; it works
33
+ # on CPU at load time and on GPU at inference time. Locally with CUDA present, use
34
+ # flash_attention_2 for maximum throughput.
35
+ _attn_impl = 'flash_attention_2' if torch.cuda.is_available() else 'sdpa'
36
+ model = AutoModel.from_pretrained(MODEL_NAME, _attn_implementation=_attn_impl, torch_dtype=torch.bfloat16, trust_remote_code=True, use_safetensors=True).eval()
37
  # .cuda() is NOT called here — on ZeroGPU, GPU is only available inside @spaces.GPU
38
  # functions. Locally, model.cuda() is called inside process_image on first run.
39
 
 
166
  def to_math_html(text):
167
  if not text:
168
  return ""
169
+ # Pre-convert \[...\] and \(...\) to $$...$$ and $...$
170
+ # Markdown strips backslashes before arithmatex can protect them,
171
+ # so convert to $-delimiters first (arithmatex recognises those).
172
+ text = re.sub(r'\\\[(.+?)\\\]', r'$$\1$$', text, flags=re.DOTALL)
173
+ text = re.sub(r'\\\((.+?)\\\)', r'$\1$', text)
174
  html = md_lib.markdown(text, extensions=[
175
  'pymdownx.arithmatex',
176
  'tables',