OmarAbualrob commited on
Commit
6d7b5d2
·
verified ·
1 Parent(s): 2768f24

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -58
app.py CHANGED
@@ -1,109 +1,108 @@
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
  from fastapi.responses import JSONResponse
3
- from transformers import AutoProcessor, AutoModelForCausalLM
4
  from PIL import Image
5
  import torch
6
  import io
7
  import logging
8
 
9
- # Set up logging
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
 
12
 
13
- # --- 1. Initialize FastAPI App ---
14
- app = FastAPI(title="Mixed-Content OCR API", description="An API to extract text from images containing both printed and handwritten text.")
 
 
 
15
 
16
- # --- 2. Load the Model and Processor (at startup) ---
17
- # This is a critical step. We load the model only once when the app starts.
18
- # This prevents reloading the model on every API call, which would be very slow.
19
- try:
20
- logger.info("Loading model and processor...")
21
- # Use the large model for better accuracy
22
- model_id = "microsoft/Florence-2-large"
23
- # NOTE: We need to trust remote code for Florence-2
24
- model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
25
- processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
26
- logger.info("Model and processor loaded successfully.")
27
- except Exception as e:
28
- logger.error(f"Error loading model: {e}")
29
- # If the model fails to load, the API is not usable. We can't proceed.
30
- model = None
31
- processor = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- # --- 3. Define the OCR Task Function ---
34
  def run_ocr(image: Image.Image) -> str:
35
- """
36
- Performs OCR on a given PIL Image using the Florence-2 model.
37
- """
38
  if model is None or processor is None:
39
- raise RuntimeError("Model is not available. Check logs for loading errors.")
40
 
41
- # Ensure image is in RGB format
42
  if image.mode != "RGB":
43
  image = image.convert("RGB")
44
 
45
- # Define the task prompt
46
  prompt = "<OCR>"
47
-
48
- # Preprocess the image and prompt
49
  inputs = processor(text=prompt, images=image, return_tensors="pt")
50
 
51
- # Generate text from the image
52
- # Note: max_new_tokens can be adjusted based on expected text length
 
53
  generated_ids = model.generate(
54
- input_ids=inputs["input_ids"],
55
- pixel_values=inputs["pixel_values"],
56
- max_new_tokens=4096, # Increased token limit for long documents
57
- do_sample=False, # Use greedy decoding for deterministic output
58
  num_beams=3
59
  )
60
 
61
- # Decode the generated IDs to a string
62
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
63
-
64
- # Post-process the output to get the clean text
65
- # The model's output for OCR is typically in the format: <OCR>extracted_text</s>
66
  parsed_text = processor.post_process_generation(generated_text, task="<OCR>", image_size=(image.width, image.height))
67
 
68
  return parsed_text.get("<OCR>", "Error: Could not parse OCR output.")
69
 
70
 
71
- # --- 4. Create the API Endpoint ---
72
  @app.post("/ocr", summary="Extract Text from Image")
73
  async def perform_ocr(file: UploadFile = File(..., description="Image file to perform OCR on.")):
74
- """
75
- Takes an image file, extracts both printed and handwritten text,
76
- and returns it as a JSON object.
77
- """
78
  if model is None:
79
- raise HTTPException(status_code=503, detail="Model is not loaded or unavailable.")
80
 
81
- # Validate file type
82
  if not file.content_type.startswith("image/"):
83
- raise HTTPException(status_code=400, detail="Invalid file type. Please upload an image.")
84
 
85
  try:
86
- # Read the image content from the uploaded file
87
  contents = await file.read()
88
  image = Image.open(io.BytesIO(contents))
89
 
90
- # Run the OCR task
91
- logger.info("Running OCR on the uploaded image...")
92
  extracted_text = run_ocr(image)
93
  logger.info("OCR completed successfully.")
94
 
95
- # Return the result
96
  return JSONResponse(
97
  content={"filename": file.filename, "text": extracted_text}
98
  )
99
-
100
  except Exception as e:
101
- logger.error(f"An error occurred during OCR processing: {e}")
102
- raise HTTPException(status_code=500, detail=f"An internal error occurred: {str(e)}")
103
 
104
  @app.get("/", summary="Health Check")
105
  def read_root():
106
- """
107
- A simple health check endpoint to confirm the API is running.
108
- """
109
- return {"status": "ok", "model_loaded": model is not None}
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
  from fastapi.responses import JSONResponse
