badru commited on
Commit
0e40a69
·
verified ·
1 Parent(s): 1d12a1d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
+ from PIL import Image
4
+ import cv2
5
+ import numpy as np
6
+ import io
7
+
8
+ app = FastAPI()
9
+
10
+ # Load the model and processor
11
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
12
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
13
+
14
+ # Helper function to preprocess the image and detect lines
15
+ def detect_lines(image, min_height=20, min_width=100):
16
+ # Convert the PIL image to a NumPy array
17
+ image_np = np.array(image)
18
+
19
+ # Convert to grayscale
20
+ gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
21
+
22
+ # Apply binary thresholding
23
+ _, binary = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
24
+
25
+ # Dilate to merge nearby text
26
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
27
+ dilated = cv2.dilate(binary, kernel, iterations=1)
28
+
29
+ # Find contours
30
+ contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
31
+
32
+ # Sort contours top-to-bottom
33
+ bounding_boxes = [cv2.boundingRect(c) for c in contours]
34
+ bounding_boxes = sorted(bounding_boxes, key=lambda b: b[1]) # Sort by y-coordinate
35
+
36
+ # Filter out small contours and merge nearby ones
37
+ filtered_boxes = []
38
+ for x, y, w, h in bounding_boxes:
39
+ if h >= min_height and w >= min_width: # Filter small boxes
40
+ filtered_boxes.append((x, y, w, h))
41
+
42
+ # Extract individual lines as images
43
+ line_images = []
44
+ for (x, y, w, h) in filtered_boxes:
45
+ line = image_np[y:y+h, x:x+w]
46
+ line_images.append(line)
47
+
48
+ return line_images
49
+
50
+ @app.post("/process_image/")
51
+ async def process_image(file: UploadFile = File(...)):
52
+ try:
53
+ # Read image bytes and convert to PIL image
54
+ image = Image.open(io.BytesIO(await file.read())).convert("RGB")
55
+
56
+ # Detect lines in the image
57
+ line_images = detect_lines(image, min_height=30, min_width=100)
58
+
59
+ # Perform OCR on each detected line
60
+ extracted_text = ""
61
+ for idx, line_img in enumerate(line_images):
62
+ # Convert the line image to PIL format
63
+ line_pil = Image.fromarray(line_img)
64
+
65
+ # Prepare the image for OCR
66
+ pixel_values = processor(images=line_pil, return_tensors="pt").pixel_values
67
+
68
+ # Generate text from the line image
69
+ generated_ids = model.generate(pixel_values)
70
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
71
+
72
+ # Append the extracted text with a newline
73
+ extracted_text += f"{generated_text}\n"
74
+
75
+ # Return extracted text as JSON response
76
+ return {"extracted_text": extracted_text.strip()} # Remove trailing newline
77
+
78
+ except Exception as e:
79
+ return {"error": str(e)}