sanskrit-ocr-lora / inference.py
arpitingle's picture
Rename v2 files: remove v2 suffix
2bbf1b7
"""
Run inference with the fine-tuned Sanskrit OCR model (v2)
"""
import os
import warnings
warnings.filterwarnings("ignore")
import re
import torch
from peft import PeftModel
from transformers import AutoModel, AutoProcessor
os.environ["HF_HUB_OFFLINE"] = "1"
DEVANAGARI_RE = re.compile(r"[\u0900-\u097F\s॥।,.!?0-9०-९]+")
REPEAT_RE = re.compile(r"(?:\b(.+?)\b)(?:\s+\1){2,}")
def clean_text(raw: str) -> str:
lines = []
for line in raw.splitlines():
line = line.strip()
if not line:
continue
if line.lower().startswith("directly resize"):
continue
chunks = DEVANAGARI_RE.findall(line)
if not chunks:
continue
joined = " ".join(chunks)
joined = REPEAT_RE.sub(r"\1", joined)
lines.append(joined)
return "\n".join(lines)
def load_model_with_lora(base_model_path, lora_path):
print("Loading base model...")
model = AutoModel.from_pretrained(
base_model_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="auto",
)
print("Loading LoRA adapters...")
model = PeftModel.from_pretrained(model, lora_path)
model = model.merge_and_unload()
model.eval()
print("Model loaded.")
return model
def run_inference(model, image_path, processor, max_new_tokens=1024):
print(f"Running inference on: {image_path}")
result = model.infer(
processor,
prompt="<image>\nFree OCR. ",
image_file=image_path,
output_path="./output",
base_size=1024,
image_size=640,
crop_mode=True,
save_results=False,
test_compress=False,
eval_mode=True,
)
return result if result else ""
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--image", type=str, required=True)
parser.add_argument("--base_model", type=str, default="/home/ubuntu/deepseek_ocr")
parser.add_argument("--lora", type=str, default="/home/ubuntu/sanskrit-ocr-lora")
args = parser.parse_args()
model = load_model_with_lora(args.base_model, args.lora)
processor = AutoProcessor.from_pretrained(args.base_model, trust_remote_code=True)
raw = run_inference(model, args.image, processor)
cleaned = clean_text(raw)
print("\n" + "=" * 50)
print("OCR Result (raw):")
print(raw[:500] if len(raw) > 500 else raw)
print("=" * 50)
print("OCR Result (cleaned):")
print(cleaned)