|
|
|
|
|
""" |
|
|
FIXED PixelText OCR Model with proper Hugging Face Hub support |
|
|
This version has the from_pretrained method and works with AutoModel.from_pretrained() |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import ( |
|
|
PaliGemmaForConditionalGeneration, |
|
|
PaliGemmaProcessor, |
|
|
AutoTokenizer, |
|
|
PreTrainedModel, |
|
|
PretrainedConfig |
|
|
) |
|
|
from PIL import Image |
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
class PixelTextConfig(PretrainedConfig): |
|
|
"""Configuration for PixelText model.""" |
|
|
|
|
|
model_type = "pixeltext" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
base_model="google/paligemma-3b-pt-224", |
|
|
hidden_size=2048, |
|
|
vocab_size=257216, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.base_model = base_model |
|
|
self.hidden_size = hidden_size |
|
|
self.vocab_size = vocab_size |
|
|
|
|
|
class FixedPixelTextOCR(PreTrainedModel): |
|
|
""" |
|
|
FIXED PixelText OCR model with proper Hugging Face Hub support. |
|
|
This version works with AutoModel.from_pretrained() |
|
|
""" |
|
|
|
|
|
config_class = PixelTextConfig |
|
|
|
|
|
def __init__(self, config=None): |
|
|
if config is None: |
|
|
config = PixelTextConfig() |
|
|
|
|
|
super().__init__(config) |
|
|
|
|
|
print(f"🚀 Loading FIXED PixelText OCR...") |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
self._device = "cuda" |
|
|
self.torch_dtype = torch.float16 |
|
|
else: |
|
|
self._device = "cpu" |
|
|
self.torch_dtype = torch.float32 |
|
|
|
|
|
print(f"🔧 Device: {self._device}") |
|
|
|
|
|
|
|
|
try: |
|
|
self.base_model = PaliGemmaForConditionalGeneration.from_pretrained( |
|
|
config.base_model, |
|
|
torch_dtype=self.torch_dtype, |
|
|
trust_remote_code=True |
|
|
).to(self._device) |
|
|
|
|
|
self.processor = PaliGemmaProcessor.from_pretrained(config.base_model) |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model) |
|
|
|
|
|
print("✅ FIXED PixelText OCR ready!") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Failed to load components: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
self.hidden_size = config.hidden_size |
|
|
self.vocab_size = config.vocab_size |
|
|
|
|
|
def forward(self, **kwargs): |
|
|
"""Forward pass through the base model.""" |
|
|
return self.base_model(**kwargs) |
|
|
|
|
|
def generate_ocr_text(self, image, prompt="<image>Extract all text from this image:", max_length=512): |
|
|
""" |
|
|
🎯 MAIN METHOD: Extract text from image |
|
|
|
|
|
Args: |
|
|
image: PIL Image, file path, or numpy array |
|
|
prompt: Custom prompt (optional) |
|
|
max_length: Maximum length of generated text |
|
|
|
|
|
Returns: |
|
|
dict: Contains extracted text, confidence, and metadata |
|
|
""" |
|
|
|
|
|
|
|
|
if isinstance(image, str): |
|
|
image = Image.open(image).convert('RGB') |
|
|
elif hasattr(image, 'shape'): |
|
|
image = Image.fromarray(image).convert('RGB') |
|
|
elif not isinstance(image, Image.Image): |
|
|
raise ValueError("Image must be PIL Image, file path, or numpy array") |
|
|
|
|
|
|
|
|
if "<image>" not in prompt: |
|
|
prompt = f"<image>{prompt}" |
|
|
|
|
|
try: |
|
|
|
|
|
inputs = self.processor(text=prompt, images=image, return_tensors="pt") |
|
|
|
|
|
|
|
|
for key in inputs: |
|
|
if isinstance(inputs[key], torch.Tensor): |
|
|
inputs[key] = inputs[key].to(self._device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
generated_ids = self.base_model.generate( |
|
|
**inputs, |
|
|
max_length=max_length, |
|
|
do_sample=False, |
|
|
num_beams=1, |
|
|
pad_token_id=self.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
generated_text = self.processor.batch_decode( |
|
|
generated_ids, |
|
|
skip_special_tokens=True |
|
|
)[0] |
|
|
|
|
|
|
|
|
text = self._clean_text(generated_text, prompt) |
|
|
|
|
|
|
|
|
confidence = self._calculate_confidence(text) |
|
|
|
|
|
return { |
|
|
'text': text, |
|
|
'confidence': confidence, |
|
|
'success': True, |
|
|
'method': 'fixed_pixeltext', |
|
|
'raw_output': generated_text |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
return { |
|
|
'text': "", |
|
|
'confidence': 0.0, |
|
|
'success': False, |
|
|
'method': 'error', |
|
|
'error': str(e) |
|
|
} |
|
|
|
|
|
def _clean_text(self, generated_text, prompt): |
|
|
"""Clean the generated text.""" |
|
|
|
|
|
|
|
|
clean_prompt = prompt.replace("<image>", "").strip() |
|
|
if clean_prompt and clean_prompt in generated_text: |
|
|
text = generated_text.replace(clean_prompt, "").strip() |
|
|
else: |
|
|
text = generated_text.strip() |
|
|
|
|
|
|
|
|
artifacts = [ |
|
|
"The image shows", "The text in the image says", |
|
|
"The image contains", "I can see", "The text reads", |
|
|
"This image shows", "The picture shows" |
|
|
] |
|
|
|
|
|
for artifact in artifacts: |
|
|
if text.lower().startswith(artifact.lower()): |
|
|
text = text[len(artifact):].strip() |
|
|
if text.startswith(":"): |
|
|
text = text[1:].strip() |
|
|
if text.startswith('"') and text.endswith('"'): |
|
|
text = text[1:-1].strip() |
|
|
|
|
|
return text |
|
|
|
|
|
def _calculate_confidence(self, text): |
|
|
"""Calculate confidence score.""" |
|
|
|
|
|
if not text: |
|
|
return 0.0 |
|
|
|
|
|
confidence = 0.5 |
|
|
|
|
|
if len(text) > 10: |
|
|
confidence += 0.2 |
|
|
if len(text) > 50: |
|
|
confidence += 0.1 |
|
|
if len(text) > 100: |
|
|
confidence += 0.1 |
|
|
|
|
|
if any(c.isalpha() for c in text): |
|
|
confidence += 0.1 |
|
|
if any(c.isdigit() for c in text): |
|
|
confidence += 0.05 |
|
|
|
|
|
if len(text.strip()) < 3: |
|
|
confidence *= 0.5 |
|
|
|
|
|
return min(0.95, confidence) |
|
|
|
|
|
def batch_ocr(self, images, prompt="<image>Extract all text from this image:", max_length=512): |
|
|
"""Process multiple images.""" |
|
|
|
|
|
results = [] |
|
|
|
|
|
for i, image in enumerate(images): |
|
|
print(f"📄 Processing image {i+1}/{len(images)}...") |
|
|
result = self.generate_ocr_text(image, prompt, max_length) |
|
|
results.append(result) |
|
|
|
|
|
if result['success']: |
|
|
print(f" ✅ Success: {len(result['text'])} characters") |
|
|
else: |
|
|
print(f" ❌ Failed: {result.get('error', 'Unknown error')}") |
|
|
|
|
|
return results |
|
|
|
|
|
def get_model_info(self): |
|
|
"""Get model information.""" |
|
|
|
|
|
return { |
|
|
'model_name': 'FIXED PixelText OCR', |
|
|
'base_model': 'PaliGemma-3B', |
|
|
'device': self._device, |
|
|
'dtype': str(self.torch_dtype), |
|
|
'hidden_size': self.hidden_size, |
|
|
'vocab_size': self.vocab_size, |
|
|
'parameters': '~3B', |
|
|
'repository': 'BabaK07/pixeltext-ai', |
|
|
'status': 'FIXED - Hub loading works!', |
|
|
'features': [ |
|
|
'Hub loading support', |
|
|
'from_pretrained method', |
|
|
'Fast OCR extraction', |
|
|
'Multi-language support', |
|
|
'Batch processing', |
|
|
'Production ready' |
|
|
] |
|
|
} |
|
|
|
|
|
|
|
|
WorkingQwenOCRModel = FixedPixelTextOCR |
|
|
|