ocr-entropy / model.py
ryandt's picture
Create model.py
b94bee0 verified
"""
Model loading and inference for OCR Confidence Visualization.
Loads Nanonets-OCR2-3B (Qwen2.5-VL fine-tune) and provides
inference with token-level probability extraction.
"""
import math
from dataclasses import dataclass, field
from typing import Generator, Optional
import torch
from PIL import Image
from transformers import AutoModelForImageTextToText, AutoProcessor
# Available models for selection
AVAILABLE_MODELS = {
"Nanonets-OCR2-3B": "nanonets/Nanonets-OCR2-3B",
"olmOCR-7B": "allenai/olmOCR-7B-0725",
"Aya-Vision-8B": "CohereLabs/aya-vision-8b",
}
DEFAULT_MODEL = "Aya-Vision-8B"
# Global model and processor (loaded once per model)
_model = None
_processor = None
_device = None
_current_model_name = None
@dataclass
class TokenData:
"""Data for a single generated token with probability info."""
token: str
probability: float
alternatives: list[dict[str, float]] # [{"token": str, "probability": float}, ...]
entropy: float = field(default=0.0) # Shannon entropy in bits
def calculate_entropy(probs: list[float]) -> float:
"""Calculate Shannon entropy in bits from a probability distribution.
Args:
probs: List of probabilities (should sum to ~1.0).
Returns:
Entropy in bits. 0.0 for empty or single-certainty distributions.
"""
entropy = 0.0
for p in probs:
if p > 0:
entropy -= p * math.log2(p)
return entropy
def load_model(model_name: str = None):
"""Load the OCR model and processor. Reloads if model_name differs from current."""
global _model, _processor, _device, _current_model_name
if model_name is None:
model_name = DEFAULT_MODEL
model_id = AVAILABLE_MODELS.get(model_name, AVAILABLE_MODELS[DEFAULT_MODEL])
# Return cached model if already loaded
if _model is not None and _current_model_name == model_name:
return _model, _processor
# Unload previous model if switching
if _model is not None:
print(f"Unloading previous model: {_current_model_name}")
del _model
del _processor
_model = None
_processor = None
torch.cuda.empty_cache()
_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {_device}")
print(f"Loading model: {model_id}...")
_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
_model = AutoModelForImageTextToText.from_pretrained(
model_id,
attn_implementation="flash_attention_2",
trust_remote_code=True,
torch_dtype=torch.float16,
).to(_device).eval()
_current_model_name = model_name
print("Model loaded successfully")
return _model, _processor
def run_ocr(image: Image.Image, prompt: str = None) -> str:
"""
Run OCR on an image and return extracted text.
Args:
image: PIL Image to process
prompt: Optional custom prompt (default: natural reading extraction)
Returns:
Extracted text from the image
"""
model, processor = load_model()
if prompt is None:
prompt = "Extract the text from the above document as if you were reading it naturally."
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": prompt},
],
}
]
prompt_full = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = processor(
text=[prompt_full],
images=[image],
return_tensors="pt",
padding=True,
).to(_device)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=1024,
do_sample=True,
temperature=1,
top_p=0.9,
top_k=50,
repetition_penalty=1.1,
)
# Slice off input tokens
generated_ids = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, output_ids)
]
output_text = processor.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)[0]
return output_text
def generate_with_logprobs(
image: Image.Image,
prompt: Optional[str] = None,
max_new_tokens: int = 1024,
top_k: int = 20,
top_p: float = 0.9,
temperature: float = 1.0, # Use 1.0 for standard distribution, pick top token (argmax)
repetition_penalty: float = 1.1,
model_name: str = None,
) -> Generator[TokenData, None, None]:
"""
Generate OCR text token-by-token with probability information.
Yields TokenData for each generated token, enabling streaming display
with confidence visualization.
Args:
image: PIL Image to process
prompt: Optional custom prompt (default: natural reading extraction)
max_new_tokens: Maximum tokens to generate
top_k: Number of top alternatives to include
top_p: Nucleus sampling parameter
temperature: Sampling temperature (low = more deterministic)
repetition_penalty: Penalty for repeating tokens (>1.0 reduces repetition)
model_name: Which model to use (from AVAILABLE_MODELS keys)
Yields:
TokenData with token string, probability, and top-k alternatives
"""
model, processor = load_model(model_name)
if prompt is None:
prompt = "Extract the text from the above document as if you were reading it naturally."
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": prompt},
],
}
]
prompt_full = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = processor(
text=[prompt_full],
images=[image],
return_tensors="pt",
padding=True,
).to(_device)
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
# Get EOS token ID for stopping - check model config first, then tokenizer
eos_token_id = model.config.eos_token_id
if eos_token_id is None:
eos_token_id = processor.tokenizer.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
elif eos_token_id is None:
eos_token_id = [] # No EOS token - will rely on max_new_tokens
# Track generated tokens
generated_ids = input_ids.clone()
# Extract image inputs (pixel_values, image_grid_thw for Qwen2.5-VL)
model_inputs = {k: v for k, v in inputs.items() if k not in ("input_ids", "attention_mask")}
# Use DynamicCache for proper KV cache management
from transformers import DynamicCache
past_key_values = DynamicCache()
# Track sequence length for cache_position
seq_length = input_ids.shape[1]
# Track rope_deltas for multimodal RoPE (required for Qwen2.5-VL)
# This is computed on the first forward pass and must be passed to subsequent passes
rope_deltas = None
with torch.no_grad():
for step in range(max_new_tokens):
# Forward pass
if step == 0:
# First step: include image data, full sequence
cache_position = torch.arange(seq_length, device=_device)
outputs = model(
input_ids=generated_ids,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
**model_inputs,
return_dict=True,
use_cache=True,
)
else:
# Subsequent steps: only new token with cache
cache_position = torch.tensor([seq_length], device=_device)
outputs = model(
input_ids=generated_ids[:, -1:],
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
rope_deltas=rope_deltas, # Pass rope_deltas for correct multimodal position encoding
return_dict=True,
use_cache=True,
)
past_key_values = outputs.past_key_values
# Capture rope_deltas from first pass for multimodal position encoding
if step == 0 and hasattr(outputs, 'rope_deltas') and outputs.rope_deltas is not None:
rope_deltas = outputs.rope_deltas
# Get logits for last token position - convert to float32 to avoid overflow
next_token_logits = outputs.logits[:, -1, :].float()
# Apply repetition penalty to previously generated tokens
if repetition_penalty != 1.0:
for prev_token_id in generated_ids[0].tolist():
if next_token_logits[0, prev_token_id] < 0:
next_token_logits[0, prev_token_id] *= repetition_penalty
else:
next_token_logits[0, prev_token_id] /= repetition_penalty
# Apply temperature
if temperature > 0:
next_token_logits = next_token_logits / temperature
# Compute probabilities via softmax
probs = torch.softmax(next_token_logits, dim=-1)
# Get top-k probabilities and indices
top_probs, top_indices = torch.topk(probs, k=min(top_k, probs.shape[-1]))
top_probs = top_probs[0].cpu().tolist()
top_indices = top_indices[0].cpu().tolist()
# Sample next token (argmax - we use temperature=1.0 for standard distribution)
next_token_id = top_indices[0]
next_token_prob = top_probs[0]
# Check for EOS
if next_token_id in eos_token_id:
break
# Decode token
token_str = processor.decode([next_token_id], skip_special_tokens=False)
# Build alternatives list (excluding the selected token)
alternatives = []
for idx, (alt_idx, alt_prob) in enumerate(zip(top_indices[1:], top_probs[1:])):
alt_token = processor.decode([alt_idx], skip_special_tokens=False)
alternatives.append({"token": alt_token, "probability": alt_prob})
# Calculate entropy from full top-k distribution
all_probs = [next_token_prob] + [alt["probability"] for alt in alternatives]
token_entropy = calculate_entropy(all_probs)
# Yield token data
yield TokenData(
token=token_str,
probability=next_token_prob,
alternatives=alternatives,
entropy=token_entropy,
)
# Update for next iteration
next_token_tensor = torch.tensor([[next_token_id]], device=_device)
generated_ids = torch.cat([generated_ids, next_token_tensor], dim=-1)
# Extend attention mask to cover full sequence (required for Qwen VL models)
attention_mask = torch.cat(
[attention_mask, torch.ones((1, 1), device=_device, dtype=attention_mask.dtype)],
dim=-1,
)
seq_length += 1