Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -320,25 +320,51 @@ def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = '
|
|
| 320 |
|
| 321 |
return "\n".join(markdown_lines)
|
| 322 |
|
| 323 |
-
# Initialize model
|
| 324 |
model_id = "rednote-hilab/dots.ocr"
|
| 325 |
model_path = "./models/dots-ocr-local"
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
)
|
| 331 |
-
model
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
|
| 343 |
# Global state variables
|
| 344 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -356,6 +382,7 @@ pdf_cache = {
|
|
| 356 |
def inference(image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> str:
|
| 357 |
"""Run inference on an image with the given prompt"""
|
| 358 |
try:
|
|
|
|
| 359 |
if model is None or processor is None:
|
| 360 |
raise RuntimeError("Model not loaded. Please check model initialization.")
|
| 361 |
|
|
@@ -392,8 +419,9 @@ def inference(image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> s
|
|
| 392 |
return_tensors="pt",
|
| 393 |
)
|
| 394 |
|
| 395 |
-
# Move to device
|
| 396 |
-
|
|
|
|
| 397 |
|
| 398 |
# Generate output
|
| 399 |
with torch.no_grad():
|
|
@@ -423,6 +451,7 @@ def inference(image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> s
|
|
| 423 |
return f"Error during inference: {str(e)}"
|
| 424 |
|
| 425 |
|
|
|
|
| 426 |
def _generate_text_and_confidence_for_crop(
|
| 427 |
image: Image.Image,
|
| 428 |
max_new_tokens: int = 128,
|
|
@@ -433,6 +462,7 @@ def _generate_text_and_confidence_for_crop(
|
|
| 433 |
Returns (generated_text, average_confidence_percent).
|
| 434 |
"""
|
| 435 |
try:
|
|
|
|
| 436 |
# Prepare a concise extraction prompt for the crop
|
| 437 |
messages = [
|
| 438 |
{
|
|
@@ -463,7 +493,8 @@ def _generate_text_and_confidence_for_crop(
|
|
| 463 |
padding=True,
|
| 464 |
return_tensors="pt",
|
| 465 |
)
|
| 466 |
-
|
|
|
|
| 467 |
|
| 468 |
# Generate with scores
|
| 469 |
with torch.no_grad():
|
|
@@ -506,9 +537,10 @@ def _generate_text_and_confidence_for_crop(
|
|
| 506 |
|
| 507 |
|
| 508 |
def process_image(
|
| 509 |
-
image: Image.Image,
|
| 510 |
min_pixels: Optional[int] = None,
|
| 511 |
-
max_pixels: Optional[int] = None
|
|
|
|
| 512 |
) -> Dict[str, Any]:
|
| 513 |
"""Process a single image with the specified prompt mode"""
|
| 514 |
try:
|
|
@@ -517,7 +549,7 @@ def process_image(
|
|
| 517 |
image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
|
| 518 |
|
| 519 |
# Run inference with the default prompt
|
| 520 |
-
raw_output = inference(image, prompt)
|
| 521 |
|
| 522 |
# Process results based on prompt mode
|
| 523 |
result = {
|
|
@@ -876,8 +908,7 @@ def create_gradio_interface():
|
|
| 876 |
datatype=["html", "str", "str"],
|
| 877 |
label="OCR Results",
|
| 878 |
interactive=True,
|
| 879 |
-
wrap=True
|
| 880 |
-
height=500
|
| 881 |
)
|
| 882 |
# Markdown output tab
|
| 883 |
with gr.Tab("π Extracted Content"):
|
|
@@ -950,11 +981,14 @@ def create_gradio_interface():
|
|
| 950 |
return table_data
|
| 951 |
|
| 952 |
# Event handlers
|
|
|
|
| 953 |
def process_document(file_path, max_tokens, min_pix, max_pix):
|
| 954 |
"""Process the uploaded document"""
|
| 955 |
global pdf_cache
|
| 956 |
|
| 957 |
try:
|
|
|
|
|
|
|
| 958 |
if not file_path:
|
| 959 |
return None, [], "Please upload a file first.", None
|
| 960 |
|
|
@@ -974,9 +1008,10 @@ def create_gradio_interface():
|
|
| 974 |
|
| 975 |
for i, img in enumerate(pdf_cache["images"]):
|
| 976 |
result = process_image(
|
| 977 |
-
img,
|
| 978 |
min_pixels=int(min_pix) if min_pix else None,
|
| 979 |
-
max_pixels=int(max_pix) if max_pix else None
|
|
|
|
| 980 |
)
|
| 981 |
all_results.append(result)
|
| 982 |
if result.get('markdown_content'):
|
|
@@ -1014,7 +1049,8 @@ def create_gradio_interface():
|
|
| 1014 |
result = process_image(
|
| 1015 |
image,
|
| 1016 |
min_pixels=int(min_pix) if min_pix else None,
|
| 1017 |
-
max_pixels=int(max_pix) if max_pix else None
|
|
|
|
| 1018 |
)
|
| 1019 |
|
| 1020 |
pdf_cache["results"] = [result]
|
|
|
|
| 320 |
|
| 321 |
return "\n".join(markdown_lines)
|
| 322 |
|
| 323 |
+
# Initialize model/processor lazily inside GPU context
|
| 324 |
model_id = "rednote-hilab/dots.ocr"
|
| 325 |
model_path = "./models/dots-ocr-local"
|
| 326 |
+
model = None
|
| 327 |
+
processor = None
|
| 328 |
+
|
| 329 |
+
def ensure_model_loaded():
|
| 330 |
+
"""Lazily download and load model/processor using eager attention (no FlashAttention)."""
|
| 331 |
+
global model, processor
|
| 332 |
+
if model is not None and processor is not None:
|
| 333 |
+
return
|
| 334 |
+
|
| 335 |
+
# Always use eager attention
|
| 336 |
+
attn_impl = "eager"
|
| 337 |
+
# Use GPU if available, otherwise CPU
|
| 338 |
+
if torch.cuda.is_available():
|
| 339 |
+
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 340 |
+
device_map = "auto"
|
| 341 |
+
else:
|
| 342 |
+
dtype = torch.float32
|
| 343 |
+
device_map = "cpu"
|
| 344 |
+
|
| 345 |
+
# Download snapshot locally (idempotent)
|
| 346 |
+
snapshot_download(
|
| 347 |
+
repo_id=model_id,
|
| 348 |
+
local_dir=model_path,
|
| 349 |
+
local_dir_use_symlinks=False,
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
# Load model/processor
|
| 353 |
+
loaded_model = AutoModelForCausalLM.from_pretrained(
|
| 354 |
+
model_path,
|
| 355 |
+
attn_implementation=attn_impl,
|
| 356 |
+
torch_dtype=dtype,
|
| 357 |
+
device_map=device_map,
|
| 358 |
+
trust_remote_code=True,
|
| 359 |
+
low_cpu_mem_usage=True,
|
| 360 |
+
)
|
| 361 |
+
loaded_processor = AutoProcessor.from_pretrained(
|
| 362 |
+
model_path,
|
| 363 |
+
trust_remote_code=True,
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
model = loaded_model
|
| 367 |
+
processor = loaded_processor
|
| 368 |
|
| 369 |
# Global state variables
|
| 370 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 382 |
def inference(image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> str:
|
| 383 |
"""Run inference on an image with the given prompt"""
|
| 384 |
try:
|
| 385 |
+
ensure_model_loaded()
|
| 386 |
if model is None or processor is None:
|
| 387 |
raise RuntimeError("Model not loaded. Please check model initialization.")
|
| 388 |
|
|
|
|
| 419 |
return_tensors="pt",
|
| 420 |
)
|
| 421 |
|
| 422 |
+
# Move to the model's primary device (works with device_map as well)
|
| 423 |
+
primary_device = next(model.parameters()).device
|
| 424 |
+
inputs = inputs.to(primary_device)
|
| 425 |
|
| 426 |
# Generate output
|
| 427 |
with torch.no_grad():
|
|
|
|
| 451 |
return f"Error during inference: {str(e)}"
|
| 452 |
|
| 453 |
|
| 454 |
+
@spaces.GPU()
|
| 455 |
def _generate_text_and_confidence_for_crop(
|
| 456 |
image: Image.Image,
|
| 457 |
max_new_tokens: int = 128,
|
|
|
|
| 462 |
Returns (generated_text, average_confidence_percent).
|
| 463 |
"""
|
| 464 |
try:
|
| 465 |
+
ensure_model_loaded()
|
| 466 |
# Prepare a concise extraction prompt for the crop
|
| 467 |
messages = [
|
| 468 |
{
|
|
|
|
| 493 |
padding=True,
|
| 494 |
return_tensors="pt",
|
| 495 |
)
|
| 496 |
+
primary_device = next(model.parameters()).device
|
| 497 |
+
inputs = inputs.to(primary_device)
|
| 498 |
|
| 499 |
# Generate with scores
|
| 500 |
with torch.no_grad():
|
|
|
|
| 537 |
|
| 538 |
|
| 539 |
def process_image(
|
| 540 |
+
image: Image.Image,
|
| 541 |
min_pixels: Optional[int] = None,
|
| 542 |
+
max_pixels: Optional[int] = None,
|
| 543 |
+
max_new_tokens: int = 24000,
|
| 544 |
) -> Dict[str, Any]:
|
| 545 |
"""Process a single image with the specified prompt mode"""
|
| 546 |
try:
|
|
|
|
| 549 |
image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
|
| 550 |
|
| 551 |
# Run inference with the default prompt
|
| 552 |
+
raw_output = inference(image, prompt, max_new_tokens=max_new_tokens)
|
| 553 |
|
| 554 |
# Process results based on prompt mode
|
| 555 |
result = {
|
|
|
|
| 908 |
datatype=["html", "str", "str"],
|
| 909 |
label="OCR Results",
|
| 910 |
interactive=True,
|
| 911 |
+
wrap=True
|
|
|
|
| 912 |
)
|
| 913 |
# Markdown output tab
|
| 914 |
with gr.Tab("π Extracted Content"):
|
|
|
|
| 981 |
return table_data
|
| 982 |
|
| 983 |
# Event handlers
|
| 984 |
+
@spaces.GPU()
|
| 985 |
def process_document(file_path, max_tokens, min_pix, max_pix):
|
| 986 |
"""Process the uploaded document"""
|
| 987 |
global pdf_cache
|
| 988 |
|
| 989 |
try:
|
| 990 |
+
# Ensure model/processor are loaded within GPU context
|
| 991 |
+
ensure_model_loaded()
|
| 992 |
if not file_path:
|
| 993 |
return None, [], "Please upload a file first.", None
|
| 994 |
|
|
|
|
| 1008 |
|
| 1009 |
for i, img in enumerate(pdf_cache["images"]):
|
| 1010 |
result = process_image(
|
| 1011 |
+
img,
|
| 1012 |
min_pixels=int(min_pix) if min_pix else None,
|
| 1013 |
+
max_pixels=int(max_pix) if max_pix else None,
|
| 1014 |
+
max_new_tokens=int(max_tokens) if max_tokens else 24000,
|
| 1015 |
)
|
| 1016 |
all_results.append(result)
|
| 1017 |
if result.get('markdown_content'):
|
|
|
|
| 1049 |
result = process_image(
|
| 1050 |
image,
|
| 1051 |
min_pixels=int(min_pix) if min_pix else None,
|
| 1052 |
+
max_pixels=int(max_pix) if max_pix else None,
|
| 1053 |
+
max_new_tokens=int(max_tokens) if max_tokens else 24000,
|
| 1054 |
)
|
| 1055 |
|
| 1056 |
pdf_cache["results"] = [result]
|