File size: 4,259 Bytes
523b7a6
 
545491d
 
523b7a6
545491d
523b7a6
 
 
 
545491d
523b7a6
 
 
 
545491d
523b7a6
 
 
 
621c44f
523b7a6
 
 
621c44f
523b7a6
621c44f
 
2a409bc
523b7a6
 
 
 
 
 
 
 
 
 
621c44f
523b7a6
 
 
 
 
 
 
621c44f
523b7a6
 
 
 
 
f8797e4
621c44f
545491d
 
f8797e4
621c44f
523b7a6
 
 
 
 
 
621c44f
545491d
621c44f
 
523b7a6
 
 
545491d
523b7a6
 
545491d
523b7a6
 
 
 
 
621c44f
523b7a6
545491d
 
523b7a6
 
 
 
 
 
 
621c44f
523b7a6
621c44f
523b7a6
545491d
 
 
523b7a6
 
545491d
523b7a6
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import io
import torch
from fastapi import FastAPI, File, UploadFile, HTTPException
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor

# --- 1. SCRIPT SETUP ---
# Set up device (use GPU if available, otherwise CPU)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"--- Running on {DEVICE} ---")

# Define model and processor IDs from Hugging Face Hub
MODEL_ID = "microsoft/Florence-2-large"
# For better performance, you can use the float16 version if your hardware supports it
# MODEL_ID = "microsoft/Florence-2-large-ft"

# --- 2. LOAD MODEL AND PROCESSOR ---
# Load the model and processor from Hugging Face
# trust_remote_code=True is required for Florence-2
# torch_dtype=torch.float16 is used for faster inference and lower memory on GPUs
try:
    model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True, torch_dtype=torch.float16).to(DEVICE)
    processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
    print("--- Model and processor loaded successfully ---")
except Exception as e:
    print(f"--- Error loading model: {e} ---")
    model = None
    processor = None

# --- 3. FASTAPI APP INITIALIZATION ---
app = FastAPI(
    title="Florence-2 OCR API",
    description="An API for extracting text from images using Microsoft's Florence-2-large model. "
                "Handles both printed and handwritten text.",
    version="1.0.0"
)

# --- 4. HELPER FUNCTION ---
def run_florence2_ocr(image: Image.Image):
    """
    Runs the Florence-2 model to perform OCR on a given image.

    Args:
        image (Image.Image): The input image in PIL format.

    Returns:
        str: The extracted text.
    """
    if not model or not processor:
        raise HTTPException(status_code=503, detail="Model is not available. Please check server logs.")

    # The task prompt for OCR
    task_prompt = "<OCR>"

    # Ensure image is in RGB format
    if image.mode != "RGB":
        image = image.convert("RGB")

    # Preprocess the image and prompt
    inputs = processor(text=task_prompt, images=image, return_tensors="pt").to(DEVICE)
    # Move inputs to float16 if the model is in float16
    if model.dtype == torch.float16:
      inputs = inputs.to(torch.float16)


    # Generate text from the image
    generated_ids = model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=2048,  # Increased token limit for dense text
        num_beams=3,
        do_sample=False # Use greedy decoding for more deterministic results
    )

    # Decode the generated IDs to text
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]

    # Parse the output to get only the OCR result
    # The model's output format is typically "<OCR>extracted_text</s>"
    # We remove the prompt and the end-of-sequence token
    parsed_text = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
    
    return parsed_text.get('<OCR>', "Error: Could not parse OCR output.")


# --- 5. API ENDPOINTS ---
@app.get("/", summary="Root Endpoint", description="Returns a welcome message.")
def read_root():
    return {"message": "Welcome to the Florence-2 OCR API. Go to /docs for usage."}

@app.post("/ocr", summary="Extract Text from Image", description="Upload an image file to extract text. Supports both computer and handwritten text.")
async def extract_text_from_image(file: UploadFile = File(..., description="Image file to process.")):
    """
    Endpoint to perform OCR on an uploaded image.
    """
    # Read image content from the uploaded file
    try:
        contents = await file.read()
        image = Image.open(io.BytesIO(contents))
    except Exception:
        raise HTTPException(status_code=400, detail="Invalid image file. Could not open image.")

    # Run the OCR model
    try:
        extracted_text = run_florence2_ocr(image)
        return {"filename": file.filename, "extracted_text": extracted_text}
    except Exception as e:
        print(f"Error during model inference: {e}")
        raise HTTPException(status_code=500, detail=f"An error occurred during processing: {str(e)}")