OmarAbualrob commited on
Commit
545491d
·
verified ·
1 Parent(s): d62fd73

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ from transformers import AutoProcessor, AutoModelForCausalLM
4
+ from PIL import Image
5
+ import torch
6
+ import io
7
+ import logging
8
+
9
+ # Set up logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # --- 1. Initialize FastAPI App ---
14
+ app = FastAPI(title="Mixed-Content OCR API", description="An API to extract text from images containing both printed and handwritten text.")
15
+
16
+ # --- 2. Load the Model and Processor (at startup) ---
17
+ # This is a critical step. We load the model only once when the app starts.
18
+ # This prevents reloading the model on every API call, which would be very slow.
19
+ try:
20
+ logger.info("Loading model and processor...")
21
+ # Use the large model for better accuracy
22
+ model_id = "microsoft/Florence-2-large"
23
+ # NOTE: We need to trust remote code for Florence-2
24
+ model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
25
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
26
+ logger.info("Model and processor loaded successfully.")
27
+ except Exception as e:
28
+ logger.error(f"Error loading model: {e}")
29
+ # If the model fails to load, the API is not usable. We can't proceed.
30
+ model = None
31
+ processor = None
32
+
33
+ # --- 3. Define the OCR Task Function ---
34
+ def run_ocr(image: Image.Image) -> str:
35
+ """
36
+ Performs OCR on a given PIL Image using the Florence-2 model.
37
+ """
38
+ if model is None or processor is None:
39
+ raise RuntimeError("Model is not available. Check logs for loading errors.")
40
+
41
+ # Ensure image is in RGB format
42
+ if image.mode != "RGB":
43
+ image = image.convert("RGB")
44
+
45
+ # Define the task prompt
46
+ prompt = "<OCR>"
47
+
48
+ # Preprocess the image and prompt
49
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
50
+
51
+ # Generate text from the image
52
+ # Note: max_new_tokens can be adjusted based on expected text length
53
+ generated_ids = model.generate(
54
+ input_ids=inputs["input_ids"],
55
+ pixel_values=inputs["pixel_values"],
56
+ max_new_tokens=4096, # Increased token limit for long documents
57
+ do_sample=False, # Use greedy decoding for deterministic output
58
+ num_beams=3 # Use beam search for potentially better results
59
+ )
60
+
61
+ # Decode the generated IDs to a string
62
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
63
+
64
+ # Post-process the output to get the clean text
65
+ # The model's output for OCR is typically in the format: <OCR>extracted_text</s>
66
+ parsed_text = processor.post_process_generation(generated_text, task="<OCR>", image_size=(image.width, image.height))
67
+
68
+ return parsed_text.get("<OCR>", "Error: Could not parse OCR output.")
69
+
70
+
71
+ # --- 4. Create the API Endpoint ---
72
+ @app.post("/ocr", summary="Extract Text from Image")
73
+ async def perform_ocr(file: UploadFile = File(..., description="Image file to perform OCR on.")):
74
+ """
75
+ Takes an image file, extracts both printed and handwritten text,
76
+ and returns it as a JSON object.
77
+ """
78
+ if model is None:
79
+ raise HTTPException(status_code=503, detail="Model is not loaded or unavailable.")
80
+
81
+ # Validate file type
82
+ if not file.content_type.startswith("image/"):
83
+ raise HTTPException(status_code=400, detail="Invalid file type. Please upload an image.")
84
+
85
+ try:
86
+ # Read the image content from the uploaded file
87
+ contents = await file.read()
88
+ image = Image.open(io.BytesIO(contents))
89
+
90
+ # Run the OCR task
91
+ logger.info("Running OCR on the uploaded image...")
92
+ extracted_text = run_ocr(image)
93
+ logger.info("OCR completed successfully.")
94
+
95
+ # Return the result
96
+ return JSONResponse(
97
+ content={"filename": file.filename, "text": extracted_text}
98
+ )
99
+
100
+ except Exception as e:
101
+ logger.error(f"An error occurred during OCR processing: {e}")
102
+ raise HTTPException(status_code=500, detail=f"An internal error occurred: {str(e)}")
103
+
104
+ @app.get("/", summary="Health Check")
105
+ def read_root():
106
+ """
107
+ A simple health check endpoint to confirm the API is running.
108
+ """
109
+ return {"status": "ok", "model_loaded": model is not None}