ocr-screenshot-reader / ocr_reader.py
kumaraguruvs's picture
Add download progress bar and explicit model info
b8d8121 verified
"""
Screenshot & Terminal Text Reader using GOT-OCR 2.0
====================================================
Extracts text from screenshots of:
- Terminal/console windows
- Log files
- Computer application windows
Model: stepfun-ai/GOT-OCR-2.0-hf (560M params, Apache 2.0)
Requirements: pip install transformers torch Pillow accelerate huggingface_hub tqdm
Usage:
python ocr_reader.py image.png
python ocr_reader.py image1.png image2.png image3.png
python ocr_reader.py ./screenshots_folder/
python ocr_reader.py image.png --output result.txt
"""
import sys
import os
import argparse
import warnings
import torch
from pathlib import Path
from PIL import Image
from huggingface_hub import snapshot_download
from tqdm import tqdm
# Suppress harmless tokenizer warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore", message=".*clean_up_tokenization_spaces.*")
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
from transformers import AutoProcessor, AutoModelForImageTextToText
# =============================================================
# Model: GOT-OCR 2.0 (General OCR Theory)
# HuggingFace: https://huggingface.co/stepfun-ai/GOT-OCR-2.0-hf
# Paper: https://arxiv.org/abs/2409.01704
# Size: 560M parameters (~1.1GB download)
# License: Apache 2.0
# =============================================================
MODEL_ID = "stepfun-ai/GOT-OCR-2.0-hf"
def load_model(device=None):
"""Load GOT-OCR 2.0 model and processor."""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"\n{'='*60}", file=sys.stderr)
print(f" Model: GOT-OCR 2.0 (stepfun-ai/GOT-OCR-2.0-hf)", file=sys.stderr)
print(f" Size: 560M parameters (~1.1GB)", file=sys.stderr)
print(f" Device: {device}", file=sys.stderr)
print(f"{'='*60}\n", file=sys.stderr)
# Download model files with progress bar
print("Downloading model (if not cached)...", file=sys.stderr)
snapshot_download(MODEL_ID, local_dir=None) # uses HF cache with progress bar
if device == "cuda":
dtype = torch.bfloat16
else:
dtype = torch.float32
print("Loading model into memory...", file=sys.stderr)
model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
dtype=dtype,
device_map=device,
)
processor = AutoProcessor.from_pretrained(MODEL_ID)
print("✓ Model ready!\n", file=sys.stderr)
return model, processor, device
def extract_text(image_path, model, processor, max_tokens=4096):
"""Extract text from a single image."""
image = Image.open(image_path).convert("RGB")
# Prepare inputs
inputs = processor(images=image, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# Generate text
with torch.no_grad():
generate_ids = model.generate(
**inputs,
do_sample=False,
tokenizer=processor.tokenizer,
stop_strings=["<|im_end|>"],
max_new_tokens=max_tokens,
)
# Decode - skip the input tokens
input_len = inputs["input_ids"].shape[1]
generated_ids = generate_ids[0, input_len:]
text = processor.decode(generated_ids, skip_special_tokens=True)
return text.strip()
def process_images(paths, output_file=None, max_tokens=4096):
"""Process one or more images or directories."""
# Collect all image paths
image_paths = []
supported_exts = {".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp", ".gif"}
for path in paths:
p = Path(path)
if p.is_dir():
for ext in supported_exts:
image_paths.extend(sorted(p.glob(f"*{ext}")))
image_paths.extend(sorted(p.glob(f"*{ext.upper()}")))
elif p.is_file() and p.suffix.lower() in supported_exts:
image_paths.append(p)
else:
print(f"Warning: Skipping '{path}' (not a supported image or directory)", file=sys.stderr)
if not image_paths:
print("No valid images found!", file=sys.stderr)
return
# Load model once
model, processor, device = load_model()
# Prepare output
output_handle = open(output_file, "w") if output_file else sys.stdout
try:
for i, img_path in enumerate(image_paths, 1):
if len(image_paths) > 1:
header = f"\n{'='*60}\n[{i}/{len(image_paths)}] {img_path.name}\n{'='*60}"
if output_file:
output_handle.write(header + "\n")
else:
print(header, file=sys.stderr)
print(f"Processing: {img_path.name}...", file=sys.stderr)
text = extract_text(str(img_path), model, processor, max_tokens)
output_handle.write(text + "\n")
if len(image_paths) > 1:
output_handle.write("\n")
finally:
if output_file:
output_handle.close()
print(f"\nOutput saved to: {output_file}", file=sys.stderr)
def main():
parser = argparse.ArgumentParser(
description="Extract text from screenshots using GOT-OCR 2.0",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python ocr_reader.py screenshot.png
python ocr_reader.py *.png --output all_text.txt
python ocr_reader.py ./screenshots/
python ocr_reader.py terminal.png --max-tokens 8192
"""
)
parser.add_argument("images", nargs="+", help="Image files or directories to process")
parser.add_argument("--output", "-o", help="Save output to file (default: print to stdout)")
parser.add_argument("--max-tokens", type=int, default=4096,
help="Maximum tokens to generate (default: 4096, increase for very long texts)")
parser.add_argument("--device", choices=["cuda", "cpu"],
help="Force device (default: auto-detect)")
args = parser.parse_args()
if args.device:
os.environ["OCR_DEVICE"] = args.device
process_images(args.images, output_file=args.output, max_tokens=args.max_tokens)
if __name__ == "__main__":
main()