mmsaiimage / app.py
badru's picture
Create app.py
364aa86 verified
from flask import Flask, request, jsonify
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import cv2
import numpy as np
# Initialize Flask app
app = Flask(__name__)
# Load model and processor
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
def detect_lines(image, min_height=20, min_width=100):
# Convert PIL image to NumPy array
image_np = np.array(image)
gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
_, binary = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
dilated = cv2.dilate(binary, kernel, iterations=1)
contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
bounding_boxes = sorted([cv2.boundingRect(c) for c in contours], key=lambda b: b[1])
line_images = [image_np[y:y+h, x:x+w] for x, y, w, h in bounding_boxes if h >= min_height and w >= min_width]
return line_images
@app.route("/process_image", methods=["POST"])
def process_image():
if 'image' not in request.files:
return jsonify({"error": "No image file provided"}), 400
try:
image_file = request.files['image']
image = Image.open(image_file).convert("RGB")
line_images = detect_lines(image)
extracted_text = ""
for idx, line_img in enumerate(line_images):
line_pil = Image.fromarray(line_img)
pixel_values = processor(images=line_pil, return_tensors="pt").pixel_values
generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
extracted_text += f"Line {idx + 1}: {generated_text}\n"
return jsonify({"extracted_text": extracted_text}), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)