Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Model loading and inference for OCR Confidence Visualization. | |
| Loads Nanonets-OCR2-3B (Qwen2.5-VL fine-tune) and provides | |
| inference with token-level probability extraction. | |
| """ | |
| import math | |
| from dataclasses import dataclass, field | |
| from typing import Generator, Optional | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoModelForImageTextToText, AutoProcessor | |
| # Available models for selection | |
| AVAILABLE_MODELS = { | |
| "Nanonets-OCR2-3B": "nanonets/Nanonets-OCR2-3B", | |
| "olmOCR-7B": "allenai/olmOCR-7B-0725", | |
| "Aya-Vision-8B": "CohereLabs/aya-vision-8b", | |
| } | |
| DEFAULT_MODEL = "Aya-Vision-8B" | |
| # Global model and processor (loaded once per model) | |
| _model = None | |
| _processor = None | |
| _device = None | |
| _current_model_name = None | |
| class TokenData: | |
| """Data for a single generated token with probability info.""" | |
| token: str | |
| probability: float | |
| alternatives: list[dict[str, float]] # [{"token": str, "probability": float}, ...] | |
| entropy: float = field(default=0.0) # Shannon entropy in bits | |
| def calculate_entropy(probs: list[float]) -> float: | |
| """Calculate Shannon entropy in bits from a probability distribution. | |
| Args: | |
| probs: List of probabilities (should sum to ~1.0). | |
| Returns: | |
| Entropy in bits. 0.0 for empty or single-certainty distributions. | |
| """ | |
| entropy = 0.0 | |
| for p in probs: | |
| if p > 0: | |
| entropy -= p * math.log2(p) | |
| return entropy | |
| def load_model(model_name: str = None): | |
| """Load the OCR model and processor. Reloads if model_name differs from current.""" | |
| global _model, _processor, _device, _current_model_name | |
| if model_name is None: | |
| model_name = DEFAULT_MODEL | |
| model_id = AVAILABLE_MODELS.get(model_name, AVAILABLE_MODELS[DEFAULT_MODEL]) | |
| # Return cached model if already loaded | |
| if _model is not None and _current_model_name == model_name: | |
| return _model, _processor | |
| # Unload previous model if switching | |
| if _model is not None: | |
| print(f"Unloading previous model: {_current_model_name}") | |
| del _model | |
| del _processor | |
| _model = None | |
| _processor = None | |
| torch.cuda.empty_cache() | |
| _device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {_device}") | |
| print(f"Loading model: {model_id}...") | |
| _processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | |
| _model = AutoModelForImageTextToText.from_pretrained( | |
| model_id, | |
| attn_implementation="flash_attention_2", | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16, | |
| ).to(_device).eval() | |
| _current_model_name = model_name | |
| print("Model loaded successfully") | |
| return _model, _processor | |
| def run_ocr(image: Image.Image, prompt: str = None) -> str: | |
| """ | |
| Run OCR on an image and return extracted text. | |
| Args: | |
| image: PIL Image to process | |
| prompt: Optional custom prompt (default: natural reading extraction) | |
| Returns: | |
| Extracted text from the image | |
| """ | |
| model, processor = load_model() | |
| if prompt is None: | |
| prompt = "Extract the text from the above document as if you were reading it naturally." | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| } | |
| ] | |
| prompt_full = processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = processor( | |
| text=[prompt_full], | |
| images=[image], | |
| return_tensors="pt", | |
| padding=True, | |
| ).to(_device) | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=1024, | |
| do_sample=True, | |
| temperature=1, | |
| top_p=0.9, | |
| top_k=50, | |
| repetition_penalty=1.1, | |
| ) | |
| # Slice off input tokens | |
| generated_ids = [ | |
| out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, output_ids) | |
| ] | |
| output_text = processor.batch_decode( | |
| generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True | |
| )[0] | |
| return output_text | |
| def generate_with_logprobs( | |
| image: Image.Image, | |
| prompt: Optional[str] = None, | |
| max_new_tokens: int = 1024, | |
| top_k: int = 20, | |
| top_p: float = 0.9, | |
| temperature: float = 1.0, # Use 1.0 for standard distribution, pick top token (argmax) | |
| repetition_penalty: float = 1.1, | |
| model_name: str = None, | |
| ) -> Generator[TokenData, None, None]: | |
| """ | |
| Generate OCR text token-by-token with probability information. | |
| Yields TokenData for each generated token, enabling streaming display | |
| with confidence visualization. | |
| Args: | |
| image: PIL Image to process | |
| prompt: Optional custom prompt (default: natural reading extraction) | |
| max_new_tokens: Maximum tokens to generate | |
| top_k: Number of top alternatives to include | |
| top_p: Nucleus sampling parameter | |
| temperature: Sampling temperature (low = more deterministic) | |
| repetition_penalty: Penalty for repeating tokens (>1.0 reduces repetition) | |
| model_name: Which model to use (from AVAILABLE_MODELS keys) | |
| Yields: | |
| TokenData with token string, probability, and top-k alternatives | |
| """ | |
| model, processor = load_model(model_name) | |
| if prompt is None: | |
| prompt = "Extract the text from the above document as if you were reading it naturally." | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| } | |
| ] | |
| prompt_full = processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = processor( | |
| text=[prompt_full], | |
| images=[image], | |
| return_tensors="pt", | |
| padding=True, | |
| ).to(_device) | |
| input_ids = inputs.input_ids | |
| attention_mask = inputs.attention_mask | |
| # Get EOS token ID for stopping - check model config first, then tokenizer | |
| eos_token_id = model.config.eos_token_id | |
| if eos_token_id is None: | |
| eos_token_id = processor.tokenizer.eos_token_id | |
| if isinstance(eos_token_id, int): | |
| eos_token_id = [eos_token_id] | |
| elif eos_token_id is None: | |
| eos_token_id = [] # No EOS token - will rely on max_new_tokens | |
| # Track generated tokens | |
| generated_ids = input_ids.clone() | |
| # Extract image inputs (pixel_values, image_grid_thw for Qwen2.5-VL) | |
| model_inputs = {k: v for k, v in inputs.items() if k not in ("input_ids", "attention_mask")} | |
| # Use DynamicCache for proper KV cache management | |
| from transformers import DynamicCache | |
| past_key_values = DynamicCache() | |
| # Track sequence length for cache_position | |
| seq_length = input_ids.shape[1] | |
| # Track rope_deltas for multimodal RoPE (required for Qwen2.5-VL) | |
| # This is computed on the first forward pass and must be passed to subsequent passes | |
| rope_deltas = None | |
| with torch.no_grad(): | |
| for step in range(max_new_tokens): | |
| # Forward pass | |
| if step == 0: | |
| # First step: include image data, full sequence | |
| cache_position = torch.arange(seq_length, device=_device) | |
| outputs = model( | |
| input_ids=generated_ids, | |
| attention_mask=attention_mask, | |
| cache_position=cache_position, | |
| past_key_values=past_key_values, | |
| **model_inputs, | |
| return_dict=True, | |
| use_cache=True, | |
| ) | |
| else: | |
| # Subsequent steps: only new token with cache | |
| cache_position = torch.tensor([seq_length], device=_device) | |
| outputs = model( | |
| input_ids=generated_ids[:, -1:], | |
| attention_mask=attention_mask, | |
| cache_position=cache_position, | |
| past_key_values=past_key_values, | |
| rope_deltas=rope_deltas, # Pass rope_deltas for correct multimodal position encoding | |
| return_dict=True, | |
| use_cache=True, | |
| ) | |
| past_key_values = outputs.past_key_values | |
| # Capture rope_deltas from first pass for multimodal position encoding | |
| if step == 0 and hasattr(outputs, 'rope_deltas') and outputs.rope_deltas is not None: | |
| rope_deltas = outputs.rope_deltas | |
| # Get logits for last token position - convert to float32 to avoid overflow | |
| next_token_logits = outputs.logits[:, -1, :].float() | |
| # Apply repetition penalty to previously generated tokens | |
| if repetition_penalty != 1.0: | |
| for prev_token_id in generated_ids[0].tolist(): | |
| if next_token_logits[0, prev_token_id] < 0: | |
| next_token_logits[0, prev_token_id] *= repetition_penalty | |
| else: | |
| next_token_logits[0, prev_token_id] /= repetition_penalty | |
| # Apply temperature | |
| if temperature > 0: | |
| next_token_logits = next_token_logits / temperature | |
| # Compute probabilities via softmax | |
| probs = torch.softmax(next_token_logits, dim=-1) | |
| # Get top-k probabilities and indices | |
| top_probs, top_indices = torch.topk(probs, k=min(top_k, probs.shape[-1])) | |
| top_probs = top_probs[0].cpu().tolist() | |
| top_indices = top_indices[0].cpu().tolist() | |
| # Sample next token (argmax - we use temperature=1.0 for standard distribution) | |
| next_token_id = top_indices[0] | |
| next_token_prob = top_probs[0] | |
| # Check for EOS | |
| if next_token_id in eos_token_id: | |
| break | |
| # Decode token | |
| token_str = processor.decode([next_token_id], skip_special_tokens=False) | |
| # Build alternatives list (excluding the selected token) | |
| alternatives = [] | |
| for idx, (alt_idx, alt_prob) in enumerate(zip(top_indices[1:], top_probs[1:])): | |
| alt_token = processor.decode([alt_idx], skip_special_tokens=False) | |
| alternatives.append({"token": alt_token, "probability": alt_prob}) | |
| # Calculate entropy from full top-k distribution | |
| all_probs = [next_token_prob] + [alt["probability"] for alt in alternatives] | |
| token_entropy = calculate_entropy(all_probs) | |
| # Yield token data | |
| yield TokenData( | |
| token=token_str, | |
| probability=next_token_prob, | |
| alternatives=alternatives, | |
| entropy=token_entropy, | |
| ) | |
| # Update for next iteration | |
| next_token_tensor = torch.tensor([[next_token_id]], device=_device) | |
| generated_ids = torch.cat([generated_ids, next_token_tensor], dim=-1) | |
| # Extend attention mask to cover full sequence (required for Qwen VL models) | |
| attention_mask = torch.cat( | |
| [attention_mask, torch.ones((1, 1), device=_device, dtype=attention_mask.dtype)], | |
| dim=-1, | |
| ) | |
| seq_length += 1 | |