File size: 3,133 Bytes
8d9cf6c
a4ad892
 
c816f98
8d9cf6c
 
 
a4ad892
 
8d9cf6c
 
a4ad892
8d9cf6c
a4ad892
8d9cf6c
c816f98
8d9cf6c
 
 
 
c816f98
8d9cf6c
 
c816f98
8d9cf6c
 
c816f98
8d9cf6c
 
c816f98
 
a4ad892
8d9cf6c
 
c816f98
8d9cf6c
 
 
a4ad892
8d9cf6c
 
 
 
 
a4ad892
8d9cf6c
 
 
 
 
a4ad892
8d9cf6c
a4ad892
8d9cf6c
 
 
 
 
c816f98
8d9cf6c
 
 
a4ad892
c816f98
 
 
 
 
 
8d9cf6c
c816f98
8d9cf6c
 
c816f98
8d9cf6c
 
a4ad892
8d9cf6c
 
 
 
c816f98
8d9cf6c
 
a4ad892
c816f98
8d9cf6c
a4ad892
c816f98
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from fastapi import FastAPI, UploadFile, File, HTTPException
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import cv2
import numpy as np
from io import BytesIO
import uvicorn

# Load the model and processor
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")

app = FastAPI()

# Helper function to preprocess the image and detect lines
def detect_lines(image, min_height=20, min_width=100):
    """
    Detects lines of text in the given image.
    """
    # Convert the PIL image to a NumPy array
    image_np = np.array(image)

    # Convert to grayscale
    gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)

    # Apply binary thresholding
    _, binary = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)

    # Dilate to merge nearby text
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
    dilated = cv2.dilate(binary, kernel, iterations=1)

    # Find contours
    contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Sort contours top-to-bottom
    bounding_boxes = [cv2.boundingRect(c) for c in contours]
    bounding_boxes = sorted(bounding_boxes, key=lambda b: b[1])  # Sort by y-coordinate

    # Filter out small contours and merge nearby ones
    filtered_boxes = []
    for x, y, w, h in bounding_boxes:
        if h >= min_height and w >= min_width:  # Filter small boxes
            filtered_boxes.append((x, y, w, h))

    # Extract individual lines as images
    line_images = []
    for (x, y, w, h) in filtered_boxes:
        line = image_np[y:y + h, x:x + w]
        line_images.append(line)

    return line_images

@app.post("/process_image")
async def process_image(file: UploadFile = File(...)):
    """
    API endpoint to process the uploaded image and extract multiline text.
    """
    try:
        # Read the uploaded image
        contents = await file.read()
        image = Image.open(BytesIO(contents)).convert("RGB")

        # Detect lines in the image
        line_images = detect_lines(image, min_height=30, min_width=100)

        # Perform OCR on each detected line
        extracted_text = ""
        for idx, line_img in enumerate(line_images):
            # Convert the line image to PIL format
            line_pil = Image.fromarray(line_img)

            # Prepare the image for OCR
            pixel_values = processor(images=line_pil, return_tensors="pt").pixel_values

            # Generate text from the line image
            generated_ids = model.generate(pixel_values)
            line_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

            # Append the extracted text
            extracted_text += f"{line_text}\n"

        # Return the extracted text as a JSON response
        return {"extracted_text": extracted_text}

    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Error processing image: {str(e)}")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)