allow custom prompts
Browse files
app.py
CHANGED
|
@@ -352,7 +352,7 @@ pdf_cache = {
|
|
| 352 |
"results": []
|
| 353 |
}
|
| 354 |
@spaces.GPU()
|
| 355 |
-
def inference(image: Image.Image,
|
| 356 |
"""Run inference on an image with the given prompt"""
|
| 357 |
try:
|
| 358 |
if model is None or processor is None:
|
|
@@ -367,7 +367,7 @@ def inference(image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> s
|
|
| 367 |
"type": "image",
|
| 368 |
"image": image
|
| 369 |
},
|
| 370 |
-
{"type": "text", "text":
|
| 371 |
]
|
| 372 |
}
|
| 373 |
]
|
|
@@ -425,7 +425,9 @@ def inference(image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> s
|
|
| 425 |
def process_image(
|
| 426 |
image: Image.Image,
|
| 427 |
min_pixels: Optional[int] = None,
|
| 428 |
-
max_pixels: Optional[int] = None
|
|
|
|
|
|
|
| 429 |
) -> Dict[str, Any]:
|
| 430 |
"""Process a single image with the specified prompt mode"""
|
| 431 |
try:
|
|
@@ -434,7 +436,7 @@ def process_image(
|
|
| 434 |
image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
|
| 435 |
|
| 436 |
# Run inference with the default prompt
|
| 437 |
-
raw_output = inference(image,
|
| 438 |
|
| 439 |
# Process results based on prompt mode
|
| 440 |
result = {
|
|
@@ -685,6 +687,7 @@ def create_gradio_interface():
|
|
| 685 |
|
| 686 |
# Advanced settings
|
| 687 |
with gr.Accordion("Advanced Settings", open=False):
|
|
|
|
| 688 |
max_new_tokens = gr.Slider(
|
| 689 |
minimum=1000,
|
| 690 |
maximum=32000,
|
|
@@ -744,7 +747,7 @@ def create_gradio_interface():
|
|
| 744 |
)
|
| 745 |
|
| 746 |
# Event handlers
|
| 747 |
-
def process_document(file_path, max_tokens, min_pix, max_pix):
|
| 748 |
"""Process the uploaded document"""
|
| 749 |
global pdf_cache
|
| 750 |
|
|
@@ -770,7 +773,9 @@ def create_gradio_interface():
|
|
| 770 |
result = process_image(
|
| 771 |
img,
|
| 772 |
min_pixels=int(min_pix) if min_pix else None,
|
| 773 |
-
max_pixels=int(max_pix) if max_pix else None
|
|
|
|
|
|
|
| 774 |
)
|
| 775 |
all_results.append(result)
|
| 776 |
if result.get('markdown_content'):
|
|
@@ -799,7 +804,9 @@ def create_gradio_interface():
|
|
| 799 |
result = process_image(
|
| 800 |
image,
|
| 801 |
min_pixels=int(min_pix) if min_pix else None,
|
| 802 |
-
max_pixels=int(max_pix) if max_pix else None
|
|
|
|
|
|
|
| 803 |
)
|
| 804 |
|
| 805 |
pdf_cache["results"] = [result]
|
|
@@ -875,7 +882,7 @@ def create_gradio_interface():
|
|
| 875 |
|
| 876 |
process_btn.click(
|
| 877 |
process_document,
|
| 878 |
-
inputs=[file_input, max_new_tokens, min_pixels, max_pixels],
|
| 879 |
outputs=[processed_image, markdown_output, json_output]
|
| 880 |
)
|
| 881 |
|
|
|
|
| 352 |
"results": []
|
| 353 |
}
|
| 354 |
@spaces.GPU()
|
| 355 |
+
def inference(image: Image.Image, max_new_tokens: int = 24000, custom_prompt: str = '') -> str:
|
| 356 |
"""Run inference on an image with the given prompt"""
|
| 357 |
try:
|
| 358 |
if model is None or processor is None:
|
|
|
|
| 367 |
"type": "image",
|
| 368 |
"image": image
|
| 369 |
},
|
| 370 |
+
{"type": "text", "text": custom_prompt}
|
| 371 |
]
|
| 372 |
}
|
| 373 |
]
|
|
|
|
| 425 |
def process_image(
|
| 426 |
image: Image.Image,
|
| 427 |
min_pixels: Optional[int] = None,
|
| 428 |
+
max_pixels: Optional[int] = None,
|
| 429 |
+
custom_prompt: Optional[str] = None,
|
| 430 |
+
max_new_tokens: int = 24000,
|
| 431 |
) -> Dict[str, Any]:
|
| 432 |
"""Process a single image with the specified prompt mode"""
|
| 433 |
try:
|
|
|
|
| 436 |
image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
|
| 437 |
|
| 438 |
# Run inference with the default prompt
|
| 439 |
+
raw_output = inference(image=image, custom_prompt=custom_prompt, max_new_tokens=max_new_tokens)
|
| 440 |
|
| 441 |
# Process results based on prompt mode
|
| 442 |
result = {
|
|
|
|
| 687 |
|
| 688 |
# Advanced settings
|
| 689 |
with gr.Accordion("Advanced Settings", open=False):
|
| 690 |
+
custom_prompt = gr.Textbox(label="Custom Prompt", value=prompt, lines=12, placeholder="Enter a custom prompt...", info="Modify the OCR / layout extraction prompt.")
|
| 691 |
max_new_tokens = gr.Slider(
|
| 692 |
minimum=1000,
|
| 693 |
maximum=32000,
|
|
|
|
| 747 |
)
|
| 748 |
|
| 749 |
# Event handlers
|
| 750 |
+
def process_document(file_path, max_tokens, min_pix, max_pix, custom_prompt):
|
| 751 |
"""Process the uploaded document"""
|
| 752 |
global pdf_cache
|
| 753 |
|
|
|
|
| 773 |
result = process_image(
|
| 774 |
img,
|
| 775 |
min_pixels=int(min_pix) if min_pix else None,
|
| 776 |
+
max_pixels=int(max_pix) if max_pix else None,
|
| 777 |
+
custom_prompt=custom_prompt,
|
| 778 |
+
max_new_tokens=max_tokens
|
| 779 |
)
|
| 780 |
all_results.append(result)
|
| 781 |
if result.get('markdown_content'):
|
|
|
|
| 804 |
result = process_image(
|
| 805 |
image,
|
| 806 |
min_pixels=int(min_pix) if min_pix else None,
|
| 807 |
+
max_pixels=int(max_pix) if max_pix else None,
|
| 808 |
+
custom_prompt=custom_prompt,
|
| 809 |
+
max_new_tokens=max_tokens
|
| 810 |
)
|
| 811 |
|
| 812 |
pdf_cache["results"] = [result]
|
|
|
|
| 882 |
|
| 883 |
process_btn.click(
|
| 884 |
process_document,
|
| 885 |
+
inputs=[file_input, max_new_tokens, min_pixels, max_pixels, custom_prompt],
|
| 886 |
outputs=[processed_image, markdown_output, json_output]
|
| 887 |
)
|
| 888 |
|