|
|
import torch |
|
|
from torch.utils.data import DataLoader |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import re |
|
|
import cv2 |
|
|
import string |
|
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel |
|
|
from vit import LineDataset, collate_fn |
|
|
from loguru import logger |
|
|
import os |
|
|
from configs import hf_token |
|
|
|
|
|
class Inference: |
|
|
def __init__(self, model_path, processor_path, target_size=(256, 64), batch_size=32): |
|
|
""" |
|
|
Initialize the TextGenerator with model and processor paths. |
|
|
|
|
|
Args: |
|
|
model_path (str): Path to the pre-trained model |
|
|
processor_path (str): Path to the pre-trained processor |
|
|
target_size (tuple): Target size for input images (height, width) |
|
|
batch_size (int): Batch size for inference |
|
|
""" |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
self.model_path = model_path |
|
|
self.processor_path = processor_path |
|
|
self.target_size = target_size |
|
|
self.batch_size = batch_size |
|
|
|
|
|
|
|
|
self.processor = None |
|
|
self.model = None |
|
|
self._initialize_model() |
|
|
|
|
|
def _get_absolute_path(self, path): |
|
|
"""Convert relative path to absolute path""" |
|
|
if os.path.isabs(path): |
|
|
return path |
|
|
|
|
|
return os.path.join(os.getcwd(), path.lstrip('./')) |
|
|
|
|
|
|
|
|
def _initialize_model(self): |
|
|
"""Load and initialize the model and processor.""" |
|
|
logger.info("Loading model...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Loading model from: {self.model_path}") |
|
|
logger.info(f"Loading processor from: {self.processor_path}") |
|
|
|
|
|
try: |
|
|
|
|
|
self.processor = TrOCRProcessor.from_pretrained(self.processor_path, do_rescale=False, use_fast=True, token=hf_token) |
|
|
logger.info("Processor loaded successfully") |
|
|
|
|
|
|
|
|
logger.info("Attempting to load model...") |
|
|
|
|
|
|
|
|
try: |
|
|
self.model = VisionEncoderDecoderModel.from_pretrained( |
|
|
self.model_path, |
|
|
use_safetensors=True, |
|
|
device_map="auto" if torch.cuda.is_available() else None, |
|
|
token=hf_token |
|
|
) |
|
|
logger.info("Model loaded with safetensors=True and device_map") |
|
|
except Exception as e1: |
|
|
logger.warning(f"Method 1 failed: {e1}") |
|
|
|
|
|
|
|
|
try: |
|
|
self.model = VisionEncoderDecoderModel.from_pretrained( |
|
|
self.model_path, |
|
|
use_safetensors=True, |
|
|
token=hf_token |
|
|
) |
|
|
logger.info("Model loaded with safetensors=True") |
|
|
except Exception as e2: |
|
|
logger.warning(f"Method 2 failed: {e2}") |
|
|
|
|
|
|
|
|
try: |
|
|
self.model = VisionEncoderDecoderModel.from_pretrained( |
|
|
self.model_path, |
|
|
use_safetensors=True, |
|
|
token=hf_token |
|
|
) |
|
|
logger.info("Model loaded with safetensors=False") |
|
|
except Exception as e3: |
|
|
logger.error(f"All loading methods failed: {e3}") |
|
|
raise |
|
|
|
|
|
|
|
|
if not hasattr(self.model, 'device') or str(self.model.device) != str(self.device): |
|
|
logger.info(f"Moving model to device: {self.device}") |
|
|
self.model.to(self.device) |
|
|
|
|
|
self.model.eval() |
|
|
logger.info("Model loaded successfully and moved to device") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading model or processor: {e}") |
|
|
import traceback |
|
|
logger.error(f"Traceback: {traceback.format_exc()}") |
|
|
raise |
|
|
def preprocess_images(self, line_segments): |
|
|
""" |
|
|
Prepare line images for inference. |
|
|
|
|
|
Args: |
|
|
line_segments (dict): Dictionary containing line segment information |
|
|
|
|
|
Returns: |
|
|
tuple: (keys, line_images) - keys and corresponding images |
|
|
""" |
|
|
keys = list(line_segments.keys()) |
|
|
line_images = [line_segments[k]["image"] for k in keys] |
|
|
return keys, line_images |
|
|
|
|
|
def create_dataloader(self, line_images): |
|
|
""" |
|
|
Create DataLoader for inference. |
|
|
|
|
|
Args: |
|
|
line_images (list): List of line images |
|
|
|
|
|
Returns: |
|
|
DataLoader: Configured DataLoader for inference |
|
|
""" |
|
|
|
|
|
dummy_labels = [""] * len(line_images) |
|
|
|
|
|
dataset = LineDataset( |
|
|
self.processor, |
|
|
self.model, |
|
|
line_images, |
|
|
dummy_labels, |
|
|
self.target_size, |
|
|
apply_augmentation=False |
|
|
) |
|
|
|
|
|
dataloader = DataLoader( |
|
|
dataset, |
|
|
batch_size=self.batch_size, |
|
|
shuffle=False, |
|
|
collate_fn=collate_fn |
|
|
) |
|
|
|
|
|
return dataloader |
|
|
|
|
|
def generate_texts(self, dataloader): |
|
|
""" |
|
|
Generate texts from images using the model. |
|
|
|
|
|
Args: |
|
|
dataloader (DataLoader): DataLoader containing preprocessed images |
|
|
|
|
|
Returns: |
|
|
list: List of generated texts |
|
|
""" |
|
|
generated_texts = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in dataloader: |
|
|
pixel_values = batch["pixel_values"].to(self.device) |
|
|
generated_ids = self.model.generate(pixel_values) |
|
|
generated_texts_batch = self.processor.batch_decode( |
|
|
generated_ids, |
|
|
skip_special_tokens=True |
|
|
) |
|
|
generated_texts.extend(generated_texts_batch) |
|
|
|
|
|
return generated_texts |
|
|
|
|
|
def update_line_segments(self, line_segments, keys, generated_texts): |
|
|
""" |
|
|
Update line segments dictionary with generated transcriptions. |
|
|
|
|
|
Args: |
|
|
line_segments (dict): Original line segments dictionary |
|
|
keys (list): List of keys corresponding to the line segments |
|
|
generated_texts (list): List of generated texts |
|
|
|
|
|
Returns: |
|
|
dict: Updated line segments dictionary with transcriptions |
|
|
""" |
|
|
for key, text in zip(keys, generated_texts): |
|
|
line_segments[key]["transcription"] = text |
|
|
|
|
|
return line_segments |
|
|
|
|
|
def generate_texts_from_images(self, line_segments): |
|
|
""" |
|
|
Main method to generate texts from line segment images. |
|
|
|
|
|
Args: |
|
|
line_segments (dict): Dictionary containing line segment information |
|
|
with "image" key for each segment |
|
|
|
|
|
Returns: |
|
|
dict: Updated line segments dictionary with "transcription" key added |
|
|
""" |
|
|
logger.info("Starting text generation from images...") |
|
|
|
|
|
keys, line_images = self.preprocess_images(line_segments) |
|
|
|
|
|
|
|
|
dataloader = self.create_dataloader(line_images) |
|
|
|
|
|
|
|
|
generated_texts = self.generate_texts(dataloader) |
|
|
|
|
|
|
|
|
updated_line_segments = self.update_line_segments( |
|
|
line_segments, keys, generated_texts |
|
|
) |
|
|
|
|
|
return updated_line_segments |
|
|
|
|
|
def generate_single_image(self, image): |
|
|
""" |
|
|
Generate text from a single image. |
|
|
|
|
|
Args: |
|
|
image: PIL Image or numpy array |
|
|
|
|
|
Returns: |
|
|
str: Generated text |
|
|
""" |
|
|
if isinstance(image, np.ndarray): |
|
|
image = Image.fromarray(image) |
|
|
|
|
|
|
|
|
temp_segments = {"temp_key": {"image": image}} |
|
|
|
|
|
|
|
|
result = self.generate_texts_from_images(temp_segments) |
|
|
|
|
|
return result["temp_key"]["transcription"] |
|
|
|