pranshh commited on
Commit
4528460
·
verified ·
1 Parent(s): 86c3f4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -54
app.py CHANGED
@@ -1,68 +1,36 @@
1
- # -*- coding: utf-8 -*-
2
- """OCR Web Application Prototype.ipynb
3
- Automatically generated by Colab.
4
- Original file is located at
5
- https://colab.research.google.com/drive/1vzsQ17-W1Vy6yJ60XUwFy0QRkOR_SIg7
6
- """
7
-
8
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
9
- from qwen_vl_utils import process_vision_info
10
  import torch
11
  import gradio as gr
12
  from PIL import Image
 
 
13
 
 
 
14
 
15
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
16
 
17
- # Initialize the model with float16 precision and handle fallback to CPU
18
- # Simplified model loading function for CPU
19
  def load_model():
20
- return Qwen2VLForConditionalGeneration.from_pretrained(
21
- "Qwen/Qwen2-VL-2B-Instruct",
22
- torch_dtype=torch.float32, # Use float32 for CPU
23
- low_cpu_mem_usage=True
24
- )
25
 
26
- # Load the model
27
  vlm = load_model()
28
 
29
- # OCR function to extract text from an image
30
- def ocr_image(image, query="Extract text from the image", keyword=""):
31
- messages = [
32
- {
33
- "role": "user",
34
- "content": [
35
- {
36
- "type": "image",
37
- "image": image,
38
- },
39
- {"type": "text", "text": query},
40
- ],
41
- }
42
- ]
43
 
44
- # Prepare inputs for the model
45
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
46
- image_inputs, video_inputs = process_vision_info(messages)
47
- inputs = processor(
48
- text=[text],
49
- images=image_inputs,
50
- videos=video_inputs,
51
- padding=True,
52
- return_tensors="pt",
53
- )
54
- inputs = inputs.to("cpu")
55
 
56
- # Generate the output text using the model
57
- generated_ids = vlm.generate(**inputs, max_new_tokens=512)
58
- generated_ids_trimmed = [
59
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
60
- ]
61
-
62
- output_text = processor.batch_decode(
63
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
64
- )[0]
65
 
 
 
 
66
  if keyword:
67
  keyword_lower = keyword.lower()
68
  if keyword_lower in output_text.lower():
@@ -73,14 +41,12 @@ def ocr_image(image, query="Extract text from the image", keyword=""):
73
  else:
74
  return output_text
75
 
76
- # Gradio interface
77
  def process_image(image, keyword=""):
78
  max_size = 1024
79
  if max(image.size) > max_size:
80
  image.thumbnail((max_size, max_size))
81
  return ocr_image(image, keyword=keyword)
82
 
83
- # Update the Gradio interface:
84
  interface = gr.Interface(
85
  fn=process_image,
86
  inputs=[
@@ -91,5 +57,4 @@ interface = gr.Interface(
91
  title="Hindi & English OCR with Keyword Search",
92
  )
93
 
94
- # Launch Gradio interface in Colab
95
  interface.launch()
 
1
+ from transformers import AutoProcessor
 
 
 
 
 
 
 
 
2
  import torch
3
  import gradio as gr
4
  from PIL import Image
5
+ from byaldi import RAGMultiModalModel
6
+ from qwen_vl_utils import process_vision_info
7
 
8
+ # Load ColPali model
9
+ RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
10
 
11
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
12
 
 
 
13
  def load_model():
14
+ return RAG.model
 
 
 
 
15
 
 
16
  vlm = load_model()
17
 
18
+ def ocr_image(image, keyword=""):
19
+ # Convert PIL Image to file-like object
20
+ import io
21
+ img_byte_arr = io.BytesIO()
22
+ image.save(img_byte_arr, format='PNG')
23
+ img_byte_arr = img_byte_arr.getvalue()
 
 
 
 
 
 
 
 
24
 
25
+ # Index the image
26
+ RAG.index(input_data=img_byte_arr, index_name="temp_index", overwrite=True)
 
 
 
 
 
 
 
 
 
27
 
28
+ # Retrieve text from the image
29
+ results = RAG.search("Extract all text from this image", k=1)
 
 
 
 
 
 
 
30
 
31
+ # Extract text from results
32
+ output_text = results[0].get('text', '')
33
+
34
  if keyword:
35
  keyword_lower = keyword.lower()
36
  if keyword_lower in output_text.lower():
 
41
  else:
42
  return output_text
43
 
 
44
  def process_image(image, keyword=""):
45
  max_size = 1024
46
  if max(image.size) > max_size:
47
  image.thumbnail((max_size, max_size))
48
  return ocr_image(image, keyword=keyword)
49
 
 
50
  interface = gr.Interface(
51
  fn=process_image,
52
  inputs=[
 
57
  title="Hindi & English OCR with Keyword Search",
58
  )
59
 
 
60
  interface.launch()