File size: 2,532 Bytes
4b6ac7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""
Run inference with the fine-tuned Sanskrit OCR model (v2)
"""

import os
import warnings
warnings.filterwarnings("ignore")

import re
import torch
from peft import PeftModel
from transformers import AutoModel, AutoProcessor

os.environ["HF_HUB_OFFLINE"] = "1"

DEVANAGARI_RE = re.compile(r"[\u0900-\u097F\s॥।,.!?0-9०-९]+")
REPEAT_RE = re.compile(r"(?:\b(.+?)\b)(?:\s+\1){2,}")


def clean_text(raw: str) -> str:
    lines = []
    for line in raw.splitlines():
        line = line.strip()
        if not line:
            continue
        if line.lower().startswith("directly resize"):
            continue

        chunks = DEVANAGARI_RE.findall(line)
        if not chunks:
            continue

        joined = " ".join(chunks)
        joined = REPEAT_RE.sub(r"\1", joined)
        lines.append(joined)

    return "\n".join(lines)


def load_model_with_lora(base_model_path, lora_path):
    print("Loading base model...")
    model = AutoModel.from_pretrained(
        base_model_path,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )

    print("Loading LoRA adapters...")
    model = PeftModel.from_pretrained(model, lora_path)
    model = model.merge_and_unload()
    model.eval()

    print("Model loaded.")
    return model


def run_inference(model, image_path, processor, max_new_tokens=1024):
    print(f"Running inference on: {image_path}")

    result = model.infer(
        processor,
        prompt="<image>\nFree OCR. ",
        image_file=image_path,
        output_path="./output",
        base_size=1024,
        image_size=640,
        crop_mode=True,
        save_results=False,
        test_compress=False,
        eval_mode=True,
    )

    return result if result else ""


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--image", type=str, required=True)
    parser.add_argument("--base_model", type=str, default="/home/ubuntu/deepseek_ocr")
    parser.add_argument("--lora", type=str, default="/home/ubuntu/sanskrit-ocr-lora")
    args = parser.parse_args()

    model = load_model_with_lora(args.base_model, args.lora)

    processor = AutoProcessor.from_pretrained(args.base_model, trust_remote_code=True)

    raw = run_inference(model, args.image, processor)

    cleaned = clean_text(raw)

    print("\n" + "=" * 50)
    print("OCR Result (raw):")
    print(raw[:500] if len(raw) > 500 else raw)
    print("=" * 50)
    print("OCR Result (cleaned):")
    print(cleaned)