yahtzee commited on
Commit
94c64d2
·
1 Parent(s): 4a015a8

allow custom prompts

Browse files
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -352,7 +352,7 @@ pdf_cache = {
352
  "results": []
353
  }
354
  @spaces.GPU()
355
- def inference(image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> str:
356
  """Run inference on an image with the given prompt"""
357
  try:
358
  if model is None or processor is None:
@@ -367,7 +367,7 @@ def inference(image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> s
367
  "type": "image",
368
  "image": image
369
  },
370
- {"type": "text", "text": prompt}
371
  ]
372
  }
373
  ]
@@ -425,7 +425,9 @@ def inference(image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> s
425
  def process_image(
426
  image: Image.Image,
427
  min_pixels: Optional[int] = None,
428
- max_pixels: Optional[int] = None
 
 
429
  ) -> Dict[str, Any]:
430
  """Process a single image with the specified prompt mode"""
431
  try:
@@ -434,7 +436,7 @@ def process_image(
434
  image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
435
 
436
  # Run inference with the default prompt
437
- raw_output = inference(image, prompt)
438
 
439
  # Process results based on prompt mode
440
  result = {
@@ -685,6 +687,7 @@ def create_gradio_interface():
685
 
686
  # Advanced settings
687
  with gr.Accordion("Advanced Settings", open=False):
 
688
  max_new_tokens = gr.Slider(
689
  minimum=1000,
690
  maximum=32000,
@@ -744,7 +747,7 @@ def create_gradio_interface():
744
  )
745
 
746
  # Event handlers
747
- def process_document(file_path, max_tokens, min_pix, max_pix):
748
  """Process the uploaded document"""
749
  global pdf_cache
750
 
@@ -770,7 +773,9 @@ def create_gradio_interface():
770
  result = process_image(
771
  img,
772
  min_pixels=int(min_pix) if min_pix else None,
773
- max_pixels=int(max_pix) if max_pix else None
 
 
774
  )
775
  all_results.append(result)
776
  if result.get('markdown_content'):
@@ -799,7 +804,9 @@ def create_gradio_interface():
799
  result = process_image(
800
  image,
801
  min_pixels=int(min_pix) if min_pix else None,
802
- max_pixels=int(max_pix) if max_pix else None
 
 
803
  )
804
 
805
  pdf_cache["results"] = [result]
@@ -875,7 +882,7 @@ def create_gradio_interface():
875
 
876
  process_btn.click(
877
  process_document,
878
- inputs=[file_input, max_new_tokens, min_pixels, max_pixels],
879
  outputs=[processed_image, markdown_output, json_output]
880
  )
881
 
 
352
  "results": []
353
  }
354
  @spaces.GPU()
355
+ def inference(image: Image.Image, max_new_tokens: int = 24000, custom_prompt: str = '') -> str:
356
  """Run inference on an image with the given prompt"""
357
  try:
358
  if model is None or processor is None:
 
367
  "type": "image",
368
  "image": image
369
  },
370
+ {"type": "text", "text": custom_prompt}
371
  ]
372
  }
373
  ]
 
425
  def process_image(
426
  image: Image.Image,
427
  min_pixels: Optional[int] = None,
428
+ max_pixels: Optional[int] = None,
429
+ custom_prompt: Optional[str] = None,
430
+ max_new_tokens: int = 24000,
431
  ) -> Dict[str, Any]:
432
  """Process a single image with the specified prompt mode"""
433
  try:
 
436
  image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
437
 
438
  # Run inference with the default prompt
439
+ raw_output = inference(image=image, custom_prompt=custom_prompt, max_new_tokens=max_new_tokens)
440
 
441
  # Process results based on prompt mode
442
  result = {
 
687
 
688
  # Advanced settings
689
  with gr.Accordion("Advanced Settings", open=False):
690
+ custom_prompt = gr.Textbox(label="Custom Prompt", value=prompt, lines=12, placeholder="Enter a custom prompt...", info="Modify the OCR / layout extraction prompt.")
691
  max_new_tokens = gr.Slider(
692
  minimum=1000,
693
  maximum=32000,
 
747
  )
748
 
749
  # Event handlers
750
+ def process_document(file_path, max_tokens, min_pix, max_pix, custom_prompt):
751
  """Process the uploaded document"""
752
  global pdf_cache
753
 
 
773
  result = process_image(
774
  img,
775
  min_pixels=int(min_pix) if min_pix else None,
776
+ max_pixels=int(max_pix) if max_pix else None,
777
+ custom_prompt=custom_prompt,
778
+ max_new_tokens=max_tokens
779
  )
780
  all_results.append(result)
781
  if result.get('markdown_content'):
 
804
  result = process_image(
805
  image,
806
  min_pixels=int(min_pix) if min_pix else None,
807
+ max_pixels=int(max_pix) if max_pix else None,
808
+ custom_prompt=custom_prompt,
809
+ max_new_tokens=max_tokens
810
  )
811
 
812
  pdf_cache["results"] = [result]
 
882
 
883
  process_btn.click(
884
  process_document,
885
+ inputs=[file_input, max_new_tokens, min_pixels, max_pixels, custom_prompt],
886
  outputs=[processed_image, markdown_output, json_output]
887
  )
888