Spaces:
Sleeping
Sleeping
| import os | |
| from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor, MllamaForConditionalGeneration | |
| import torch | |
| import re | |
| from PIL import Image | |
| # ---- GOT OCR Model Initialization and Extraction ---- | |
| def init_got_model(): | |
| """Initialize GOT model and tokenizer.""" | |
| model_name = "srimanth-d/GOT_CPU" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, return_tensors='pt') | |
| model = AutoModel.from_pretrained(model_name, trust_remote_code=True, low_cpu_mem_usage=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id) | |
| return model.eval(), tokenizer | |
| def extract_text_got(uploaded_file): | |
| """Extract text from the uploaded image using GOT model.""" | |
| temp_file_path = 'temp_image_got.jpg' | |
| try: | |
| with open(temp_file_path, 'wb') as temp_file: | |
| temp_file.write(uploaded_file.read()) | |
| print(f"Processing image using GOT from: {temp_file_path}") | |
| model, tokenizer = init_got_model() | |
| outputs = model.chat(tokenizer, temp_file_path, ocr_type='ocr') | |
| if outputs and isinstance(outputs, list): | |
| return outputs[0].strip() if outputs[0].strip() else "No text extracted." | |
| return "No text extracted." | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| finally: | |
| if os.path.exists(temp_file_path): | |
| os.remove(temp_file_path) | |
| # ---- Qwen OCR Model Initialization and Extraction ---- | |
| def init_qwen_model(): | |
| """Initialize Qwen model and processor.""" | |
| model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", device_map="cpu", torch_dtype=torch.float16) | |
| processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") | |
| return model.eval(), processor | |
| def extract_text_qwen(uploaded_file): | |
| """Extract text using Qwen model.""" | |
| try: | |
| model, processor = init_qwen_model() | |
| image = Image.open(uploaded_file).convert('RGB') | |
| conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Extract text from this image."}]}] | |
| prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) | |
| inputs = processor(text=[prompt], images=[image], return_tensors="pt") | |
| output_ids = model.generate(**inputs) | |
| output_text = processor.batch_decode(output_ids, skip_special_tokens=True) | |
| return output_text[0] if output_text else "No text extracted." | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # ---- LLaMA OCR Model Initialization and Extraction ---- | |
| def init_llama_model(): | |
| """Initialize LLaMA OCR model and processor.""" | |
| model = MllamaForConditionalGeneration.from_pretrained("meta-llama/Llama-3.2-11B-Vision-Instruct", torch_dtype=torch.bfloat16, device_map="cpu") | |
| processor = AutoProcessor.from_pretrained("meta-llama/Llama-3.2-11B-Vision-Instruct") | |
| return model.eval(), processor | |
| def extract_text_llama(uploaded_file): | |
| """Extract text using LLaMA model.""" | |
| try: | |
| model, processor = init_llama_model() | |
| image = Image.open(uploaded_file).convert('RGB') | |
| prompt = "You are an OCR engine. Extract text from this image." | |
| inputs = processor(images=image, text=prompt, return_tensors="pt") | |
| output_ids = model.generate(**inputs) | |
| return processor.decode(output_ids[0], skip_special_tokens=True).strip() | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # ---- AI-based Text Cleanup ---- | |
| def clean_extracted_text(text): | |
| """Clean the extracted text by removing extra spaces intelligently.""" | |
| # Remove multiple spaces | |
| cleaned_text = re.sub(r'\s+', ' ', text).strip() | |
| # Further clean punctuations with spaces around them | |
| cleaned_text = re.sub(r'\s([?.!,])', r'\1', cleaned_text) | |
| return cleaned_text | |