badru commited on
Commit
364aa86
·
verified ·
1 Parent(s): 6acf59a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -0
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
+ from PIL import Image
4
+ import cv2
5
+ import numpy as np
6
+
7
+ # Initialize Flask app
8
+ app = Flask(__name__)
9
+
10
+ # Load model and processor
11
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
12
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
13
+
14
+ def detect_lines(image, min_height=20, min_width=100):
15
+ # Convert PIL image to NumPy array
16
+ image_np = np.array(image)
17
+ gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
18
+ _, binary = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
19
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
20
+ dilated = cv2.dilate(binary, kernel, iterations=1)
21
+ contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
22
+ bounding_boxes = sorted([cv2.boundingRect(c) for c in contours], key=lambda b: b[1])
23
+ line_images = [image_np[y:y+h, x:x+w] for x, y, w, h in bounding_boxes if h >= min_height and w >= min_width]
24
+ return line_images
25
+
26
+ @app.route("/process_image", methods=["POST"])
27
+ def process_image():
28
+ if 'image' not in request.files:
29
+ return jsonify({"error": "No image file provided"}), 400
30
+ try:
31
+ image_file = request.files['image']
32
+ image = Image.open(image_file).convert("RGB")
33
+ line_images = detect_lines(image)
34
+ extracted_text = ""
35
+ for idx, line_img in enumerate(line_images):
36
+ line_pil = Image.fromarray(line_img)
37
+ pixel_values = processor(images=line_pil, return_tensors="pt").pixel_values
38
+ generated_ids = model.generate(pixel_values)
39
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
40
+ extracted_text += f"Line {idx + 1}: {generated_text}\n"
41
+ return jsonify({"extracted_text": extracted_text}), 200
42
+ except Exception as e:
43
+ return jsonify({"error": str(e)}), 500
44
+
45
+ if __name__ == "__main__":
46
+ app.run(host="0.0.0.0", port=7860)