import streamlit as st from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor from surya.ocr import run_ocr from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor from surya.model.recognition.model import load_model as load_rec_model from surya.model.recognition.processor import load_processor as load_rec_processor from PIL import Image import torch import tempfile import os import re import json from groq import Groq # Page configuration st.set_page_config(page_title="DualTextOCRFusion", page_icon="🔍", layout="wide") device = "cuda" if torch.cuda.is_available() else "cpu" # Directories for images and results IMAGES_DIR = "images" RESULTS_DIR = "results" os.makedirs(IMAGES_DIR, exist_ok=True) os.makedirs(RESULTS_DIR, exist_ok=True) # Load Surya OCR Models (English + Hindi) det_processor, det_model = load_det_processor(), load_det_model() det_model.to(device) rec_model, rec_processor = load_rec_model(), load_rec_processor() rec_model.to(device) # Load GOT Models @st.cache_resource def init_got_model(): tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True) model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id) return model.eval(), tokenizer @st.cache_resource def init_got_gpu_model(): tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True) model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id) return model.eval().cuda(), tokenizer # Load Qwen Model @st.cache_resource def init_qwen_model(): model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", device_map="cpu", torch_dtype=torch.float16) processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") return model.eval(), processor # Text Cleaning AI - Clean spaces, handle dual languages def clean_extracted_text(text): cleaned_text = re.sub(r'\s+', ' ', text).strip() cleaned_text = re.sub(r'\s([?.!,])', r'\1', cleaned_text) return cleaned_text # Polish the text using a model def polish_text_with_ai(cleaned_text): prompt = f"Remove unwanted spaces between and inside words to join incomplete words, creating a meaningful sentence in either Hindi, English, or Hinglish without altering any words from the given extracted text. Then, return the corrected text with adjusted spaces." client = Groq(api_key="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") chat_completion = client.chat.completions.create( messages=[ {"role": "system", "content": "You are a pedantic sentence corrector."}, {"role": "user", "content": prompt}, ], model="gemma2-9b-it", ) polished_text = chat_completion.choices[0].message.content return polished_text # Extract text using GOT def extract_text_got(image_file, model, tokenizer): return model.chat(tokenizer, image_file, ocr_type='ocr') # Extract text using Qwen def extract_text_qwen(image_file, model, processor): try: image = Image.open(image_file).convert('RGB') conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Extract text from this image."}]}] text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) inputs = processor(text=[text_prompt], images=[image], return_tensors="pt") output_ids = model.generate(**inputs) output_text = processor.batch_decode(output_ids, skip_special_tokens=True) return output_text[0] if output_text else "No text extracted from the image." except Exception as e: return f"An error occurred: {str(e)}" # Highlight keyword search def highlight_text(text, search_term): if not search_term: return text pattern = re.compile(re.escape(search_term), re.IGNORECASE) return pattern.sub(lambda m: f'{m.group()}', text) # Title and UI st.title("DualTextOCRFusion - 🔍") st.header("OCR Application - Multimodel Support") st.write("Upload an image for OCR using various models, with support for English, Hindi, and Hinglish.") # Sidebar Configuration st.sidebar.header("Configuration") model_choice = st.sidebar.selectbox("Select OCR Model:", ("GOT_CPU", "GOT_GPU", "Qwen", "Surya (English+Hindi)")) # Upload Section uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"]) clipboard_text = st.sidebar.text_area("Paste image path from clipboard:") if uploaded_file or clipboard_text: image_path = None if uploaded_file: image_path = os.path.join(IMAGES_DIR, uploaded_file.name) with open(image_path, "wb") as f: f.write(uploaded_file.getvalue()) elif clipboard_text: image_path = clipboard_text.strip() # Predict button predict_button = st.sidebar.button("Predict") # Main columns col1, col2 = st.columns([2, 1]) # Check if result JSON already exists result_json_path = os.path.join(RESULTS_DIR, f"{os.path.basename(image_path)}_result.json") if image_path else None if predict_button and image_path: if os.path.exists(result_json_path): with open(result_json_path, "r") as json_file: result_data = json.load(json_file) polished_text = result_data.get("polished_text", "") else: with st.spinner("Processing..."): image = Image.open(image_path).convert("RGB") if model_choice == "GOT_CPU": got_model, tokenizer = init_got_model() extracted_text = extract_text_got(image_path, got_model, tokenizer) elif model_choice == "GOT_GPU": got_gpu_model, tokenizer = init_got_gpu_model() extracted_text = extract_text_got(image_path, got_gpu_model, tokenizer) elif model_choice == "Qwen": qwen_model, qwen_processor = init_qwen_model() extracted_text = extract_text_qwen(image_path, qwen_model, qwen_processor) elif model_choice == "Surya (English+Hindi)": langs = ["en", "hi"] predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec_processor) text_list = re.findall(r"text='(.*?)'", str(predictions[0])) extracted_text = ' '.join(text_list) cleaned_text = clean_extracted_text(extracted_text) polished_text = polish_text_with_ai(cleaned_text) if model_choice in ["GOT_CPU", "GOT_GPU"] else cleaned_text # Save result to JSON with open(result_json_path, "w") as json_file: json.dump({"polished_text": polished_text}, json_file) # Display image preview and text if image_path: with col1: col1.image(image_path, caption='Uploaded Image', use_column_width=False, width=300) st.subheader("Extracted Text (Cleaned & Polished)") st.markdown(polished_text, unsafe_allow_html=True) # Input box for real-time search search_query = st.text_input("Search in extracted text:", key="search_query", placeholder="Type to search...", on_change=lambda: st.session_state.update(search_query) disabled=not uploaded_file) # Highlight the search term in the text if search_query: highlighted_text = highlight_text(polished_text, search_query) st.markdown("### Highlighted Search Results:") st.markdown(highlighted_text, unsafe_allow_html=True)