| import sys
|
| import torch
|
| from pathlib import Path
|
| from PIL import Image
|
| from huggingface_hub import snapshot_download
|
|
|
|
|
| class Nav2TexPipeline:
|
| def __init__(self, model, processor, device):
|
| self.model = model
|
| self.processor = processor
|
| self.device = device
|
|
|
| @classmethod
|
| def from_pretrained(cls, repo_id_or_path: str, device: str = None):
|
| if device is None:
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
| path = Path(repo_id_or_path)
|
| if not path.exists():
|
| path = Path(snapshot_download(repo_id_or_path))
|
|
|
| sys.path.insert(0, str(path))
|
|
|
| from nav2tex.tokenization_latex_ocr import LaTeXTokenizer
|
| from nav2tex.image_processing_latex_ocr import Nav2TexImageProcessor
|
| from nav2tex.processing_latex_ocr import Nav2TexProcessor
|
| from nav2tex.modeling_latex_ocr import Nav2TexModel
|
| from nav2tex.configuration_latex_ocr import Nav2TexConfig
|
|
|
| config = Nav2TexConfig.from_pretrained(str(path))
|
| image_processor = Nav2TexImageProcessor.from_pretrained(str(path))
|
| tokenizer = LaTeXTokenizer(str(path / "tokenizer.json"))
|
| processor = Nav2TexProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
| model = Nav2TexModel.from_pretrained(str(path), config=config).to(device).eval()
|
|
|
| return cls(model=model, processor=processor, device=device)
|
|
|
| def __call__(self, image, max_new_tokens: int = None, num_beams: int = None):
|
| single = not isinstance(image, list)
|
| images = [image] if single else image
|
|
|
| loaded = []
|
| for img in images:
|
| if isinstance(img, (str, Path)):
|
| img = Image.open(img).convert("RGB")
|
| elif isinstance(img, Image.Image):
|
| img = img.convert("RGB")
|
| else:
|
| raise TypeError(f"Unsupported image type: {type(img)}")
|
| loaded.append(img)
|
|
|
| kwargs = {}
|
| if max_new_tokens is not None:
|
| kwargs["max_new_tokens"] = max_new_tokens
|
| if num_beams is not None:
|
| kwargs["num_beams"] = num_beams
|
|
|
|
|
|
|
| pixel_values = [
|
| self.processor(images=img, return_tensors="pt")["pixel_values"].to(self.device)
|
| for img in loaded
|
| ]
|
|
|
| with torch.no_grad():
|
| generated_ids = self.model.generate(pixel_values, **kwargs)
|
|
|
| results = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
|
| return results[0] if single else results |