OmarAbualrob commited on
Commit
f8797e4
·
verified ·
1 Parent(s): 29fc161

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -40
app.py CHANGED
@@ -2,83 +2,108 @@ from fastapi import FastAPI, File, UploadFile, HTTPException
2
  from fastapi.responses import JSONResponse
3
  from transformers import AutoProcessor, AutoModelForCausalLM
4
  from PIL import Image
 
5
  import io
6
  import logging
7
 
 
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
- app = FastAPI(title="Florence-2 OCR API (CPU)", description="An API to extract text from images using the Florence-2-large model on CPU.")
11
 
12
- # --- Global Variables and Device Configuration ---
13
- device = "cpu" # Force CPU
14
- model = None
15
- processor = None
16
 
17
- # --- Model Loading Logic (at startup) ---
18
- @app.on_event("startup")
19
- async def startup_event():
20
- global model, processor
21
- try:
22
- logger.info(f"Using device: {device}")
23
- logger.info("Starting model loading process for CPU...")
24
-
25
- model_id = "microsoft/Florence-2-large"
26
-
27
- # Load the model in full precision for CPU
28
- model = AutoModelForCausalLM.from_pretrained(
29
- model_id,
30
- trust_remote_code=True,
31
- )
32
-
33
- processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
34
-
35
- logger.info("Model and processor loaded successfully on CPU.")
36
-
37
- except Exception as e:
38
- logger.error(f"FATAL: An error occurred during model loading: {e}", exc_info=True)
39
 
40
- # --- Define the OCR Task Function (CPU version) ---
41
  def run_ocr(image: Image.Image) -> str:
 
 
 
42
  if model is None or processor is None:
43
- raise RuntimeError("Model is not available. Check startup logs for loading errors.")
 
 
44
  if image.mode != "RGB":
45
  image = image.convert("RGB")
 
 
46
  prompt = "<OCR>"
 
 
47
  inputs = processor(text=prompt, images=image, return_tensors="pt")
48
 
49
- # Generate on CPU (no .to(device) or dtype changes needed)
 
50
  generated_ids = model.generate(
51
  input_ids=inputs["input_ids"],
52
  pixel_values=inputs["pixel_values"],
53
- max_new_tokens=4096,
54
- do_sample=False,
55
  num_beams=3
56
  )
57
 
 
58
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
 
 
 
59
  parsed_text = processor.post_process_generation(generated_text, task="<OCR>", image_size=(image.width, image.height))
 
60
  return parsed_text.get("<OCR>", "Error: Could not parse OCR output.")
61
 
62
 
63
- # --- API Endpoints ---
64
- # (Your @app.post and @app.get endpoints remain exactly the same)
65
  @app.post("/ocr", summary="Extract Text from Image")
66
  async def perform_ocr(file: UploadFile = File(..., description="Image file to perform OCR on.")):
 
 
 
 
67
  if model is None:
68
- raise HTTPException(status_code=503, detail="Model is not loaded or unavailable. Please check the server logs.")
 
 
69
  if not file.content_type.startswith("image/"):
70
  raise HTTPException(status_code=400, detail="Invalid file type. Please upload an image.")
 
71
  try:
 
72
  contents = await file.read()
73
  image = Image.open(io.BytesIO(contents))
74
- logger.info(f"Running OCR on uploaded file: {file.filename}")
 
 
75
  extracted_text = run_ocr(image)
76
  logger.info("OCR completed successfully.")
77
- return JSONResponse(content={"filename": file.filename, "text": extracted_text})
 
 
 
 
 
78
  except Exception as e:
79
- logger.error(f"An error occurred during OCR processing for {file.filename}: {e}", exc_info=True)
80
- raise HTTPException(status_code=500, detail=f"An internal server error occurred: {str(e)}")
81
 
82
  @app.get("/", summary="Health Check")
83
  def read_root():
84
- return {"status": "ok", "model_loaded": model is not None, "device": device}
 
 
 
 
2
  from fastapi.responses import JSONResponse
3
  from transformers import AutoProcessor, AutoModelForCausalLM
4
  from PIL import Image
5
+ import torch
6
  import io
7
  import logging
