File size: 2,858 Bytes
0e40a69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c851fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from fastapi import FastAPI, File, UploadFile
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import cv2
import numpy as np
import io

app = FastAPI()

# Load the model and processor
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")

# Helper function to preprocess the image and detect lines
def detect_lines(image, min_height=20, min_width=100):
    # Convert the PIL image to a NumPy array
    image_np = np.array(image)

    # Convert to grayscale
    gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)

    # Apply binary thresholding
    _, binary = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)

    # Dilate to merge nearby text
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
    dilated = cv2.dilate(binary, kernel, iterations=1)

    # Find contours
    contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Sort contours top-to-bottom
    bounding_boxes = [cv2.boundingRect(c) for c in contours]
    bounding_boxes = sorted(bounding_boxes, key=lambda b: b[1])  # Sort by y-coordinate

    # Filter out small contours and merge nearby ones
    filtered_boxes = []
    for x, y, w, h in bounding_boxes:
        if h >= min_height and w >= min_width:  # Filter small boxes
            filtered_boxes.append((x, y, w, h))

    # Extract individual lines as images
    line_images = []
    for (x, y, w, h) in filtered_boxes:
        line = image_np[y:y+h, x:x+w]
        line_images.append(line)

    return line_images

@app.post("/process_image/")
async def process_image(file: UploadFile = File(...)):
    try:
        # Read image bytes and convert to PIL image
        image = Image.open(io.BytesIO(await file.read())).convert("RGB")

        # Detect lines in the image
        line_images = detect_lines(image, min_height=30, min_width=100)

        # Perform OCR on each detected line
        extracted_text = ""
        for idx, line_img in enumerate(line_images):
            # Convert the line image to PIL format
            line_pil = Image.fromarray(line_img)

            # Prepare the image for OCR
            pixel_values = processor(images=line_pil, return_tensors="pt").pixel_values

            # Generate text from the line image
            generated_ids = model.generate(pixel_values)
            generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

            # Append the extracted text with a newline
            extracted_text += f"{generated_text}\n"

        # Return extracted text as JSON response
        return {"extracted_text": extracted_text.strip()}  # Remove trailing newline

    except Exception as e:
        return {"error": str(e)}