VanguardAI commited on
Commit
08b7752
Β·
verified Β·
1 Parent(s): 54d5943

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -28
app.py CHANGED
@@ -320,25 +320,51 @@ def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = '
320
 
321
  return "\n".join(markdown_lines)
322
 
323
- # Initialize model and processor at script level
324
  model_id = "rednote-hilab/dots.ocr"
325
  model_path = "./models/dots-ocr-local"
326
- snapshot_download(
327
- repo_id=model_id,
328
- local_dir=model_path,
329
- local_dir_use_symlinks=False, # Recommended to set to False to avoid symlink issues
330
- )
331
- model = AutoModelForCausalLM.from_pretrained(
332
- model_path,
333
- attn_implementation="flash_attention_2",
334
- torch_dtype=torch.bfloat16,
335
- device_map="auto",
336
- trust_remote_code=True
337
- )
338
- processor = AutoProcessor.from_pretrained(
339
- model_path,
340
- trust_remote_code=True
341
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
  # Global state variables
344
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -356,6 +382,7 @@ pdf_cache = {
356
  def inference(image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> str:
357
  """Run inference on an image with the given prompt"""
358
  try:
 
359
  if model is None or processor is None:
360
  raise RuntimeError("Model not loaded. Please check model initialization.")
361
 
@@ -392,8 +419,9 @@ def inference(image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> s
392
  return_tensors="pt",
393
  )
394
 
395
- # Move to device
396
- inputs = inputs.to(device)
 
397
 
398
  # Generate output
399
  with torch.no_grad():
@@ -423,6 +451,7 @@ def inference(image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> s
423
  return f"Error during inference: {str(e)}"
424
 
425
 
 
426
  def _generate_text_and_confidence_for_crop(
427
  image: Image.Image,
428
  max_new_tokens: int = 128,
@@ -433,6 +462,7 @@ def _generate_text_and_confidence_for_crop(
433
  Returns (generated_text, average_confidence_percent).
434
  """
435
  try:
 
436
  # Prepare a concise extraction prompt for the crop
437
  messages = [
438
  {
@@ -463,7 +493,8 @@ def _generate_text_and_confidence_for_crop(
463
  padding=True,
464
  return_tensors="pt",
465
  )
466
- inputs = inputs.to(device)
 
467
 
468
  # Generate with scores
469
  with torch.no_grad():
@@ -506,9 +537,10 @@ def _generate_text_and_confidence_for_crop(
506
 
507
 
508
  def process_image(
509
- image: Image.Image,
510
  min_pixels: Optional[int] = None,
511
- max_pixels: Optional[int] = None
 
512
  ) -> Dict[str, Any]:
513
  """Process a single image with the specified prompt mode"""
514
  try:
@@ -517,7 +549,7 @@ def process_image(
517
  image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
518
 
519
  # Run inference with the default prompt
520
- raw_output = inference(image, prompt)
521
 
522
  # Process results based on prompt mode
523
  result = {
@@ -876,8 +908,7 @@ def create_gradio_interface():
876
  datatype=["html", "str", "str"],
877
  label="OCR Results",
878
  interactive=True,
879
- wrap=True,
880
- height=500
881
  )
882
  # Markdown output tab
883
  with gr.Tab("πŸ“ Extracted Content"):
@@ -950,11 +981,14 @@ def create_gradio_interface():
950
  return table_data
951
 
952
  # Event handlers
 
953
  def process_document(file_path, max_tokens, min_pix, max_pix):
954
  """Process the uploaded document"""
955
  global pdf_cache
956
 
957
  try:
 
 
958
  if not file_path:
959
  return None, [], "Please upload a file first.", None
960
 
@@ -974,9 +1008,10 @@ def create_gradio_interface():
974
 
975
  for i, img in enumerate(pdf_cache["images"]):
976
  result = process_image(
977
- img,
978
  min_pixels=int(min_pix) if min_pix else None,
979
- max_pixels=int(max_pix) if max_pix else None
 
980
  )
981
  all_results.append(result)
982
  if result.get('markdown_content'):
@@ -1014,7 +1049,8 @@ def create_gradio_interface():
1014
  result = process_image(
1015
  image,
1016
  min_pixels=int(min_pix) if min_pix else None,
1017
- max_pixels=int(max_pix) if max_pix else None
 
1018
  )
1019
 
1020
  pdf_cache["results"] = [result]
 
320
 
321
  return "\n".join(markdown_lines)
322
 
323
+ # Initialize model/processor lazily inside GPU context
324
  model_id = "rednote-hilab/dots.ocr"
325
  model_path = "./models/dots-ocr-local"
326
+ model = None
327
+ processor = None
328
+
329
+ def ensure_model_loaded():
330
+ """Lazily download and load model/processor using eager attention (no FlashAttention)."""
331
+ global model, processor
332
+ if model is not None and processor is not None:
333
+ return
334
+
335
+ # Always use eager attention
336
+ attn_impl = "eager"
337
+ # Use GPU if available, otherwise CPU
338
+ if torch.cuda.is_available():
339
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
340
+ device_map = "auto"
341
+ else:
342
+ dtype = torch.float32
343
+ device_map = "cpu"
344
+
345
+ # Download snapshot locally (idempotent)
346
+ snapshot_download(
347
+ repo_id=model_id,
348
+ local_dir=model_path,
349
+ local_dir_use_symlinks=False,
350
+ )
351
+
352
+ # Load model/processor
353
+ loaded_model = AutoModelForCausalLM.from_pretrained(
354
+ model_path,
355
+ attn_implementation=attn_impl,
356
+ torch_dtype=dtype,
357
+ device_map=device_map,
358
+ trust_remote_code=True,
359
+ low_cpu_mem_usage=True,
360
+ )
361
+ loaded_processor = AutoProcessor.from_pretrained(
362
+ model_path,
363
+ trust_remote_code=True,
364
+ )
365
+
366
+ model = loaded_model
367
+ processor = loaded_processor
368
 
369
  # Global state variables
370
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
382
  def inference(image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> str:
383
  """Run inference on an image with the given prompt"""
384
  try:
385
+ ensure_model_loaded()
386
  if model is None or processor is None:
387
  raise RuntimeError("Model not loaded. Please check model initialization.")
388
 
 
419
  return_tensors="pt",
420
  )
421
 
422
+ # Move to the model's primary device (works with device_map as well)
423
+ primary_device = next(model.parameters()).device
424
+ inputs = inputs.to(primary_device)
425
 
426
  # Generate output
427
  with torch.no_grad():
 
451
  return f"Error during inference: {str(e)}"
452
 
453
 
454
+ @spaces.GPU()
455
  def _generate_text_and_confidence_for_crop(
456
  image: Image.Image,
457
  max_new_tokens: int = 128,
 
462
  Returns (generated_text, average_confidence_percent).
463
  """
464
  try:
465
+ ensure_model_loaded()
466
  # Prepare a concise extraction prompt for the crop
467
  messages = [
468
  {
 
493
  padding=True,
494
  return_tensors="pt",
495
  )
496
+ primary_device = next(model.parameters()).device
497
+ inputs = inputs.to(primary_device)
498
 
499
  # Generate with scores
500
  with torch.no_grad():
 
537
 
538
 
539
  def process_image(
540
+ image: Image.Image,
541
  min_pixels: Optional[int] = None,
542
+ max_pixels: Optional[int] = None,
543
+ max_new_tokens: int = 24000,
544
  ) -> Dict[str, Any]:
545
  """Process a single image with the specified prompt mode"""
546
  try:
 
549
  image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
550
 
551
  # Run inference with the default prompt
552
+ raw_output = inference(image, prompt, max_new_tokens=max_new_tokens)
553
 
554
  # Process results based on prompt mode
555
  result = {
 
908
  datatype=["html", "str", "str"],
909
  label="OCR Results",
910
  interactive=True,
911
+ wrap=True
 
912
  )
913
  # Markdown output tab
914
  with gr.Tab("πŸ“ Extracted Content"):
 
981
  return table_data
982
 
983
  # Event handlers
984
+ @spaces.GPU()
985
  def process_document(file_path, max_tokens, min_pix, max_pix):
986
  """Process the uploaded document"""
987
  global pdf_cache
988
 
989
  try:
990
+ # Ensure model/processor are loaded within GPU context
991
+ ensure_model_loaded()
992
  if not file_path:
993
  return None, [], "Please upload a file first.", None
994
 
 
1008
 
1009
  for i, img in enumerate(pdf_cache["images"]):
1010
  result = process_image(
1011
+ img,
1012
  min_pixels=int(min_pix) if min_pix else None,
1013
+ max_pixels=int(max_pix) if max_pix else None,
1014
+ max_new_tokens=int(max_tokens) if max_tokens else 24000,
1015
  )
1016
  all_results.append(result)
1017
  if result.get('markdown_content'):
 
1049
  result = process_image(
1050
  image,
1051
  min_pixels=int(min_pix) if min_pix else None,
1052
+ max_pixels=int(max_pix) if max_pix else None,
1053
+ max_new_tokens=int(max_tokens) if max_tokens else 24000,
1054
  )
1055
 
1056
  pdf_cache["results"] = [result]