3
+ from transformers import AutoProcessor, AutoModelForCausalLM, BitsAndBytesConfig
4
  from PIL import Image
5
  import torch
6
  import io
7
  import logging
8
 
9
+ # --- 1. Basic Setup ---
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
+ app = FastAPI(title="Florence-2 OCR API", description="An API to extract text from images using the Florence-2-large model on a GPU.")
13
 
14
+ # --- 2. Global Variables and Device Configuration ---
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ torch_dtype = torch.bfloat16
17
+ model = None
18
+ processor = None
19
 
20
+ # --- 3. Model Loading Logic (at startup) ---
21
+ @app.on_event("startup")
22
+ async def startup_event():
23
+ global model, processor
24
+
25
+ if device == "cpu":
26
+ logger.warning("CUDA not available, model will not be loaded. This API requires a GPU.")
27
+ return
28
+
29
+ try:
30
+ logger.info(f"Using device: {device}")
31
+ logger.info("Starting model loading process with 4-bit quantization...")
32
+
33
+ model_id = "microsoft/Florence-2-large"
34
+
35
+ quantization_config = BitsAndBytesConfig(
36
+ load_in_4bit=True,
37
+ bnb_4bit_compute_dtype=torch_dtype
38
+ )
39
+
40
+ model = AutoModelForCausalLM.from_pretrained(
41
+ model_id,
42
+ trust_remote_code=True,
43
+ quantization_config=quantization_config,
44
+ revision="e134b72",
45
+ )
46
+
47
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, revision="e134b72")
48
+
49
+ logger.info("Model and processor loaded successfully.")
50
+
51
+ except Exception as e:
52
+ logger.error(f"FATAL: An error occurred during model loading: {e}", exc_info=True)
53
 
54
+ # --- 4. Define the OCR Task Function ---
55
  def run_ocr(image: Image.Image) -> str:
 
 
 
56
  if model is None or processor is None:
57
+ raise RuntimeError("Model is not available. Check startup logs for loading errors.")
58
 
 
59
  if image.mode != "RGB":
60
  image = image.convert("RGB")
61
 
 
62
  prompt = "<OCR>"
 
 
63
  inputs = processor(text=prompt, images=image, return_tensors="pt")
64
 
65
+ input_ids = inputs["input_ids"].to(device)
66
+ pixel_values = inputs["pixel_values"].to(device, dtype=torch_dtype)
67
+
68
  generated_ids = model.generate(
69
+ input_ids=input_ids,
70
+ pixel_values=pixel_values,
71
+ max_new_tokens=4096,
72
+ do_sample=False,
73
  num_beams=3
74
  )
75
 
 
76
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
 
 
 
77
  parsed_text = processor.post_process_generation(generated_text, task="<OCR>", image_size=(image.width, image.height))
78
 
79
  return parsed_text.get("<OCR>", "Error: Could not parse OCR output.")
80
 
81
 
82
+ # --- 5. Create API Endpoints ---
83
  @app.post("/ocr", summary="Extract Text from Image")
84
  async def perform_ocr(file: UploadFile = File(..., description="Image file to perform OCR on.")):
 
 
 
 
85
  if model is None:
86
+ raise HTTPException(status_code=503, detail="Model is not loaded or unavailable. Please check the server logs.")
87
 
 
88
  if not file.content_type.startswith("image/"):
89
+ raise HTTPException(status_code=400, detail="Invalid file type. Please upload an image (e.g., PNG, JPG).")
90
 
91
  try:
 
92
  contents = await file.read()
93
  image = Image.open(io.BytesIO(contents))
94
 
95
+ logger.info(f"Running OCR on uploaded file: {file.filename}")
 
96
  extracted_text = run_ocr(image)
97
  logger.info("OCR completed successfully.")
98
 
 
99
  return JSONResponse(
100
  content={"filename": file.filename, "text": extracted_text}
101
  )
 
102
  except Exception as e:
103
+ logger.error(f"An error occurred during OCR processing for {file.filename}: {e}", exc_info=True)
104
+ raise HTTPException(status_code=500, detail=f"An internal server error occurred: {str(e)}")
105
 
106
  @app.get("/", summary="Health Check")
107
  def read_root():
108
+ return {"status": "ok", "model_loaded": model is not None, "device": device}