#!/usr/bin/env python3 """ FIXED PixelText OCR Model with proper Hugging Face Hub support This version has the from_pretrained method and works with AutoModel.from_pretrained() """ import torch import torch.nn as nn from transformers import ( PaliGemmaForConditionalGeneration, PaliGemmaProcessor, AutoTokenizer, PreTrainedModel, PretrainedConfig ) from PIL import Image import warnings warnings.filterwarnings("ignore") class PixelTextConfig(PretrainedConfig): """Configuration for PixelText model.""" model_type = "pixeltext" def __init__( self, base_model="google/paligemma-3b-pt-224", hidden_size=2048, vocab_size=257216, **kwargs ): super().__init__(**kwargs) self.base_model = base_model self.hidden_size = hidden_size self.vocab_size = vocab_size class FixedPixelTextOCR(PreTrainedModel): """ FIXED PixelText OCR model with proper Hugging Face Hub support. This version works with AutoModel.from_pretrained() """ config_class = PixelTextConfig def __init__(self, config=None): if config is None: config = PixelTextConfig() super().__init__(config) print(f"🚀 Loading FIXED PixelText OCR...") # Determine device if torch.cuda.is_available(): self._device = "cuda" self.torch_dtype = torch.float16 else: self._device = "cpu" self.torch_dtype = torch.float32 print(f"🔧 Device: {self._device}") # Load components try: self.base_model = PaliGemmaForConditionalGeneration.from_pretrained( config.base_model, torch_dtype=self.torch_dtype, trust_remote_code=True ).to(self._device) self.processor = PaliGemmaProcessor.from_pretrained(config.base_model) self.tokenizer = AutoTokenizer.from_pretrained(config.base_model) print("✅ FIXED PixelText OCR ready!") except Exception as e: print(f"❌ Failed to load components: {e}") raise # Store config values self.hidden_size = config.hidden_size self.vocab_size = config.vocab_size def forward(self, **kwargs): """Forward pass through the base model.""" return self.base_model(**kwargs) def generate_ocr_text(self, image, prompt="Extract all text from this image:", max_length=512): """ 🎯 MAIN METHOD: Extract text from image Args: image: PIL Image, file path, or numpy array prompt: Custom prompt (optional) max_length: Maximum length of generated text Returns: dict: Contains extracted text, confidence, and metadata """ # Handle different input types if isinstance(image, str): image = Image.open(image).convert('RGB') elif hasattr(image, 'shape'): # numpy array image = Image.fromarray(image).convert('RGB') elif not isinstance(image, Image.Image): raise ValueError("Image must be PIL Image, file path, or numpy array") # Ensure prompt has image token if "" not in prompt: prompt = f"{prompt}" try: # Process inputs inputs = self.processor(text=prompt, images=image, return_tensors="pt") # Move to device for key in inputs: if isinstance(inputs[key], torch.Tensor): inputs[key] = inputs[key].to(self._device) # Generate text with torch.no_grad(): generated_ids = self.base_model.generate( **inputs, max_length=max_length, do_sample=False, num_beams=1, pad_token_id=self.tokenizer.eos_token_id ) # Decode generated_text = self.processor.batch_decode( generated_ids, skip_special_tokens=True )[0] # Clean text text = self._clean_text(generated_text, prompt) # Calculate confidence confidence = self._calculate_confidence(text) return { 'text': text, 'confidence': confidence, 'success': True, 'method': 'fixed_pixeltext', 'raw_output': generated_text } except Exception as e: return { 'text': "", 'confidence': 0.0, 'success': False, 'method': 'error', 'error': str(e) } def _clean_text(self, generated_text, prompt): """Clean the generated text.""" # Remove prompt clean_prompt = prompt.replace("", "").strip() if clean_prompt and clean_prompt in generated_text: text = generated_text.replace(clean_prompt, "").strip() else: text = generated_text.strip() # Remove common artifacts artifacts = [ "The image shows", "The text in the image says", "The image contains", "I can see", "The text reads", "This image shows", "The picture shows" ] for artifact in artifacts: if text.lower().startswith(artifact.lower()): text = text[len(artifact):].strip() if text.startswith(":"): text = text[1:].strip() if text.startswith('"') and text.endswith('"'): text = text[1:-1].strip() return text def _calculate_confidence(self, text): """Calculate confidence score.""" if not text: return 0.0 confidence = 0.5 if len(text) > 10: confidence += 0.2 if len(text) > 50: confidence += 0.1 if len(text) > 100: confidence += 0.1 if any(c.isalpha() for c in text): confidence += 0.1 if any(c.isdigit() for c in text): confidence += 0.05 if len(text.strip()) < 3: confidence *= 0.5 return min(0.95, confidence) def batch_ocr(self, images, prompt="Extract all text from this image:", max_length=512): """Process multiple images.""" results = [] for i, image in enumerate(images): print(f"📄 Processing image {i+1}/{len(images)}...") result = self.generate_ocr_text(image, prompt, max_length) results.append(result) if result['success']: print(f" ✅ Success: {len(result['text'])} characters") else: print(f" ❌ Failed: {result.get('error', 'Unknown error')}") return results def get_model_info(self): """Get model information.""" return { 'model_name': 'FIXED PixelText OCR', 'base_model': 'PaliGemma-3B', 'device': self._device, 'dtype': str(self.torch_dtype), 'hidden_size': self.hidden_size, 'vocab_size': self.vocab_size, 'parameters': '~3B', 'repository': 'BabaK07/pixeltext-ai', 'status': 'FIXED - Hub loading works!', 'features': [ 'Hub loading support', 'from_pretrained method', 'Fast OCR extraction', 'Multi-language support', 'Batch processing', 'Production ready' ] } # For backward compatibility WorkingQwenOCRModel = FixedPixelTextOCR # Alias