badru commited on
Commit
c816f98
·
verified ·
1 Parent(s): 4813fb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -37
app.py CHANGED
@@ -1,7 +1,15 @@
1
  import streamlit as st
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  from PIL import Image
4
- import json
 
 
 
 
 
 
 
 
5
 
6
  # Load the model and processor
7
  @st.cache_resource
@@ -12,53 +20,75 @@ def load_model():
12
 
13
  processor, model = load_model()
14
 
15
- # Check if the request is an API call
16
- if st.runtime.scriptrunner.script_run_context.is_running_with_auth:
17
- import io
18
- from fastapi import FastAPI, File, UploadFile
19
- from fastapi.responses import JSONResponse
20
 
21
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- @app.post("/process_image")
24
- async def process_image(image: UploadFile = File(...)):
25
- try:
26
- # Read the uploaded image
27
- image_data = await image.read()
28
- image = Image.open(io.BytesIO(image_data)).convert("RGB")
 
29
 
30
- # Perform OCR
31
- pixel_values = processor(images=image, return_tensors="pt").pixel_values
 
 
 
 
32
  generated_ids = model.generate(pixel_values)
33
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
 
 
34
 
35
- # Return extracted text as JSON
36
- return JSONResponse(content={"extracted_text": generated_text})
37
 
38
- except Exception as e:
39
- return JSONResponse(content={"error": str(e)}, status_code=500)
40
- else:
41
- # Streamlit UI for manual testing
42
- st.title("OCR API Service")
43
 
44
- uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"])
 
45
 
46
- if uploaded_file is not None:
47
- try:
48
- # Load and display the uploaded image
49
- image = Image.open(uploaded_file).convert("RGB")
50
- st.image(image, caption="Uploaded Image", use_column_width=True)
51
 
52
- # Perform OCR
53
- pixel_values = processor(images=image, return_tensors="pt").pixel_values
 
 
 
 
 
 
 
 
54
  generated_ids = model.generate(pixel_values)
55
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
 
 
 
56
 
57
- # Display extracted text
58
- st.subheader("Extracted Text:")
59
- st.text(generated_text)
60
 
61
- except Exception as e:
62
- st.error(f"An error occurred: {e}")
63
- else:
64
- st.info("Please upload an image to start the OCR process.")
 
1
  import streamlit as st
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  from PIL import Image
4
+ from fastapi import FastAPI, UploadFile, File
5
+ from fastapi.responses import JSONResponse
6
+ import uvicorn
7
+ import numpy as np
8
+ import cv2
9
+ import io
10
+
11
+ # Create a FastAPI app instance
12
+ app = FastAPI()
13
 
14
  # Load the model and processor
15
  @st.cache_resource
 
20
 
21
  processor, model = load_model()
22
 
 
 
 
 
 
23
 
24
+ # Function to preprocess image and detect lines (used for multiline OCR)
25
+ def detect_lines(image, min_height=20, min_width=100):
26
+ image_np = np.array(image)
27
+ gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
28
+ _, binary = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
29
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
30
+ dilated = cv2.dilate(binary, kernel, iterations=1)
31
+ contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
32
+ bounding_boxes = sorted([cv2.boundingRect(c) for c in contours], key=lambda b: b[1])
33
+ line_images = [image_np[y:y+h, x:x+w] for x, y, w, h in bounding_boxes if h >= min_height and w >= min_width]
34
+ return line_images
35
+
36
 
37
+ # FastAPI endpoint to handle image processing
38
+ @app.post("/process_image")
39
+ async def process_image(image: UploadFile = File(...)):
40
+ try:
41
+ # Read the uploaded image
42
+ image_data = await image.read()
43
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
44
 
45
+ # Detect lines and process each line
46
+ line_images = detect_lines(image, min_height=30, min_width=100)
47
+ extracted_text = ""
48
+ for line_img in line_images:
49
+ line_pil = Image.fromarray(line_img)
50
+ pixel_values = processor(images=line_pil, return_tensors="pt").pixel_values
51
  generated_ids = model.generate(pixel_values)
52
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
53
+ extracted_text += generated_text + "\n"
54
+
55
+ # Return extracted text as JSON
56
+ return JSONResponse(content={"extracted_text": extracted_text.strip()})
57
 
58
+ except Exception as e:
59
+ return JSONResponse(content={"error": str(e)}, status_code=500)
60
 
 
 
 
 
 
61
 
62
+ # Streamlit UI for testing (optional)
63
+ st.title("OCR API Service with Multiline Support")
64
 
65
+ uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"])
66
+ if uploaded_file is not None:
67
+ try:
68
+ image = Image.open(uploaded_file).convert("RGB")
69
+ st.image(image, caption="Uploaded Image", use_column_width=True)
70
 
71
+ # Detect lines in the image
72
+ st.write("Detecting lines...")
73
+ line_images = detect_lines(image, min_height=30, min_width=100)
74
+ st.write(f"Detected {len(line_images)} lines in the image.")
75
+
76
+ # Perform OCR on each detected line
77
+ extracted_text = ""
78
+ for idx, line_img in enumerate(line_images):
79
+ line_pil = Image.fromarray(line_img)
80
+ pixel_values = processor(images=line_pil, return_tensors="pt").pixel_values
81
  generated_ids = model.generate(pixel_values)
82
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
83
+ extracted_text += f"{generated_text}\n"
84
+
85
+ # Display extracted text
86
+ st.subheader("Extracted Text:")
87
+ st.text_area("Output Text", extracted_text.strip(), height=300)
88
 
89
+ except Exception as e:
90
+ st.error(f"An error occurred: {e}")
 
91
 
92
+ # Run the FastAPI app
93
+ if __name__ == "__main__":
94
+ uvicorn.run(app, host="0.0.0.0", port=8000)