8
 
9
+ # Set up logging
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
 
12
 
13
+ # --- 1. Initialize FastAPI App ---
14
+ app = FastAPI(title="Mixed-Content OCR API", description="An API to extract text from images containing both printed and handwritten text.")
 
 
15
 
16
+ # --- 2. Load the Model and Processor (at startup) ---
17
+ # This is a critical step. We load the model only once when the app starts.
18
+ # This prevents reloading the model on every API call, which would be very slow.
19
+ try:
20
+ logger.info("Loading model and processor...")
21
+ # Use the large model for better accuracy
22
+ model_id = "microsoft/Florence-2-large"
23
+ # NOTE: We need to trust remote code for Florence-2
24
+ model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
25
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
26
+ logger.info("Model and processor loaded successfully.")
27
+ except Exception as e:
28
+ logger.error(f"Error loading model: {e}")
29
+ # If the model fails to load, the API is not usable. We can't proceed.
30
+ model = None
31
+ processor = None
 
 
 
 
 
 
32
 
33
+ # --- 3. Define the OCR Task Function ---
34
  def run_ocr(image: Image.Image) -> str:
35
+ """
36
+ Performs OCR on a given PIL Image using the Florence-2 model.
37
+ """
38
  if model is None or processor is None:
39
+ raise RuntimeError("Model is not available. Check logs for loading errors.")
40
+
41
+ # Ensure image is in RGB format
42
  if image.mode != "RGB":
43
  image = image.convert("RGB")
44
+
45
+ # Define the task prompt
46
  prompt = "<OCR>"
47
+
48
+ # Preprocess the image and prompt
49
  inputs = processor(text=prompt, images=image, return_tensors="pt")
50
 
51
+ # Generate text from the image
52
+ # Note: max_new_tokens can be adjusted based on expected text length
53
  generated_ids = model.generate(
54
  input_ids=inputs["input_ids"],
55
  pixel_values=inputs["pixel_values"],
56
+ max_new_tokens=4096, # Increased token limit for long documents
57
+ do_sample=False, # Use greedy decoding for deterministic output
58
  num_beams=3
59
  )
60
 
61
+ # Decode the generated IDs to a string
62
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
63
+
64
+ # Post-process the output to get the clean text
65
+ # The model's output for OCR is typically in the format: <OCR>extracted_text</s>
66
  parsed_text = processor.post_process_generation(generated_text, task="<OCR>", image_size=(image.width, image.height))
67
+
68
  return parsed_text.get("<OCR>", "Error: Could not parse OCR output.")
69
 
70
 
71
+ # --- 4. Create the API Endpoint ---
 
72
  @app.post("/ocr", summary="Extract Text from Image")
73
  async def perform_ocr(file: UploadFile = File(..., description="Image file to perform OCR on.")):
74
+ """
75
+ Takes an image file, extracts both printed and handwritten text,
76
+ and returns it as a JSON object.
77
+ """
78
  if model is None:
79
+ raise HTTPException(status_code=503, detail="Model is not loaded or unavailable.")
80
+
81
+ # Validate file type
82
  if not file.content_type.startswith("image/"):
83
  raise HTTPException(status_code=400, detail="Invalid file type. Please upload an image.")
84
+
85
  try:
86
+ # Read the image content from the uploaded file
87
  contents = await file.read()
88
  image = Image.open(io.BytesIO(contents))
89
+
90
+ # Run the OCR task
91
+ logger.info("Running OCR on the uploaded image...")
92
  extracted_text = run_ocr(image)
93
  logger.info("OCR completed successfully.")
94
+
95
+ # Return the result
96
+ return JSONResponse(
97
+ content={"filename": file.filename, "text": extracted_text}
98
+ )
99
+
100
  except Exception as e:
101
+ logger.error(f"An error occurred during OCR processing: {e}")
102
+ raise HTTPException(status_code=500, detail=f"An internal error occurred: {str(e)}")
103
 
104
  @app.get("/", summary="Health Check")
105
  def read_root():
106
+ """
107
+ A simple health check endpoint to confirm the API is running.
108
+ """
109
+ return {"status": "ok", "model_loaded": model is not None}