pranshh commited on
Commit
28c861d
·
verified ·
1 Parent(s): dd3183f

Uploaded app.py and requirements.txt

Browse files
Files changed (2) hide show
  1. app.py.py +157 -0
  2. requirements.txt +8 -0
app.py.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """OCR Web Application Prototype.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1vzsQ17-W1Vy6yJ60XUwFy0QRkOR_SIg7
8
+ """
9
+
10
+ import gradio as gr
11
+ from transformers import AutoModel, AutoTokenizer
12
+ from PIL import Image
13
+ import os
14
+
15
+ revision = "5364fe1ab774ef13c2c79023dc91d8c1e7cfdce4"
16
+
17
+ # Load tokenizer and model
18
+ tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
19
+ model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, low_cpu_mem_usage=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
20
+ model = model.eval()
21
+
22
+ # Function to perform OCR and optional keyword search
23
+ def process_image_with_search(image, keyword):
24
+ try:
25
+ # Save the PIL image to a temporary file
26
+ temp_img_path = "temp_image.png"
27
+ image.save(temp_img_path)
28
+
29
+ # Perform OCR with the model using the file path
30
+ extracted_text = model.chat(tokenizer, temp_img_path, ocr_type='format')
31
+
32
+ # Delete the temporary file
33
+ if os.path.exists(temp_img_path):
34
+ os.remove(temp_img_path)
35
+
36
+ # Convert extracted text to string if it's not already
37
+ extracted_text = extracted_text if isinstance(extracted_text, str) else str(extracted_text)
38
+
39
+ # If a keyword is provided, search for it
40
+ if keyword:
41
+ # Perform keyword search (case-insensitive)
42
+ if keyword.lower() in extracted_text.lower():
43
+ # Highlight the keyword in the extracted text
44
+ highlighted_text = extracted_text.replace(keyword, f"**{keyword}**", 1) # Highlight first occurrence
45
+ result = f"Keyword '{keyword}' found:\n\n{highlighted_text}"
46
+ else:
47
+ result = f"Keyword '{keyword}' not found in the extracted text.\n\nExtracted Text:\n{extracted_text}"
48
+ else:
49
+ # If no keyword is provided, return the extracted text without searching
50
+ result = f"Extracted Text:\n\n{extracted_text}"
51
+
52
+ return result
53
+ except Exception as e:
54
+ return str(e) # Return error message in case of failure
55
+
56
+ # Define Gradio interface
57
+ iface = gr.Interface(
58
+ fn=process_image_with_search, # The function to process the image and search keyword
59
+ inputs=[gr.Image(type='pil'), gr.Textbox(label="Enter keyword to search (optional)")], # Image input + Keyword input
60
+ outputs='text', # Output will be plain text with the search result
61
+ title="OCR with GOT and Keyword Search",
62
+ description="Upload an image to get OCR results. You can also search for a keyword in the extracted text."
63
+ )
64
+
65
+ # Launch the interface
66
+ iface.launch(debug=True)
67
+
68
+ # !pip install --upgrade git+https://github.com/huggingface/transformers.git byaldi accelerate flash-attn qwen_vl_utils pdf2image gradio
69
+ # !sudo apt-get install -y poppler-utils
70
+
71
+ # from byaldi import RAGMultiModalModel
72
+ # from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
73
+ # from qwen_vl_utils import process_vision_info
74
+ # import torch
75
+ # import gradio as gr
76
+ # from PIL import Image
77
+
78
+ # processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
79
+
80
+ # # Initialize the model with float16 precision and handle fallback to CPU
81
+ # def load_model():
82
+ # try:
83
+ # vlm = Qwen2VLForConditionalGeneration.from_pretrained(
84
+ # "Qwen/Qwen2-VL-2B-Instruct",
85
+ # torch_dtype=torch.float16,
86
+ # attn_implementation="flash_attention_2", # FlashAttention enabled
87
+ # device_map="cuda"
88
+ # )
89
+ # print("Model loaded with FlashAttention on GPU")
90
+ # except RuntimeError as e:
91
+ # if "FlashAttention only supports Ampere GPUs" in str(e):
92
+ # print("FlashAttention not supported. Falling back to standard attention.")
93
+ # vlm = Qwen2VLForConditionalGeneration.from_pretrained(
94
+ # "Qwen/Qwen2-VL-2B-Instruct",
95
+ # torch_dtype=torch.float16, # Still use float16 to save memory
96
+ # attn_implementation="default", # Use standard attention mechanism
97
+ # device_map="cuda" if torch.cuda.is_available() else "cpu"
98
+ # )
99
+ # else:
100
+ # raise e # Raise other runtime errors if not related to FlashAttention
101
+ # return vlm
102
+
103
+ # # Load the model
104
+ # vlm = load_model()
105
+
106
+ # # OCR function to extract text from an image
107
+ # def ocr_image(image, query="Extract text from the image"):
108
+ # messages = [
109
+ # {
110
+ # "role": "user",
111
+ # "content": [
112
+ # {
113
+ # "type": "image",
114
+ # "image": image,
115
+ # },
116
+ # {"type": "text", "text": query},
117
+ # ],
118
+ # }
119
+ # ]
120
+
121
+ # # Prepare inputs for the model
122
+ # text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
123
+ # image_inputs, video_inputs = process_vision_info(messages)
124
+ # inputs = processor(
125
+ # text=[text],
126
+ # images=image_inputs,
127
+ # videos=video_inputs,
128
+ # padding=True,
129
+ # return_tensors="pt",
130
+ # )
131
+ # inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu")
132
+
133
+ # # Generate the output text using the model
134
+ # generated_ids = vlm.generate(**inputs, max_new_tokens=512)
135
+ # generated_ids_trimmed = [
136
+ # out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
137
+ # ]
138
+ # output_text = processor.batch_decode(
139
+ # generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
140
+ # )
141
+ # return output_text[0]
142
+
143
+ # # Gradio interface
144
+ # def process_image(image):
145
+ # return ocr_image(image)
146
+
147
+ # # Create Gradio interface for uploading an image
148
+ # interface = gr.Interface(
149
+ # fn=process_image,
150
+ # inputs=gr.Image(type="pil"),
151
+ # outputs="text",
152
+ # title="Hindi & English OCR",
153
+ # description="Upload an image containing text in Hindi or English to extract the text using OCR."
154
+ # )
155
+
156
+ # # Launch Gradio interface in Colab
157
+ # interface.launch(share=True, debug=True)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+ torch==2.0.1
3
+ torchvision==0.15.2
4
+ transformers==4.37.2
5
+ tiktoken==0.6.0
6
+ verovio==4.3.1
7
+ accelerate==0.28.0
8
+ gradio