""" Model loading and inference for OCR Confidence Visualization. Loads Nanonets-OCR2-3B (Qwen2.5-VL fine-tune) and provides inference with token-level probability extraction. """ import math from dataclasses import dataclass, field from typing import Generator, Optional import torch from PIL import Image from transformers import AutoModelForImageTextToText, AutoProcessor # Available models for selection AVAILABLE_MODELS = { "Nanonets-OCR2-3B": "nanonets/Nanonets-OCR2-3B", "olmOCR-7B": "allenai/olmOCR-7B-0725", "Aya-Vision-8B": "CohereLabs/aya-vision-8b", } DEFAULT_MODEL = "Aya-Vision-8B" # Global model and processor (loaded once per model) _model = None _processor = None _device = None _current_model_name = None @dataclass class TokenData: """Data for a single generated token with probability info.""" token: str probability: float alternatives: list[dict[str, float]] # [{"token": str, "probability": float}, ...] entropy: float = field(default=0.0) # Shannon entropy in bits def calculate_entropy(probs: list[float]) -> float: """Calculate Shannon entropy in bits from a probability distribution. Args: probs: List of probabilities (should sum to ~1.0). Returns: Entropy in bits. 0.0 for empty or single-certainty distributions. """ entropy = 0.0 for p in probs: if p > 0: entropy -= p * math.log2(p) return entropy def load_model(model_name: str = None): """Load the OCR model and processor. Reloads if model_name differs from current.""" global _model, _processor, _device, _current_model_name if model_name is None: model_name = DEFAULT_MODEL model_id = AVAILABLE_MODELS.get(model_name, AVAILABLE_MODELS[DEFAULT_MODEL]) # Return cached model if already loaded if _model is not None and _current_model_name == model_name: return _model, _processor # Unload previous model if switching if _model is not None: print(f"Unloading previous model: {_current_model_name}") del _model del _processor _model = None _processor = None torch.cuda.empty_cache() _device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(f"Using device: {_device}") print(f"Loading model: {model_id}...") _processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) _model = AutoModelForImageTextToText.from_pretrained( model_id, attn_implementation="flash_attention_2", trust_remote_code=True, torch_dtype=torch.float16, ).to(_device).eval() _current_model_name = model_name print("Model loaded successfully") return _model, _processor def run_ocr(image: Image.Image, prompt: str = None) -> str: """ Run OCR on an image and return extracted text. Args: image: PIL Image to process prompt: Optional custom prompt (default: natural reading extraction) Returns: Extracted text from the image """ model, processor = load_model() if prompt is None: prompt = "Extract the text from the above document as if you were reading it naturally." messages = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": prompt}, ], } ] prompt_full = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = processor( text=[prompt_full], images=[image], return_tensors="pt", padding=True, ).to(_device) with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=1024, do_sample=True, temperature=1, top_p=0.9, top_k=50, repetition_penalty=1.1, ) # Slice off input tokens generated_ids = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, output_ids) ] output_text = processor.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True )[0] return output_text def generate_with_logprobs( image: Image.Image, prompt: Optional[str] = None, max_new_tokens: int = 1024, top_k: int = 20, top_p: float = 0.9, temperature: float = 1.0, # Use 1.0 for standard distribution, pick top token (argmax) repetition_penalty: float = 1.1, model_name: str = None, ) -> Generator[TokenData, None, None]: """ Generate OCR text token-by-token with probability information. Yields TokenData for each generated token, enabling streaming display with confidence visualization. Args: image: PIL Image to process prompt: Optional custom prompt (default: natural reading extraction) max_new_tokens: Maximum tokens to generate top_k: Number of top alternatives to include top_p: Nucleus sampling parameter temperature: Sampling temperature (low = more deterministic) repetition_penalty: Penalty for repeating tokens (>1.0 reduces repetition) model_name: Which model to use (from AVAILABLE_MODELS keys) Yields: TokenData with token string, probability, and top-k alternatives """ model, processor = load_model(model_name) if prompt is None: prompt = "Extract the text from the above document as if you were reading it naturally." messages = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": prompt}, ], } ] prompt_full = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = processor( text=[prompt_full], images=[image], return_tensors="pt", padding=True, ).to(_device) input_ids = inputs.input_ids attention_mask = inputs.attention_mask # Get EOS token ID for stopping - check model config first, then tokenizer eos_token_id = model.config.eos_token_id if eos_token_id is None: eos_token_id = processor.tokenizer.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] elif eos_token_id is None: eos_token_id = [] # No EOS token - will rely on max_new_tokens # Track generated tokens generated_ids = input_ids.clone() # Extract image inputs (pixel_values, image_grid_thw for Qwen2.5-VL) model_inputs = {k: v for k, v in inputs.items() if k not in ("input_ids", "attention_mask")} # Use DynamicCache for proper KV cache management from transformers import DynamicCache past_key_values = DynamicCache() # Track sequence length for cache_position seq_length = input_ids.shape[1] # Track rope_deltas for multimodal RoPE (required for Qwen2.5-VL) # This is computed on the first forward pass and must be passed to subsequent passes rope_deltas = None with torch.no_grad(): for step in range(max_new_tokens): # Forward pass if step == 0: # First step: include image data, full sequence cache_position = torch.arange(seq_length, device=_device) outputs = model( input_ids=generated_ids, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, **model_inputs, return_dict=True, use_cache=True, ) else: # Subsequent steps: only new token with cache cache_position = torch.tensor([seq_length], device=_device) outputs = model( input_ids=generated_ids[:, -1:], attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, rope_deltas=rope_deltas, # Pass rope_deltas for correct multimodal position encoding return_dict=True, use_cache=True, ) past_key_values = outputs.past_key_values # Capture rope_deltas from first pass for multimodal position encoding if step == 0 and hasattr(outputs, 'rope_deltas') and outputs.rope_deltas is not None: rope_deltas = outputs.rope_deltas # Get logits for last token position - convert to float32 to avoid overflow next_token_logits = outputs.logits[:, -1, :].float() # Apply repetition penalty to previously generated tokens if repetition_penalty != 1.0: for prev_token_id in generated_ids[0].tolist(): if next_token_logits[0, prev_token_id] < 0: next_token_logits[0, prev_token_id] *= repetition_penalty else: next_token_logits[0, prev_token_id] /= repetition_penalty # Apply temperature if temperature > 0: next_token_logits = next_token_logits / temperature # Compute probabilities via softmax probs = torch.softmax(next_token_logits, dim=-1) # Get top-k probabilities and indices top_probs, top_indices = torch.topk(probs, k=min(top_k, probs.shape[-1])) top_probs = top_probs[0].cpu().tolist() top_indices = top_indices[0].cpu().tolist() # Sample next token (argmax - we use temperature=1.0 for standard distribution) next_token_id = top_indices[0] next_token_prob = top_probs[0] # Check for EOS if next_token_id in eos_token_id: break # Decode token token_str = processor.decode([next_token_id], skip_special_tokens=False) # Build alternatives list (excluding the selected token) alternatives = [] for idx, (alt_idx, alt_prob) in enumerate(zip(top_indices[1:], top_probs[1:])): alt_token = processor.decode([alt_idx], skip_special_tokens=False) alternatives.append({"token": alt_token, "probability": alt_prob}) # Calculate entropy from full top-k distribution all_probs = [next_token_prob] + [alt["probability"] for alt in alternatives] token_entropy = calculate_entropy(all_probs) # Yield token data yield TokenData( token=token_str, probability=next_token_prob, alternatives=alternatives, entropy=token_entropy, ) # Update for next iteration next_token_tensor = torch.tensor([[next_token_id]], device=_device) generated_ids = torch.cat([generated_ids, next_token_tensor], dim=-1) # Extend attention mask to cover full sequence (required for Qwen VL models) attention_mask = torch.cat( [attention_mask, torch.ones((1, 1), device=_device, dtype=attention_mask.dtype)], dim=-1, ) seq_length += 1