Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| from pathlib import Path | |
| import pillow_heif | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoModelForImageTextToText, AutoProcessor | |
| pillow_heif.register_heif_opener() | |
| DEFAULT_MODEL_ID = "THUDM/glm-4v-9b" | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| description=( | |
| "Convert a HEIC image into Markdown or LaTeX using a GLM OCR-capable vision model " | |
| "from Hugging Face." | |
| ) | |
| ) | |
| parser.add_argument("input", type=Path, help="Input .heic/.heif image path") | |
| parser.add_argument( | |
| "--output", | |
| type=Path, | |
| help="Output file path. Defaults to input stem + .md/.tex based on --format.", | |
| ) | |
| parser.add_argument( | |
| "--format", | |
| choices=["md", "latex"], | |
| default="md", | |
| help="Output format for OCR transcription.", | |
| ) | |
| parser.add_argument( | |
| "--model-id", | |
| default=DEFAULT_MODEL_ID, | |
| help="Hugging Face model id for GLM OCR-style VLM inference.", | |
| ) | |
| parser.add_argument("--max-new-tokens", type=int, default=2048) | |
| parser.add_argument( | |
| "--device", | |
| choices=["auto", "cpu", "cuda"], | |
| default="auto", | |
| help="Run inference on CPU/CUDA, or auto-detect.", | |
| ) | |
| return parser.parse_args() | |
| def build_prompt(target_format: str) -> str: | |
| if target_format == "latex": | |
| return ( | |
| "You are an OCR engine. Read the image exactly and return clean LaTeX only. " | |
| "Keep math in proper LaTeX syntax and preserve document structure where possible. " | |
| "Do not add explanations." | |
| ) | |
| return ( | |
| "You are an OCR engine. Read the image exactly and return clean Markdown only. " | |
| "Use standard markdown headings/lists/tables where appropriate and preserve equations " | |
| "using $...$ or $$...$$. Do not add explanations." | |
| ) | |
| def load_model(model_id: str, device: str): | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| if device == "cpu": | |
| device_map = {"": "cpu"} | |
| elif device == "cuda": | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("--device cuda was requested but CUDA is not available.") | |
| device_map = "auto" | |
| else: | |
| device_map = "auto" | |
| processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| model_id, | |
| torch_dtype=torch_dtype, | |
| device_map=device_map, | |
| trust_remote_code=True, | |
| ) | |
| return processor, model | |
| def run_ocr( | |
| image_path: Path, target_format: str, model_id: str, max_new_tokens: int, device: str | |
| ) -> str: | |
| image = Image.open(image_path).convert("RGB") | |
| prompt = build_prompt(target_format) | |
| processor, model = load_model(model_id, device) | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| } | |
| ] | |
| text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = processor(text=[text], images=[image], return_tensors="pt") | |
| inputs = {k: v.to(model.device) if hasattr(v, "to") else v for k, v in inputs.items()} | |
| with torch.inference_mode(): | |
| generated = model.generate(**inputs, max_new_tokens=max_new_tokens) | |
| prompt_len = inputs["input_ids"].shape[-1] | |
| generated_trimmed = generated[:, prompt_len:] | |
| output = processor.batch_decode(generated_trimmed, skip_special_tokens=True) | |
| return output[0].strip() | |
| def resolve_output_path(image_path: Path, output: Path | None, target_format: str) -> Path: | |
| if output is not None: | |
| return output | |
| extension = ".md" if target_format == "md" else ".tex" | |
| return image_path.with_suffix(extension) | |
| def main() -> None: | |
| args = parse_args() | |
| if not args.input.exists(): | |
| raise FileNotFoundError(f"Input file not found: {args.input}") | |
| output_path = resolve_output_path(args.input, args.output, args.format) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| text = run_ocr( | |
| image_path=args.input, | |
| target_format=args.format, | |
| model_id=args.model_id, | |
| max_new_tokens=args.max_new_tokens, | |
| device=args.device, | |
| ) | |
| output_path.write_text(text, encoding="utf-8") | |
| print(f"Saved {args.format} output to: {output_path}") | |
| if __name__ == "__main__": | |
| main() | |