Spaces:
Runtime error
Runtime error
| from byaldi import RAGMultiModalModel | |
| from transformers import Qwen2VLForConditionalGeneration, AutoProcessor | |
| from qwen_vl_utils import process_vision_info | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| import re | |
| # Load models | |
| def initialize_models(): | |
| """Loads and returns the RAG multimodal and Qwen2-VL models along with the processor.""" | |
| multimodal_rag = RAGMultiModalModel.from_pretrained("vidore/colpali") | |
| qwen_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype=torch.float32) | |
| qwen_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True) | |
| return multimodal_rag, qwen_model, qwen_processor | |
| multimodal_rag, qwen_model, qwen_processor = initialize_models() | |
| # Text extraction function | |
| def perform_ocr(image): | |
| """Extracts Sanskrit and English text from an image using the Qwen model.""" | |
| query = "Extract text from the image in original language" | |
| # Format the request for the model | |
| user_input = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": query} | |
| ] | |
| } | |
| ] | |
| # Preprocess the input | |
| input_text = qwen_processor.apply_chat_template(user_input, tokenize=False, add_generation_prompt=True) | |
| image_inputs, video_inputs = process_vision_info(user_input) | |
| model_inputs = qwen_processor( | |
| text=[input_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt" | |
| ).to("cpu") # Use CPU for inference | |
| # Generate output | |
| with torch.no_grad(): | |
| generated_ids = qwen_model.generate(**model_inputs, max_new_tokens=2000) | |
| trimmed_ids = [output[len(input_ids):] for input_ids, output in zip(model_inputs.input_ids, generated_ids)] | |
| ocr_result = qwen_processor.batch_decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
| return ocr_result | |
| # Keyword search function | |
| def highlight_keyword(text, keyword): | |
| """Searches and highlights the keyword in the extracted text.""" | |
| keyword_lowercase = keyword.lower() | |
| sentences = text.split('. ') | |
| results = [] | |
| for sentence in sentences: | |
| if keyword_lowercase in sentence.lower(): | |
| highlighted = re.sub(f'({re.escape(keyword)})', r'<mark>\1</mark>', sentence, flags=re.IGNORECASE) | |
| results.append(highlighted) | |
| return results if results else ["No matches found."] | |
| # Gradio app for text extraction | |
| def extract_text(image): | |
| """Extracts text from an uploaded image.""" | |
| return perform_ocr(image) | |
| # Gradio app for keyword search | |
| def search_in_text(extracted_text, keyword): | |
| """Searches for a keyword in the extracted text and highlights matches.""" | |
| results = highlight_keyword(extracted_text, keyword) | |
| return "<br>".join(results) | |
| # Updated title with revised phrasing | |
| header_html = """ | |
| <h1 style="text-align: center; color: #4CAF50;"><span class="gradient-text">OCR and Text Search Prototype</span></h1> | |
| """ | |
| # CSS to fix button sizes | |
| custom_css = """ | |
| .gr-button { | |
| width: 200px; /* Set a fixed width for the buttons */ | |
| padding: 12px 20px; /* Add padding to buttons for consistency */ | |
| } | |
| .gr-textbox { | |
| max-height: 300px; /* Set a maximum height for the extracted text output */ | |
| overflow-y: scroll; /* Enable scrolling when text exceeds the height */ | |
| } | |
| """ | |
| # Gradio Interface | |
| with gr.Blocks(css=custom_css) as interface: | |
| # Header section | |
| gr.HTML(header_html) | |
| # Sidebar section | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=200): | |
| gr.Markdown("## Instructions") | |
| gr.Markdown(""" | |
| 1. Upload an image containing text. | |
| 2. Extract the text from the image. | |
| 3. Search for specific keywords within the extracted text. | |
| """) | |
| gr.Markdown("### Features") | |
| gr.Markdown(""" | |
| - **OCR**: Extract text from images. | |
| - **Keyword Search**: Search and highlight keywords in extracted text. | |
| """) | |
| with gr.Column(scale=3): | |
| # Main content in tabs | |
| with gr.Tabs(): | |
| # First Tab: Text Extraction | |
| with gr.Tab("Extract Text"): | |
| gr.Markdown("### Upload an image to extract text:") | |
| with gr.Row(): | |
| image_upload = gr.Image(type="pil", label="Upload Image", interactive=True) | |
| with gr.Row(): | |
| extract_btn = gr.Button("Extract Text") | |
| extracted_textbox = gr.Textbox(label="Extracted Text") | |
| extract_btn.click(extract_text, inputs=image_upload, outputs=extracted_textbox) | |
| # Second Tab: Keyword Search | |
| with gr.Tab("Search in Extracted Text"): | |
| gr.Markdown("### Search for a keyword in the extracted text:") | |
| with gr.Row(): | |
| keyword_searchbox = gr.Textbox(label="Enter Keyword", placeholder="Keyword to search") | |
| with gr.Row(): | |
| search_btn = gr.Button("Search") | |
| search_results = gr.HTML(label="Results") | |
| search_btn.click(search_in_text, inputs=[extracted_textbox, keyword_searchbox], outputs=search_results) | |
| # Launch the Gradio App | |
| interface.launch() |