Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor | |
| from surya.ocr import run_ocr | |
| from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor | |
| from surya.model.recognition.model import load_model as load_rec_model | |
| from surya.model.recognition.processor import load_processor as load_rec_processor | |
| from PIL import Image | |
| import torch | |
| import tempfile | |
| import os | |
| import re | |
| from groq import Groq | |
| # Page configuration | |
| st.set_page_config(page_title="DualTextOCRFusion", page_icon="π", layout="wide") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load Surya OCR Models (English + Hindi) | |
| det_processor, det_model = load_det_processor(), load_det_model() | |
| det_model.to(device) | |
| rec_model, rec_processor = load_rec_model(), load_rec_processor() | |
| rec_model.to(device) | |
| # Load GOT Models | |
| def init_got_model(): | |
| tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True) | |
| model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id) | |
| return model.eval(), tokenizer | |
| def init_got_gpu_model(): | |
| tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True) | |
| model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id) | |
| return model.eval().cuda(), tokenizer | |
| # Load Qwen Model | |
| def init_qwen_model(): | |
| model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", device_map="cpu", torch_dtype=torch.float16) | |
| processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") | |
| return model.eval(), processor | |
| # Text Cleaning AI - Clean spaces, handle dual languages | |
| def clean_extracted_text(text): | |
| # Remove extra spaces | |
| cleaned_text = re.sub(r'\s+', ' ', text).strip() | |
| cleaned_text = re.sub(r'\s([?.!,])', r'\1', cleaned_text) | |
| return cleaned_text | |
| # Polish the text using a model | |
| def polish_text_with_ai(cleaned_text): | |
| prompt = f"Correct and clean the following text: '{cleaned_text}' and make it meaningful." | |
| client = Groq(api_key="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") | |
| chat_completion = client.chat.completions.create( | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": "You are a meaningful sentence pedantic, you remove extra spaces in between words and word to make the sentence meaningful in English/Hindi/Hinglish according to the sentence." | |
| }, | |
| { | |
| "role": "user", | |
| "content": prompt, | |
| } | |
| ], | |
| model="gemma2-9b-it", | |
| ) | |
| polished_text=(chat_completion.choices[0].message.content | |
| return polished_text | |
| # Extract text using GOT | |
| def extract_text_got(image_file, model, tokenizer): | |
| return model.chat(tokenizer, image_file, ocr_type='ocr') | |
| # Extract text using Qwen | |
| def extract_text_qwen(image_file, model, processor): | |
| try: | |
| image = Image.open(image_file).convert('RGB') | |
| conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Extract text from this image."}]}] | |
| text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) | |
| inputs = processor(text=[text_prompt], images=[image], return_tensors="pt") | |
| output_ids = model.generate(**inputs) | |
| output_text = processor.batch_decode(output_ids, skip_special_tokens=True) | |
| return output_text[0] if output_text else "No text extracted from the image." | |
| except Exception as e: | |
| return f"An error occurred: {str(e)}" | |
| # Highlight keyword search | |
| def highlight_text(text, search_term): | |
| if not search_term: | |
| return text | |
| pattern = re.compile(re.escape(search_term), re.IGNORECASE) | |
| return pattern.sub(lambda m: f'<span style="background-color: yellow;">{m.group()}</span>', text) | |
| # Title and UI | |
| st.title("OCR Application - Multimodel Support") | |
| st.write("Upload an image for OCR using various models, with support for English, Hindi, and Hinglish.") | |
| # Sidebar Configuration | |
| st.sidebar.header("Configuration") | |
| model_choice = st.sidebar.selectbox("Select OCR Model:", ("GOT_CPU", "GOT_GPU", "Qwen", "Surya (English+Hindi)")) | |
| # Upload Section | |
| uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"]) | |
| # Predict button | |
| predict_button = st.sidebar.button("Predict") | |
| # Main columns | |
| col1, col2 = st.columns([2, 1]) | |
| # Display image preview | |
| if uploaded_file: | |
| image = Image.open(uploaded_file) | |
| with col1: | |
| col1.image(image, caption='Uploaded Image', use_column_width=False, width=300) | |
| # Handle predictions | |
| if predict_button and uploaded_file: | |
| with st.spinner("Processing..."): | |
| # Save uploaded image | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: | |
| temp_file.write(uploaded_file.getvalue()) | |
| temp_file_path = temp_file.name | |
| image = Image.open(temp_file_path) | |
| image = image.convert("RGB") | |
| if model_choice == "GOT_CPU": | |
| got_model, tokenizer = init_got_model() | |
| extracted_text = extract_text_got(temp_file_path, got_model, tokenizer) | |
| elif model_choice == "GOT_GPU": | |
| got_gpu_model, tokenizer = init_got_gpu_model() | |
| extracted_text = extract_text_got(temp_file_path, got_gpu_model, tokenizer) | |
| elif model_choice == "Qwen": | |
| qwen_model, qwen_processor = init_qwen_model() | |
| extracted_text = extract_text_qwen(temp_file_path, qwen_model, qwen_processor) | |
| elif model_choice == "Surya (English+Hindi)": | |
| langs = ["en", "hi"] | |
| predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec_processor) | |
| text_list = re.findall(r"text='(.*?)'", str(predictions[0])) | |
| extracted_text = ' '.join(text_list) | |
| # Clean extracted text | |
| cleaned_text = clean_extracted_text(extracted_text) | |
| # Optionally, polish text with AI model for better language flow | |
| if model_choice in ["GOT_CPU", "GOT_GPU"]: | |
| polished_text = polish_text_with_ai(cleaned_text) | |
| else: | |
| polished_text = cleaned_text | |
| # Delete temp file | |
| if os.path.exists(temp_file_path): | |
| os.remove(temp_file_path) | |
| # Display extracted text and search functionality | |
| st.subheader("Extracted Text (Cleaned & Polished)") | |
| st.markdown(polished_text, unsafe_allow_html=True) | |
| search_query = st.text_input("Search in extracted text:", key="search_query", placeholder="Type to search...") | |
| if search_query: | |
| highlighted_text = highlight_text(polished_text, search_query) | |
| st.markdown("### Highlighted Search Results:") | |
| st.markdown(highlighted_text, unsafe_allow_html=True) | |
| else: | |
| st.markdown("### Extracted Text:") | |
| st.markdown(polished_text, unsafe_allow_html=True) | |