badru commited on
Commit
d0c7ab2
·
verified ·
1 Parent(s): 3eb6b07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -7
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import streamlit as st
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  from PIL import Image
 
 
4
 
5
  # Load the model and processor
6
  @st.cache_resource
@@ -11,8 +13,44 @@ def load_model():
11
 
12
  processor, model = load_model()
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # Streamlit app
15
- st.title("OCR API Service")
16
 
17
  uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"])
18
 
@@ -22,19 +60,37 @@ if uploaded_file is not None:
22
  image = Image.open(uploaded_file).convert("RGB")
23
  st.image(image, caption="Uploaded Image", use_column_width=True)
24
 
25
- # Perform OCR
26
- pixel_values = processor(images=image, return_tensors="pt").pixel_values
27
- generated_ids = model.generate(pixel_values)
28
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  # Display extracted text
31
  st.subheader("Extracted Text:")
32
- st.text(generated_text)
33
 
34
  # Simulate API-like JSON response
35
- json_response = {"extracted_text": generated_text}
36
  st.write("API Response:")
37
  st.json(json_response)
38
 
39
  except Exception as e:
40
  st.error(f"An error occurred: {e}")
 
 
 
1
  import streamlit as st
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  from PIL import Image
4
+ import cv2
5
+ import numpy as np
6
 
7
  # Load the model and processor
8
  @st.cache_resource
 
13
 
14
  processor, model = load_model()
15
 
16
+ # Helper function to preprocess the image and detect lines
17
+ def detect_lines(image, min_height=20, min_width=100):
18
+ # Convert the PIL image to a NumPy array
19
+ image_np = np.array(image)
20
+
21
+ # Convert to grayscale
22
+ gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
23
+
24
+ # Apply binary thresholding
25
+ _, binary = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
26
+
27
+ # Dilate to merge nearby text
28
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
29
+ dilated = cv2.dilate(binary, kernel, iterations=1)
30
+
31
+ # Find contours
32
+ contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
33
+
34
+ # Sort contours top-to-bottom
35
+ bounding_boxes = [cv2.boundingRect(c) for c in contours]
36
+ bounding_boxes = sorted(bounding_boxes, key=lambda b: b[1]) # Sort by y-coordinate
37
+
38
+ # Filter out small contours and merge nearby ones
39
+ filtered_boxes = []
40
+ for x, y, w, h in bounding_boxes:
41
+ if h >= min_height and w >= min_width: # Filter small boxes
42
+ filtered_boxes.append((x, y, w, h))
43
+
44
+ # Extract individual lines as images
45
+ line_images = []
46
+ for (x, y, w, h) in filtered_boxes:
47
+ line = image_np[y:y+h, x:x+w]
48
+ line_images.append(line)
49
+
50
+ return line_images
51
+
52
  # Streamlit app
53
+ st.title("OCR API Service with Multiline Support")
54
 
55
  uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"])
56
 
 
60
  image = Image.open(uploaded_file).convert("RGB")
61
  st.image(image, caption="Uploaded Image", use_column_width=True)
62
 
63
+ # Detect lines in the image
64
+ st.write("Detecting lines...")
65
+ line_images = detect_lines(image, min_height=30, min_width=100)
66
+ st.write(f"Detected {len(line_images)} lines in the image.")
67
+
68
+ # Perform OCR on each detected line
69
+ extracted_text = ""
70
+ for idx, line_img in enumerate(line_images):
71
+ # Convert the line image to PIL format
72
+ line_pil = Image.fromarray(line_img)
73
+
74
+ # Prepare the image for OCR
75
+ pixel_values = processor(images=line_pil, return_tensors="pt").pixel_values
76
+
77
+ # Generate text from the line image
78
+ generated_ids = model.generate(pixel_values)
79
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
80
+
81
+ # Append the extracted text
82
+ extracted_text += f"{generated_text}\n"
83
 
84
  # Display extracted text
85
  st.subheader("Extracted Text:")
86
+ st.text_area("Output Text", extracted_text, height=300)
87
 
88
  # Simulate API-like JSON response
89
+ json_response = {"extracted_text": extracted_text}
90
  st.write("API Response:")
91
  st.json(json_response)
92
 
93
  except Exception as e:
94
  st.error(f"An error occurred: {e}")
95
+ else:
96
+ st.info("Please upload an image to start the OCR process.")