""" Run inference with the fine-tuned Sanskrit OCR model (v2) """ import os import warnings warnings.filterwarnings("ignore") import re import torch from peft import PeftModel from transformers import AutoModel, AutoProcessor os.environ["HF_HUB_OFFLINE"] = "1" DEVANAGARI_RE = re.compile(r"[\u0900-\u097F\s॥।,.!?0-9०-९]+") REPEAT_RE = re.compile(r"(?:\b(.+?)\b)(?:\s+\1){2,}") def clean_text(raw: str) -> str: lines = [] for line in raw.splitlines(): line = line.strip() if not line: continue if line.lower().startswith("directly resize"): continue chunks = DEVANAGARI_RE.findall(line) if not chunks: continue joined = " ".join(chunks) joined = REPEAT_RE.sub(r"\1", joined) lines.append(joined) return "\n".join(lines) def load_model_with_lora(base_model_path, lora_path): print("Loading base model...") model = AutoModel.from_pretrained( base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto", ) print("Loading LoRA adapters...") model = PeftModel.from_pretrained(model, lora_path) model = model.merge_and_unload() model.eval() print("Model loaded.") return model def run_inference(model, image_path, processor, max_new_tokens=1024): print(f"Running inference on: {image_path}") result = model.infer( processor, prompt="\nFree OCR. ", image_file=image_path, output_path="./output", base_size=1024, image_size=640, crop_mode=True, save_results=False, test_compress=False, eval_mode=True, ) return result if result else "" if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--image", type=str, required=True) parser.add_argument("--base_model", type=str, default="/home/ubuntu/deepseek_ocr") parser.add_argument("--lora", type=str, default="/home/ubuntu/sanskrit-ocr-lora") args = parser.parse_args() model = load_model_with_lora(args.base_model, args.lora) processor = AutoProcessor.from_pretrained(args.base_model, trust_remote_code=True) raw = run_inference(model, args.image, processor) cleaned = clean_text(raw) print("\n" + "=" * 50) print("OCR Result (raw):") print(raw[:500] if len(raw) > 500 else raw) print("=" * 50) print("OCR Result (cleaned):") print(cleaned)